diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000000..cb828eabaed39 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,18 @@ +FROM ubuntu:24.04 + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + # Intellij IDEA dev container prerequisites + curl \ + git \ + unzip \ + # Java 8 and 17 jdk + openjdk-8-jdk \ + openjdk-17-jdk \ + # for documentation + python3 \ + python3-pip \ + python3.12-venv \ + fonts-freefont-otf \ + xindy + diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000..12ef523071cec --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,6 @@ +{ + "name": "PrestoDB Dev Container", + "build": { + "dockerfile": "Dockerfile" + } +} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 46565491c504b..3ffe1c71c401c 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -20,9 +20,9 @@ env: concurrency: group: "${{github.workflow}}-${{github.ref}}" - # Cancel in-progress jobs for efficiency. Exclude the `release-0.293-clp-connector` branch so - # that each commit to release-0.293-clp-connector is checked completely. - cancel-in-progress: "${{github.ref != 'refs/heads/release-0.293-clp-connector'}}" + # Cancel in-progress jobs for efficiency. Exclude the `release-0.297-edge-10-clp-connector` branch so + # that each commit to release-0.297-edge-10-clp-connector is checked completely. + cancel-in-progress: "${{github.ref != 'refs/heads/release-0.297-edge-10-clp-connector'}}" jobs: test: @@ -44,7 +44,7 @@ jobs: - uses: actions/setup-java@v4 with: distribution: 'temurin' - java-version: 8.0.442 + java-version: 17.0.13 cache: 'maven' - name: Maven Install run: | diff --git a/.github/workflows/maven-checks.yml b/.github/workflows/maven-checks.yml index 2295c413bef7b..92d3cac9f7d5c 100644 --- a/.github/workflows/maven-checks.yml +++ b/.github/workflows/maven-checks.yml @@ -14,16 +14,16 @@ env: concurrency: group: "${{github.workflow}}-${{github.ref}}" - # Cancel in-progress jobs for efficiency. Exclude the `release-0.293-clp-connector` branch so - # that each commit to release-0.293-clp-connector is checked completely. - cancel-in-progress: "${{github.ref != 'refs/heads/release-0.293-clp-connector'}}" + # Cancel in-progress jobs for efficiency. Exclude the `release-0.297-edge-10-clp-connector` branch so + # that each commit to release-0.297-edge-10-clp-connector is checked completely. + cancel-in-progress: "${{github.ref != 'refs/heads/release-0.297-edge-10-clp-connector'}}" jobs: maven-checks: strategy: fail-fast: false matrix: - java: [ 8.0.442, 17.0.13 ] + java: [ 17.0.13 ] runs-on: ubuntu-latest timeout-minutes: 45 steps: @@ -47,24 +47,46 @@ jobs: export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" ./mvnw install -B -V -T 1C -DskipTests -Dmaven.javadoc.skip=true --no-transfer-progress -P ci -pl '!presto-test-coverage,!:presto-docs' - name: "Upload presto-server" - if: matrix.java == '8.0.442' uses: "actions/upload-artifact@v4" with: name: "presto-server" - path: "presto-server/target/presto-server-0.293.tar.gz" + path: "presto-server/target/presto-server-0.297-edge10.1-SNAPSHOT.tar.gz" if-no-files-found: "error" retention-days: 1 - name: "Upload presto-cli" - if: matrix.java == '8.0.442' uses: "actions/upload-artifact@v4" with: name: "presto-cli" - path: "presto-cli/target/presto-cli-0.293-executable.jar" + path: "presto-cli/target/presto-cli-0.297-edge10.1-SNAPSHOT-executable.jar" if-no-files-found: "error" retention-days: 1 - name: "Clean Maven output" run: "./mvnw clean -pl '!:presto-server,!:presto-cli,!presto-test-coverage'" + clp-connector-unit-tests: + name: "clp-connector-unit-tests" + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + with: + show-progress: false + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: 17.0.13 + cache: 'maven' + - name: Download nodejs to maven cache + run: .github/bin/download_nodejs + - name: Install presto-clp dependencies + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install -B -V -T 1C -DskipTests -Dmaven.javadoc.skip=true --no-transfer-progress -am -pl 'presto-clp' + - name: Run presto-clp unit tests + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw test -B --no-transfer-progress -pl 'presto-clp' + presto-coordinator-image: name: "presto-coordinator-image" needs: "maven-checks" @@ -105,11 +127,11 @@ jobs: with: build-args: |- JMX_PROMETHEUS_JAVA_AGENT_VERSION=0.20.0 - PRESTO_VERSION=0.293 + PRESTO_VERSION=0.297-edge10.1-SNAPSHOT context: "./docker" file: "./docker/Dockerfile" push: >- ${{github.event_name != 'pull_request' - && github.ref == 'refs/heads/release-0.293-clp-connector'}} + && github.ref == 'refs/heads/release-0.297-edge-10-clp-connector'}} tags: "${{steps.meta.outputs.tags}}" labels: "${{steps.meta.outputs.labels}}" diff --git a/.github/workflows/pr-title-checks.yaml b/.github/workflows/pr-title-checks.yaml index 886249e6348c0..024c01fe72f7c 100644 --- a/.github/workflows/pr-title-checks.yaml +++ b/.github/workflows/pr-title-checks.yaml @@ -8,7 +8,7 @@ on: # pull request triggered by this event. # - Each job has `permissions` set to only those necessary. types: ["edited", "opened", "reopened"] - branches: ["release-0.293-clp-connector"] + branches: ["release-0.297-edge-10-clp-connector"] permissions: {} diff --git a/.github/workflows/prestissimo-worker-images-build.yml b/.github/workflows/prestissimo-worker-images-build.yml index b36dcb71949be..f80bb7493a300 100644 --- a/.github/workflows/prestissimo-worker-images-build.yml +++ b/.github/workflows/prestissimo-worker-images-build.yml @@ -34,7 +34,7 @@ jobs: file: "./presto-native-execution/scripts/dockerfiles/ubuntu-22.04-dependency.dockerfile" push: >- ${{github.event_name != 'pull_request' - && github.ref == 'refs/heads/release-0.293-clp-connector'}} + && github.ref == 'refs/heads/release-0.297-edge-10-clp-connector'}} tags: "${{steps.metadata-deps-image.outputs.tags}}" labels: "${{steps.metadata-deps-image.outputs.labels}}" @@ -56,15 +56,13 @@ jobs: build-args: |- BASE_IMAGE=ubuntu:22.04 DEPENDENCY_IMAGE=${{steps.metadata-deps-image.outputs.tags}} - EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF \ - -DPRESTO_ENABLE_PARQUET=ON \ - -DPRESTO_ENABLE_S3=ON + EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_PARQUET=ON -DPRESTO_ENABLE_S3=ON -DTREAT_WARNINGS_AS_ERRORS=0 NUM_THREADS=${{steps.get-cores.outputs.num_cores}} OSNAME=ubuntu context: "./presto-native-execution" file: "./presto-native-execution/scripts/dockerfiles/prestissimo-runtime.dockerfile" push: >- ${{github.event_name != 'pull_request' - && github.ref == 'refs/heads/release-0.293-clp-connector'}} + && github.ref == 'refs/heads/release-0.297-edge-10-clp-connector'}} tags: "${{steps.metadata-runtime-image.outputs.tags}}" labels: "${{steps.metadata-runtime-image.outputs.labels}}" diff --git a/.github/workflows/prestocpp-format-and-header-check.yml b/.github/workflows/prestocpp-format-and-header-check.yml index c554ee8785786..08723f24563ad 100644 --- a/.github/workflows/prestocpp-format-and-header-check.yml +++ b/.github/workflows/prestocpp-format-and-header-check.yml @@ -14,9 +14,9 @@ on: concurrency: group: "${{github.workflow}}-${{github.ref}}" - # Cancel in-progress jobs for efficiency. Exclude the `release-0.293-clp-connector` branch so - # that each commit to release-0.293-clp-connector is checked completely. - cancel-in-progress: "${{github.ref != 'refs/heads/release-0.293-clp-connector'}}" + # Cancel in-progress jobs for efficiency. Exclude the `release-0.297-edge-10-clp-connector` branch so + # that each commit to release-0.297-edge-10-clp-connector is checked completely. + cancel-in-progress: "${{github.ref != 'refs/heads/release-0.297-edge-10-clp-connector'}}" jobs: prestocpp-format-and-header-check: diff --git a/.github/workflows/prestocpp-linux-build-and-unit-test.yml b/.github/workflows/prestocpp-linux-build-and-unit-test.yml index e26e330403ec2..7e3578a37a44e 100644 --- a/.github/workflows/prestocpp-linux-build-and-unit-test.yml +++ b/.github/workflows/prestocpp-linux-build-and-unit-test.yml @@ -14,15 +14,15 @@ on: concurrency: group: "${{github.workflow}}-${{github.ref}}" - # Cancel in-progress jobs for efficiency. Exclude the `release-0.293-clp-connector` branch so - # that each commit to release-0.293-clp-connector is checked completely. - cancel-in-progress: "${{github.ref != 'refs/heads/release-0.293-clp-connector'}}" + # Cancel in-progress jobs for efficiency. Exclude the `release-0.297-edge-10-clp-connector` branch so + # that each commit to release-0.297-edge-10-clp-connector is checked completely. + cancel-in-progress: "${{github.ref != 'refs/heads/release-0.297-edge-10-clp-connector'}}" jobs: prestocpp-linux-build-for-test: runs-on: ubuntu-22.04 container: - image: prestodb/presto-native-dependency:0.293-20250522140509-484b00e + image: prestodb/presto-native-dependency:0.297-202512180933-75d7d4ea env: CCACHE_DIR: "${{ github.workspace }}/ccache" steps: @@ -99,7 +99,7 @@ jobs: needs: prestocpp-linux-build-for-test runs-on: ubuntu-22.04 container: - image: prestodb/presto-native-dependency:0.293-20250522140509-484b00e + image: prestodb/presto-native-dependency:0.297-202512180933-75d7d4ea env: MAVEN_OPTS: "-Xmx4G -XX:+ExitOnOutOfMemoryError" MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all -Dmaven.javadoc.skip=true" @@ -112,6 +112,11 @@ jobs: # it doesn't work run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + - name: Update velox + run: | + cd presto-native-execution + make velox-submodule + - name: Download artifacts uses: actions/download-artifact@v4 with: @@ -130,7 +135,7 @@ jobs: uses: actions/setup-java@v4 with: distribution: 'temurin' - java-version: 8.0.442 + java-version: 17.0.13 cache: 'maven' - name: Download nodejs to maven cache run: .github/bin/download_nodejs @@ -145,17 +150,7 @@ jobs: - name: Run presto-native e2e tests run: | export PRESTO_SERVER_PATH="${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server" - export TESTFILES=`find ./presto-native-execution/src/test -type f -name 'TestPrestoNative*.java'` - # Convert file paths to comma separated class names - export TESTCLASSES= - for test_file in $TESTFILES - do - tmp=${test_file##*/} - test_class=${tmp%%\.*} - export TESTCLASSES="${TESTCLASSES},$test_class" - done - export TESTCLASSES=${TESTCLASSES#,} - echo "TESTCLASSES = $TESTCLASSES" + export TESTCLASSES=TestPrestoNativeClpGeneralQueries # TODO: neeed to enable remote function tests with # "-Ppresto-native-execution-remote-functions" once # > https://github.com/facebookincubator/velox/discussions/6163 @@ -170,145 +165,3 @@ jobs: -Duser.timezone=America/Bahia_Banderas \ -T1C - prestocpp-linux-presto-native-tests: - needs: prestocpp-linux-build-for-test - runs-on: ubuntu-22.04 - strategy: - fail-fast: false - matrix: - storage-format: [ "PARQUET", "DWRF" ] - container: - image: prestodb/presto-native-dependency:0.293-20250522140509-484b00e - env: - MAVEN_OPTS: "-Xmx4G -XX:+ExitOnOutOfMemoryError" - MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all -Dmaven.javadoc.skip=true" - MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --fail-at-end" - steps: - - uses: actions/checkout@v4 - - - name: Fix git permissions - # Usually actions/checkout does this but as we run in a container - # it doesn't work - run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - - - name: Download artifacts - uses: actions/download-artifact@v4 - with: - name: presto-native-build - path: presto-native-execution/_build/release - - # Permissions are lost when uploading. Details here: https://github.com/actions/upload-artifact/issues/38 - - name: Restore execute permissions and library path - run: | - chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server - chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/velox/velox/functions/remote/server/velox_functions_remote_server_main - # Ensure transitive dependency libboost-iostreams is found. - ldconfig /usr/local/lib - - - name: Install OpenJDK8 - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '8.0.442' - cache: 'maven' - - name: Download nodejs to maven cache - run: .github/bin/download_nodejs - - - name: Maven install - env: - # Use different Maven options to install. - MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" - run: | - for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-tests' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) - - - name: Run presto-native tests - run: | - export PRESTO_SERVER_PATH="${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server" - export TESTFILES=`find ./presto-native-tests/src/test -type f -name 'Test*.java'` - # Convert file paths to comma separated class names - export TESTCLASSES= - for test_file in $TESTFILES - do - tmp=${test_file##*/} - test_class=${tmp%%\.*} - export TESTCLASSES="${TESTCLASSES},$test_class" - done - export TESTCLASSES=${TESTCLASSES#,} - echo "TESTCLASSES = $TESTCLASSES" - - mvn test \ - ${MAVEN_TEST} \ - -pl 'presto-native-tests' \ - -DstorageFormat=${{ matrix.storage-format }} \ - -Dtest="${TESTCLASSES}" \ - -DPRESTO_SERVER=${PRESTO_SERVER_PATH} \ - -DDATA_DIR=${RUNNER_TEMP} \ - -Duser.timezone=America/Bahia_Banderas \ - -T1C - - prestocpp-linux-presto-sidecar-tests: - needs: prestocpp-linux-build-for-test - runs-on: ubuntu-22.04 - container: - image: prestodb/presto-native-dependency:0.293-20250522140509-484b00e - env: - MAVEN_OPTS: "-Xmx4G -XX:+ExitOnOutOfMemoryError" - MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all -Dmaven.javadoc.skip=true" - MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --fail-at-end" - steps: - - uses: actions/checkout@v4 - - name: Fix git permissions - # Usually actions/checkout does this but as we run in a container - # it doesn't work - run: git config --global --add safe.directory ${GITHUB_WORKSPACE} - - - name: Download artifacts - uses: actions/download-artifact@v4 - with: - name: presto-native-build - path: presto-native-execution/_build/release - - # Permissions are lost when uploading. Details here: https://github.com/actions/upload-artifact/issues/38 - - name: Restore execute permissions and library path - run: | - chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server - chmod +x ${GITHUB_WORKSPACE}/presto-native-execution/_build/release/velox/velox/functions/remote/server/velox_functions_remote_server_main - # Ensure transitive dependency libboost-iostreams is found. - ldconfig /usr/local/lib - - name: Install OpenJDK8 - uses: actions/setup-java@v4 - with: - distribution: 'temurin' - java-version: '8.0.442' - cache: 'maven' - - name: Download nodejs to maven cache - run: .github/bin/download_nodejs - - - name: Maven install - env: - # Use different Maven options to install. - MAVEN_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" - run: | - for i in $(seq 1 3); do ./mvnw clean install $MAVEN_FAST_INSTALL -pl 'presto-native-execution' -am && s=0 && break || s=$? && sleep 10; done; (exit $s) - - name: Run presto-native sidecar tests - run: | - export PRESTO_SERVER_PATH="${GITHUB_WORKSPACE}/presto-native-execution/_build/release/presto_cpp/main/presto_server" - export TESTFILES=`find ./presto-native-sidecar-plugin/src/test -type f -name 'Test*.java'` - # Convert file paths to comma separated class names - export TESTCLASSES= - for test_file in $TESTFILES - do - tmp=${test_file##*/} - test_class=${tmp%%\.*} - export TESTCLASSES="${TESTCLASSES},$test_class" - done - export TESTCLASSES=${TESTCLASSES#,} - echo "TESTCLASSES = $TESTCLASSES" - mvn test \ - ${MAVEN_TEST} \ - -pl 'presto-native-sidecar-plugin' \ - -Dtest="${TESTCLASSES}" \ - -DPRESTO_SERVER=${PRESTO_SERVER_PATH} \ - -DDATA_DIR=${RUNNER_TEMP} \ - -Duser.timezone=America/Bahia_Banderas \ - -T1C diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index 6f829502fbbb8..0000000000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: test - -on: - pull_request: - push: - -env: - # An envar that signals to tests we are executing in the CI environment - CONTINUOUS_INTEGRATION: true - MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" - MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" - MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" - MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" - RETRY: .github/bin/retry - -concurrency: - group: "${{github.workflow}}-${{github.ref}}" - - # Cancel in-progress jobs for efficiency. Exclude the `release-0.293-clp-connector` branch so - # that each commit to release-0.293-clp-connector is checked completely. - cancel-in-progress: "${{github.ref != 'refs/heads/release-0.293-clp-connector'}}" - -jobs: - changes: - runs-on: ubuntu-latest - # Required permissions - permissions: - pull-requests: read - # Set job outputs to values from filter step - outputs: - codechange: ${{ steps.filter.outputs.codechange }} - steps: - - uses: "actions/checkout@v4" - with: - submodules: "recursive" - - uses: dorny/paths-filter@v2 - id: filter - with: - filters: | - codechange: - - '!presto-docs/**' - - test: - runs-on: ubuntu-latest - needs: changes - strategy: - fail-fast: false - matrix: - java: [8.0.442, 17.0.13] - modules: - - ":presto-tests -P presto-tests-execution-memory" - - ":presto-tests -P presto-tests-general" - - ":presto-tests -P ci-only-distributed-non-hash-gen" - - ":presto-tests -P ci-only-tpch-distributed-queries" - - ":presto-tests -P ci-only-local-queries" - - ":presto-tests -P ci-only-distributed-queries" - - ":presto-tests -P ci-only-aggregation-queries" - - ":presto-tests -P ci-only-plan-determinism" - - ":presto-tests -P ci-only-resource-manager" - - ":presto-main-base" - - ":presto-main" - timeout-minutes: 80 - steps: - - uses: actions/checkout@v4 - if: needs.changes.outputs.codechange == 'true' - with: - show-progress: false - - uses: actions/setup-java@v4 - if: needs.changes.outputs.codechange == 'true' - with: - distribution: 'temurin' - java-version: ${{ matrix.java }} - cache: 'maven' - - name: Download nodejs to maven cache - if: needs.changes.outputs.codechange == 'true' - run: .github/bin/download_nodejs - - name: Maven Install - if: needs.changes.outputs.codechange == 'true' - run: | - export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" - ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl $(echo '${{ matrix.modules }}' | cut -d' ' -f1) - - name: Maven Tests - if: needs.changes.outputs.codechange == 'true' - run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} diff --git a/.gitignore b/.gitignore index a4512f9f794d9..e076be056cdbb 100644 --- a/.gitignore +++ b/.gitignore @@ -28,9 +28,11 @@ benchmark_outputs *.class .checkstyle .mvn/timing.properties +.mvn/maven.config .editorconfig node_modules presto-docs-venv/ +.m2/ #==============================================================================# # presto-native-execution @@ -66,3 +68,4 @@ presto-native-execution/deps-install # Compiled executables used for docker build /docker/presto-cli-*-executable.jar /docker/presto-server-*.tar.gz +/docker/presto-function-server-executable.jar diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000..4e9d20a35fc02 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,116 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# See https://pre-commit.com for more information +# General excludes, files can also be excluded on a hook level +files: ^presto-native-execution/.*|^.github/.*\.yml +exclude: "^(?:presto-native-execution/(?:velox|presto_cpp/external)/.*)|\ + ^presto-native-execution/presto_cpp/presto_protocol/.*\\.yml|\ + \\.(?:patch|header|sql)|\ + ThriftLibrary\\.cmake|\ + /data/|\ + build/(?:fb_code_builder|deps)|\ + cmake-build-debug|\ + NOTICE\\.txt|\ + scripts/git-clang-format" +default_install_hook_types: [pre-commit, pre-push] +repos: + - repo: meta + hooks: + - id: check-hooks-apply + - id: check-useless-excludes + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - id: end-of-file-fixer + - id: check-added-large-files + - id: check-executables-have-shebangs + - id: check-shebang-scripts-are-executable + + - repo: local + hooks: + - id: clang-tidy + name: clang-tidy + description: Run clang-tidy on C/C++ files + stages: + - manual + entry: clang-tidy + language: python + types_or: [c++, c] + additional_dependencies: [clang-tidy==18.1.8] + require_serial: true + + - id: license-header + name: license-header + description: Add missing license headers. + entry: presto-native-execution/scripts/license-header.py + args: [-i] + language: python + additional_dependencies: [regex] + require_serial: true + exclude: ^.github/.* + + # CMake + - repo: https://github.com/BlankSpruce/gersemi + rev: 0.21.0 + hooks: + - id: gersemi + name: CMake formatter + + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v20.1.8 + hooks: + - id: clang-format + # types_or: [c++, c, cuda, metal, objective-c] + files: \.(cpp|cc|c|h|hpp|inc|cu|cuh|clcpp|mm|metal)$ + + - repo: https://github.com/scop/pre-commit-shfmt + rev: v3.11.0-1 + hooks: + - id: shfmt + # w: write changes, s: simplify, i set indent to 2 spaces + args: [-w, -s, -i, '2'] + + # Python + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.3 + hooks: + - id: ruff-check + args: [ --fix, --no-unsafe-fixes ] + - id: ruff-format + + # The following checks mostly target GitHub Actions workflows. + - repo: https://github.com/adrienverge/yamllint.git + rev: v1.37.0 + hooks: + - id: yamllint + args: [ --format, parsable, --strict ] + exclude: .*\.clang-(tidy|format)|presto_cpp/main/thrift/.* + + - repo: https://github.com/google/yamlfmt + rev: v0.16.0 + hooks: + - id: yamlfmt + exclude: .*\.clang-(tidy|format)|presto_cpp/main/thrift/.* + + - repo: https://github.com/zizmorcore/zizmor-pre-commit + rev: v1.7.0 + hooks: + - id: zizmor + + - repo: https://github.com/mpalmer/action-validator + rev: v0.7.1 + hooks: + - id: action-validator diff --git a/.yamlfmt.yml b/.yamlfmt.yml new file mode 100644 index 0000000000000..46a2713a1c677 --- /dev/null +++ b/.yamlfmt.yml @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +match_type: doublestar +exclude: + - '**/.clang-format' + - '**/.clang-tidy' +formatter: + type: basic + retain_line_breaks_single: true + scan_folded_as_literal: true + indent: 2 diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 0000000000000..866fb793d61ef --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +rules: + braces: + min-spaces-inside: 0 + max-spaces-inside: 1 + min-spaces-inside-empty: 0 + max-spaces-inside-empty: 0 + brackets: + min-spaces-inside: 0 + max-spaces-inside: 0 + min-spaces-inside-empty: 0 + max-spaces-inside-empty: 0 + comments: disable + comments-indentation: disable + document-end: disable + document-start: disable + empty-lines: disable + empty-values: + forbid-in-flow-mappings: true + forbid-in-block-sequences: true + float-values: + forbid-inf: true + forbid-nan: true + forbid-scientific-notation: true + require-numeral-before-decimal: true + indentation: disable + line-length: disable + octal-values: enable + quoted-strings: + required: only-when-needed + extra-allowed: ['.*\$\{\{.*\}\}.*'] + truthy: + allowed-values: ['true', 'false', 'on'] + level: warning diff --git a/CODEOWNERS b/CODEOWNERS index d3798352db118..e4a0e3bb56aed 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -50,6 +50,7 @@ /presto-lark-sheets @prestodb/committers /presto-local-file @prestodb/committers /presto-main @prestodb/committers +/presto-main-base @prestodb/committers /presto-matching @prestodb/committers /presto-memory @prestodb/committers /presto-memory-context @prestodb/committers @@ -75,7 +76,9 @@ /presto-resource-group-managers @prestodb/committers /presto-router @prestodb/committers /presto-server @prestodb/committers -/presto-session-property-managers @prestodb/committers +/presto-session-property-managers-common @prestodb/committers +/presto-file-session-property-manager @prestodb/committers +/presto-db-session-property-manager @prestodb/committers /presto-singlestore @prestodb/committers /presto-spi @prestodb/committers /presto-sqlserver @prestodb/committers @@ -100,10 +103,10 @@ CODEOWNERS @prestodb/team-tsc # Presto core # Presto analyzer and optimizer -/presto-main/src/*/java/com/facebook/presto/sql @jaystarshot @feilong-liu @ClarenceThreepwood @prestodb/committers +/presto-main-base/src/*/java/com/facebook/presto/sql @jaystarshot @feilong-liu @ClarenceThreepwood @prestodb/committers # Presto cost based optimizer framework -/presto-main/src/*/java/com/facebook/presto/cost @jaystarshot @feilong-liu @ClarenceThreepwood @prestodb/committers +/presto-main-base/src/*/java/com/facebook/presto/cost @jaystarshot @feilong-liu @ClarenceThreepwood @prestodb/committers # Testing module # Note: all code owners in Presto core should be included here as well @@ -111,14 +114,16 @@ CODEOWNERS @prestodb/team-tsc ##################################################################### # Prestissimo module -/presto-native-execution @prestodb/team-velox -/presto-native-sidecar-plugin @prestodb/team-velox -/presto-native-tests @prestodb/team-velox +/presto-native-execution @prestodb/team-velox @prestodb/committers +/presto-native-sidecar-plugin @pdabre12 @prestodb/team-velox @prestodb/committers +/presto-native-tests @prestodb/team-velox @prestodb/committers +/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/NativeWorkerSessionPropertyProvider.java @prestodb/team-velox @prestodb/committers /.github/workflows/prestocpp-* @prestodb/team-velox @prestodb/committers ##################################################################### # Presto on Spark module /presto-spark* @shrinidhijoshi @prestodb/committers +/presto-native-execution/*/com/facebook/presto/spark/* @shrinidhijoshi @prestodb/committers ##################################################################### # Presto connectors and plugins @@ -160,6 +165,5 @@ CODEOWNERS @prestodb/team-tsc ##################################################################### # Presto CI and builds -/.github @czentgr @prestodb/committers -/docker @czentgr @prestodb/committers - +/.github @czentgr @unidevel @prestodb/committers +/docker @czentgr @unidevel @prestodb/committers diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 79f4af3092b0e..36987cca9c14b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,6 +63,17 @@ To commit code, you should: * Add or modify existing tests related to code changes being submitted * Run and ensure that local tests pass before submitting a merge request +## Release Process and Version Support + +* **Release Cadence**: ~2 months (volunteer dependent) +* **Version Support**: Latest release plus N-1 through N-4 receive critical fixes; N-5+ unsupported +* **Trunk**: Not stable - never use in production +* **Testing**: Extended RC periods with community testing + +Details: +* [Release Process Documentation](presto-docs/src/main/sphinx/develop/release-process.rst) - For developers +* [Version Support Guide](presto-docs/src/main/sphinx/admin/version-support.rst) - For administrators + ## Designing Your Code * Consider your code through 3 axes 1. Code Quality and Maintainability, for example: @@ -89,6 +100,7 @@ To commit code, you should: ## Code Style +### Java We recommend you use IntelliJ as your IDE. The code style template for the project can be found in the [codestyle](https://github.com/airlift/codestyle) repository along with our general programming and Java guidelines. In addition to those you should also adhere to the following: * **Naming** @@ -284,7 +296,7 @@ We recommend you use IntelliJ as your IDE. The code style template for the proje .map(OriginalExpressionUtils::castToExpression) ``` -* When appropriate use Java 8 Stream API +* When appropriate use Java Stream API * Categorize errors when throwing an exception * **Tests** * Avoid adding `Thread.sleep` in tests--these can fail due to environmental conditions, such as garbage collection or noisy neighbors in the CI environment. @@ -297,44 +309,109 @@ We recommend you use IntelliJ as your IDE. The code style template for the proje reinitialize them before each test in a `@BeforeMethod` method, and annotate the class with `@Test(singleThreaded = true)`. +### PrestoC++/Prestissimo (presto-native-execution) + +The project follows the [coding standards](https://github.com/facebookincubator/velox/blob/main/CONTRIBUTING.md#coding-best-practices) of the [Velox](https://github.com/facebookincubator/velox) project. + +For code formatting a pre-commit hook is used that is installed locally and can fix issues before changes are pushed to the repository. +Please install the [pre-commit](https://pre-commit.com/) tool. Once installed, run `pre-commit run` to scan and fix your staged changes manually, or optionally, install the hook in the local repository by running `pre-commit install` in the project root. +This results in the hook being automatically run on `git commit` executions. ## Commit Standards -* ### Commit Size - * Recommended lines of code should be no more than +1000 lines, and should focus on one single major topic.\ - If your commit is more than 1000 lines, consider breaking it down into multiple commits, each handling a specific small topic. -* ### Commit Message Style - * **Separate subject from body with a blank line** - * **Subject** - * Limit the subject line to 10 words or 50 characters - * If you cannot make the subject short, you may be committing too many changes at once - * Capitalize the subject line - * Do not end the subject line with a period - * Use the imperative mood in the subject line - * **Body** - * Wrap the body at 72 characters - * Use the body to explain what and why versus how - * Use the indicative mood in the body\ - For example, “If applied, this commit will ___________” - * Communicate only context (why) for the commit in the subject line - * Use the body for What and Why - * If your commit is complex or dense, share some of the How context - * Assume someone may need to revert your change during an emergency - * **Content** - * **Aim for smaller commits for easier review and simpler code maintenance** - * All bug fixes and new features must have associated tests - * Commits should pass tests on their own, not be dependent on other commits - * The following is recommended: - * Describe why a change is being made. - * How does it address the issue? - * What effects does the patch have? - * Do not assume the reviewer understands what the original problem was. - * Do not assume the code is self-evident or self-documenting. - * Read the commit message to see if it hints at improved code structure. - * The first commit line is the most important. - * Describe any limitations of the current code. - * Do not include patch set-specific comments. - -Details for each point and good commit message examples can be found on https://wiki.openstack.org/wiki/GitCommitMessages#Information_in_commit_messages + +### Conventional Commits +We follow the [Conventional Commits](https://www.conventionalcommits.org/) specification for our commit messages and PR titles. + +**PR Title Format:** +``` +[(scope)]: +``` + +**Types:** Defined in [.github/workflows/conventional-commit-check.yml](.github/workflows/conventional-commit-check.yml): +* **feat**: New feature or functionality +* **fix**: Bug fix +* **docs**: Documentation only changes +* **refactor**: Code refactoring without changing functionality +* **perf**: Performance improvements +* **test**: Adding or modifying tests +* **build**: Build system or dependency changes +* **ci**: CI/CD configuration changes +* **chore**: Maintenance tasks +* **revert**: Reverting a previous commit +* **misc**: Miscellaneous changes + +Note: Each PR/commit should have a single primary type. If your changes span multiple categories, choose the most significant one or consider splitting into separate PRs. + +**Scope (optional):** The area of code affected. Valid scopes are defined in [.github/workflows/conventional-commit-check.yml](.github/workflows/conventional-commit-check.yml). Common scopes include: + +* `parser` - SQL parser and grammar +* `analyzer` - Query analysis and validation +* `planner` - Query planning, optimization, and rules (including CBO) +* `spi` - Service Provider Interface changes +* `scheduler` - Task scheduling and execution +* `protocol` - Wire protocol and serialization +* `connector` - Changes to broader connector functionality or connector SPI +* `resource` - Resource management (memory manager, resource groups) +* `security` - Authentication and authorization +* `function` - Built-in functions and operators +* `type` - Type system and type coercion +* `expression` - Expression evaluation +* `operator` - Query operators (join, aggregation, etc.) +* `client` - Client libraries and protocols +* `server` - Server configuration and management +* `native` - Native execution engine +* `testing` - Test framework and utilities +* `docs` - Documentation +* `build` - Build system and dependencies + +Additionally, any connector name (e.g., `hive`, `iceberg`, `delta`, `kafka`) or plugin name (e.g., `session-property-manager`, `access-control`, `event-listener`) can be used as a scope. These scopes should use the format `plugin-` (e.g., `plugin-iceberg`, `plugin-password-authenticator`). + +**Description:** +* Must start with a capital letter +* Must not end with a period +* Use imperative mood ("Add feature" not "Added feature") +* Be concise but descriptive + +**Breaking Changes:** +* Use `!` after the type/scope (e.g., `feat!: Remove deprecated API`) +* AND include `BREAKING CHANGE:` in the commit description footer with a detailed explanation of the change +* Use to indicate any change that is not backward compatible as defined in the [Backward Compatibility Guidelines](presto-docs/src/main/sphinx/develop/release-process.rst#backward-compatibility-guidelines) + +**Examples:** +* `feat(connector): Add support for dynamic catalog registration` (new feature for connectors) +* `fix: Resolve memory leak in query executor` +* `docs(api): Update REST API documentation` +* `feat!: Remove deprecated configuration options` (breaking change) +* `feat(plugin-iceberg): Add support for Iceberg table properties` (new feature in Iceberg, **NOTE: connectors are plugins**) + +### Single Commit PRs +* **All PRs must be merged as a single commit** using GitHub's "Squash and merge" feature +* The PR title will become the commit message, so it must follow the conventional commit format +* Multiple commits within a PR are allowed during development for easier review, but they will be squashed on merge +* If you need to reference other commits or PRs, include them in the PR description or commit body, not as separate commits + +### Commit Message Guidelines +* **PR Title/First Line** + * Must follow conventional commit format + * Limit to 50-72 characters when possible + * If you cannot make it concise, you may be changing too much at once + +* **PR Description/Commit Body** + * Separate from title with a blank line + * Wrap at 72 characters + * Explain what and why, not how + * Include: + * Why the change is being made + * What issue it addresses + * Any side effects or limitations + * Breaking changes or migration notes if applicable + * Assume someone may need to revert your change during an emergency + +* **Content Requirements** + * All bug fixes and new features must have associated tests + * Changes should be focused on a single topic + * Code should pass all tests independently + * Include documentation updates with code changes * **Metadata** * If the commit was to solve a Github issue, refer to it at the end of a commit message in a rfc822 header line format.\ @@ -449,20 +526,13 @@ We use the [Fork and Pull model](https://docs.github.com/en/pull-requests/collab - Make sure your code follows the [code style guidelines](https://github.com/prestodb/presto/blob/master/CONTRIBUTING.md#code-style), [development guidelines](https://github.com/prestodb/presto/wiki/Presto-Development-Guidelines#development) and [formatting guidelines](https://github.com/prestodb/presto/wiki/Presto-Development-Guidelines#formatting) -- Make sure you follow the [review and commit guidelines](https://github.com/prestodb/presto/wiki/Review-and-Commit-guidelines), in particular: - - - Ensure that each commit is correct independently. Each commit should compile and pass tests. - - When possible, reduce the size of the commit for ease of review. Consider breaking a large PR into multiple commits, with each one addressing a particular issue. For example, if you are introducing a new feature that requires certain refactor, making a separate refactor commit before the real change helps the reviewer to isolate the changes. - - Do not send commits like addressing comments or fixing tests for previous commits in the same PR. Squash such commits to its corresponding base commit before the PR is rebased and merged. - - Make sure commit messages [follow these guidelines](https://chris.beams.io/posts/git-commit/). In particular (from the guidelines): +- Make sure you follow the [Commit Standards](#commit-standards) section above, which uses Conventional Commits format: - * Separate subject from body with a blank line - * Limit the subject line to 50 characters - * Capitalize the subject line - * Do not end the subject line with a period - * Use the imperative mood in the subject line - * Wrap the body at 72 characters - * Use the body to explain what and why vs. how + - PR titles must follow the conventional commit format (e.g., `feat: Add new feature`, `fix: Resolve bug`) + - All PRs will be squashed into a single commit on merge, so the PR title becomes the commit message + - While developing, you can have multiple commits in your PR for easier review + - Ensure each commit in your PR compiles and passes tests independently + - The PR description should explain what and why, not how. Keep lines wrapped at 72 characters for better readability. Include context about why the change is needed, what issue it addresses, any side effects or breaking changes, and enough detail that someone could understand whether to revert it during an emergency. - Ensure all code is peer reviewed within your own organization or peers before submitting - Implement and address existing feedback before requesting further review - Make a good faith effort to locate example or referential code before requesting someone else direct you towards it @@ -471,6 +541,9 @@ We use the [Fork and Pull model](https://docs.github.com/en/pull-requests/collab - Implement or modify relevant tests, otherwise provide clear explanation why test updates were not necessary - Tag your PR with affected code areas as best as you can, it’s okay to tag too many, better to cut down irrelevant tags than miss getting input from relevant subject matter experts - All tests shall pass before requesting a code review. If there are test failures, even it's from unrelated problems, try to address them by either sending a PR to fix it or creating a Github issue so it can be triaged and fixed soon. +- If adding new dependencies: + * OpenSSF Scorecard: Ensure they have an [OpenSSF Scorecard](https://securityscorecards.dev/#the-checks) score of 5.0 or higher. Dependencies with scores below 5.0 require explicit approval from the TSC. [The OpenSSF score can be checked here](https://scorecard.dev/viewer/). Automated checks will comment on the PR with scorecard scores for new dependencies. + * Vulnerabilities: Ensure new dependencies do not introduce known high or critical severity vulnerabilities. Automated checks will fail the build if such vulnerabilities are detected. In exceptional cases, this can be overridden by TSC vote, requiring an administrator to merge the PR. ### What not to do for Pull Requests * Submit before getting peer review in your own organization diff --git a/Jenkinsfile b/Jenkinsfile index 0df505949a921..9a8e5c5afef8f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -133,6 +133,7 @@ pipeline { returnStdout: true).trim() env.PRESTO_PKG = "presto-server-${PRESTO_VERSION}.tar.gz" env.PRESTO_CLI_JAR = "presto-cli-${PRESTO_VERSION}-executable.jar" + env.PRESTO_FUNCTION_SERVER_JAR = "presto-function-server-${PRESTO_VERSION}-executable.jar" env.PRESTO_BUILD_VERSION = env.PRESTO_VERSION + '-' + sh(script: "git show -s --format=%cd --date=format:'%Y%m%d%H%M%S'", returnStdout: true).trim() + "-" + env.PRESTO_COMMIT_SHA.substring(0, 7) @@ -160,8 +161,9 @@ pipeline { echo "${PRESTO_BUILD_VERSION}" > index.txt git log -n 10 >> index.txt aws s3 cp index.txt ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress - aws s3 cp presto-server/target/${PRESTO_PKG} ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress - aws s3 cp presto-cli/target/${PRESTO_CLI_JAR} ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress + aws s3 cp presto-server/target/${PRESTO_PKG} ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress + aws s3 cp presto-cli/target/${PRESTO_CLI_JAR} ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress + aws s3 cp presto-function-server/target/${PRESTO_FUNCTION_SERVER_JAR} ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/ --no-progress ''' } } @@ -203,8 +205,9 @@ pipeline { secretKeyVariable: 'AWS_SECRET_ACCESS_KEY']]) { sh '''#!/bin/bash -ex cd docker/ - aws s3 cp ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/${PRESTO_PKG} . --no-progress - aws s3 cp ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/${PRESTO_CLI_JAR} . --no-progress + aws s3 cp ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/${PRESTO_PKG} . --no-progress + aws s3 cp ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/${PRESTO_CLI_JAR} . --no-progress + aws s3 cp ${AWS_S3_PREFIX}/${PRESTO_BUILD_VERSION}/${PRESTO_FUNCTION_SERVER_JAR} . --no-progress echo "Building ${DOCKER_IMAGE}" REG_ORG=${AWS_ECR} IMAGE_NAME=${IMG_NAME} TAG=${PRESTO_BUILD_VERSION} ./build.sh ${PRESTO_VERSION} diff --git a/NOTICES b/NOTICES index 9c15b9d6155f4..ab483fa8e8774 100644 --- a/NOTICES +++ b/NOTICES @@ -1,3 +1,21 @@ The code for the t-digest was originally authored by Ted Dunning Adrien Grand contributed the heart of the AVLTreeDigest (https://github.com/jpountz) + +This product includes software from the Apache DataSketches C++ project. +* https://github.com/apache/datasketches-cpp/tree/master/theta +which contains the following NOTICE file: +------- +Apache DataSketches C++ +Copyright 2025 The Apache Software Foundation + +Copyright 2015-2018 Yahoo Inc. +Copyright 2019-2020 Verizon Media +Copyright 2021- Yahoo Inc. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Prior to moving to ASF, the software for this project was developed at +Yahoo Inc. (https://developer.yahoo.com). +------- diff --git a/README.md b/README.md index 004cbe80ad6b3..4b3b1fcdedb39 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Presto +[![LFX Health Score](https://insights.linuxfoundation.org/api/badge/health-score?project=presto)](https://insights.linuxfoundation.org/project/presto) + Presto is a distributed SQL query engine for big data. See the [Presto installation documentation](https://prestodb.io/docs/current/installation.html) for deployment instructions. @@ -14,7 +16,7 @@ See [PrestoDB: Mission and Architecture](ARCHITECTURE.md). ## Requirements * Mac OS X or Linux -* Java 8 Update 151 or higher (8u151+), 64-bit. Both Oracle JDK and OpenJDK are supported. +* Java 17 64-bit. Both Oracle JDK and OpenJDK are supported. * Maven 3.6.3+ (for building) * Python 2.4+ (for running with the launcher script) @@ -29,6 +31,8 @@ Presto is a standard Maven project. Simply run the following command from the pr On the first build, Maven will download all the dependencies from the internet and cache them in the local repository (`~/.m2/repository`), which can take a considerable amount of time. Subsequent builds will be faster. +When building multiple Presto projects locally, each project may write updates to the user's global M2 cache, which could cause build issues. You can configure your local `.mvn/maven.config` to support a local cache specific to that project via `-Dmaven.repo.local=./.m2/repository`. + Presto has a comprehensive set of unit tests that can take several minutes to run. You can disable the tests when building: ./mvnw clean install -DskipTests @@ -38,13 +42,18 @@ After building Presto for the first time, you can load the project into your IDE After opening the project in IntelliJ, double check that the Java SDK is properly configured for the project: * Open the File menu and select Project Structure -* In the SDKs section, ensure that a 1.8 JDK is selected (create one if none exist) -* In the Project section, ensure the Project language level is set to 8.0 as Presto makes use of several Java 8 language features +* In the SDKs section, ensure that a distribution of JDK 17 is selected (create one if none exist) +* In the Project section, ensure the Project language level is set to at least 8.0. +* When using JDK 17, an [IntelliJ bug](https://youtrack.jetbrains.com/issue/IDEA-201168) requires you + to disable the `Use '--release' option for cross-compilation (Java 9 and later)` setting in + `Settings > Build, Execution, Deployment > Compiler > Java Compiler`. If this option remains enabled, + you may encounter errors such as: `package sun.misc does not exist` because IntelliJ fails to resolve + certain internal JDK classes. Presto comes with sample configuration that should work out-of-the-box for development. Use the following options to create a run configuration: * Main Class: `com.facebook.presto.server.PrestoServer` -* VM Options: `-ea -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit -XX:+ExplicitGCInvokesConcurrent -Xmx2G -Dconfig=etc/config.properties -Dlog.levels-file=etc/log.properties` +* VM Options: `-ea -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit -XX:+ExplicitGCInvokesConcurrent -Xmx2G -Dconfig=etc/config.properties -Dlog.levels-file=etc/log.properties -Djdk.attach.allowAttachSelf=true` * Working directory: `$MODULE_WORKING_DIR$` or `$MODULE_DIR$`(Depends your version of IntelliJ) * Use classpath of module: `presto-main` @@ -54,6 +63,34 @@ Additionally, the Hive plugin must be configured with location of your Hive meta -Dhive.metastore.uri=thrift://localhost:9083 +To modify the loaded plugins in IntelliJ, modify the `config.properties` located in `presto-main/etc`. You can modify `plugin.bundles` with the location of the plugin pom.xml + +### Additional configuration for Java 17 + +When running with Java 17, additional `--add-opens` flags are required to allow reflective access used by certain catalogs based on which catalogs are configured. +For the default set of catalogs loaded when starting the Presto server in IntelliJ without changes, add the following flags to the **VM Options**: + + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.lang.ref=ALL-UNNAMED + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED + --add-opens=java.base/java.net=ALL-UNNAMED + --add-opens=java.base/java.nio=ALL-UNNAMED + --add-opens=java.base/java.security=ALL-UNNAMED + --add-opens=java.base/javax.security.auth=ALL-UNNAMED + --add-opens=java.base/javax.security.auth.login=ALL-UNNAMED + --add-opens=java.base/java.text=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED + --add-opens=java.base/java.util.regex=ALL-UNNAMED + --add-opens=java.base/jdk.internal.loader=ALL-UNNAMED + --add-opens=java.base/sun.security.action=ALL-UNNAMED + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED + +These flags ensure that internal JDK modules are accessible at runtime for components used by Presto’s default configuration. +It is not a comprehensive list. Additional flags may need to be added, depending on the catalogs configured on the server. + ### Using SOCKS for Hive or HDFS If your Hive metastore or HDFS cluster is not directly accessible to your local machine, you can use SSH port forwarding to access it. Setup a dynamic SOCKS proxy with SSH listening on local port 1080: diff --git a/docker/Dockerfile b/docker/Dockerfile index 7bd81465b0178..bb7ffa0d86fe2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -10,7 +10,7 @@ ENV PRESTO_HOME="/opt/presto-server" COPY $PRESTO_PKG . COPY $PRESTO_CLI_JAR /opt/presto-cli -RUN dnf install -y java-11-openjdk less procps python3 \ +RUN dnf install -y java-17-openjdk less procps python3 \ && ln -s $(which python3) /usr/bin/python \ # Download Presto and move \ && tar -zxf $PRESTO_PKG \ diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index d4df39e601fb9..a991cee68b7e0 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -1,5 +1,8 @@ #!/bin/sh set -e +trap 'kill -TERM $app 2>/dev/null' TERM -$PRESTO_HOME/bin/launcher run +$PRESTO_HOME/bin/launcher run & +app=$! +wait $app diff --git a/docker/etc/jvm.config.example b/docker/etc/jvm.config.example index 8436145168343..b6956996878c0 100644 --- a/docker/etc/jvm.config.example +++ b/docker/etc/jvm.config.example @@ -7,3 +7,20 @@ -XX:+HeapDumpOnOutOfMemoryError -XX:+ExitOnOutOfMemoryError -Djdk.attach.allowAttachSelf=true +--add-opens=java.base/java.io=ALL-UNNAMED +--add-opens=java.base/java.lang=ALL-UNNAMED +--add-opens=java.base/java.lang.ref=ALL-UNNAMED +--add-opens=java.base/java.lang.reflect=ALL-UNNAMED +--add-opens=java.base/java.net=ALL-UNNAMED +--add-opens=java.base/java.nio=ALL-UNNAMED +--add-opens=java.base/java.security=ALL-UNNAMED +--add-opens=java.base/javax.security.auth=ALL-UNNAMED +--add-opens=java.base/javax.security.auth.login=ALL-UNNAMED +--add-opens=java.base/java.text=ALL-UNNAMED +--add-opens=java.base/java.util=ALL-UNNAMED +--add-opens=java.base/java.util.concurrent=ALL-UNNAMED +--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED +--add-opens=java.base/java.util.regex=ALL-UNNAMED +--add-opens=java.base/jdk.internal.loader=ALL-UNNAMED +--add-opens=java.base/sun.security.action=ALL-UNNAMED +--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED diff --git a/pom.xml b/pom.xml index 57dc0439a8401..6765d88cf1ef3 100644 --- a/pom.xml +++ b/pom.xml @@ -5,12 +5,12 @@ com.facebook.airlift airbase - 104 + 108 com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT pom presto-root @@ -30,7 +30,7 @@ scm:git:git://github.com/prestodb/presto.git https://github.com/prestodb/presto - 0.293 + HEAD @@ -40,60 +40,64 @@ true src/checkstyle/presto-checks.xml - 1.8.0-151 + 17 3.3.9 - 4.7.1 - 0.216 + 4.13.2 + 0.225 ${dep.airlift.version} 0.38 0.6 1.12.782 4.12.0 - 3.4.0 + 3.49.0 19.3.0.0 - 1.43 + ${dep.airlift.version} - 2.12.7 - 1.54 + 2.13.1 + 1.55 7.5 - 8.11.3 + 9.12.0 3.8.0 - 1.2.13 - 1.13.1 - 1.6.8 + 1.16.0 9.7.1 1.9.17 313 - 1.7.36 - 3.9.0 - 0.11.0 + 2.0.16 + 3.9.1 + 1.3.0 30.0.1 2.3.1 + 4.0.6 0.14.0 1.20.5 3.4.1 2.9.0 3.1.3 - 2.36.0 + 2.37.0 32.1.0-jre 2.15.4 3.0.0 - 1.11.4 - 1.26.2 + 1.12.0 + 1.27.1 4.29.0 - 4.1.119.Final - 2.0 - 9.4.56.v20240826 + 12.0.29 + 4.1.130.Final + 1.2.8 + 2.5 2.12.1 - 3.17.0 - 5.1.0 + 3.18.0 + 6.0.0 17.0.0 + 3.5.4 + 2.0.2-6 + 3.4.1-1 + 2.13.17 net.java.dev.jna jna - 5.13.0 + 5.18.1 org.postgresql postgresql - 42.6.1 + 42.7.9 @@ -1466,13 +1569,13 @@ com.microsoft.sqlserver mssql-jdbc - 12.8.1.jre8 + 13.2.1.jre11 org.fusesource.jansi jansi - 1.18 + 2.4.2 @@ -1540,10 +1643,22 @@ 0.11.5 + + io.projectreactor.netty + reactor-netty-core + ${dep.reactor-netty.version} + + + + io.projectreactor.netty + reactor-netty-http + ${dep.reactor-netty.version} + + org.apache.thrift libthrift - 0.14.1 + 0.18.1 org.apache.httpcomponents @@ -1617,9 +1732,9 @@ - com.facebook.airlift.discovery + com.facebook.airlift discovery-server - 1.33 + ${dep.airlift.version} @@ -1813,7 +1928,7 @@ com.google.auth google-auth-library-oauth2-http - 0.12.0 + 1.39.1 commons-logging @@ -1977,7 +2092,7 @@ commons-codec commons-codec - 1.17.0 + ${dep.commons.codec.version} @@ -2044,6 +2159,24 @@ 9.6.3-4 + + org.apache.pinot + pinot-spi + ${dep.pinot.version} + + + + org.apache.pinot + pinot-common + ${dep.pinot.version} + + + + org.apache.pinot + pinot-core + ${dep.pinot.version} + + org.apache.kafka kafka_2.12 @@ -2187,33 +2320,33 @@ - org.apache.pinot - presto-pinot-driver - ${dep.pinot.version} + org.apache.kafka + kafka-metadata + ${dep.kafka.version} org.xerial.snappy snappy-java - 1.1.10.4 + 1.1.10.7 com.github.luben zstd-jni - 1.5.2-3 + 1.5.7-6 org.roaringbitmap RoaringBitmap - 0.9.3 + 1.3.0 org.apache.zookeeper zookeeper - 3.9.3 + 3.9.4 jline @@ -2237,19 +2370,19 @@ org.checkerframework checker-qual - 3.37.0 + 3.52.0 org.jgrapht jgrapht-core - 1.3.1 + 1.5.2 redis.clients jedis - 2.6.2 + 7.0.0 @@ -2275,6 +2408,10 @@ com.google.inject.extensions guice-multibindings + + com.google.inject + guice + org.testng testng @@ -2295,6 +2432,10 @@ org.testng testng + + com.google.inject + guice + @@ -2315,6 +2456,10 @@ org.testng testng + + com.google.inject + guice + @@ -2331,6 +2476,10 @@ org.testng testng + + com.google.inject + guice + @@ -2354,7 +2503,7 @@ org.apache.lucene - lucene-analyzers-common + lucene-analysis-common ${dep.lucene.version} @@ -2367,13 +2516,13 @@ org.locationtech.jts jts-core - 1.19.0 + ${dep.jts.version} org.locationtech.jts.io jts-io-common - 1.19.0 + ${dep.jts.version} junit @@ -2385,7 +2534,7 @@ org.anarres.lzo lzo-hadoop - 1.0.5 + 1.0.6 org.apache.hadoop @@ -2420,24 +2569,11 @@ 2.1.0-3 - - com.facebook.presto - presto-clp - ${project.version} - - - - com.facebook.presto - presto-clp - ${project.version} - test-jar - - org.javassist javassist - 3.22.0-GA + 3.30.2-GA @@ -2449,38 +2585,31 @@ net.jodah failsafe - 2.0.1 - - - - com.facebook.presto.spark - spark-core - 2.0.2-6 - provided + 2.4.4 com.clearspring.analytics stream - 2.9.5 + 2.9.8 io.opentelemetry opentelemetry-api - 1.19.0 + ${dep.io.opentelemetry.version} io.opentelemetry opentelemetry-context - 1.19.0 + ${dep.io.opentelemetry.version} io.opentelemetry opentelemetry-exporter-otlp - 1.19.0 + ${dep.io.opentelemetry.version} com.squareup.okhttp3 @@ -2492,43 +2621,43 @@ io.opentelemetry opentelemetry-extension-trace-propagators - 1.19.0 + ${dep.io.opentelemetry.version} io.opentelemetry opentelemetry-sdk - 1.19.0 + ${dep.io.opentelemetry.version} io.opentelemetry opentelemetry-sdk-common - 1.19.0 + ${dep.io.opentelemetry.version} io.opentelemetry opentelemetry-sdk-trace - 1.19.0 + ${dep.io.opentelemetry.version} - io.opentelemetry + io.opentelemetry.semconv opentelemetry-semconv - 1.19.0-alpha + 1.37.0 org.apache.datasketches datasketches-memory - 2.2.0 + ${dep.datasketches-memory.version} org.apache.datasketches datasketches-java - 5.0.1 + ${dep.datasketches-java.version} @@ -2615,6 +2744,35 @@ ${dep.arrow.version} + + org.mariadb.jdbc + mariadb-java-client + ${dep.mariadb.version} + + + + com.nimbusds + nimbus-jose-jwt + 10.0.2 + + + + com.nimbusds + oauth2-oidc-sdk + 11.30.1 + + + org.aw2 + asm + + + + + + org.scala-lang + scala-library + ${scala.version} + @@ -2662,7 +2820,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-maven-plugin ${dep.drift.version} @@ -2688,7 +2846,7 @@ org.codehaus.mojo extra-enforcer-rules - 1.6.2 + 1.11.0 @@ -2722,6 +2880,7 @@ git.properties + mime.types @@ -2734,18 +2893,6 @@ - - org.sonatype.plugins - nexus-staging-maven-plugin - ${dep.nexus-staging-plugin.version} - - - ossrh - https://oss.sonatype.org/ - - - com.facebook.presto presto-maven-plugin @@ -2787,6 +2934,13 @@ org.apache.maven.plugins maven-surefire-plugin 3.0.0-M7 + + + org.apache.maven.surefire + surefire-testng + 3.5.1 + + **/Test*.java @@ -2951,22 +3105,14 @@ - org.sonatype.plugins - nexus-staging-maven-plugin - - - - default-deploy - deploy - - deploy - - - + org.sonatype.central + central-publishing-maven-plugin + 0.8.0 + true - ossrh - https://oss.sonatype.org/ + ossrh + ${release.autoPublish} + validated @@ -3083,5 +3229,149 @@ + + spark2 + + + true + + !spark-version + + + + + 2.0.2-6 + + + + presto-spark-classloader-spark2 + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.3.0 + + + org.codehaus.mojo + extra-enforcer-rules + 1.11.0 + + + + true + + + + + org.codehaus.plexus:plexus-utils + com.google.guava:guava + com.fasterxml.jackson.core:jackson-annotations + com.fasterxml.jackson.core:jackson-core + com.fasterxml.jackson.core:jackson-databind + + + + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + true + + com.github.benmanes.caffeine.* + + META-INF.versions.9.module-info + + META-INF.versions.11.module-info + + META-INF.versions.9.org.apache.lucene.* + + + + + + + + + + + + spark3 + + + + spark-version + 3 + + + + + 3.4.1-1 + + + + presto-spark-classloader-spark3 + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.3.0 + + + org.codehaus.mojo + extra-enforcer-rules + 1.11.0 + + + + true + + + + + org.codehaus.plexus:plexus-utils + com.google.guava:guava + com.fasterxml.jackson.core:jackson-annotations + com.fasterxml.jackson.core:jackson-core + com.fasterxml.jackson.core:jackson-databind + + + + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + true + + com.github.benmanes.caffeine.* + + META-INF.versions.9.module-info + + META-INF.versions.11.module-info + + META-INF.versions.9.org.apache.lucene.* + + + + + + + + + + diff --git a/presto-accumulo/pom.xml b/presto-accumulo/pom.xml index 39d0e1b5c7d23..db5bb8eeb5baf 100644 --- a/presto-accumulo/pom.xml +++ b/presto-accumulo/pom.xml @@ -5,10 +5,11 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-accumulo + presto-accumulo Presto - Accumulo Connector presto-plugin @@ -16,7 +17,8 @@ ${project.parent.basedir} 1.10.1 2.12.0 - 2.24.3 + 2.25.3 + true @@ -172,9 +174,11 @@ ${dep.curator.version} + com.facebook.presto.hadoop hadoop-apache2 + 2.7.4-12 @@ -209,13 +213,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api @@ -228,31 +232,19 @@ guice - - com.google.code.findbugs - jsr305 - true - - - - commons-lang - commons-lang - 2.6 - - org.apache.commons commons-lang3 - com.github.docker-java - docker-java-api + org.apache.zookeeper + zookeeper - org.apache.zookeeper - zookeeper + jakarta.inject + jakarta.inject-api @@ -279,7 +271,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -297,7 +289,7 @@ - io.airlift + com.facebook.airlift units provided @@ -347,10 +339,15 @@ test + + com.github.docker-java + docker-java-api + test + + org.jetbrains annotations - 19.0.0 test @@ -379,6 +376,20 @@ + + + + org.apache.maven.plugins + maven-dependency-plugin + + + commons-io:commons-io + + + + + + skip-accumulo-tests diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java index e923d3d740bb0..3600665a2006b 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloClient.java @@ -39,6 +39,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.AccumuloException; import org.apache.accumulo.core.client.AccumuloSecurityException; import org.apache.accumulo.core.client.Connector; @@ -52,8 +53,6 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.io.Text; -import javax.inject.Inject; - import java.security.InvalidParameterException; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloConnector.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloConnector.java index f0a96454c5336..05a34eb9c50b5 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloConnector.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloConnector.java @@ -28,8 +28,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.ConcurrentHashMap; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadata.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadata.java index 1d968bae6c1b5..c1ea2bc36dcb5 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadata.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadata.java @@ -42,8 +42,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -57,6 +56,7 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -259,7 +259,7 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable } @Override - public List getTableLayouts( + public ConnectorTableLayoutResult getTableLayoutForConstraint( ConnectorSession session, ConnectorTableHandle table, Constraint constraint, @@ -267,7 +267,7 @@ public List getTableLayouts( { AccumuloTableHandle tableHandle = (AccumuloTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new AccumuloTableLayoutHandle(tableHandle, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -282,7 +282,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect AccumuloTableHandle handle = (AccumuloTableHandle) table; checkArgument(handle.getConnectorId().equals(connectorId), "table is not for this connector"); SchemaTableName tableName = new SchemaTableName(handle.getSchema(), handle.getTable()); - ConnectorTableMetadata metadata = getTableMetadata(tableName); + ConnectorTableMetadata metadata = getTableMetadata(session, tableName); if (metadata == null) { throw new TableNotFoundException(tableName); } @@ -353,7 +353,7 @@ public Map> listTableColumns(ConnectorSess requireNonNull(prefix, "prefix is null"); ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { - ConnectorTableMetadata tableMetadata = getTableMetadata(tableName); + ConnectorTableMetadata tableMetadata = getTableMetadata(session, tableName); // table can disappear during listing operation if (tableMetadata != null) { columns.put(tableName, tableMetadata.getColumns()); @@ -385,7 +385,7 @@ public void rollback() } } - private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName tableName) { if (!client.getSchemaNames().contains(tableName.getSchemaName())) { return null; @@ -398,7 +398,13 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) return null; } - return new ConnectorTableMetadata(tableName, table.getColumnsMetadata()); + List columns = table.getColumnsMetadata().stream() + .map(column -> column.toBuilder() + .setName(normalizeIdentifier(session, column.getName())) + .build()) + .collect(toImmutableList()); + + return new ConnectorTableMetadata(tableName, columns); } return null; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadataFactory.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadataFactory.java index 89151dda058e0..031c93e1b167f 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadataFactory.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloMetadataFactory.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.accumulo; -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java index 57168596ddeda..6b78ed0095605 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloModule.java @@ -32,6 +32,7 @@ import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Scopes; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.AccumuloException; import org.apache.accumulo.core.client.AccumuloSecurityException; import org.apache.accumulo.core.client.Connector; @@ -42,7 +43,6 @@ import org.apache.log4j.Level; import org.apache.log4j.PatternLayout; -import javax.inject.Inject; import javax.inject.Provider; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloSplitManager.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloSplitManager.java index e9e2966e6bc12..f1b24c6c745ca 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloSplitManager.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloSplitManager.java @@ -32,8 +32,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloTableManager.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloTableManager.java index 0a7099bf3383d..34292611853be 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloTableManager.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/AccumuloTableManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.PrestoException; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.AccumuloException; import org.apache.accumulo.core.client.AccumuloSecurityException; import org.apache.accumulo.core.client.Connector; @@ -25,8 +26,6 @@ import org.apache.accumulo.core.iterators.IteratorUtil.IteratorScope; import org.apache.hadoop.io.Text; -import javax.inject.Inject; - import java.util.EnumSet; import java.util.Map; import java.util.Set; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java index f99f96e79ac03..64027249f3ae1 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloConfig.java @@ -16,10 +16,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; -import io.airlift.units.Duration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java index 304ba6de591c9..a8dcab13cb72c 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/conf/AccumuloSessionProperties.java @@ -13,12 +13,11 @@ */ package com.facebook.presto.accumulo.conf; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java index 2e14a31ece8fe..7a46d56fcc77f 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/ColumnCardinalityCache.java @@ -15,6 +15,7 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.accumulo.conf.AccumuloConfig; import com.facebook.presto.accumulo.model.AccumuloColumnConstraint; import com.facebook.presto.spi.PrestoException; @@ -27,7 +28,8 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.MultimapBuilder; -import io.airlift.units.Duration; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.BatchScanner; import org.apache.accumulo.core.client.Connector; import org.apache.accumulo.core.client.TableNotFoundException; @@ -39,9 +41,6 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.io.Text; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.Collection; import java.util.HashMap; import java.util.Map; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java index 1a766643547fb..5fad5fffd64ca 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/IndexLookup.java @@ -15,6 +15,7 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.accumulo.model.AccumuloColumnConstraint; import com.facebook.presto.accumulo.model.TabletSplitMetadata; import com.facebook.presto.accumulo.serializers.AccumuloRowSerializer; @@ -24,7 +25,8 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; -import io.airlift.units.Duration; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.BatchScanner; import org.apache.accumulo.core.client.Connector; import org.apache.accumulo.core.client.Scanner; @@ -35,9 +37,6 @@ import org.apache.accumulo.core.security.Authorizations; import org.apache.hadoop.io.Text; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/Indexer.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/Indexer.java index 02930eec39f08..3d3e0215fd35e 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/Indexer.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/index/Indexer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.accumulo.index; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.accumulo.Types; import com.facebook.presto.accumulo.iterators.MaxByteArrayCombiner; import com.facebook.presto.accumulo.iterators.MinByteArrayCombiner; @@ -45,12 +46,10 @@ import org.apache.accumulo.core.iterators.user.SummingCombiner; import org.apache.accumulo.core.security.Authorizations; import org.apache.accumulo.core.security.ColumnVisibility; -import org.apache.commons.lang.ArrayUtils; +import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.io.Text; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.Closeable; import java.nio.ByteBuffer; import java.util.Collection; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloPageSinkProvider.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloPageSinkProvider.java index 055c48477d56a..6eeb81127eee6 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloPageSinkProvider.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloPageSinkProvider.java @@ -23,10 +23,9 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.Connector; -import javax.inject.Inject; - import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordCursor.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordCursor.java index 39fbf440dc975..4ca2e9a250ea1 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordCursor.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordCursor.java @@ -30,7 +30,7 @@ import org.apache.accumulo.core.data.Value; import org.apache.accumulo.core.iterators.FirstEntryInRowIterator; import org.apache.accumulo.core.iterators.user.WholeRowIterator; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.io.Text; import java.io.IOException; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordSetProvider.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordSetProvider.java index e1651443a479f..724fcfb7fa9c5 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordSetProvider.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/io/AccumuloRecordSetProvider.java @@ -24,10 +24,9 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.accumulo.core.client.Connector; -import javax.inject.Inject; - import java.util.List; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/metadata/ZooKeeperMetadataManager.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/metadata/ZooKeeperMetadataManager.java index a581839c16c67..b386a8745450e 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/metadata/ZooKeeperMetadataManager.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/metadata/ZooKeeperMetadataManager.java @@ -23,13 +23,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.CuratorFrameworkFactory; import org.apache.curator.retry.RetryForever; import org.apache.zookeeper.KeeperException; -import javax.inject.Inject; - import java.io.IOException; import java.util.HashSet; import java.util.Locale; diff --git a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/model/Row.java b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/model/Row.java index 893ba72b64526..bcec391eb259b 100644 --- a/presto-accumulo/src/main/java/com/facebook/presto/accumulo/model/Row.java +++ b/presto-accumulo/src/main/java/com/facebook/presto/accumulo/model/Row.java @@ -21,7 +21,7 @@ import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import java.sql.Date; import java.sql.Time; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/AccumuloQueryRunner.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/AccumuloQueryRunner.java index f2d068f7944e0..998540e132658 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/AccumuloQueryRunner.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/AccumuloQueryRunner.java @@ -29,10 +29,10 @@ import java.util.Map; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java index d98681a3e8a3d..957b10f793dad 100644 --- a/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java +++ b/presto-accumulo/src/test/java/com/facebook/presto/accumulo/TestAccumuloDistributedQueries.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.intellij.lang.annotations.Language; +import org.testng.annotations.Optional; import org.testng.annotations.Test; import static com.facebook.presto.accumulo.AccumuloQueryRunner.createAccumuloQueryRunner; @@ -120,6 +121,12 @@ public void testUpdate() // Updates are not supported by the connector } + @Override + public void testNonAutoCommitTransactionWithRollback() + { + // This connector do not support rollback for insert actions + } + @Override public void testInsert() { @@ -324,7 +331,7 @@ public void testScalarSubquery() } @Override - public void testShowColumns() + public void testShowColumns(@Optional("PARQUET") String storageFormat) { // Override base class because table descriptions for Accumulo connector include comments MaterializedResult actual = computeActual("SHOW COLUMNS FROM orders"); diff --git a/presto-analyzer/pom.xml b/presto-analyzer/pom.xml index 1155e1aee43f0..8271b7311cd77 100644 --- a/presto-analyzer/pom.xml +++ b/presto-analyzer/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-analyzer @@ -13,6 +13,7 @@ ${project.parent.basedir} + true @@ -37,14 +38,20 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true - javax.inject - javax.inject + jakarta.annotation + jakarta.annotation-api + true + + + + jakarta.inject + jakarta.inject-api diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java index d9a75e912cebc..c3c1f43f8e142 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Analysis.java @@ -14,24 +14,35 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.SourceColumn; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.analyzer.AccessControlInfo; import com.facebook.presto.spi.analyzer.AccessControlInfoForTable; import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.spi.analyzer.AccessControlRole; +import com.facebook.presto.spi.analyzer.UpdateInfo; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AccessControlContext; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.tree.ExistsPredicate; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.Identifier; @@ -43,6 +54,7 @@ import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.Parameter; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.QuantifiedComparisonExpression; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; @@ -51,6 +63,7 @@ import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.SubqueryExpression; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionInvocation; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableList; @@ -59,9 +72,8 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Multiset; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.ArrayDeque; import java.util.ArrayList; @@ -100,7 +112,7 @@ public class Analysis @Nullable private final Statement root; private final Map, Expression> parameters; - private String updateType; + private UpdateInfo updateInfo; private final Map, NamedQuery> namedQueries = new LinkedHashMap<>(); @@ -167,6 +179,13 @@ public class Analysis private final Multiset columnMaskScopes = HashMultiset.create(); private final Map, Map> columnMasks = new LinkedHashMap<>(); + // for call distributed procedure + private Optional distributedProcedureType = Optional.empty(); + private Optional procedureName = Optional.empty(); + private Optional procedureArguments = Optional.empty(); + private Optional callTarget = Optional.empty(); + private Optional targetQuery = Optional.empty(); + // for create table private Optional createTableDestination = Optional.empty(); private Map createTableProperties = ImmutableMap.of(); @@ -180,6 +199,7 @@ public class Analysis private Optional analyzeTarget = Optional.empty(); private Optional> updatedColumns = Optional.empty(); + private Optional mergeAnalysis = Optional.empty(); // for describe input and describe output private final boolean isDescribe; @@ -195,16 +215,41 @@ public class Analysis private final Map materializedViews = new LinkedHashMap<>(); + private final Map, MaterializedViewInfo> materializedViewInfoMap = new LinkedHashMap<>(); + private Optional expandedQuery = Optional.empty(); // Keeps track of the subquery we are visiting, so we have access to base query information when processing materialized view status private Optional currentQuerySpecification = Optional.empty(); - public Analysis(@Nullable Statement root, Map, Expression> parameters, boolean isDescribe) + // Track WHERE clause from the query accessing a view for subquery analysis such as materialized view + private Optional viewAccessorWhereClause = Optional.empty(); + + // Maps each output Field to its originating SourceColumn(s) for column-level lineage tracking. + private final Multimap originColumnDetails = ArrayListMultimap.create(); + + // Maps each analyzed Expression to the Field(s) it produces, supporting expression-level lineage. + private final Multimap, Field> fieldLineage = ArrayListMultimap.create(); + + private Optional> updatedSourceColumns = Optional.empty(); + + // names of tables and aliased relations. All names are resolved case-insensitive. + private final Map, QualifiedName> relationNames = new LinkedHashMap<>(); + private final Map, TableFunctionInvocationAnalysis> tableFunctionAnalyses = new LinkedHashMap<>(); + private final Set> aliasedRelations = new LinkedHashSet<>(); + private final Set> polymorphicTableFunctions = new LinkedHashSet<>(); + + // Row id field used for MERGE INTO command. + private final Map, FieldReference> rowIdField = new LinkedHashMap<>(); + + private final ViewDefinitionReferences viewDefinitionReferences; + + public Analysis(@Nullable Statement root, Map, Expression> parameters, boolean isDescribe, ViewDefinitionReferences viewDefinitionReferences) { this.root = root; this.parameters = ImmutableMap.copyOf(requireNonNull(parameters, "parameterMap is null")); this.isDescribe = isDescribe; + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } public Statement getStatement() @@ -212,14 +257,14 @@ public Statement getStatement() return root; } - public String getUpdateType() + public UpdateInfo getUpdateInfo() { - return updateType; + return updateInfo; } - public void setUpdateType(String updateType) + public void setUpdateInfo(UpdateInfo updateInfo) { - this.updateType = updateType; + this.updateInfo = updateInfo; } public boolean isCreateTableAsSelectWithData() @@ -412,6 +457,16 @@ public Expression getJoinCriteria(Join join) return joins.get(NodeRef.of(join)); } + public void setRowIdField(Table table, FieldReference field) + { + rowIdField.put(NodeRef.of(table), field); + } + + public FieldReference getRowIdField(Table table) + { + return rowIdField.get(NodeRef.of(table)); + } + public void recordSubqueries(Node node, ExpressionAnalysis expressionAnalysis) { NodeRef key = NodeRef.of(node); @@ -645,6 +700,46 @@ public Optional getCreateTableDestination() return createTableDestination; } + public Optional getProcedureName() + { + return procedureName; + } + + public void setProcedureName(Optional procedureName) + { + this.procedureName = procedureName; + } + + public Optional getDistributedProcedureType() + { + return distributedProcedureType; + } + + public void setDistributedProcedureType(Optional distributedProcedureType) + { + this.distributedProcedureType = distributedProcedureType; + } + + public Optional getProcedureArguments() + { + return procedureArguments; + } + + public void setProcedureArguments(Optional procedureArguments) + { + this.procedureArguments = procedureArguments; + } + + public Optional getCallTarget() + { + return callTarget; + } + + public void setCallTarget(TableHandle callTarget) + { + this.callTarget = Optional.of(callTarget); + } + public Optional getAnalyzeTarget() { return analyzeTarget; @@ -705,6 +800,16 @@ public Optional> getUpdatedColumns() return updatedColumns; } + public Optional getMergeAnalysis() + { + return mergeAnalysis; + } + + public void setMergeAnalysis(MergeAnalysis mergeAnalysis) + { + this.mergeAnalysis = Optional.of(mergeAnalysis); + } + public void setRefreshMaterializedViewAnalysis(RefreshMaterializedViewAnalysis refreshMaterializedViewAnalysis) { this.refreshMaterializedViewAnalysis = Optional.of(refreshMaterializedViewAnalysis); @@ -795,6 +900,19 @@ public boolean hasTableRegisteredForMaterializedView(Table view, Table table) return tablesForMaterializedView.containsEntry(NodeRef.of(view), table); } + public void setMaterializedViewInfo(Table table, MaterializedViewInfo materializedViewInfo) + { + requireNonNull(table, "table is null"); + requireNonNull(materializedViewInfo, "materializedViewInfo is null"); + materializedViewInfoMap.put(NodeRef.of(table), materializedViewInfo); + } + + public Optional getMaterializedViewInfo(Table table) + { + requireNonNull(table, "table is null"); + return Optional.ofNullable(materializedViewInfoMap.get(NodeRef.of(table))); + } + public void setSampleRatio(SampledRelation relation, double ratio) { sampleRatios.put(NodeRef.of(relation), ratio); @@ -843,9 +961,9 @@ public AccessControlReferences getAccessControlReferences() return accessControlReferences; } - public void addQueryAccessControlInfo(AccessControlInfo accessControlInfo) + public ViewDefinitionReferences getViewDefinitionReferences() { - accessControlReferences.setQueryAccessControlInfo(accessControlInfo); + return viewDefinitionReferences; } public void addAccessControlCheckForTable(AccessControlRole accessControlRole, AccessControlInfoForTable accessControlInfoForTable) @@ -893,12 +1011,12 @@ public Map>> getUtilized return ImmutableMap.copyOf(utilizedTableColumnReferences); } - public void populateTableColumnAndSubfieldReferencesForAccessControl(boolean checkAccessControlOnUtilizedColumnsOnly, boolean checkAccessControlWithSubfields) + public void populateTableColumnAndSubfieldReferencesForAccessControl(boolean checkAccessControlOnUtilizedColumnsOnly, boolean checkAccessControlWithSubfields, boolean isLegacyMaterializedViews) { - accessControlReferences.addTableColumnAndSubfieldReferencesForAccessControl(getTableColumnAndSubfieldReferencesForAccessControl(checkAccessControlOnUtilizedColumnsOnly, checkAccessControlWithSubfields)); + accessControlReferences.addTableColumnAndSubfieldReferencesForAccessControl(getTableColumnAndSubfieldReferencesForAccessControl(checkAccessControlOnUtilizedColumnsOnly, checkAccessControlWithSubfields, isLegacyMaterializedViews)); } - private Map>> getTableColumnAndSubfieldReferencesForAccessControl(boolean checkAccessControlOnUtilizedColumnsOnly, boolean checkAccessControlWithSubfields) + private Map>> getTableColumnAndSubfieldReferencesForAccessControl(boolean checkAccessControlOnUtilizedColumnsOnly, boolean checkAccessControlWithSubfields, boolean isLegacyMaterializedViews) { Map>> references; if (!checkAccessControlWithSubfields) { @@ -930,19 +1048,26 @@ else if (!checkAccessControlOnUtilizedColumnsOnly) { }) .collect(toImmutableSet()))))); } - return buildMaterializedViewAccessControl(references); + return buildMaterializedViewAccessControl(references, isLegacyMaterializedViews); } /** - * For a query on materialized view, only check the actual required access controls for its base tables. For the materialized view, - * will not check access control by replacing with AllowAllAccessControl. + * For a query on materialized view: + * - When legacy_materialized_views=true: Only check access controls for base tables, bypass access control + * for the materialized view itself by replacing with AllowAllAccessControl. + * - When legacy_materialized_views=false: Check access control for both the materialized view itself + * and all base tables referenced in the view query. **/ - private Map>> buildMaterializedViewAccessControl(Map>> tableColumnReferences) + private Map>> buildMaterializedViewAccessControl(Map>> tableColumnReferences, boolean isLegacyMaterializedViews) { if (!(getStatement() instanceof Query) || materializedViews.isEmpty()) { return tableColumnReferences; } + if (!isLegacyMaterializedViews) { + return tableColumnReferences; + } + Map>> newTableColumnReferences = new LinkedHashMap<>(); tableColumnReferences.forEach((accessControlInfo, references) -> { @@ -994,11 +1119,37 @@ public void setCurrentSubquery(QuerySpecification currentSubQuery) { this.currentQuerySpecification = Optional.of(currentSubQuery); } + public Optional getCurrentQuerySpecification() { return currentQuerySpecification; } + public void setViewAccessorWhereClause(Expression whereClause) + { + this.viewAccessorWhereClause = Optional.of(whereClause); + } + + public void clearViewAccessorWhereClause() + { + this.viewAccessorWhereClause = Optional.empty(); + } + + public Optional getViewAccessorWhereClause() + { + return viewAccessorWhereClause; + } + + public void setTargetQuery(QuerySpecification targetQuery) + { + this.targetQuery = Optional.of(targetQuery); + } + + public Optional getTargetQuery() + { + return this.targetQuery; + } + public Map> getInvokedFunctions() { Map> functionMap = new HashMap<>(); @@ -1040,6 +1191,38 @@ public boolean hasColumnMask(QualifiedObjectName table, String column, String id return columnMaskScopes.contains(new ColumnMaskScopeEntry(table, column, identity)); } + public void addSourceColumns(Field field, Set sourceColumn) + { + originColumnDetails.putAll(field, sourceColumn); + } + + public Set getSourceColumns(Field field) + { + return ImmutableSet.copyOf(originColumnDetails.get(field)); + } + + public void addExpressionFields(Expression expression, Collection fields) + { + fieldLineage.putAll(NodeRef.of(expression), fields); + } + + public Set getExpressionSourceColumns(Expression expression) + { + return fieldLineage.get(NodeRef.of(expression)).stream() + .flatMap(field -> getSourceColumns(field).stream()) + .collect(toImmutableSet()); + } + + public void setUpdatedSourceColumns(Optional> targetColumns) + { + this.updatedSourceColumns = targetColumns; + } + + public Optional> getUpdatedSourceColumns() + { + return updatedSourceColumns; + } + public void registerTableForColumnMasking(QualifiedObjectName table, String column, String identity) { columnMaskScopes.add(new ColumnMaskScopeEntry(table, column, identity)); @@ -1062,6 +1245,46 @@ public Map getColumnMasks(Table table) return columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()); } + public void setTableFunctionAnalysis(TableFunctionInvocation node, TableFunctionInvocationAnalysis analysis) + { + tableFunctionAnalyses.put(NodeRef.of(node), analysis); + } + + public TableFunctionInvocationAnalysis getTableFunctionAnalysis(TableFunctionInvocation node) + { + return tableFunctionAnalyses.get(NodeRef.of(node)); + } + + public void setRelationName(Relation relation, QualifiedName name) + { + relationNames.put(NodeRef.of(relation), name); + } + + public QualifiedName getRelationName(Relation relation) + { + return relationNames.get(NodeRef.of(relation)); + } + + public void addAliased(Relation relation) + { + aliasedRelations.add(NodeRef.of(relation)); + } + + public boolean isAliased(Relation relation) + { + return aliasedRelations.contains(NodeRef.of(relation)); + } + + public void addPolymorphicTableFunction(TableFunctionInvocation invocation) + { + polymorphicTableFunctions.add(NodeRef.of(invocation)); + } + + public boolean isPolymorphicTableFunction(TableFunctionInvocation invocation) + { + return polymorphicTableFunctions.contains(NodeRef.of(invocation)); + } + @Immutable public static final class Insert { @@ -1117,6 +1340,47 @@ public Query getQuery() } } + @Immutable + public static final class MaterializedViewInfo + { + private final QualifiedObjectName materializedViewName; + private final Table dataTable; + private final Query viewQuery; + private final MaterializedViewDefinition materializedViewDefinition; + + public MaterializedViewInfo( + QualifiedObjectName materializedViewName, + Table dataTable, + Query viewQuery, + MaterializedViewDefinition materializedViewDefinition) + { + this.materializedViewName = requireNonNull(materializedViewName, "materializedViewName is null"); + this.dataTable = requireNonNull(dataTable, "dataTable is null"); + this.viewQuery = requireNonNull(viewQuery, "viewQuery is null"); + this.materializedViewDefinition = requireNonNull(materializedViewDefinition, "materializedViewDefinition is null"); + } + + public QualifiedObjectName getMaterializedViewName() + { + return materializedViewName; + } + + public Table getDataTable() + { + return dataTable; + } + + public Query getViewQuery() + { + return viewQuery; + } + + public MaterializedViewDefinition getMaterializedViewDefinition() + { + return materializedViewDefinition; + } + } + public static final class JoinUsingAnalysis { private final List leftJoinFields; @@ -1312,4 +1576,365 @@ public int hashCode() return Objects.hash(table, column, identity); } } + + public static class TableArgumentAnalysis + { + private final String argumentName; + private final Optional name; + private final Relation relation; + private final Optional> partitionBy; // it is allowed to partition by empty list + private final Optional orderBy; + private final boolean pruneWhenEmpty; + private final boolean rowSemantics; + private final boolean passThroughColumns; + + private TableArgumentAnalysis( + String argumentName, + Optional name, + Relation relation, + Optional> partitionBy, + Optional orderBy, + boolean pruneWhenEmpty, + boolean rowSemantics, + boolean passThroughColumns) + { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.name = requireNonNull(name, "name is null"); + this.relation = requireNonNull(relation, "relation is null"); + this.partitionBy = requireNonNull(partitionBy, "partitionBy is null").map(ImmutableList::copyOf); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowSemantics = rowSemantics; + this.passThroughColumns = passThroughColumns; + } + + public String getArgumentName() + { + return argumentName; + } + + public Optional getName() + { + return name; + } + + public Relation getRelation() + { + return relation; + } + + public Optional> getPartitionBy() + { + return partitionBy; + } + + public Optional getOrderBy() + { + return orderBy; + } + + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean isRowSemantics() + { + return rowSemantics; + } + + public boolean isPassThroughColumns() + { + return passThroughColumns; + } + + public static Builder builder() + { + return new Builder(); + } + + public static final class Builder + { + private String argumentName; + private Optional name = Optional.empty(); + private Relation relation; + private Optional> partitionBy = Optional.empty(); + private Optional orderBy = Optional.empty(); + private boolean pruneWhenEmpty; + private boolean rowSemantics; + private boolean passThroughColumns; + + private Builder() {} + + public Builder withArgumentName(String argumentName) + { + this.argumentName = argumentName; + return this; + } + + public Builder withName(QualifiedName name) + { + this.name = Optional.of(name); + return this; + } + + public Builder withRelation(Relation relation) + { + this.relation = relation; + return this; + } + + public Builder withPartitionBy(List partitionBy) + { + this.partitionBy = Optional.of(partitionBy); + return this; + } + + public Builder withOrderBy(OrderBy orderBy) + { + this.orderBy = Optional.of(orderBy); + return this; + } + + public Builder withPruneWhenEmpty(boolean pruneWhenEmpty) + { + this.pruneWhenEmpty = pruneWhenEmpty; + return this; + } + + public Builder withRowSemantics(boolean rowSemantics) + { + this.rowSemantics = rowSemantics; + return this; + } + + public Builder withPassThroughColumns(boolean passThroughColumns) + { + this.passThroughColumns = passThroughColumns; + return this; + } + + public TableArgumentAnalysis build() + { + return new TableArgumentAnalysis(argumentName, name, relation, partitionBy, orderBy, pruneWhenEmpty, rowSemantics, passThroughColumns); + } + } + } + + /** + * Encapsulates the result of analyzing a table function invocation. + * Includes the connector ID, function name, argument bindings, and the + * connector-specific table function handle needed for planning and execution. + * + * Example of a TableFunctionInvocationAnalysis for a table function + * with two table arguments, required columns, and co-partitioning + * implemented by {@link TestingTableFunctions.TwoTableArgumentsFunction} + * + * SQL: + * SELECT * FROM TABLE(system.two_table_arguments_function( + * input1 => TABLE(t1) PARTITION BY (a, b), + * input2 => TABLE(SELECT 1, 2) t1(x, y) PARTITION BY (x, y) + * COPARTITION(t1, s1.t1))) + * + * Table Function: + * super( + * SCHEMA_NAME, + * "two_table_arguments_function", + * ImmutableList.of( + * TableArgumentSpecification.builder() + * .name("INPUT1") + * .build(), + * TableArgumentSpecification.builder() + * .name("INPUT2") + * .build()), + * GENERIC_TABLE); + * + * analyze: + * return TableFunctionAnalysis.builder() + * .handle(HANDLE) + * .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + * .requiredColumns("INPUT1", ImmutableList.of(0)) + * .requiredColumns("INPUT2", ImmutableList.of(0)) + * .build(); + * + * Example values: + * connectorId = "tpch" + * functionName = "two_table_arguments_function" + * arguments = { + * "input1" -> row(a bigint,b bigint,c bigint,d bigint), + * "input2" -> row(x integer,y integer) + * } + * tableArgumentAnalyses = [ + * { argumentName="INPUT1", relation=Table{t1}, partitionBy=[a, b], orderBy=[], pruneWhenEmpty=false }, + * { argumentName="INPUT2", relation=AliasedRelation{SELECT 1, 2 AS x, y}, partitionBy=[x, y], orderBy=[], pruneWhenEmpty=false } + * ] + * requiredColumns = { + * "INPUT1" -> [0], + * "INPUT2" -> [0] + * } + * copartitioningLists = [ + * ["INPUT2", "INPUT1"] + * ] + * properColumnsCount = 1 + * connectorTableFunctionHandle = TestingTableFunctionPushdownHandle + * transactionHandle = AbstractAnalyzerTest$1$1 + * + */ + public static class TableFunctionInvocationAnalysis + { + private final ConnectorId connectorId; + private final String functionName; + private final Map arguments; + private final List tableArgumentAnalyses; + private final Map> requiredColumns; + private final List> copartitioningLists; + private final int properColumnsCount; + private final ConnectorTableFunctionHandle connectorTableFunctionHandle; + private final ConnectorTransactionHandle transactionHandle; + + public TableFunctionInvocationAnalysis( + ConnectorId connectorId, + String functionName, + Map arguments, + List tableArgumentAnalyses, + Map> requiredColumns, + List> copartitioningLists, + int properColumnsCount, + ConnectorTableFunctionHandle connectorTableFunctionHandle, + ConnectorTransactionHandle transactionHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.functionName = requireNonNull(functionName, "functionName is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.connectorTableFunctionHandle = requireNonNull(connectorTableFunctionHandle, "connectorTableFunctionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); + this.requiredColumns = requiredColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableList.copyOf(entry.getValue()))); + this.copartitioningLists = ImmutableList.copyOf(copartitioningLists); + this.properColumnsCount = properColumnsCount; + } + + public ConnectorId getConnectorId() + { + return connectorId; + } + + public String getFunctionName() + { + return functionName; + } + + public Map getArguments() + { + return arguments; + } + + public List getTableArgumentAnalyses() + { + return tableArgumentAnalyses; + } + + public Map> getRequiredColumns() + { + return requiredColumns; + } + + public List> getCopartitioningLists() + { + return copartitioningLists; + } + + /** + * Proper columns are the columns produced by the table function, as opposed to pass-through columns from input tables. + * Proper columns should be considered the actual result of the table function. + * @return the number of table function's proper columns + */ + public int getProperColumnsCount() + { + return properColumnsCount; + } + + public ConnectorTableFunctionHandle getConnectorTableFunctionHandle() + { + return connectorTableFunctionHandle; + } + + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + } + + public static class MergeAnalysis + { + private final Table targetTable; + private final List targetColumnsMetadata; + private final List targetColumnHandles; + private final List> mergeCaseColumnHandles; + private final Set nonNullableColumnHandles; + private final Map columnHandleFieldNumbers; + private final Scope targetTableScope; + private final Scope joinScope; + + public MergeAnalysis( + Table targetTable, + List targetColumnsMetadata, + List targetColumnHandles, + List> mergeCaseColumnHandles, + Set nonNullableTargetColumnHandles, + Map targetColumnHandleFieldNumbers, + Scope targetTableScope, + Scope joinScope) + { + this.targetTable = requireNonNull(targetTable, "targetTable is null"); + this.targetColumnsMetadata = requireNonNull(targetColumnsMetadata, "targetColumnsMetadata is null"); + this.targetColumnHandles = requireNonNull(targetColumnHandles, "targetColumnHandles is null"); + this.mergeCaseColumnHandles = requireNonNull(mergeCaseColumnHandles, "mergeCaseColumnHandles is null"); + this.nonNullableColumnHandles = requireNonNull(nonNullableTargetColumnHandles, "nonNullableTargetColumnHandles is null"); + this.columnHandleFieldNumbers = requireNonNull(targetColumnHandleFieldNumbers, "targetColumnHandleFieldNumbers is null"); + this.targetTableScope = requireNonNull(targetTableScope, "targetTableScope is null"); + this.joinScope = requireNonNull(joinScope, "joinScope is null"); + } + + public Table getTargetTable() + { + return targetTable; + } + + public List getTargetColumnsMetadata() + { + return targetColumnsMetadata; + } + + public List getTargetColumnHandles() + { + return targetColumnHandles; + } + + public List> getMergeCaseColumnHandles() + { + return mergeCaseColumnHandles; + } + + public Set getNonNullableColumnHandles() + { + return nonNullableColumnHandles; + } + + public Map getColumnHandleFieldNumbers() + { + return columnHandleFieldNumbers; + } + + public Scope getJoinScope() + { + return joinScope; + } + + public Scope getTargetTableScope() + { + return targetTableScope; + } + } } diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalysis.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalysis.java index a5d2b6e4d5427..3a70ce4e66b0b 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalysis.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalysis.java @@ -17,6 +17,8 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.spi.analyzer.QueryAnalysis; +import com.facebook.presto.spi.analyzer.UpdateInfo; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.sql.tree.Explain; import com.google.common.collect.ImmutableSet; @@ -41,9 +43,9 @@ public Analysis getAnalysis() } @Override - public String getUpdateType() + public UpdateInfo getUpdateInfo() { - return analysis.getUpdateType(); + return analysis.getUpdateInfo(); } @Override @@ -64,6 +66,12 @@ public AccessControlReferences getAccessControlReferences() return analysis.getAccessControlReferences(); } + @Override + public ViewDefinitionReferences getViewDefinitionReferences() + { + return analysis.getViewDefinitionReferences(); + } + @Override public boolean isExplainAnalyzeQuery() { diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java index 02181308de7b9..5ad696f46edfc 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparer.java @@ -13,30 +13,36 @@ */ package com.facebook.presto.sql.analyzer; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.resourceGroups.QueryType; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; import com.facebook.presto.spi.analyzer.QueryPreparer; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.sql.analyzer.utils.StatementUtils; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.ExplainType; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; import java.util.Optional; import static com.facebook.presto.common.WarningHandlingLevel.AS_ERROR; +import static com.facebook.presto.common.resourceGroups.QueryType.CALL_DISTRIBUTED_PROCEDURE; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.WARNING_AS_ERROR; @@ -44,6 +50,7 @@ import static com.facebook.presto.sql.analyzer.ConstantExpressionVerifier.verifyExpressionIsConstant; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static com.facebook.presto.sql.analyzer.utils.AnalyzerUtil.createParsingOptions; +import static com.facebook.presto.sql.analyzer.utils.MetadataUtils.createQualifiedObjectName; import static com.facebook.presto.sql.analyzer.utils.ParameterExtractor.getParameterCount; import static com.facebook.presto.sql.tree.ExplainType.Type.VALIDATE; import static java.lang.String.format; @@ -57,11 +64,15 @@ public class BuiltInQueryPreparer implements QueryPreparer { private final SqlParser sqlParser; + private final ProcedureRegistry procedureRegistry; @Inject - public BuiltInQueryPreparer(SqlParser sqlParser) + public BuiltInQueryPreparer( + SqlParser sqlParser, + ProcedureRegistry procedureRegistry) { this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); } @Override @@ -88,6 +99,18 @@ public BuiltInPreparedQuery prepareQuery(AnalyzerOptions analyzerOptions, Statem statement = sqlParser.createStatement(query, createParsingOptions(analyzerOptions)); } + Optional distributedProcedureName = Optional.empty(); + if (statement instanceof Call) { + QualifiedName qualifiedName = ((Call) statement).getName(); + QualifiedObjectName qualifiedObjectName = createQualifiedObjectName(analyzerOptions.getSessionCatalogName(), analyzerOptions.getSessionSchemaName(), + statement, qualifiedName, (catalogName, objectName) -> objectName); + if (procedureRegistry.isDistributedProcedure( + new ConnectorId(qualifiedObjectName.getCatalogName()), + new SchemaTableName(qualifiedObjectName.getSchemaName(), qualifiedObjectName.getObjectName()))) { + distributedProcedureName = Optional.of(qualifiedObjectName); + } + } + if (statement instanceof Explain && ((Explain) statement).isAnalyze()) { Statement innerStatement = ((Explain) statement).getStatement(); Optional innerQueryType = StatementUtils.getQueryType(innerStatement.getClass()); @@ -104,7 +127,7 @@ public BuiltInPreparedQuery prepareQuery(AnalyzerOptions analyzerOptions, Statem if (analyzerOptions.isLogFormattedQueryEnabled()) { formattedQuery = Optional.of(getFormattedQuery(statement, parameters)); } - return new BuiltInPreparedQuery(wrappedStatement, statement, parameters, formattedQuery, prepareSql); + return new BuiltInPreparedQuery(wrappedStatement, statement, parameters, formattedQuery, prepareSql, distributedProcedureName); } private static String getFormattedQuery(Statement statement, List parameters) @@ -132,13 +155,19 @@ public static class BuiltInPreparedQuery private final Statement statement; private final Statement wrappedStatement; private final List parameters; + private final Optional distributedProcedureName; - public BuiltInPreparedQuery(Statement wrappedStatement, Statement statement, List parameters, Optional formattedQuery, Optional prepareSql) + public BuiltInPreparedQuery( + Statement wrappedStatement, + Statement statement, List parameters, + Optional formattedQuery, Optional prepareSql, + Optional distributedProcedureName) { super(formattedQuery, prepareSql); this.wrappedStatement = requireNonNull(wrappedStatement, "wrappedStatement is null"); this.statement = requireNonNull(statement, "statement is null"); this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + this.distributedProcedureName = requireNonNull(distributedProcedureName, "distributedProcedureName is null"); } public Statement getStatement() @@ -158,9 +187,17 @@ public List getParameters() public Optional getQueryType() { + if (getDistributedProcedureName().isPresent()) { + return Optional.of(CALL_DISTRIBUTED_PROCEDURE); + } return StatementUtils.getQueryType(statement.getClass()); } + public Optional getDistributedProcedureName() + { + return this.distributedProcedureName; + } + public boolean isTransactionControlStatement() { return StatementUtils.isTransactionControlStatement(getStatement()); diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/FunctionAndTypeResolver.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/FunctionAndTypeResolver.java index 652d06dc77a1c..03616d8690d51 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/FunctionAndTypeResolver.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/FunctionAndTypeResolver.java @@ -52,4 +52,13 @@ FunctionHandle resolveFunction( FunctionHandle lookupCast(String castType, Type fromType, Type toType); QualifiedObjectName qualifyObjectName(QualifiedName name); + + /** + * Validate a function call during analysis phase on the coordinator. + * Delegates to the FunctionNamespaceManager for custom validation logic. + * + * @param functionHandle The function handle being validated + * @param arguments Raw argument expressions (not yet evaluated) + */ + void validateFunctionCall(FunctionHandle functionHandle, List arguments); } diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationId.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationId.java index b4ff1ad5f3daf..da12259cac66d 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationId.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationId.java @@ -14,8 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.sql.tree.Node; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import static com.google.common.base.MoreObjects.toStringHelper; import static java.lang.String.format; diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationType.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationType.java index 5c3d346bcc941..ad1daf917edb8 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationType.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/RelationType.java @@ -17,9 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Collection; import java.util.List; @@ -52,7 +50,7 @@ public RelationType(List fields) { requireNonNull(fields, "fields is null"); this.allFields = ImmutableList.copyOf(fields); - this.visibleFields = ImmutableList.copyOf(Iterables.filter(fields, not(Field::isHidden))); + this.visibleFields = ImmutableList.copyOf(fields.stream().filter(not(Field::isHidden)).iterator()); int index = 0; ImmutableMap.Builder builder = ImmutableMap.builder(); diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/ResolvedField.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/ResolvedField.java index 5cacdb0fd6ada..76fa38c82dd64 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/ResolvedField.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/ResolvedField.java @@ -14,8 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.common.type.Type; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static java.util.Objects.requireNonNull; diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Scope.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Scope.java index 8d91bfdd3657c..7442c6ba64657 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Scope.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Scope.java @@ -19,8 +19,7 @@ import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.WithQuery; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.HashMap; import java.util.List; @@ -31,7 +30,7 @@ import static com.facebook.presto.sql.analyzer.SemanticExceptions.missingAttributeException; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static java.util.Objects.requireNonNull; @Immutable @@ -137,7 +136,7 @@ private Optional resolveField(Expression node, QualifiedName name throw ambiguousAttributeException(node, name); } else if (matches.size() == 1) { - return Optional.of(asResolvedField(getOnlyElement(matches), fieldIndexOffset, local)); + return Optional.of(asResolvedField(matches.stream().collect(onlyElement()), fieldIndexOffset, local)); } else { if (isColumnReference(name, relation)) { diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java index 656b6dbaab079..957764db853d9 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/SemanticErrorCode.java @@ -57,7 +57,6 @@ public enum SemanticErrorCode INVALID_FUNCTION_NAME, DUPLICATE_PARAMETER_NAME, EXCEPTIONS_WHEN_RESOLVING_FUNCTIONS, - ORDER_BY_MUST_BE_IN_SELECT, ORDER_BY_MUST_BE_IN_AGGREGATE, REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING, @@ -91,6 +90,7 @@ public enum SemanticErrorCode SAMPLE_PERCENTAGE_OUT_OF_RANGE, + PROCEDURE_NOT_FOUND, INVALID_PROCEDURE_ARGUMENTS, INVALID_SESSION_PROPERTY, @@ -112,4 +112,16 @@ public enum SemanticErrorCode TOO_MANY_GROUPING_SETS, INVALID_OFFSET_ROW_COUNT, + + TABLE_FUNCTION_MISSING_RETURN_TYPE, + TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE, + TABLE_FUNCTION_MISSING_ARGUMENT, + TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + TABLE_FUNCTION_INVALID_ARGUMENTS, + TABLE_FUNCTION_IMPLEMENTATION_ERROR, + TABLE_FUNCTION_INVALID_COPARTITIONING, + TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, + TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE, + TABLE_FUNCTION_INVALID_COLUMN_REFERENCE, + TABLE_FUNCTION_COLUMN_NOT_FOUND } diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java new file mode 100644 index 0000000000000..9d241ccd22992 --- /dev/null +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/MetadataUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer.utils; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.Node; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; + +import static com.facebook.presto.spi.StandardErrorCode.SYNTAX_ERROR; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class MetadataUtils +{ + private MetadataUtils() + {} + + public static QualifiedObjectName createQualifiedObjectName(Optional sessionCatalogName, Optional sessionSchemaName, Node node, QualifiedName name, + BiFunction normalizer) + { + requireNonNull(sessionCatalogName, "sessionCatalogName is null"); + requireNonNull(sessionSchemaName, "sessionSchemaName is null"); + requireNonNull(name, "name is null"); + if (name.getParts().size() > 3) { + throw new PrestoException(SYNTAX_ERROR, format("Too many dots in table name: %s", name)); + } + + List parts = Lists.reverse(name.getOriginalParts()); + String objectName = parts.get(0).getValue(); + String schemaName = (parts.size() > 1) ? parts.get(1).getValue() : sessionSchemaName.orElseThrow(() -> + new SemanticException(SCHEMA_NOT_SPECIFIED, node, "Schema must be specified when session schema is not set")); + String catalogName = (parts.size() > 2) ? parts.get(2).getValue() : sessionCatalogName.orElseThrow(() -> + new SemanticException(CATALOG_NOT_SPECIFIED, node, "Catalog must be specified when session catalog is not set")); + + catalogName = catalogName.toLowerCase(ENGLISH); + schemaName = normalizer.apply(catalogName, schemaName); + objectName = normalizer.apply(catalogName, objectName); + return new QualifiedObjectName(catalogName, schemaName, objectName); + } +} diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java index fa84d900eef57..e777e219eae6e 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/utils/StatementUtils.java @@ -33,6 +33,7 @@ import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DescribeInput; import com.facebook.presto.sql.tree.DescribeOutput; +import com.facebook.presto.sql.tree.DropBranch; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropConstraint; import com.facebook.presto.sql.tree.DropFunction; @@ -40,11 +41,13 @@ import com.facebook.presto.sql.tree.DropRole; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; +import com.facebook.presto.sql.tree.DropTag; import com.facebook.presto.sql.tree.DropView; import com.facebook.presto.sql.tree.Explain; import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GrantRoles; import com.facebook.presto.sql.tree.Insert; +import com.facebook.presto.sql.tree.Merge; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.RefreshMaterializedView; @@ -105,6 +108,7 @@ private StatementUtils() {} builder.put(Delete.class, QueryType.DELETE); builder.put(Update.class, QueryType.UPDATE); + builder.put(Merge.class, QueryType.MERGE); builder.put(ShowCatalogs.class, QueryType.DESCRIBE); builder.put(ShowCreate.class, QueryType.DESCRIBE); @@ -131,6 +135,8 @@ private StatementUtils() {} builder.put(RenameColumn.class, QueryType.DATA_DEFINITION); builder.put(DropColumn.class, QueryType.DATA_DEFINITION); builder.put(DropTable.class, QueryType.DATA_DEFINITION); + builder.put(DropBranch.class, QueryType.DATA_DEFINITION); + builder.put(DropTag.class, QueryType.DATA_DEFINITION); builder.put(DropConstraint.class, QueryType.DATA_DEFINITION); builder.put(AddConstraint.class, QueryType.DATA_DEFINITION); builder.put(AlterColumnNotNull.class, QueryType.DATA_DEFINITION); diff --git a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java index bdc99102d1909..c37a54fca368d 100644 --- a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java +++ b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestBuiltInQueryPreparer.java @@ -13,47 +13,114 @@ */ package com.facebook.presto.sql.analyzer; +import com.facebook.presto.common.resourceGroups.QueryType; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer.BuiltInPreparedQuery; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.AllColumns; +import com.facebook.presto.sql.tree.Call; +import com.facebook.presto.sql.tree.CallArgument; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.StringLiteral; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.sql.QueryUtil.selectList; import static com.facebook.presto.sql.QueryUtil.simpleQuery; import static com.facebook.presto.sql.QueryUtil.table; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_PARAMETER_USAGE; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; public class TestBuiltInQueryPreparer { private static final SqlParser SQL_PARSER = new SqlParser(); - private static final BuiltInQueryPreparer QUERY_PREPARER = new BuiltInQueryPreparer(SQL_PARSER); private static final Map emptyPreparedStatements = ImmutableMap.of(); private static final AnalyzerOptions testAnalyzerOptions = AnalyzerOptions.builder().build(); + private static ProcedureRegistry procedureRegistry; + private static BuiltInQueryPreparer queryPreparer; + + @BeforeClass + public void setup() + { + procedureRegistry = new TestProcedureRegistry(); + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, VARCHAR)); + arguments.add(new Argument(TABLE_NAME, VARCHAR)); + + List distributedArguments = new ArrayList<>(); + distributedArguments.add(new DistributedProcedure.Argument(SCHEMA, VARCHAR)); + distributedArguments.add(new DistributedProcedure.Argument(TABLE_NAME, VARCHAR)); + List> procedures = new ArrayList<>(); + procedures.add(new Procedure("system", "fun", arguments)); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_fun", + distributedArguments, + (session, transactionContext, procedureHandle, fragments, sortOrderIndex) -> null, + (session, transactionContext, procedureHandle, fragments) -> {}, + ignored -> new TestProcedureRegistry.TestProcedureContext())); + procedureRegistry.addProcedures(new ConnectorId("test"), procedures); + queryPreparer = new BuiltInQueryPreparer(SQL_PARSER, procedureRegistry); + } @Test public void testSelectStatement() { - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "SELECT * FROM foo", emptyPreparedStatements, WarningCollector.NOOP); + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "SELECT * FROM foo", emptyPreparedStatements, WarningCollector.NOOP); assertEquals(preparedQuery.getStatement(), simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("foo")))); } + @Test + public void testCallProcedureStatement() + { + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "call test.system.fun('a', 'b')", emptyPreparedStatements, WarningCollector.NOOP); + List arguments = new ArrayList<>(); + arguments.add(new CallArgument(new StringLiteral("a"))); + arguments.add(new CallArgument(new StringLiteral("b"))); + assertEquals(preparedQuery.getStatement(), + new Call(QualifiedName.of("test", "system", "fun"), arguments)); + assertTrue(preparedQuery.getQueryType().isPresent()); + assertEquals(preparedQuery.getQueryType().get(), QueryType.DATA_DEFINITION); + } + + @Test + public void testCallDistributedProcedureStatement() + { + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "call test.system.distributed_fun('a', 'b')", emptyPreparedStatements, WarningCollector.NOOP); + List arguments = new ArrayList<>(); + arguments.add(new CallArgument(new StringLiteral("a"))); + arguments.add(new CallArgument(new StringLiteral("b"))); + assertEquals(preparedQuery.getStatement(), + new Call(QualifiedName.of("test", "system", "distributed_fun"), arguments)); + assertTrue(preparedQuery.getQueryType().isPresent()); + assertEquals(preparedQuery.getQueryType().get(), QueryType.CALL_DISTRIBUTED_PROCEDURE); + } + @Test public void testExecuteStatement() { Map preparedStatements = ImmutableMap.of("my_query", "SELECT * FROM foo"); - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query", preparedStatements, WarningCollector.NOOP); + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query", preparedStatements, WarningCollector.NOOP); assertEquals(preparedQuery.getStatement(), simpleQuery(selectList(new AllColumns()), table(QualifiedName.of("foo")))); } @@ -62,7 +129,7 @@ public void testExecuteStatement() public void testExecuteStatementDoesNotExist() { try { - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "execute my_query", emptyPreparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "execute my_query", emptyPreparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (PrestoException e) { @@ -75,7 +142,7 @@ public void testTooManyParameters() { try { Map preparedStatements = ImmutableMap.of("my_query", "SELECT * FROM foo where col1 = ?"); - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1,2", preparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1,2", preparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (SemanticException e) { @@ -88,7 +155,7 @@ public void testTooFewParameters() { try { Map preparedStatements = ImmutableMap.of("my_query", "SELECT ? FROM foo where col1 = ?"); - QUERY_PREPARER.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1", preparedStatements, WarningCollector.NOOP); + queryPreparer.prepareQuery(testAnalyzerOptions, "EXECUTE my_query USING 1", preparedStatements, WarningCollector.NOOP); fail("expected exception"); } catch (SemanticException e) { @@ -100,7 +167,7 @@ public void testTooFewParameters() public void testFormattedQuery() { AnalyzerOptions analyzerOptions = AnalyzerOptions.builder().setLogFormattedQueryEnabled(true).build(); - BuiltInPreparedQuery preparedQuery = QUERY_PREPARER.prepareQuery( + BuiltInPreparedQuery preparedQuery = queryPreparer.prepareQuery( analyzerOptions, "PREPARE test FROM SELECT * FROM foo where col1 = ?", emptyPreparedStatements, @@ -112,7 +179,7 @@ public void testFormattedQuery() " foo\n" + " WHERE (col1 = ?)\n")); - preparedQuery = QUERY_PREPARER.prepareQuery( + preparedQuery = queryPreparer.prepareQuery( analyzerOptions, "PREPARE test FROM SELECT * FROM foo", emptyPreparedStatements, diff --git a/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java new file mode 100644 index 0000000000000..6c0fcd7f80c4c --- /dev/null +++ b/presto-analyzer/src/test/java/com/facebook/presto/sql/analyzer/TestProcedureRegistry.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; +import static java.util.Objects.requireNonNull; + +public class TestProcedureRegistry + implements ProcedureRegistry +{ + private final Map>> connectorProcedures = new ConcurrentHashMap<>(); + + @Override + public void addProcedures(ConnectorId connectorId, Collection> procedures) + { + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(procedures, "procedures is null"); + + Map> proceduresByName = procedures.stream().collect(Collectors.toMap( + procedure -> new SchemaTableName(procedure.getSchema(), procedure.getName()), + Function.identity())); + if (connectorProcedures.putIfAbsent(connectorId, proceduresByName) != null) { + throw new IllegalStateException("Procedures already registered for connector: " + connectorId); + } + } + + @Override + public void removeProcedures(ConnectorId connectorId) + { + connectorProcedures.remove(connectorId); + } + + @Override + public BaseProcedure resolve(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + BaseProcedure procedure = procedures.get(name); + if (procedure != null) { + return procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); + } + + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + BaseProcedure procedure = procedures.get(name); + if (procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + + public static class TestProcedureContext + implements ConnectorProcedureContext + {} +} diff --git a/presto-atop/pom.xml b/presto-atop/pom.xml index 614294b443447..6b4efc25ac01a 100644 --- a/presto-atop/pom.xml +++ b/presto-atop/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-atop + presto-atop Presto - Atop Connector presto-plugin ${project.parent.basedir} + true @@ -58,24 +60,24 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -99,7 +101,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -111,7 +113,7 @@ - io.airlift + com.facebook.airlift units provided diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnector.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnector.java index 2898dbec0dcad..c758635db645d 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnector.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnector.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnectorConfig.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnectorConfig.java index 8ca4a2446b6fc..bc8d957f2dcfd 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnectorConfig.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopConnectorConfig.java @@ -15,11 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.time.ZoneId; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopMetadata.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopMetadata.java index 34bd2171e22fa..16710c146f604 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopMetadata.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopMetadata.java @@ -31,8 +31,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -86,7 +85,8 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable } @Override - public List getTableLayouts(ConnectorSession session, + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) @@ -105,7 +105,7 @@ public List getTableLayouts(ConnectorSession session } AtopTableLayoutHandle layoutHandle = new AtopTableLayoutHandle(tableHandle, startTimeDomain, endTimeDomain); ConnectorTableLayout tableLayout = getTableLayout(session, layoutHandle); - return ImmutableList.of(new ConnectorTableLayoutResult(tableLayout, constraint.getSummary())); + return new ConnectorTableLayoutResult(tableLayout, constraint.getSummary()); } @Override @@ -121,7 +121,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect ImmutableList.Builder columns = ImmutableList.builder(); for (AtopColumn column : atopTableHandle.getTable().getColumns()) { columns.add(ColumnMetadata.builder() - .setName(column.getName()) + .setName(normalizeIdentifier(session, column.getName())) .setType(typeManager.getType(column.getType())) .build()); } diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSource.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSource.java index a86debc42dd6e..06fdc089a9401 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSource.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSource.java @@ -22,8 +22,7 @@ import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.time.ZonedDateTime; import java.util.List; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSourceProvider.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSourceProvider.java index 3f5a887f99b59..6ebf0da4e88dd 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSourceProvider.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopPageSourceProvider.java @@ -24,8 +24,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.time.ZonedDateTime; import java.util.List; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopProcessFactory.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopProcessFactory.java index 4cd5a2478fe6a..0fe8dfc4182c3 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopProcessFactory.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopProcessFactory.java @@ -13,14 +13,13 @@ */ package com.facebook.presto.atop; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.PrestoException; import com.google.common.util.concurrent.SimpleTimeLimiter; import com.google.common.util.concurrent.TimeLimiter; import com.google.common.util.concurrent.UncheckedTimeoutException; -import io.airlift.units.Duration; - -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.io.BufferedReader; import java.io.IOException; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopSplitManager.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopSplitManager.java index 7e31d72ad62fd..1d0d9b89b1341 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopSplitManager.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopSplitManager.java @@ -25,8 +25,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.time.ZoneId; import java.time.ZonedDateTime; diff --git a/presto-atop/src/main/java/com/facebook/presto/atop/AtopTableLayoutHandle.java b/presto-atop/src/main/java/com/facebook/presto/atop/AtopTableLayoutHandle.java index bef9cf5ad8096..ec036f500fec4 100644 --- a/presto-atop/src/main/java/com/facebook/presto/atop/AtopTableLayoutHandle.java +++ b/presto-atop/src/main/java/com/facebook/presto/atop/AtopTableLayoutHandle.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-atop/src/test/java/com/facebook/presto/atop/TestAtopConnectorConfig.java b/presto-atop/src/test/java/com/facebook/presto/atop/TestAtopConnectorConfig.java index a7b439df37a27..bd1dfbe27bc1c 100644 --- a/presto-atop/src/test/java/com/facebook/presto/atop/TestAtopConnectorConfig.java +++ b/presto-atop/src/test/java/com/facebook/presto/atop/TestAtopConnectorConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.atop; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index 9428c4d767ce4..8c8c0092589f9 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-base-arrow-flight @@ -15,8 +15,19 @@ ${project.parent.basedir} -Xss10M + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + org.apache.arrow @@ -83,6 +94,11 @@ guava + + jakarta.inject + jakarta.inject-api + + javax.inject javax.inject @@ -119,9 +135,8 @@ - com.google.code.findbugs - jsr305 - true + jakarta.annotation + jakarta.annotation-api @@ -134,17 +149,19 @@ configuration + joda-time joda-time + test org.jdbi jdbi3-core + test - org.testng testng @@ -228,6 +245,13 @@ org.apache.maven.plugins maven-dependency-plugin + + + com.fasterxml.jackson.core:jackson-databind + com.facebook.airlift:log-manager + javax.inject:javax.inject + + org.basepom.maven diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java index 3a5ce28cd76fd..ed703f1b2444f 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java @@ -39,6 +39,7 @@ import com.google.common.base.CharMatcher; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import jakarta.inject.Inject; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DateDayVector; @@ -69,8 +70,6 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; -import javax.inject.Inject; - import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.time.Duration; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java index 6b217e4f56cdf..2da078e05e870 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java @@ -27,7 +27,9 @@ public enum ArrowErrorCode ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR), ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL), ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL), - ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL); + ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL), + ARROW_FLIGHT_INVALID_KEY_ERROR(5, INTERNAL_ERROR), + ARROW_FLIGHT_INVALID_CERT_ERROR(6, INTERNAL_ERROR); private final ErrorCode errorCode; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java index d01333860e878..292d0304582c7 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -14,14 +14,18 @@ package com.facebook.plugin.arrow; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; public class ArrowFlightConfig { private String server; private boolean verifyServer = true; private String flightServerSSLCertificate; + private String flightClientSSLCertificate; + private String flightClientSSLKey; private boolean arrowFlightServerSslEnabled; private Integer arrowFlightPort; + private boolean caseSensitiveNameMatchingEnabled; public String getFlightServerName() { @@ -82,4 +86,44 @@ public ArrowFlightConfig setArrowFlightServerSslEnabled(boolean arrowFlightServe this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; return this; } + + public String getFlightClientSSLCertificate() + { + return flightClientSSLCertificate; + } + + @ConfigDescription("Path to the client SSL certificate used for mTLS authentication with Flight server") + @Config("arrow-flight.client-ssl-certificate") + public ArrowFlightConfig setFlightClientSSLCertificate(String flightClientSSLCertificate) + { + this.flightClientSSLCertificate = flightClientSSLCertificate; + return this; + } + + public String getFlightClientSSLKey() + { + return flightClientSSLKey; + } + + @ConfigDescription("Path to the client SSL key used for mTLS authentication with Flight server") + @Config("arrow-flight.client-ssl-key") + public ArrowFlightConfig setFlightClientSSLKey(String flightClientSSLKey) + { + this.flightClientSSLKey = flightClientSSLKey; + return this; + } + + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatchingEnabled; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching of schema, table names across the connector. " + + "When disabled, names are matched case-insensitively using lowercase normalization.") + public ArrowFlightConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatchingEnabled) + { + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; + return this; + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java index d3ef4a8a3b8a7..24eae282d1557 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java @@ -29,11 +29,10 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -44,6 +43,7 @@ import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; public class ArrowMetadata @@ -51,12 +51,14 @@ public class ArrowMetadata { private final BaseArrowFlightClientHandler clientHandler; private final ArrowBlockBuilder arrowBlockBuilder; + private final ArrowFlightConfig arrowFlightConfig; @Inject - public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder) + public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder, ArrowFlightConfig arrowFlightConfig) { this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + this.arrowFlightConfig = requireNonNull(arrowFlightConfig, "arrowFlightConfig is null"); } @Override @@ -84,10 +86,10 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName()); } - public List getColumnsList(String schema, String table, ConnectorSession connectorSession) + public List getColumnsList(ConnectorSession connectorSession, String schema, String table) { try { - Schema flightSchema = clientHandler.getSchemaForTable(schema, table, connectorSession); + Schema flightSchema = clientHandler.getSchemaForTable(connectorSession, schema, table); return flightSchema.getFields(); } catch (Exception e) { @@ -102,7 +104,7 @@ public Map getColumnHandles(ConnectorSession session, Conn String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); String tableValue = ((ArrowTableHandle) tableHandle).getTable(); - List columnList = getColumnsList(schemaValue, tableValue, session); + List columnList = getColumnsList(session, schemaValue, tableValue); for (Field field : columnList) { String columnName = field.getName(); @@ -113,7 +115,11 @@ public Map getColumnHandles(ConnectorSession session, Conn } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { checkArgument(table instanceof ArrowTableHandle, "Invalid table handle: Expected an instance of ArrowTableHandle but received %s", @@ -130,7 +136,7 @@ public List getTableLayouts(ConnectorSession session } ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -143,12 +149,12 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) { List meta = new ArrayList<>(); - List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session); + List columnList = getColumnsList(session, ((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()); for (Field field : columnList) { String columnName = field.getName(); Type fieldType = getPrestoTypeFromArrowField(field); - meta.add(ColumnMetadata.builder().setName(columnName).setType(fieldType).build()); + meta.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, columnName)).setType(fieldType).build()); } return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); } @@ -169,7 +175,7 @@ public Map> listTableColumns(ConnectorSess tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } else { - tables = listTables(session, Optional.of(prefix.getSchemaName())); + tables = listTables(session, Optional.ofNullable(prefix.getSchemaName())); } for (SchemaTableName tableName : tables) { @@ -189,6 +195,12 @@ public Map> listTableColumns(ConnectorSess return columns.build(); } + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return arrowFlightConfig.isCaseSensitiveNameMatching() ? identifier : identifier.toLowerCase(ROOT); + } + private Type getPrestoTypeFromArrowField(Field field) { return arrowBlockBuilder.getPrestoTypeFromArrowField(field); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index fcacb204da7d4..b67c7dd2ec701 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; import java.util.ArrayList; import java.util.List; @@ -48,7 +49,7 @@ public ArrowPageSource( this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); requireNonNull(clientHandler, "clientHandler is null"); this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); - this.flightStreamAndClient = clientHandler.getFlightStream(split, connectorSession); + this.flightStreamAndClient = clientHandler.getFlightStream(connectorSession, split); } @Override @@ -97,16 +98,19 @@ public Page getNextPage() // Create blocks from the loaded Arrow record batch List blocks = new ArrayList<>(); - List vectors = flightStreamAndClient.getRoot().getFieldVectors(); - for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { - FieldVector vector = vectors.get(columnIndex); - Type type = columnHandles.get(columnIndex).getColumnType(); + VectorSchemaRoot vectorSchemaRoot = flightStreamAndClient.getRoot(); + for (ArrowColumnHandle columnHandle : columnHandles) { + // In scenarios where the user query contains a Table Valued Function, the output columns could be in a + // different order or could be a subset of the columns in the flight stream. So we are fetching the requested + // field vector by matching the column name instead of fetching by column index. + FieldVector vector = requireNonNull(vectorSchemaRoot.getVector(columnHandle.getColumnName()), "No field named " + columnHandle.getColumnName() + " in the list of vectors from flight stream"); + Type type = columnHandle.getColumnType(); Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider()); blocks.add(block); } if (logger.isDebugEnabled()) { - logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectors.size()); + logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectorSchemaRoot.getFieldVectors().size()); } return new Page(flightStreamAndClient.getRoot().getRowCount(), blocks.toArray(new Block[0])); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java index 216939831a2c8..c4c1571c90657 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java index 9308bd60aa934..497aa42bdf845 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Collections; import java.util.List; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java index 129a5e11f6355..4821e072ab7cf 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java @@ -19,10 +19,9 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import jakarta.inject.Inject; import org.apache.arrow.flight.FlightInfo; -import javax.inject.Inject; - import java.util.List; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -44,7 +43,7 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand { ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; ArrowTableHandle tableHandle = tableLayoutHandle.getTable(); - FlightInfo flightInfo = clientHandler.getFlightInfoForTableScan(tableLayoutHandle, session); + FlightInfo flightInfo = clientHandler.getFlightInfoForTableScan(session, tableLayoutHandle); List splits = flightInfo.getEndpoints() .stream() .map(info -> new ArrowSplit( diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java index e8fde1ec39521..d559768fcd6c8 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java @@ -13,6 +13,7 @@ */ package com.facebook.plugin.arrow; +import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import org.apache.arrow.flight.CallOption; @@ -30,11 +31,15 @@ import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.file.Paths; +import java.security.InvalidKeyException; +import java.security.cert.CertificateException; import java.util.List; import java.util.Optional; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_CERT_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INVALID_KEY_ERROR; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; import static java.nio.file.Files.newInputStream; import static java.util.Objects.requireNonNull; @@ -43,6 +48,7 @@ public abstract class BaseArrowFlightClientHandler { private final ArrowFlightConfig config; private final BufferAllocator allocator; + private static final Logger logger = Logger.get(BaseArrowFlightClientHandler.class); public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig config) { @@ -64,30 +70,66 @@ protected FlightClient createFlightClient() protected FlightClient createFlightClient(Location location) { + Optional trustedCertificate = Optional.empty(); + Optional clientCertificate = Optional.empty(); + Optional clientKey = Optional.empty(); try { - Optional trustedCertificate = Optional.empty(); FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); flightClientBuilder.verifyServer(config.getVerifyServer()); if (config.getFlightServerSSLCertificate() != null) { trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate()))); flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); } - - FlightClient flightClient = flightClientBuilder.build(); - if (trustedCertificate.isPresent()) { - trustedCertificate.get().close(); + if (config.getFlightClientSSLCertificate() != null && config.getFlightClientSSLKey() != null) { + clientCertificate = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLCertificate()))); + clientKey = Optional.of(newInputStream(Paths.get(config.getFlightClientSSLKey()))); + flightClientBuilder.clientCertificate(clientCertificate.get(), clientKey.get()).useTls(); } - return flightClient; + return flightClientBuilder.build(); } catch (Exception e) { - throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e); + if (e.getCause() instanceof InvalidKeyException) { + throw new ArrowException(ARROW_FLIGHT_INVALID_KEY_ERROR, "Error creating flight client, invalid key file: " + e.getMessage(), e); + } + else if (e.getCause() instanceof CertificateException) { + throw new ArrowException(ARROW_FLIGHT_INVALID_CERT_ERROR, "Error creating flight client, invalid certificate file: " + e.getMessage(), e); + } + else { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e); + } + } + finally { + if (trustedCertificate.isPresent()) { + try { + trustedCertificate.get().close(); + } + catch (IOException e) { + logger.error("Error closing input stream for server certificate", e); + } + } + if (clientCertificate.isPresent()) { + try { + clientCertificate.get().close(); + } + catch (IOException e) { + logger.error("Error closing input stream for client certificate", e); + } + } + if (clientKey.isPresent()) { + try { + clientKey.get().close(); + } + catch (IOException e) { + logger.error("Error closing input stream for client key", e); + } + } } } public abstract CallOption[] getCallOptions(ConnectorSession connectorSession); - protected FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + protected FlightInfo getFlightInfo(ConnectorSession connectorSession, FlightDescriptor flightDescriptor) { try (FlightClient client = createFlightClient()) { CallOption[] callOptions = getCallOptions(connectorSession); @@ -98,7 +140,7 @@ protected FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorS } } - protected ClientClosingFlightStream getFlightStream(ArrowSplit split, ConnectorSession connectorSession) + protected ClientClosingFlightStream getFlightStream(ConnectorSession connectorSession, ArrowSplit split) { ByteBuffer endpointBytes = ByteBuffer.wrap(split.getFlightEndpointBytes()); try { @@ -116,7 +158,7 @@ protected ClientClosingFlightStream getFlightStream(ArrowSplit split, ConnectorS } } - public Schema getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + public Schema getSchema(ConnectorSession connectorSession, FlightDescriptor flightDescriptor) { try (FlightClient client = createFlightClient()) { CallOption[] callOptions = this.getCallOptions(connectorSession); @@ -131,19 +173,19 @@ public Schema getSchema(FlightDescriptor flightDescriptor, ConnectorSession conn public abstract List listTables(ConnectorSession session, Optional schemaName); - protected abstract FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName); + protected abstract FlightDescriptor getFlightDescriptorForSchema(ConnectorSession session, String schemaName, String tableName); - protected abstract FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle); + protected abstract FlightDescriptor getFlightDescriptorForTableScan(ConnectorSession session, ArrowTableLayoutHandle tableLayoutHandle); - public Schema getSchemaForTable(String schemaName, String tableName, ConnectorSession connectorSession) + public Schema getSchemaForTable(ConnectorSession connectorSession, String schemaName, String tableName) { - FlightDescriptor flightDescriptor = getFlightDescriptorForSchema(schemaName, tableName); - return getSchema(flightDescriptor, connectorSession); + FlightDescriptor flightDescriptor = getFlightDescriptorForSchema(connectorSession, schemaName, tableName); + return getSchema(connectorSession, flightDescriptor); } - public FlightInfo getFlightInfoForTableScan(ArrowTableLayoutHandle tableLayoutHandle, ConnectorSession session) + public FlightInfo getFlightInfoForTableScan(ConnectorSession session, ArrowTableLayoutHandle tableLayoutHandle) { - FlightDescriptor flightDescriptor = getFlightDescriptorForTableScan(tableLayoutHandle); - return getFlightInfo(flightDescriptor, session); + FlightDescriptor flightDescriptor = getFlightDescriptorForTableScan(session, tableLayoutHandle); + return getFlightInfo(session, flightDescriptor); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java index eea7ff598da72..72ade51e19829 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -27,6 +27,7 @@ import java.io.File; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.URI; import java.util.Map; @@ -39,6 +40,8 @@ public class ArrowFlightQueryRunner { + private static final Logger log = Logger.get(ArrowFlightQueryRunner.class); + private ArrowFlightQueryRunner() { throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); @@ -47,31 +50,24 @@ private ArrowFlightQueryRunner() public static int findUnusedPort() throws IOException { - try (ServerSocket socket = new ServerSocket(0)) { + try (ServerSocket socket = new ServerSocket()) { + socket.setReuseAddress(false); + socket.bind(new InetSocketAddress(0)); return socket.getLocalPort(); } } public static DistributedQueryRunner createQueryRunner(int flightServerPort) throws Exception { - return createQueryRunner(flightServerPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); + return createQueryRunner(flightServerPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty()); } public static DistributedQueryRunner createQueryRunner( int flightServerPort, Map extraProperties, Map coordinatorProperties, - Optional> externalWorkerLauncher) - throws Exception - { - return createQueryRunner(extraProperties, ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)), coordinatorProperties, externalWorkerLauncher); - } - - private static DistributedQueryRunner createQueryRunner( - Map extraProperties, - Map catalogProperties, - Map coordinatorProperties, - Optional> externalWorkerLauncher) + Optional> externalWorkerLauncher, + Optional mTLSEnabled) throws Exception { Session session = testSessionBuilder() @@ -86,17 +82,24 @@ private static DistributedQueryRunner createQueryRunner( DistributedQueryRunner queryRunner = queryRunnerBuilder .setExtraProperties(extraProperties) .setCoordinatorProperties(coordinatorProperties) - .setExternalWorkerLauncher(externalWorkerLauncher).build(); + .setExternalWorkerLauncher(externalWorkerLauncher) + .build(); try { boolean nativeExecution = externalWorkerLauncher.isPresent(); queryRunner.installPlugin(new TestingArrowFlightPlugin(nativeExecution)); + Map catalogProperties = ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)); ImmutableMap.Builder properties = ImmutableMap.builder() .putAll(catalogProperties) .put("arrow-flight.server", "localhost") .put("arrow-flight.server-ssl-enabled", "true") - .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt"); + .put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/ca.crt"); + + if (mTLSEnabled.orElse(false)) { + properties.put("arrow-flight.client-ssl-certificate", "src/test/resources/certs/client.crt"); + properties.put("arrow-flight.client-ssl-key", "src/test/resources/certs/client.key"); + } queryRunner.createCatalog(ARROW_FLIGHT_CATALOG, ARROW_FLIGHT_CONNECTOR, properties.build()); @@ -125,25 +128,33 @@ public static void main(String[] args) { Logging.initialize(); + boolean mTLSenabled = Boolean.parseBoolean(System.getProperty("flight.mtls.enabled", "false")); + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); Location serverLocation = Location.forGrpcTls("localhost", 9443); - File certChainFile = new File("src/test/resources/server.crt"); - File privateKeyFile = new File("src/test/resources/server.key"); - FlightServer server = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator)) - .useTls(certChainFile, privateKeyFile) - .build(); + FlightServer.Builder serverBuilder = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator, false)); - server.start(); + File serverCert = new File("src/test/resources/certs/server.crt"); + File serverKey = new File("src/test/resources/certs/server.key"); + serverBuilder.useTls(serverCert, serverKey); - Logger log = Logger.get(ArrowFlightQueryRunner.class); + if (mTLSenabled) { + File caCert = new File("src/test/resources/certs/ca.crt"); + serverBuilder.useMTlsClientVerification(caCert); + } + + FlightServer server = serverBuilder.build(); + server.start(); log.info("Server listening on port " + server.getPort()); DistributedQueryRunner queryRunner = createQueryRunner( + server.getPort(), ImmutableMap.of("http-server.http.port", "8080"), - ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443)), ImmutableMap.of(), - Optional.empty()); + Optional.empty(), + Optional.of(mTLSenabled)); + Thread.sleep(10); log.info("======== SERVER STARTED ========"); log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java index 1d9c490180abc..27ded280221c5 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java @@ -17,8 +17,6 @@ import com.facebook.presto.spi.ColumnMetadata; import org.testng.annotations.Test; -import java.util.Locale; - import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -67,7 +65,7 @@ public void testGetColumnMetadata() // Then assertNotNull(columnMetadata, "ColumnMetadata should not be null"); - assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getName(), columnName, "ColumnMetadata name should match the column name"); assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java index 8715d413f687f..7457e19bed877 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java @@ -17,10 +17,12 @@ import com.facebook.airlift.log.Logger; import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest; import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse; +import com.facebook.presto.Session; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.common.type.Type; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; @@ -47,8 +49,11 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; @@ -83,6 +88,7 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; +import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -92,19 +98,25 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.TimeZone; import java.util.TreeSet; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.SystemSessionProperties.LEGACY_TIMESTAMP; import static com.facebook.presto.common.block.MethodHandleUtil.compose; import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle; import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static com.facebook.presto.util.DateTimeUtils.parseTimestampWithoutTimeZone; import static java.lang.String.format; import static java.nio.channels.Channels.newChannel; @@ -113,26 +125,20 @@ public class TestArrowFlightEchoQueries { private static final Logger logger = Logger.get(TestArrowFlightEchoQueries.class); private static final CallOption CALL_OPTIONS = CallOptions.timeout(300, TimeUnit.SECONDS); - private final int serverPort; + private int serverPort; private RootAllocator allocator; private FlightServer server; private DistributedQueryRunner arrowFlightQueryRunner; private JsonCodec requestCodec; private JsonCodec responseCodec; - public TestArrowFlightEchoQueries() - throws IOException - { - this.serverPort = ArrowFlightQueryRunner.findUnusedPort(); - } - @BeforeClass public void setup() throws Exception { arrowFlightQueryRunner = getDistributedQueryRunner(); - File certChainFile = new File("src/test/resources/server.crt"); - File privateKeyFile = new File("src/test/resources/server.key"); + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); allocator = new RootAllocator(Long.MAX_VALUE); @@ -161,9 +167,61 @@ public void close() protected QueryRunner createQueryRunner() throws Exception { + serverPort = ArrowFlightQueryRunner.findUnusedPort(); return ArrowFlightQueryRunner.createQueryRunner(serverPort); } + @Test + public void testDateTimeVectors() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + DateDayVector dateVector = new DateDayVector("date", bufferAllocator); + TimeMilliVector timeVector = new TimeMilliVector("time", bufferAllocator); + TimeStampMilliVector timestampVector = new TimeStampMilliVector("timestamp", bufferAllocator); + VectorSchemaRoot root = new VectorSchemaRoot(Arrays.asList(intVector, dateVector, timeVector, timestampVector)); + FlightClient client = createFlightClient(bufferAllocator, serverPort)) { + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, DATE, TIME, TIMESTAMP); + + List values = ImmutableList.of( + "1970-01-01T00:00:00", + "2024-01-01T01:01:01", + "2024-01-02T12:00:00", + "2112-12-31T23:58:00", + "1968-07-05T08:15:12.345"); + + for (int i = 0; i < values.size(); i++) { + intVector.setSafe(i, i); + LocalDateTime dateTime = LocalDateTime.parse(values.get(i)); + // First vector value is explicitly set to 0 to ensure no issues with parsing + dateVector.setSafe(i, i == 0 ? 0 : (int) dateTime.toLocalDate().toEpochDay()); + timeVector.setSafe(i, i == 0 ? 0 : (int) TimeUnit.NANOSECONDS.toMillis(dateTime.toLocalTime().toNanoOfDay())); + timestampVector.setSafe(i, i == 0 ? 0 : parseTimestampWithoutTimeZone(values.get(i).replace("T", " "))); + expectedBuilder.row(i, dateTime.toLocalDate(), dateTime.toLocalTime(), dateTime); + } + + root.setRowCount(values.size()); + + String tableName = "datetime"; + addTableToServer(client, root, tableName); + + for (String timeZoneId : ImmutableList.of(TimeZone.getDefault().getID(), "UTC", "America/New_York", "Asia/Tokyo")) { + Session sessionWithTimezone = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(timeZoneId)) + .setSystemProperty(LEGACY_TIMESTAMP, "false") + .build(); + + MaterializedResult actual = computeActual(sessionWithTimezone, format("SELECT * FROM %s", tableName)); + MaterializedResult expected = expectedBuilder.build(); + + assertEquals(actual.getRowCount(), root.getRowCount()); + assertEquals(actual, expected); + } + + removeTableFromServer(client, tableName); + } + } + @Test public void testVarCharVector() throws Exception { @@ -407,7 +465,7 @@ private static MapType createMapType(Type keyType, Type valueType) private static FlightClient createFlightClient(BufferAllocator allocator, int serverPort) throws IOException { - InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/server.crt"))); + InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/server.crt"))); Location location = Location.forGrpcTls("localhost", serverPort); return FlightClient.builder(allocator, location).useTls().trustedCertificates(trustedCertificate).build(); } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationMixedCase.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationMixedCase.java new file mode 100644 index 0000000000000..af4a2dde0f02f --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationMixedCase.java @@ -0,0 +1,148 @@ + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CONNECTOR; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.tests.QueryAssertions.assertContains; +import static org.testng.Assert.assertTrue; + +@Test +public class TestArrowFlightIntegrationMixedCase + extends AbstractTestQueryFramework +{ + private static final Logger logger = Logger.get(TestArrowFlightIntegrationMixedCase.class); + private static final String ARROW_FLIGHT_MIXED_CATALOG = "arrow_mixed_catalog"; + private int serverPort; + private RootAllocator allocator; + private FlightServer server; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + arrowFlightQueryRunner.createCatalog(ARROW_FLIGHT_MIXED_CATALOG, ARROW_FLIGHT_CONNECTOR, getCatalogProperties()); + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + Location location = Location.forGrpcTls("localhost", serverPort); + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, true)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port %s", server.getPort()); + } + + private Map getCatalogProperties() + { + ImmutableMap.Builder catalogProperties = ImmutableMap.builder() + .put("arrow-flight.server.port", String.valueOf(serverPort)) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/server.crt") + .put("case-sensitive-name-matching", "true"); + return catalogProperties.build(); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + arrowFlightQueryRunner.close(); + server.close(); + allocator.close(); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + serverPort = ArrowFlightQueryRunner.findUnusedPort(); + return ArrowFlightQueryRunner.createQueryRunner(serverPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty()); + } + + @Test + public void testShowSchemas() + { + MaterializedResult actualRow = computeActual("SHOW schemas FROM arrow_mixed_catalog"); + MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50)) + .row("Tpch_Mx") + .row("tpch_mx") + .build(); + + assertContains(actualRow, expectedRow); + } + + @Test + public void testShowTables() + { + MaterializedResult actualRow = computeActual("SHOW TABLES FROM arrow_mixed_catalog.tpch_mx"); + MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50)) + .row("MXTEST") + .row("mxtest") + .build(); + + assertContains(actualRow, expectedRow); + } + + @Test + public void testShowColumns() + { + MaterializedResult actualRow = computeActual("SHOW columns FROM arrow_mixed_catalog.tpch_mx.mxtest"); + MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50)) + .row("ID", "integer", "", "", Long.valueOf(10), null, null) + .row("NAME", "varchar(50)", "", "", null, null, Long.valueOf(50)) + .row("name", "varchar(50)", "", "", null, null, Long.valueOf(50)) + .row("Address", "varchar(50)", "", "", null, null, Long.valueOf(50)) + .build(); + + assertContains(actualRow, expectedRow); + } + + @Test + public void testSelect() + { + MaterializedResult actualRow = computeActual("SELECT * from arrow_mixed_catalog.tpch_mx.mxtest"); + MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createVarcharType(50), createVarcharType(50)) + .row(1, "TOM", "test", "kochi") + .row(2, "MARY", "test", "kochi") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java index 1bca76c23a9f0..025d119a938f3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java @@ -26,34 +26,27 @@ import org.testng.annotations.BeforeClass; import java.io.File; -import java.io.IOException; public class TestArrowFlightIntegrationSmokeTest extends AbstractTestIntegrationSmokeTest { private static final Logger logger = Logger.get(TestArrowFlightIntegrationSmokeTest.class); - private final int serverPort; + private int serverPort; private RootAllocator allocator; private FlightServer server; private DistributedQueryRunner arrowFlightQueryRunner; - public TestArrowFlightIntegrationSmokeTest() - throws IOException - { - this.serverPort = ArrowFlightQueryRunner.findUnusedPort(); - } - @BeforeClass public void setup() throws Exception { arrowFlightQueryRunner = getDistributedQueryRunner(); - File certChainFile = new File("src/test/resources/server.crt"); - File privateKeyFile = new File("src/test/resources/server.key"); + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); allocator = new RootAllocator(Long.MAX_VALUE); Location location = Location.forGrpcTls("127.0.0.1", serverPort); - server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator)) + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false)) .useTls(certChainFile, privateKeyFile) .build(); @@ -65,6 +58,7 @@ public void setup() protected QueryRunner createQueryRunner() throws Exception { + serverPort = ArrowFlightQueryRunner.findUnusedPort(); return ArrowFlightQueryRunner.createQueryRunner(serverPort); } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightMtls.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightMtls.java new file mode 100644 index 0000000000000..4d02a94bc4b7b --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightMtls.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CONNECTOR; + +public class TestArrowFlightMtls + extends AbstractTestQueryFramework +{ + private static final Logger logger = Logger.get(TestArrowFlightMtls.class); + private int serverPort; + private RootAllocator allocator; + private FlightServer server; + private DistributedQueryRunner arrowFlightQueryRunner; + private static final String ARROW_FLIGHT_CATALOG_WITH_INVALID_CERT = "arrow_catalog_with_invalid_cert"; + private static final String ARROW_FLIGHT_CATALOG_WITH_NO_MTLS_CERTS = "arrow_catalog_with_no_mtls_certs"; + private static final String ARROW_FLIGHT_CATALOG_WITH_MTLS_CERTS = "arrow_catalog_with_mtls_certs"; + + @BeforeClass + private void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + arrowFlightQueryRunner.createCatalog(ARROW_FLIGHT_CATALOG_WITH_INVALID_CERT, ARROW_FLIGHT_CONNECTOR, getInvalidCertCatalogProperties()); + arrowFlightQueryRunner.createCatalog(ARROW_FLIGHT_CATALOG_WITH_NO_MTLS_CERTS, ARROW_FLIGHT_CONNECTOR, getNoMtlsCatalogProperties()); + arrowFlightQueryRunner.createCatalog(ARROW_FLIGHT_CATALOG_WITH_MTLS_CERTS, ARROW_FLIGHT_CONNECTOR, getMtlsCatalogProperties()); + + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); + File caCertFile = new File("src/test/resources/certs/ca.crt"); + + allocator = new RootAllocator(Long.MAX_VALUE); + + Location location = Location.forGrpcTls("localhost", serverPort); + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false)) + .useTls(certChainFile, privateKeyFile) + .useMTlsClientVerification(caCertFile) + .build(); + + server.start(); + logger.info("Server listening on port %s", server.getPort()); + } + + @AfterClass(alwaysRun = true) + private void tearDown() + throws InterruptedException + { + arrowFlightQueryRunner.close(); + server.close(); + allocator.close(); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + serverPort = ArrowFlightQueryRunner.findUnusedPort(); + return ArrowFlightQueryRunner.createQueryRunner(serverPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty()); + } + + private Map getInvalidCertCatalogProperties() + { + ImmutableMap.Builder catalogProperties = ImmutableMap.builder() + .put("arrow-flight.server.port", String.valueOf(serverPort)) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/server.crt") + .put("arrow-flight.client-ssl-certificate", "src/test/resources/certs/invalid_cert.crt") + .put("arrow-flight.client-ssl-key", "src/test/resources/certs/client.key"); + return catalogProperties.build(); + } + + private Map getNoMtlsCatalogProperties() + { + ImmutableMap.Builder catalogProperties = ImmutableMap.builder() + .put("arrow-flight.server.port", String.valueOf(serverPort)) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/server.crt"); + return catalogProperties.build(); + } + + private Map getMtlsCatalogProperties() + { + ImmutableMap.Builder catalogProperties = ImmutableMap.builder() + .put("arrow-flight.server.port", String.valueOf(serverPort)) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/server.crt") + .put("arrow-flight.client-ssl-certificate", "src/test/resources/certs/client.crt") + .put("arrow-flight.client-ssl-key", "src/test/resources/certs/client.key"); + return catalogProperties.build(); + } + + @Test + public void testMtlsInvalidCert() + { + assertQueryFails("SELECT COUNT(*) FROM " + ARROW_FLIGHT_CATALOG_WITH_INVALID_CERT + ".tpch.orders", ".*invalid certificate file.*"); + } + + @Test + public void testMtlsFailure() + { + assertQueryFails("SELECT COUNT(*) FROM " + ARROW_FLIGHT_CATALOG_WITH_NO_MTLS_CERTS + ".tpch.orders", "ssl exception"); + } + + @Test + public void testMtls() + { + assertQuery("SELECT COUNT(*) FROM " + ARROW_FLIGHT_CATALOG_WITH_MTLS_CERTS + ".tpch.orders", "SELECT COUNT(*) FROM orders"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java index 21ee7f53ed78a..9e2225d09b6a3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueries.java @@ -49,15 +49,14 @@ public class TestArrowFlightNativeQueries extends AbstractTestQueryFramework { private static final Logger log = Logger.get(TestArrowFlightNativeQueries.class); - private final int serverPort; + private int serverPort; private RootAllocator allocator; private FlightServer server; private DistributedQueryRunner arrowFlightQueryRunner; - public TestArrowFlightNativeQueries() - throws IOException + protected boolean ismTLSEnabled() { - this.serverPort = ArrowFlightQueryRunner.findUnusedPort(); + return false; } @BeforeClass @@ -65,18 +64,22 @@ public void setup() throws Exception { arrowFlightQueryRunner = getDistributedQueryRunner(); - allocator = new RootAllocator(Long.MAX_VALUE); Location location = Location.forGrpcTls("localhost", serverPort); - File certChainFile = new File("src/test/resources/server.crt"); - File privateKeyFile = new File("src/test/resources/server.key"); + FlightServer.Builder serverBuilder = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false)); - server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator)) - .useTls(certChainFile, privateKeyFile) - .build(); + File serverCert = new File("src/test/resources/certs/server.crt"); + File serverKey = new File("src/test/resources/certs/server.key"); + serverBuilder.useTls(serverCert, serverKey); + + if (ismTLSEnabled()) { + File caCert = new File("src/test/resources/certs/ca.crt"); + serverBuilder.useMTlsClientVerification(caCert); + } + server = serverBuilder.build(); server.start(); - log.info("Server listening on port %s", server.getPort()); + log.info("Server listening on port %s (%s)", server.getPort(), ismTLSEnabled() ? "mTLS" : "TLS"); } @AfterClass(alwaysRun = true) @@ -99,9 +102,14 @@ protected QueryRunner createQueryRunner() log.info("Using PRESTO_SERVER binary at %s", prestoServerPath); ImmutableMap coordinatorProperties = ImmutableMap.of("native-execution-enabled", "true"); - String flightCertPath = Paths.get("src/test/resources/server.crt").toAbsolutePath().toString(); - return ArrowFlightQueryRunner.createQueryRunner(serverPort, getNativeWorkerSystemProperties(), coordinatorProperties, getExternalWorkerLauncher(prestoServerPath.toString(), serverPort, flightCertPath)); + serverPort = ArrowFlightQueryRunner.findUnusedPort(); + return ArrowFlightQueryRunner.createQueryRunner( + serverPort, + getNativeWorkerSystemProperties(), + coordinatorProperties, + getExternalWorkerLauncher(prestoServerPath.toString(), serverPort, ismTLSEnabled()), + Optional.of(ismTLSEnabled())); } @Override @@ -110,6 +118,18 @@ protected FeaturesConfig createFeaturesConfig() return new FeaturesConfig().setNativeExecutionEnabled(true); } + @Test + public void testQueryFunctionWithRestrictedColumns() + { + assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4"); + } + + @Test + public void testQueryFunctionWithoutRestrictedColumns() throws InterruptedException + { + assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4"); + } + @Test public void testFiltersAndProjections1() { @@ -338,50 +358,63 @@ public static Map getNativeWorkerSystemProperties() .build(); } - public static Optional> getExternalWorkerLauncher(String prestoServerPath, int flightServerPort, String flightCertPath) + public static Optional> getExternalWorkerLauncher(String prestoServerPath, int flightServerPort, boolean ismTLSEnabled) { - return - Optional.of((workerIndex, discoveryUri) -> { - try { - Path dir = Paths.get("/tmp", TestArrowFlightNativeQueries.class.getSimpleName()); - Files.createDirectories(dir); - Path tempDirectoryPath = Files.createTempDirectory(dir, "worker"); - log.info("Temp directory for Worker #%d: %s", workerIndex, tempDirectoryPath.toString()); - - // Write config file - use an ephemeral port for the worker. - String configProperties = format("discovery.uri=%s%n" + - "presto.version=testversion%n" + - "system-memory-gb=4%n" + - "http-server.http.port=0%n", discoveryUri); - - Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); - Files.write(tempDirectoryPath.resolve("node.properties"), - format("node.id=%s%n" + - "node.internal-address=127.0.0.1%n" + - "node.environment=testing%n" + - "node.location=test-location", UUID.randomUUID()).getBytes()); - - Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); - Files.createDirectory(catalogDirectoryPath); - - Files.write(catalogDirectoryPath.resolve(format("%s.properties", ARROW_FLIGHT_CATALOG)), - format("connector.name=%s\n" + - "arrow-flight.server=localhost\n" + - "arrow-flight.server.port=%d\n" + - "arrow-flight.server-ssl-enabled=true\n" + - "arrow-flight.server-ssl-certificate=%s", ARROW_FLIGHT_CONNECTOR, flightServerPort, flightCertPath).getBytes()); - - // Disable stack trace capturing as some queries (using TRY) generate a lot of exceptions. - return new ProcessBuilder(prestoServerPath, "--logtostderr=1", "--v=1") - .directory(tempDirectoryPath.toFile()) - .redirectErrorStream(true) - .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) - .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) - .start(); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - }); + return Optional.of((workerIndex, discoveryUri) -> { + try { + Path dir = Paths.get("/tmp", TestArrowFlightNativeQueries.class.getSimpleName()); + Files.createDirectories(dir); + Path tempDirectoryPath = Files.createTempDirectory(dir, "worker"); + log.info("Temp directory for Worker #%d: %s", workerIndex, tempDirectoryPath.toString()); + + // Write config file - use an ephemeral port for the worker. + String configProperties = format("discovery.uri=%s%n" + + "presto.version=testversion%n" + + "system-memory-gb=4%n" + + "http-server.http.port=0%n", discoveryUri); + + Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); + Files.write(tempDirectoryPath.resolve("node.properties"), + format("node.id=%s%n" + + "node.internal-address=127.0.0.1%n" + + "node.environment=testing%n" + + "node.location=test-location", UUID.randomUUID()).getBytes()); + + Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); + Files.createDirectory(catalogDirectoryPath); + + String caCertPath = Paths.get("src/test/resources/certs/ca.crt").toAbsolutePath().toString(); + + StringBuilder catalogBuilder = new StringBuilder(); + catalogBuilder.append(format( + "connector.name=%s\n" + + "arrow-flight.server=localhost\n" + + "arrow-flight.server.port=%d\n" + + "arrow-flight.server-ssl-enabled=true\n" + + "arrow-flight.server-ssl-certificate=%s\n", + ARROW_FLIGHT_CONNECTOR, flightServerPort, caCertPath)); + + if (ismTLSEnabled) { + String clientCertPath = Paths.get("src/test/resources/certs/client.crt").toAbsolutePath().toString(); + String clientKeyPath = Paths.get("src/test/resources/certs/client.key").toAbsolutePath().toString(); + catalogBuilder.append(format("arrow-flight.client-ssl-certificate=%s\n", clientCertPath)); + catalogBuilder.append(format("arrow-flight.client-ssl-key=%s\n", clientKeyPath)); + } + + Files.write( + catalogDirectoryPath.resolve(format("%s.properties", ARROW_FLIGHT_CATALOG)), + catalogBuilder.toString().getBytes()); + + return new ProcessBuilder(prestoServerPath, "--logtostderr=1", "--v=1") + .directory(tempDirectoryPath.toFile()) + .redirectErrorStream(true) + .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) + .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("worker." + workerIndex + ".out").toFile())) + .start(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + }); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueriesMtls.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueriesMtls.java new file mode 100644 index 0000000000000..25079773e7c02 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightNativeQueriesMtls.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import java.io.IOException; + +public class TestArrowFlightNativeQueriesMtls + extends TestArrowFlightNativeQueries +{ + public TestArrowFlightNativeQueriesMtls() + throws IOException + { + super(); + } + + @Override + protected boolean ismTLSEnabled() + { + return true; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java index a0ba11ccde53f..2bdf8508b8a1b 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java @@ -29,7 +29,6 @@ import org.testng.annotations.Test; import java.io.File; -import java.io.IOException; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; @@ -37,6 +36,7 @@ import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.CharType.createCharType; import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.IntegerType.INTEGER; @@ -53,28 +53,22 @@ public class TestArrowFlightQueries extends AbstractTestQueries { private static final Logger logger = Logger.get(TestArrowFlightQueries.class); - private final int serverPort; + private int serverPort; private RootAllocator allocator; private FlightServer server; private DistributedQueryRunner arrowFlightQueryRunner; - public TestArrowFlightQueries() - throws IOException - { - this.serverPort = ArrowFlightQueryRunner.findUnusedPort(); - } - @BeforeClass public void setup() throws Exception { arrowFlightQueryRunner = getDistributedQueryRunner(); - File certChainFile = new File("src/test/resources/server.crt"); - File privateKeyFile = new File("src/test/resources/server.key"); + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); allocator = new RootAllocator(Long.MAX_VALUE); Location location = Location.forGrpcTls("localhost", serverPort); - server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator)) + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false)) .useTls(certChainFile, privateKeyFile) .build(); @@ -95,6 +89,7 @@ public void close() protected QueryRunner createQueryRunner() throws Exception { + serverPort = ArrowFlightQueryRunner.findUnusedPort(); return ArrowFlightQueryRunner.createQueryRunner(serverPort); } @@ -103,18 +98,18 @@ public void testShowCharColumns() { MaterializedResult actual = computeActual("SHOW COLUMNS FROM member"); - MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("id", "integer", "", "") - .row("name", "varchar", "", "") - .row("sex", "char", "", "") - .row("state", "char", "", "") + MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("id", "integer", "", "", 10L, null, null) + .row("name", "varchar", "", "", null, null, 2147483647L) + .row("sex", "char", "", "", null, null, 2147483647L) + .row("state", "char", "", "", null, null, 2147483647L) .build(); - MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("id", "integer", "", "") - .row("name", "varchar(50)", "", "") - .row("sex", "char(1)", "", "") - .row("state", "char(5)", "", "") + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("id", "integer", "", "", 10L, null, null) + .row("name", "varchar(50)", "", "", null, null, 50L) + .row("sex", "char(1)", "", "", null, null, 1L) + .row("state", "char(5)", "", "", null, null, 5L) .build(); assertTrue(actual.equals(expectedParametrizedVarchar) || actual.equals(expectedUnparametrizedVarchar), @@ -145,16 +140,56 @@ public void testSelectTime() assertTrue(actualRow.equals(expectedRow)); } + @Test + public void testSystemJdbcColumns() + { + MaterializedResult actualRow = computeActual("SELECT * from system.jdbc.columns"); + assertTrue(actualRow.getRowCount() > 0); + } + + @Test + public void testSystemJdbcTables() + { + MaterializedResult actualRow = computeActual("SELECT * from system.jdbc.tables"); + assertTrue(actualRow.getRowCount() > 0); + } + @Test public void testDescribeUnknownTable() { MaterializedResult actualRows = computeActual("DESCRIBE information_schema.enabled_roles"); - MaterializedResult expectedRows = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("role_name", "varchar", "", "") + MaterializedResult expectedRows = resultBuilder(getSession(), + VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("role_name", "varchar", "", "", null, null, 2147483647L) .build(); + assertEquals(actualRows, expectedRows); } + @Test + public void testQueryFunctionWithRestrictedColumns() + { + assertQuery("SELECT NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME FROM nation WHERE NATIONKEY = 4"); + } + + @Test + public void testQueryFunctionWithoutRestrictedColumns() + { + assertQuery("SELECT NATIONKEY, NAME FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NATIONKEY, NAME FROM nation WHERE NATIONKEY = 4"); + } + + @Test + public void testQueryFunctionWithDifferentColumnOrder() + { + assertQuery("SELECT NAME, NATIONKEY FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "SELECT NAME, NATIONKEY FROM nation WHERE NATIONKEY = 4"); + } + + @Test + public void testQueryFunctionWithInvalidColumn() + { + assertQueryFails("SELECT NAME, NATIONKEY, INVALID_COLUMN FROM TABLE(system.query_function('SELECT NATIONKEY, NAME FROM tpch.nation WHERE NATIONKEY = 4','NATIONKEY BIGINT, NAME VARCHAR'))", "Column 'invalid_column' cannot be resolved", true); + } + private LocalDate getDate(String dateString) { DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/PrimitiveToPrestoTypeMappings.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/PrimitiveToPrestoTypeMappings.java new file mode 100644 index 0000000000000..a82a7df382aad --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/PrimitiveToPrestoTypeMappings.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; + +public final class PrimitiveToPrestoTypeMappings +{ + private PrimitiveToPrestoTypeMappings() + { + throw new UnsupportedOperationException(); + } + + public static Type fromPrimitiveToPrestoType(String dataType) + { + switch (dataType) { + case "INTEGER": + return IntegerType.INTEGER; + case "VARCHAR": + return createUnboundedVarcharType(); + case "DOUBLE": + return DoubleType.DOUBLE; + case "SMALLINT": + return SmallintType.SMALLINT; + case "BOOLEAN": + return BooleanType.BOOLEAN; + case "TIMESTAMP": + return TimestampType.TIMESTAMP; + case "TIME": + return TimeType.TIME; + case "REAL": + return RealType.REAL; + case "DATE": + return DateType.DATE; + case "BIGINT": + return BigintType.BIGINT; + } + throw new PrestoException(NOT_SUPPORTED, "Unsupported datatype '" + dataType + "' in the selected table."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java index a8752e0690f5e..f42941a96387c 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java @@ -19,10 +19,9 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; +import jakarta.inject.Inject; import org.apache.arrow.vector.types.pojo.Field; -import javax.inject.Inject; - import java.util.Optional; public class TestingArrowBlockBuilder diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowConnector.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowConnector.java new file mode 100644 index 0000000000000..5099a596f1f54 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowConnector.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowConnector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Inject; +import org.apache.arrow.memory.BufferAllocator; + +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class TestingArrowConnector + extends ArrowConnector +{ + private final Set connectorTableFunctions; + + @Inject + public TestingArrowConnector(ConnectorMetadata metadata, ConnectorSplitManager splitManager, ConnectorPageSourceProvider pageSourceProvider, Set connectorTableFunctions, BufferAllocator allocator) + { + super(metadata, splitManager, pageSourceProvider, allocator); + this.connectorTableFunctions = ImmutableSet.copyOf(requireNonNull(connectorTableFunctions, "connectorTableFunctions is null")); + } + + @Override + public Set getTableFunctions() + { + return connectorTableFunctions; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java index 4d35cab2cc1bb..27a09d2efe679 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.CallOptions; @@ -34,8 +35,6 @@ import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.memory.BufferAllocator; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -43,7 +42,7 @@ import java.util.concurrent.TimeUnit; import static com.facebook.presto.common.Utils.checkArgument; -import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; public class TestingArrowFlightClientHandler @@ -51,6 +50,7 @@ public class TestingArrowFlightClientHandler { private final JsonCodec requestCodec; private final JsonCodec responseCodec; + private boolean caseSensitiveNameMatchingEnabled; @Inject public TestingArrowFlightClientHandler( @@ -62,6 +62,7 @@ public TestingArrowFlightClientHandler( super(allocator, config); this.requestCodec = requireNonNull(requestCodec, "requestCodec is null"); this.responseCodec = requireNonNull(responseCodec, "responseCodec is null"); + this.caseSensitiveNameMatchingEnabled = config.isCaseSensitiveNameMatching(); } @Override @@ -74,7 +75,7 @@ public CallOption[] getCallOptions(ConnectorSession connectorSession) } @Override - public FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName) + public FlightDescriptor getFlightDescriptorForSchema(ConnectorSession session, String schemaName, String tableName) { TestingArrowFlightRequest request = TestingArrowFlightRequest.createDescribeTableRequest(schemaName, tableName); return FlightDescriptor.command(requestCodec.toBytes(request)); @@ -103,7 +104,7 @@ public List listSchemaNames(ConnectorSession session) List listSchemas = res; List names = new ArrayList<>(); for (String value : listSchemas) { - names.add(value.toLowerCase(ENGLISH)); + names.add(normalizeIdentifier(value)); } return ImmutableList.copyOf(names); } @@ -132,22 +133,40 @@ public List listTables(ConnectorSession session, Optional listTables = res; List tables = new ArrayList<>(); for (String value : listTables) { - tables.add(new SchemaTableName(schemaValue.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH))); + tables.add(new SchemaTableName(normalizeIdentifier(schemaValue), normalizeIdentifier(value))); } return tables; } @Override - public FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle) + public FlightDescriptor getFlightDescriptorForTableScan(ConnectorSession session, ArrowTableLayoutHandle tableLayoutHandle) { ArrowTableHandle tableHandle = tableLayoutHandle.getTable(); - String query = new TestingArrowQueryBuilder().buildSql( - tableHandle.getSchema(), - tableHandle.getTable(), - tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), - tableLayoutHandle.getTupleDomain()); - TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query); + + String query; + String table; + + if (tableHandle instanceof TestingQueryArrowTableHandle) { + TestingQueryArrowTableHandle testingQueryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle; + query = testingQueryArrowTableHandle.getQuery(); + table = null; + } + else { + query = new TestingArrowQueryBuilder().buildSql( + tableHandle.getSchema(), + tableHandle.getTable(), + tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), + tableLayoutHandle.getTupleDomain()); + table = tableHandle.getTable(); + } + + TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), table, query); return FlightDescriptor.command(requestCodec.toBytes(request)); } + + private String normalizeIdentifier(String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowMetadata.java new file mode 100644 index 0000000000000..a7eee75abf487 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowMetadata.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.ArrowFlightConfig; +import com.facebook.plugin.arrow.ArrowMetadata; +import com.facebook.plugin.arrow.BaseArrowFlightClientHandler; +import com.facebook.plugin.arrow.testingConnector.tvf.QueryFunctionProvider; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import jakarta.inject.Inject; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +public class TestingArrowMetadata + extends ArrowMetadata +{ + @Inject + public TestingArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder, ArrowFlightConfig config) + { + super(clientHandler, arrowBlockBuilder, config); + } + + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + if (handle instanceof QueryFunctionProvider.QueryFunctionHandle) { + QueryFunctionProvider.QueryFunctionHandle functionHandle = (QueryFunctionProvider.QueryFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), new ArrayList<>(functionHandle.getTableHandle().getColumns()))); + } + return Optional.empty(); + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + if (tableHandle instanceof TestingQueryArrowTableHandle) { + TestingQueryArrowTableHandle queryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle; + return queryArrowTableHandle.getColumns().stream().collect(Collectors.toMap(c -> normalizeIdentifier(session, c.getColumnName()), c -> c)); + } + else { + return super.getColumnHandles(session, tableHandle); + } + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) + { + if (tableHandle instanceof TestingQueryArrowTableHandle) { + TestingQueryArrowTableHandle queryArrowTableHandle = (TestingQueryArrowTableHandle) tableHandle; + + List meta = new ArrayList<>(); + for (ArrowColumnHandle columnHandle : queryArrowTableHandle.getColumns()) { + meta.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, columnHandle.getColumnName())).setType(columnHandle.getColumnType()).build()); + } + return new ConnectorTableMetadata(new SchemaTableName(queryArrowTableHandle.getSchema(), queryArrowTableHandle.getTable()), meta); + } + else { + return super.getTableMetadata(session, tableHandle); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java index 04e1ec34f4d2a..4856327633609 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java @@ -14,14 +14,19 @@ package com.facebook.plugin.arrow.testingConnector; import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.plugin.arrow.ArrowConnector; import com.facebook.plugin.arrow.BaseArrowFlightClientHandler; +import com.facebook.plugin.arrow.testingConnector.tvf.QueryFunctionProvider; import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest; import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Scopes; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.google.inject.multibindings.Multibinder.newSetBinder; public class TestingArrowModule implements Module @@ -36,6 +41,9 @@ public TestingArrowModule(boolean nativeExecution) @Override public void configure(Binder binder) { + binder.bind(ConnectorMetadata.class).to(TestingArrowMetadata.class).in(Scopes.SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(QueryFunctionProvider.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnector.class).to(TestingArrowConnector.class).in(Scopes.SINGLETON); // Concrete implementation of the BaseFlightClientHandler binder.bind(BaseArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); // Override the ArrowBlockBuilder with an implementation that handles h2 types, skip for native diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingQueryArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingQueryArrowTableHandle.java new file mode 100644 index 0000000000000..0590936563d3a --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingQueryArrowTableHandle.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.ArrowTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static java.util.Objects.requireNonNull; +public class TestingQueryArrowTableHandle + extends ArrowTableHandle +{ + private final String query; + private final List columns; + + @JsonCreator + public TestingQueryArrowTableHandle(String query, List columns) + { + super("schema-" + UUID.randomUUID(), "table-" + UUID.randomUUID()); + this.columns = Collections.unmodifiableList(requireNonNull(columns)); + this.query = requireNonNull(query); + } + + @JsonProperty + public String getQuery() + { + return query; + } + + @JsonProperty + public List getColumns() + { + return columns; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java new file mode 100644 index 0000000000000..e3c930926895b --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/tvf/QueryFunctionProvider.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector.tvf; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.plugin.arrow.testingConnector.PrimitiveToPrestoTypeMappings; +import com.facebook.plugin.arrow.testingConnector.TestingQueryArrowTableHandle; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import javax.inject.Provider; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static java.util.Objects.requireNonNull; + +public class QueryFunctionProvider + implements Provider +{ + private static final String SYSTEM = "system"; + private static final String QUERY_FUNCTION = "query_function"; + private static final String QUERY = "QUERY"; + private static final String DATATYPES = "DATATYPES"; + + @Override + public ConnectorTableFunction get() + { + return new QueryFunction(); + } + + public static class QueryFunction + extends AbstractConnectorTableFunction + { + public QueryFunction() + { + super( + SYSTEM, + QUERY_FUNCTION, + Arrays.asList( + ScalarArgumentSpecification.builder() + .name(QUERY) + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name(DATATYPES) + .type(VARCHAR) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze( + ConnectorSession session, + ConnectorTransactionHandle transaction, + Map arguments) + { + Slice dataTypes = (Slice) ((ScalarArgument) arguments.get(DATATYPES)).getValue(); + List columnHandles = ImmutableList.copyOf(extractColumnParameters(dataTypes.toStringUtf8())); + + // preparing descriptor from column handles + Descriptor returnedType = new Descriptor(columnHandles.stream() + .map(column -> new Descriptor.Field(column.getColumnName(), Optional.of(column.getColumnType()))) + .collect(Collectors.toList())); + + Slice query = (Slice) ((ScalarArgument) arguments.get(QUERY)).getValue(); + + TestingQueryArrowTableHandle queryArrowTableHandle = new TestingQueryArrowTableHandle(query.toStringUtf8(), columnHandles); + QueryFunctionHandle handle = new QueryFunctionHandle(queryArrowTableHandle); + + return TableFunctionAnalysis.builder() + .returnedType(returnedType) + .handle(handle) + .build(); + } + } + + public static class QueryFunctionHandle + implements ConnectorTableFunctionHandle + { + private final TestingQueryArrowTableHandle tableHandle; + + @JsonCreator + public QueryFunctionHandle(@JsonProperty("tableHandle") TestingQueryArrowTableHandle tableHandle) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + } + + @JsonProperty + public TestingQueryArrowTableHandle getTableHandle() + { + return tableHandle; + } + } + + private static List extractColumnParameters(String input) + { + String regex = "\\s*([\\w]+)\\s+([\\w ]+)(?:\\((\\d+)(?:,(\\d+))?\\))?\\s*"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(input); + + List columnHandles = new ArrayList<>(); + + while (matcher.find()) { + //map columnType to presto type + requireNonNull(matcher.group(2), "Column data type is null"); + Type prestoType = PrimitiveToPrestoTypeMappings.fromPrimitiveToPrestoType(matcher.group(2).toUpperCase()); + if (prestoType == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Unsupported data type: " + matcher.group(2)); + } + columnHandles.add(new ArrowColumnHandle(matcher.group(1), prestoType)); + } + + return columnHandles; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java index 31d6bf0d4e20d..98681a2dad9a1 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java @@ -52,7 +52,7 @@ public static TestingArrowFlightRequest createDescribeTableRequest(String schema public static TestingArrowFlightRequest createQueryRequest(String schema, String table, String query) { - return new TestingArrowFlightRequest(Optional.of(schema), Optional.of(table), Optional.of(query)); + return new TestingArrowFlightRequest(Optional.ofNullable(schema), Optional.ofNullable(table), Optional.ofNullable(query)); } @JsonProperty diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java index a29ba4aaf8724..3a0c1548cf450 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java @@ -73,8 +73,9 @@ public class TestingArrowProducer private static final Logger logger = Logger.get(TestingArrowProducer.class); private final JsonCodec requestCodec; private final JsonCodec responseCodec; + private boolean caseSensitiveNameMatchingEnabled; - public TestingArrowProducer(BufferAllocator allocator) throws Exception + public TestingArrowProducer(BufferAllocator allocator, boolean caseSensitiveNameMatchingEnabled) throws Exception { this.allocator = allocator; String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1"; @@ -82,6 +83,7 @@ public TestingArrowProducer(BufferAllocator allocator) throws Exception this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", ""); this.requestCodec = jsonCodec(TestingArrowFlightRequest.class); this.responseCodec = jsonCodec(TestingArrowFlightResponse.class); + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; } @Override @@ -100,7 +102,7 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen logger.debug("Executing query: %s", query); - try (ResultSet resultSet = stmt.executeQuery(query.toUpperCase())) { + try (ResultSet resultSet = stmt.executeQuery(normalizeIdentifier(query))) { JdbcToArrowConfig config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048) .setCalendar(Calendar.getInstance(TimeZone.getDefault())).build(); Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), config); @@ -158,8 +160,8 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight List fields = new ArrayList<>(); if (tableName.isPresent()) { String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " + - "WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " + - "AND TABLE_NAME='" + tableName.get().toUpperCase() + "'"; + "WHERE TABLE_SCHEMA='" + normalizeIdentifier(schemaName) + "' " + + "AND TABLE_NAME='" + normalizeIdentifier(tableName.get()) + "'"; try (ResultSet rs = connection.createStatement().executeQuery(query)) { while (rs.next()) { @@ -182,7 +184,7 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight } } else if (selectStatement != null) { - selectStatement = selectStatement.toUpperCase(); + selectStatement = normalizeIdentifier(selectStatement); logger.debug("Executing SELECT query: %s", selectStatement); try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) { ResultSetMetaData metaData = rs.getMetaData(); @@ -232,7 +234,7 @@ public void doAction(CallContext callContext, Action action, StreamListener names = new ArrayList<>(); @@ -306,4 +308,9 @@ private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int s throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); } } + + private String normalizeIdentifier(String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toUpperCase(); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java index 0b8f8b5db9ef5..48bf90441200a 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java @@ -193,6 +193,25 @@ public static void setup(String h2JdbcUrl) throws Exception " ) "); insertRows(tpchMetadata, SUPPLIER, handle); + stmt.execute("CREATE SCHEMA IF NOT EXISTS \"Tpch_Mx\""); + stmt.execute("CREATE SCHEMA IF NOT EXISTS \"tpch_mx\""); + + stmt.execute("CREATE TABLE \"tpch_mx\".\"mxtest\" (" + + " ID INTEGER PRIMARY KEY," + + " \"NAME\" VARCHAR(50)," + + " \"name\" VARCHAR(50)," + + " \"Address\" VARCHAR(50)" + + ")"); + + stmt.execute("INSERT INTO \"tpch_mx\".\"mxtest\" VALUES(1, 'TOM','test', 'kochi'),(2, 'MARY', 'test', 'kochi')"); + + stmt.execute("CREATE TABLE \"tpch_mx\".\"MXTEST\" (" + + " ID INTEGER PRIMARY KEY," + + " \"NAME\" VARCHAR(50)," + + " \"name\" VARCHAR(50)," + + " \"Address\" VARCHAR(50)" + + ")"); + ResultSet resultSet1 = stmt.executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'TPCH'"); List tables = new ArrayList<>(); while (resultSet1.next()) { diff --git a/presto-base-arrow-flight/src/test/resources/certs/ca.crt b/presto-base-arrow-flight/src/test/resources/certs/ca.crt new file mode 100644 index 0000000000000..3d3f516aa0148 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/certs/ca.crt @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFozCCA4ugAwIBAgIUTnZhzGzxAzaTJ5DDQnmBp8jvBUMwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjIxNjA2MDNaGA8yMTI1MDQyODE2MDYwM1owYDELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJvb3RDQTCCAiIwDQYJ +KoZIhvcNAQEBBQADggIPADCCAgoCggIBALmNRxPZt9vvFvIQcdCM3xcAOiTepahR +zsK86AwzXbwcWwOw+rQP0iXiMGjP9wJy0C1hU+j7F+gu5+d5j3DK9/iTlIqByjN+ +gffsZDSBBZBozJxARBWKaEFKRKbhCmr+v92ZN1KMYhhwc5W3SRp6y3kvsWYM4/JE +pRzYUy0505g3p7PnpQzZZSSzV33+9fn5ggafVB9FSJqOuBKUBTITt+Rf1pNV2ZGC +f2X39KOIIRz2Hz1L6hdqeMQN9V5KQ7C6oPfl+ho97JHaZirtHr1XZJ6YQRy7Mldp +jOPj4MMKqi8m+HhyCSznIDNiMp0OACX3CZ5RmKRYfnunOuw+fvI28HyOXMgfWKa2 +o/2/YlAzNrD50hJPwQMTlKWJm2gY2n5x2FWT6/8aXeCnsJALK6Dj6Ax0wxkAFfIo +76REHlXX2fIiz0cciYwYtvwhjp5efqX22B7LDkhu7fJ42yUd6g3crbmGM8OOoXgL +w3MWx30FatTyDT8un2ZvVDJEADW3+WWtyrZWHFMVFrVnN7Di4MAuWsRIVtuO6PEV +6pPS55NBmvKzWAoYBmpH103GlIxvZ6CCTeKqFbcrI77smrIO5CLmzl5yjb/urmy4 +GLYHM+EPB5pJKQO8g6fyM369mhEjBxt/RJGv9E0Jw7KmnHn5V2qSn4Yi41dUp7Cp +pQVIaYlJGn65AgMBAAGjUzBRMB0GA1UdDgQWBBRjt9KoJO8CO/Cy/xTJztrbdUlv +9DAfBgNVHSMEGDAWgBRjt9KoJO8CO/Cy/xTJztrbdUlv9DAPBgNVHRMBAf8EBTAD +AQH/MA0GCSqGSIb3DQEBCwUAA4ICAQApizBpSaZDYjBhOxD6+B5em18OIf5ZIjHk +iAhuppR+HaqyAHJgCO707dZVmFz9ECUZI9mvKcmj+h0Wh/mK4cSiDunFB9yUr67U +wV5F2/u/JAAq6VsbdrDiZPUwET8U9ai14LMEgPPF+Zif+wnopRav5lbiPoJVUjqr +wVoP2AHijIP46YCWwXqOTJMC79ccUMBZeDwF4bOquIADLmEnANp6fiMI+eE6OLFs +fDtjFqybRUZqzewv2lpzH2ZYEYk4bIk76TGkOYtrwJ+iQj77ZZFSBW5zkry/zaG3 +/5Ufjv65T9Zr1jmIMigcmCHwNsCLOYzIKjRaiuLGs4B9s8SGEauTRhP0dG5ndWPw +50NeSNJr37MHdKky44WAFlAk9BAKlghOaC5m2RyMof8DwYKPEe5epe1wBotiPqSX +doaZvch6wkuo8xvFKqH6rBTWJLMwuFt7m3XrGqGYlE+1gvuEfn7ZGzG00sl218mZ +MfsLqJfft92ARC1/qJvUFr5mM6SV4eQeTl1tAtv6Xfczr+3/iqc5gbeG25dXQclO +y1qIKthAoXFq6rAZ+bvfASiVV1OQS76nWSiYYS1dDPQJ/g4aawOkUYr9OjS5HNQ+ +rNAcLB+I0oaZQzZ85098qAVAJ76eFATb4ieDK6m0j6Fq5ddwrlYzxEEr7TscfHW2 +zPCtinUe0w== +-----END CERTIFICATE----- diff --git a/presto-base-arrow-flight/src/test/resources/certs/client.crt b/presto-base-arrow-flight/src/test/resources/certs/client.crt new file mode 100644 index 0000000000000..0e2bc9ff7f662 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/certs/client.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEkDCCAnigAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQYwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjIxNjA4MzFaGA8yMTI1MDQyODE2MDgzMVowXjELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MQ8wDQYDVQQDDAZjbGllbnQwggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQC8DK/RdBF6I+k+DMGrjhMMBCnpPNwJtzJU +uXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIMzS4HBU/CHY26eR2ItiWq +DoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7DJC2xTTsyEjBzD+MnIB7T +ja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7BeqGuFS8bEgtLwwCzRxPK +kG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5MulEIjuxII/dTuoo+/uirN0 +wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPoHpWZsYIHAgMBAAGjQjBA +MB0GA1UdDgQWBBRW/6j5rNwnwTBAXY5igXm3ETzv/DAfBgNVHSMEGDAWgBRjt9Ko +JO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAVLpfZDgkL1Dz/+Fq +vSl2IoxOFNd2DTa6yM8/1wpvMVTA024lp0ttyoz0o1621hTRexcXTimZqUNAtPV6 +Gwmb2ACLN4XtLk9QT9XjDtWKzPxCJ+ze7rrhj1jYqv9yUebdkJoMKfcbwYi0gtpt +HlaJqNoKgzZxOCGhTtdS3ypb9nDCyx3fmFk5mIYfzEszoMmqNL006ANlJ0IKkFZj +vUkkFyMGLerInmTDRjDLkUCkNKaJUYjZhf/FNwVtc1A9a/bDJMEYVog+CY0dpXKb +1IGaXzB4ewhuQKuhb/LCZT1pNm/cGCY2cRGFy9EVAuZ5FV0ajh1HYwdpGDzucoUD +UaTHIK40E2/kZorJa37Xyn7Lekgun6YpfOudBkKg5mlV6qoL9W/lTZPjMHs2ufvW +/A5S4okR4JmhC44TMgAv90MU9yEP90OkzW6egatBShWySJ3Bn5W+ebQSwJ38wgTy +e6j5jWh7xiiPC4TJbSXVMGEfJw/c2hx4R/83MqBhVLPfoapaUCDUWniv6n7zl5ML +k7WIZzXSK212/H+eVXFJ1Gq6zztOPkN9QGgr+dbsCzdLWPJzZDmI7+lgT2EKxIV/ +VMtHVOe2bLkePTNA2+vXQhs0p7JDtzyATAyMdhJwPljt8X+HJxEoAP4Dk6PcK3Bd +E92yOuW2jl6FzgqKqtVQvxIiBgc= +-----END CERTIFICATE----- diff --git a/presto-base-arrow-flight/src/test/resources/certs/client.key b/presto-base-arrow-flight/src/test/resources/certs/client.key new file mode 100644 index 0000000000000..72434bc294aff --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/certs/client.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC8DK/RdBF6I+k+ +DMGrjhMMBCnpPNwJtzJUuXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIM +zS4HBU/CHY26eR2ItiWqDoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7D +JC2xTTsyEjBzD+MnIB7Tja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7B +eqGuFS8bEgtLwwCzRxPKkG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5Mul +EIjuxII/dTuoo+/uirN0wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPo +HpWZsYIHAgMBAAECggEAViw6JXFa0O3D5HtUBJmGgOsniYoqCwm4NrsGNLuHb2ME +rSpTwNNGJtqpDcQdEtVXfY1muO9xjuznPJaJkQ4ODpYcbGcz8YIGoHck+XHJjHsp +2VIeNFFsbsFzWZqzfYHrj/rMjpVJJx90tlfN2IHbroZHTXLqVPOTLL6wvZZ6P9XF +zqpWABKOaEDbenhSFFZeF5KR3NG9HSTm4YLuekumkH+QgrveDfDwXG4hAHqg836o +OF3NPaij6VlSR18nuyW0wMs/Ceu13P+GALqHmz98pFyVgHWQFryL9IccvJQDyEnt +saeG4IAVlJbZDGTnRgANLhpwBr7XhMG1aK+wmOMRgQKBgQDkcatiATlr9L8gfnHb +6pmX//AZLdXuQLXfuTvu638Brhm770noLgfIC+HIp5kCHxT2Xj5Vn+MSnYD6R6Wh +chApRKJUdsuz1iOq23YJjvsSLWCGpl9IxR7WY27uGOPIjQcOd1PRbkCq9AgUJwyn +ryca3sbYh/XQOWGLbJNIQs/S/QKBgQDSu6PVeMaS3276KblvGIvvaSAQDQWxXcC+ +sA4CBmvjzx3xx5GAox/w7tcKmK/KQxNhaYy6N7xLc1YUJ9FbnT2PZQJhtP2d2Gat +Zre/+Qa+u84cR5hj9EI+B8FjW7D/psEj16KjHCds/SET6ngPM+RdB4N9daVFCurt +p0f717yiUwKBgBTJDun06I+dDkLbnmp/FwiQff0cgYmTE7lOdliPzteNSsQhypy4 +i3a1Ng72yOI7h8G+43cQ/C02bYTYPgbJhRTsLMT4piIvysECBORrwQZvYIf/3U2W +ue6Rz4cUdq1Jv6meS98TZAjp+U40G1+qfSlhub/75u7SOcDg2SnLAnPVAoGBAIOO +EmRE5qpwA+b2P0Ykq89E8Hg0uPYWEiq427XV7mqkNQxoSuRkcZ9Ga0a5NRzurN2m +N+1UuB7eHMGubdtkmTa4lzkJ9T4iB09/DX0x6E0QD0bGR1M2/FefHdJ6PlAK+Q34 +Ixbyj4ZRq+G0AUl0Wr7c3vBmjktA2pKMWLrW3nLzAoGBAKTl7qX6CD42gAJuT5Hp +rrXqlppVIyRvuXzXtX/Xq81IUHlBgS/t9HPyqDzmTKfxD8540kI+15bWPDHSJxiQ +ccqPaKyXhBXstDwGmlPKVzJUxk0dz5NHs+8gItUDOg78pM3siXN7vW9XBCH7mCDA +4zet/C0YCAiFVT+ipMoXy8Nc +-----END PRIVATE KEY----- diff --git a/presto-base-arrow-flight/src/test/resources/certs/invalid_cert.crt b/presto-base-arrow-flight/src/test/resources/certs/invalid_cert.crt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/presto-base-arrow-flight/src/test/resources/certs/server.crt b/presto-base-arrow-flight/src/test/resources/certs/server.crt new file mode 100644 index 0000000000000..8fb80c1bba51d --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/certs/server.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEkzCCAnugAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQgwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjMwMTQzNTBaGA8yMTI1MDQyOTAxNDM1MFowYTELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCtkJ+r1F8+YOVuwWLxbGVsJKw3BESh +tCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h0xVlHES7eTAD+NPl +WHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqAyPZjUjhkAZPgrT1/ +8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMOkSYNsanb+JdEryev +goju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDygZoO8Qr/YrGvXXYqt +0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3QMRWGpsvAgMBAAGj +QjBAMB0GA1UdDgQWBBRf0XdDhjNSIZQsvSFbe+hdsydllzAfBgNVHSMEGDAWgBRj +t9KoJO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAec7y1Odyg3x/ +Uj0jfZWYNE0BuR114UVwEYhxFi9tRAGxjTlsl6ATCSYBWU+fwUkC29C2r3bu59fp +/KvYPRrwGPyOtXwHR2cmwJ7QrUhlIwPipWO4Kal3/EKWvrV9rzdnOd2QYqEMS6f0 +UixGLcT5p5KmEsH3W8Y9Uk94g/z12ZgdGeKKyY7hWnu1d47b2TS4oiRx+d6AacAD +1BzJRUhDjS2Vfe2cnpOqBHJWyCT1BxsfxKAc3rLa6JznbulHQPCE5WWBolHb8Tob +Yf32sIJydOcWU+zJ9VsGuglQ8dQInMemW8k5y48ACqHb00lAoucGJ3Izy9tJiBeU +C8TcmRxQSjoMGhFGNIjBAvXak0UzobUKE3YyABBbUdWLofs0N325K1aSga4vXmlO +OPzP4FWMLZyDicVMUGLP9jcyeb/gFMzjoU2En59gRNVDNNo01Lyj9MhGHF/jvV7p +Zf782GvXMT/NSPtXqmABSV/Svy70vXogeQxTii9YOZePKWQnFhcEwin/7d9bf4d/ +nUqtzDFb5FiDeFA+H9FyeNaeub3OtvsZUacAVDCT1t9/8uShjJo34v8WbJegepcY +Tpdkm4x0DWvv/QT3JBqC0wprsmBVzuqTJ/jBxYye02bds1bM7xePAuVOwuSb6Azg +gLrweiDMakgmZSkGgUonnjblYbHGTHw= +-----END CERTIFICATE----- diff --git a/presto-base-arrow-flight/src/test/resources/certs/server.key b/presto-base-arrow-flight/src/test/resources/certs/server.key new file mode 100644 index 0000000000000..0e49bebd60beb --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCtkJ+r1F8+YOVu +wWLxbGVsJKw3BEShtCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h +0xVlHES7eTAD+NPlWHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqA +yPZjUjhkAZPgrT1/8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMO +kSYNsanb+JdEryevgoju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDyg +ZoO8Qr/YrGvXXYqt0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3 +QMRWGpsvAgMBAAECggEAIK9C+lNAallJ62z8inU8tjxDuAqUOBVbJZRVcPbIr1zn +HmLlpyd4Sghhh7CjYYoPuHDtTQxIcBNwlDBxb3x+zwUXzy+tC5v5j7DN01qex2Ew +XTDSAEN3Ra2r1S+/1hSztVd0oXDozFxKk+UETRjfKKoJZH6LPcy7MOLFR5EEuJ8L +0kvGdEtuNCmZ1vPBwqR3IKQS9NsB1IdTtK0g2LdtVzM3U6F173CrAx51qNeAL30j +Np+I0rfm7vYVco6nDQXJB86hzwwBnLMzmZR2E0z+JStQCjQtEJN9wp+NBnViMb8C +mZl0K/PH3ZKNEs1Aw/TsRpPu6Fc+sN6iIs2oOGiKfQKBgQDa0Wpfj+SHflLfmrRU +PplGNjWdJiyuXROqX18iNE8nAD0eRqAFdzj9yU1IW49KCzuHInEl2pP9yrDZTWXB +Bht4C+Vk13mrBE3Sc1LDrks5EhDLaaolLgx1B+JN1X2DpfuzO8WHrXR11PCzFTAp +yDSVd451CFFXMseS1V9UxCy3lQKBgQDLDrLX/0hGhG+a5RUaAE+hZk+tU9RyjYm6 +/5lIoDjDwA9Yst69JCTHDApkdZ6IrjPDZrxkAQR6QwsGo+zRGkHV2wCoqR/RxcT5 +RBcbe/8xL86ZKwnhAheP6ssgZeK5zOG1iLol319kXXuo6NueN+YlocmsppRvAOq7 +/qMnhzXGswKBgQCpke2wHo9HnNJWK8ohGt2mtm232ZR4jvKlbgEIPac1Hw89/hcW +BT0qFqyILUQOakP4Re2PGyLiYwfHbh4zhisVTYq4Ke9EYzJ3qxzxPYlXsbNIHxtW +cqf+rVxnWtFIiwFR9TjvGrEMezcIYJwRVO/DAIJqGUcHnvdfx3B3/Qp2PQKBgQCk +y7UR37kEog8BotHRXFdEIgigHtzYa05QWYhJjN8E3yaVUfW7g03lzTvR9DNJsjeI +aiSS9NBxeV/Fb9yOh8TOjwKl3zxXvy3xLvWh9KxTev0tCeTmnBALWP6puIadTE4S +Snjoq7R7e/MUToeOjMdX20oVuMvWmuPm1u4K8o0OSQKBgQCec+QLllYXk22A8/e/ +f5HhSYr161lEFFmuzKhuuy+esyCQU/KZmxQH0UqnsL3Ww4ofq42lteqyUJnriHsx +QP5FTIMKH8W+Xels1i6jCC+MVXAXraAF27dOlmKxWMN7mnElZ/7lQKmBq64wil35 +sfcJA4FDxVM2Amv4KRo/w1C/zQ== +-----END PRIVATE KEY----- diff --git a/presto-base-jdbc/pom.xml b/presto-base-jdbc/pom.xml index ccfb62a26b53d..a8479e72fd9cb 100644 --- a/presto-base-jdbc/pom.xml +++ b/presto-base-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-base-jdbc @@ -14,6 +14,7 @@ ${project.parent.basedir} + true @@ -38,7 +39,7 @@ - io.airlift + com.facebook.airlift units @@ -53,24 +54,23 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - com.google.code.findbugs - jsr305 - true + joda-time + joda-time - joda-time - joda-time + jakarta.inject + jakarta.inject-api @@ -78,6 +78,11 @@ javax.inject + + org.locationtech.jts + jts-core + + com.facebook.presto @@ -110,6 +115,17 @@ jmxutils + + com.esri.geometry + esri-geometry-api + + + + com.facebook.presto + presto-geospatial-toolkit + provided + + @@ -173,6 +189,12 @@ test-jar + + com.facebook.presto + presto-analyzer + test + + com.facebook.presto presto-tpch @@ -210,6 +232,15 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + javax.inject:javax.inject + + + diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java index bb704f1057bef..ba1642e18ca37 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcClient.java @@ -20,6 +20,8 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.UuidType; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.mapping.ReadMapping; +import com.facebook.presto.plugin.jdbc.mapping.WriteMapping; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorSession; @@ -27,19 +29,18 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.statistics.TableStatistics; -import com.google.common.base.CharMatcher; import com.google.common.base.Joiner; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.annotation.Nullable; -import javax.annotation.PreDestroy; +import jakarta.annotation.Nullable; +import jakarta.annotation.PreDestroy; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -70,14 +71,17 @@ import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; -import static com.facebook.presto.plugin.jdbc.StandardReadMappings.jdbcTypeToPrestoType; +import static com.facebook.presto.plugin.jdbc.JdbcWarningCode.USE_OF_DEPRECATED_CONFIGURATION_PROPERTY; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.jdbcTypeToReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.prestoTypeToWriteMapping; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.timestampReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.timestampReadMappingLegacy; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; @@ -118,6 +122,7 @@ public class BaseJdbcClient protected final Cache> remoteSchemaNames; protected final Cache> remoteTableNames; protected final Set listSchemasIgnoredSchemas; + protected final boolean caseSensitiveNameMatchingEnabled; public BaseJdbcClient(JdbcConnectorId connectorId, BaseJdbcConfig config, String identifierQuote, ConnectionFactory connectionFactory) { @@ -132,6 +137,7 @@ public BaseJdbcClient(JdbcConnectorId connectorId, BaseJdbcConfig config, String this.remoteSchemaNames = remoteNamesCacheBuilder.build(); this.remoteTableNames = remoteNamesCacheBuilder.build(); this.listSchemasIgnoredSchemas = config.getlistSchemasIgnoredSchemas(); + this.caseSensitiveNameMatchingEnabled = config.isCaseSensitiveNameMatching(); } @PreDestroy @@ -152,7 +158,7 @@ public final Set getSchemaNames(ConnectorSession session, JdbcIdentity i { try (Connection connection = connectionFactory.openConnection(identity)) { return listSchemas(connection).stream() - .map(schemaName -> schemaName.toLowerCase(ENGLISH)) + .map(schemaName -> normalizeIdentifier(session, schemaName)) .collect(toImmutableSet()); } catch (SQLException e) { @@ -182,13 +188,14 @@ protected Collection listSchemas(Connection connection) public List getTableNames(ConnectorSession session, JdbcIdentity identity, Optional schema) { try (Connection connection = connectionFactory.openConnection(identity)) { - Optional remoteSchema = schema.map(schemaName -> toRemoteSchemaName(identity, connection, schemaName)); + Optional remoteSchema = schema.map(schemaName -> toRemoteSchemaName(session, identity, connection, schemaName)); try (ResultSet resultSet = getTables(connection, remoteSchema, Optional.empty())) { ImmutableList.Builder list = ImmutableList.builder(); while (resultSet.next()) { String tableSchema = getTableSchemaName(resultSet); String tableName = resultSet.getString("TABLE_NAME"); - list.add(new SchemaTableName(tableSchema.toLowerCase(ENGLISH), tableName.toLowerCase(ENGLISH))); + list.add(new SchemaTableName(normalizeIdentifier(session, tableSchema), + normalizeIdentifier(session, tableName))); } return list.build(); } @@ -203,8 +210,8 @@ public List getTableNames(ConnectorSession session, JdbcIdentit public JdbcTableHandle getTableHandle(ConnectorSession session, JdbcIdentity identity, SchemaTableName schemaTableName) { try (Connection connection = connectionFactory.openConnection(identity)) { - String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); - String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); + String remoteSchema = toRemoteSchemaName(session, identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(session, identity, connection, remoteSchema, schemaTableName.getTableName()); try (ResultSet resultSet = getTables(connection, Optional.of(remoteSchema), Optional.of(remoteTable))) { List tableHandles = new ArrayList<>(); while (resultSet.next()) { @@ -243,13 +250,13 @@ public List getColumns(ConnectorSession session, JdbcTableHand resultSet.getString("TYPE_NAME"), resultSet.getInt("COLUMN_SIZE"), resultSet.getInt("DECIMAL_DIGITS")); - Optional columnMapping = toPrestoType(session, typeHandle); + Optional readMapping = toPrestoType(session, typeHandle); // skip unsupported column types - if (columnMapping.isPresent()) { + if (readMapping.isPresent()) { String columnName = resultSet.getString("COLUMN_NAME"); boolean nullable = columnNullable == resultSet.getInt("NULLABLE"); Optional comment = Optional.ofNullable(emptyToNull(resultSet.getString("REMARKS"))); - columns.add(new JdbcColumnHandle(connectorId, columnName, typeHandle, columnMapping.get().getType(), nullable, comment)); + columns.add(new JdbcColumnHandle(connectorId, columnName, typeHandle, readMapping.get().getType(), nullable, comment)); } } if (columns.isEmpty()) { @@ -271,7 +278,11 @@ public List getColumns(ConnectorSession session, JdbcTableHand @Override public Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle) { - return jdbcTypeToPrestoType(typeHandle); + if (typeHandle.getJdbcType() == java.sql.Types.TIMESTAMP) { + boolean legacyTimestamp = session.getSqlFunctionProperties().isLegacyTimestamp(); + return Optional.of(legacyTimestamp ? timestampReadMappingLegacy() : timestampReadMapping()); + } + return jdbcTypeToReadMapping(typeHandle); } @Override @@ -363,9 +374,9 @@ protected JdbcOutputTableHandle createTable(ConnectorTableMetadata tableMetadata try (Connection connection = connectionFactory.openConnection(identity)) { boolean uppercase = connection.getMetaData().storesUpperCaseIdentifiers(); - String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); - String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); - if (uppercase) { + String remoteSchema = toRemoteSchemaName(session, identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(session, identity, connection, remoteSchema, schemaTableName.getTableName()); + if (uppercase && !caseSensitiveNameMatchingEnabled) { tableName = tableName.toUpperCase(ENGLISH); } String catalog = connection.getCatalog(); @@ -375,7 +386,7 @@ protected JdbcOutputTableHandle createTable(ConnectorTableMetadata tableMetadata ImmutableList.Builder columnList = ImmutableList.builder(); for (ColumnMetadata column : tableMetadata.getColumns()) { String columnName = column.getName(); - if (uppercase) { + if (uppercase && !caseSensitiveNameMatchingEnabled) { columnName = columnName.toUpperCase(ENGLISH); } columnNames.add(columnName); @@ -442,7 +453,7 @@ protected void renameTable(JdbcIdentity identity, String catalogName, SchemaTabl String tableName = oldTable.getTableName(); String newSchemaName = newTable.getSchemaName(); String newTableName = newTable.getTableName(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { schemaName = schemaName.toUpperCase(ENGLISH); tableName = tableName.toUpperCase(ENGLISH); newSchemaName = newSchemaName.toUpperCase(ENGLISH); @@ -490,7 +501,7 @@ public void addColumn(ConnectorSession session, JdbcIdentity identity, JdbcTable String table = handle.getTableName(); String columnName = column.getName(); DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { schema = schema != null ? schema.toUpperCase(ENGLISH) : null; table = table.toUpperCase(ENGLISH); columnName = columnName.toUpperCase(ENGLISH); @@ -511,14 +522,20 @@ public void renameColumn(ConnectorSession session, JdbcIdentity identity, JdbcTa { try (Connection connection = connectionFactory.openConnection(identity)) { DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + String schema = handle.getSchemaName(); + String table = handle.getTableName(); + String columnName = jdbcColumn.getColumnName(); + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { + schema = schema != null ? schema.toUpperCase(ENGLISH) : null; + table = table.toUpperCase(ENGLISH); + columnName = columnName.toUpperCase(ENGLISH); newColumnName = newColumnName.toUpperCase(ENGLISH); } String sql = format( "ALTER TABLE %s RENAME COLUMN %s TO %s", - quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName()), - jdbcColumn.getColumnName(), - newColumnName); + quoted(handle.getCatalogName(), schema, table), + quoted(columnName), + quoted(newColumnName)); execute(connection, sql); } catch (SQLException e) { @@ -530,10 +547,19 @@ public void renameColumn(ConnectorSession session, JdbcIdentity identity, JdbcTa public void dropColumn(ConnectorSession session, JdbcIdentity identity, JdbcTableHandle handle, JdbcColumnHandle column) { try (Connection connection = connectionFactory.openConnection(identity)) { + DatabaseMetaData metadata = connection.getMetaData(); + String schema = handle.getSchemaName(); + String table = handle.getTableName(); + String columnName = column.getColumnName(); + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { + schema = schema != null ? schema.toUpperCase(ENGLISH) : null; + table = table.toUpperCase(ENGLISH); + columnName = columnName.toUpperCase(ENGLISH); + } String sql = format( "ALTER TABLE %s DROP COLUMN %s", - quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName()), - column.getColumnName()); + quoted(handle.getCatalogName(), schema, table), + quoted(columnName)); execute(connection, sql); } catch (SQLException e) { @@ -607,6 +633,12 @@ public PreparedStatement getPreparedStatement(ConnectorSession session, Connecti return connection.prepareStatement(sql); } + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return identifier.toLowerCase(ENGLISH); + } + protected ResultSet getTables(Connection connection, Optional schemaName, Optional tableName) throws SQLException { @@ -625,12 +657,14 @@ protected String getTableSchemaName(ResultSet resultSet) return resultSet.getString("TABLE_SCHEM"); } - protected String toRemoteSchemaName(JdbcIdentity identity, Connection connection, String schemaName) + protected String toRemoteSchemaName(ConnectorSession session, JdbcIdentity identity, Connection connection, String schemaName) { requireNonNull(schemaName, "schemaName is null"); - verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(schemaName), "Expected schema name from internal metadata to be lowercase: %s", schemaName); if (caseInsensitiveNameMatching) { + session.getWarningCollector().add(new PrestoWarning(USE_OF_DEPRECATED_CONFIGURATION_PROPERTY, + "'case-insensitive-name-matching' is deprecated. Use of this configuration value may lead to query failures. " + + "Please switch to using 'case-sensitive-name-matching' for proper case sensitivity behavior.")); try { Map mapping = remoteSchemaNames.getIfPresent(identity); if (mapping != null && !mapping.containsKey(schemaName)) { @@ -653,7 +687,7 @@ protected String toRemoteSchemaName(JdbcIdentity identity, Connection connection try { DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { return schemaName.toUpperCase(ENGLISH); } return schemaName; @@ -669,13 +703,15 @@ protected Map listSchemasByLowerCase(Connection connection) .collect(toImmutableMap(schemaName -> schemaName.toLowerCase(ENGLISH), schemaName -> schemaName)); } - protected String toRemoteTableName(JdbcIdentity identity, Connection connection, String remoteSchema, String tableName) + protected String toRemoteTableName(ConnectorSession session, JdbcIdentity identity, Connection connection, String remoteSchema, String tableName) { requireNonNull(remoteSchema, "remoteSchema is null"); requireNonNull(tableName, "tableName is null"); - verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(tableName), "Expected table name from internal metadata to be lowercase: %s", tableName); if (caseInsensitiveNameMatching) { + session.getWarningCollector().add(new PrestoWarning(USE_OF_DEPRECATED_CONFIGURATION_PROPERTY, + "'case-insensitive-name-matching' is deprecated. Use of this configuration value may lead to query failures. " + + "Please switch to using 'case-sensitive-name-matching' for proper case sensitivity behavior.")); try { RemoteTableNameCacheKey cacheKey = new RemoteTableNameCacheKey(identity, remoteSchema); Map mapping = remoteTableNames.getIfPresent(cacheKey); @@ -699,7 +735,7 @@ protected String toRemoteTableName(JdbcIdentity identity, Connection connection, try { DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { return tableName.toUpperCase(ENGLISH); } return tableName; @@ -738,7 +774,6 @@ protected void execute(Connection connection, String query) statement.execute(query); } } - protected String toSqlType(Type type) { if (isVarcharType(type)) { @@ -765,6 +800,15 @@ protected String toSqlType(Type type) throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); } + public WriteMapping toWriteMapping(ConnectorSession session, Type type) + { + Optional writeMapping = prestoTypeToWriteMapping(session, type); + if (writeMapping.isPresent()) { + return writeMapping.get(); + } + throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); + } + protected String quoted(String name) { name = name.replace(identifierQuote, identifierQuote + identifierQuote); diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java index fc4f8ed2c3d0d..000dcbdf199d0 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BaseJdbcConfig.java @@ -14,20 +14,35 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.annotation.Nullable; -import javax.validation.constraints.NotNull; +import com.google.inject.ConfigurationException; +import com.google.inject.spi.Message; +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.validation.constraints.NotNull; import java.util.Set; import static java.util.Locale.ENGLISH; import static java.util.concurrent.TimeUnit.MINUTES; +/** + * Base configuration class for JDBC connectors. + * + * This class is provided for convenience and contains common configuration properties + * that many JDBC connectors may need. However, core JDBC functionality should not + * depend on this class, as JDBC connectors may choose to use their own mechanisms + * for connection management, authentication, and other configuration needs. + * + * Connectors are free to implement their own configuration classes and connection + * strategies without extending or using this base configuration. + */ public class BaseJdbcConfig { private String connectionUrl; @@ -38,6 +53,7 @@ public class BaseJdbcConfig private boolean caseInsensitiveNameMatching; private Duration caseInsensitiveNameMatchingCacheTtl = new Duration(1, MINUTES); private Set listSchemasIgnoredSchemas = ImmutableSet.of("information_schema"); + private boolean caseSensitiveNameMatchingEnabled; @NotNull public String getConnectionUrl() @@ -105,12 +121,18 @@ public BaseJdbcConfig setPasswordCredentialName(String passwordCredentialName) return this; } + @Deprecated public boolean isCaseInsensitiveNameMatching() { return caseInsensitiveNameMatching; } + @Deprecated @Config("case-insensitive-name-matching") + @ConfigDescription("Deprecated: This will be removed in future releases. Use 'case-sensitive-name-matching=true' instead for mysql. " + + "This configuration setting converts all schema/table names to lowercase. " + + "If your source database contains names differing only by case (e.g., 'Testdb' and 'testdb'), " + + "this setting can lead to conflicts and query failures.") public BaseJdbcConfig setCaseInsensitiveNameMatching(boolean caseInsensitiveNameMatching) { this.caseInsensitiveNameMatching = caseInsensitiveNameMatching; @@ -142,4 +164,31 @@ public BaseJdbcConfig setlistSchemasIgnoredSchemas(String listSchemasIgnoredSche this.listSchemasIgnoredSchemas = ImmutableSet.copyOf(Splitter.on(",").trimResults().omitEmptyStrings().split(listSchemasIgnoredSchemas.toLowerCase(ENGLISH))); return this; } + + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatchingEnabled; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching of schema, table names across the connector. " + + "When disabled, names are matched case-insensitively using lowercase normalization.") + public BaseJdbcConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatchingEnabled) + { + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; + return this; + } + + @PostConstruct + public void validateConfig() + { + if (isCaseInsensitiveNameMatching() && isCaseSensitiveNameMatching()) { + throw new ConfigurationException(ImmutableList.of(new Message("Only one of 'case-insensitive-name-matching=true' or 'case-sensitive-name-matching=true' can be set. " + + "These options are mutually exclusive."))); + } + + if (connectionUrl == null) { + throw new ConfigurationException(ImmutableList.of(new Message("connection-url is required but was not provided"))); + } + } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DefaultTableLocationProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DefaultTableLocationProvider.java new file mode 100644 index 0000000000000..fc3d9a7fe7abe --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DefaultTableLocationProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +/** + * Default implementation of TableLocationProvider that uses the connection URL + * from BaseJdbcConfig as the table location. + */ +public class DefaultTableLocationProvider + implements TableLocationProvider +{ + private final String connectionUrl; + + @Inject + public DefaultTableLocationProvider(BaseJdbcConfig baseJdbcConfig) + { + requireNonNull(baseJdbcConfig, "baseJdbcConfig is null"); + this.connectionUrl = requireNonNull(baseJdbcConfig.getConnectionUrl(), "connection-url is null"); + } + + @Override + public String getTableLocation() + { + return connectionUrl; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/GeometryUtils.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/GeometryUtils.java new file mode 100644 index 0000000000000..8e7f2fc5d8134 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/GeometryUtils.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.esri.core.geometry.ogc.OGCGeometry; +import com.facebook.presto.spi.PrestoException; +import io.airlift.slice.Slice; + +import static com.esri.core.geometry.ogc.OGCGeometry.fromBinary; +import static com.facebook.presto.geospatial.GeometryUtils.wktFromJtsGeometry; +import static com.facebook.presto.geospatial.serde.EsriGeometrySerde.serialize; +import static com.facebook.presto.geospatial.serde.JtsGeometrySerde.deserialize; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static io.airlift.slice.Slices.utf8Slice; +import static java.util.Objects.requireNonNull; + +public class GeometryUtils +{ + private GeometryUtils() {} + + public static Slice getAsText(Slice input) + { + return utf8Slice(wktFromJtsGeometry(deserialize(input))); + } + + public static Slice stGeomFromBinary(Slice input) + { + requireNonNull(input, "input is null"); + OGCGeometry geometry; + try { + geometry = fromBinary(input.toByteBuffer().slice()); + } + catch (IllegalArgumentException | IndexOutOfBoundsException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Invalid Well-Known Binary (WKB)", e); + } + geometry.setSpatialReference(null); + return serialize(geometry); + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java index 722841fe05c3e..c307189dc0d08 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcClient.java @@ -14,6 +14,9 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.plugin.jdbc.mapping.ReadMapping; +import com.facebook.presto.plugin.jdbc.mapping.WriteMapping; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorSession; @@ -21,8 +24,7 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.statistics.TableStatistics; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.sql.Connection; import java.sql.PreparedStatement; @@ -51,6 +53,8 @@ default boolean schemaExists(ConnectorSession session, JdbcIdentity identity, St Optional toPrestoType(ConnectorSession session, JdbcTypeHandle typeHandle); + WriteMapping toWriteMapping(ConnectorSession session, Type type); + ConnectorSplitSource getSplits(ConnectorSession session, JdbcIdentity identity, JdbcTableLayoutHandle layoutHandle); Connection getConnection(ConnectorSession session, JdbcIdentity identity, JdbcSplit split) @@ -98,4 +102,6 @@ PreparedStatement getPreparedStatement(ConnectorSession session, Connection conn throws SQLException; TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, List columnHandles, TupleDomain tupleDomain); + + String normalizeIdentifier(ConnectorSession session, String identifier); } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcColumnHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcColumnHandle.java index 3be953c156a89..ec7e1d1ac764b 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcColumnHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcColumnHandle.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -98,6 +99,16 @@ public ColumnMetadata getColumnMetadata() .build(); } + public ColumnMetadata getColumnMetadata(ConnectorSession session, JdbcClient jdbcClient) + { + return ColumnMetadata.builder() + .setName(jdbcClient.normalizeIdentifier(session, columnName)) + .setType(columnType) + .setNullable(nullable) + .setComment(comment.orElse(null)) + .build(); + } + @Override public boolean equals(Object obj) { diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java index 511899a2e03f9..ade0d2c16ff4a 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcConnector.java @@ -35,8 +35,7 @@ import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcInputInfo.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcInputInfo.java new file mode 100644 index 0000000000000..1b7f0a6e7b37b --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcInputInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.plugin.jdbc; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class JdbcInputInfo +{ + private final String tableLocation; + + public JdbcInputInfo( + @JsonProperty("tableLocation") String tableLocation) + + { + this.tableLocation = requireNonNull(tableLocation, "tableLocation is null"); + } + + @JsonProperty + public String getTableLocation() + { + return tableLocation; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java index fd92e8730af74..9c37bd8a3d731 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadata.java @@ -55,14 +55,15 @@ public class JdbcMetadata private final JdbcMetadataCache jdbcMetadataCache; private final JdbcClient jdbcClient; private final boolean allowDropTable; - + private final String url; private final AtomicReference rollbackAction = new AtomicReference<>(); - public JdbcMetadata(JdbcMetadataCache jdbcMetadataCache, JdbcClient jdbcClient, boolean allowDropTable) + public JdbcMetadata(JdbcMetadataCache jdbcMetadataCache, JdbcClient jdbcClient, boolean allowDropTable, TableLocationProvider tableLocationProvider) { this.jdbcMetadataCache = requireNonNull(jdbcMetadataCache, "jdbcMetadataCache is null"); this.jdbcClient = requireNonNull(jdbcClient, "client is null"); this.allowDropTable = allowDropTable; + this.url = requireNonNull(tableLocationProvider, "tableLocationProvider is null").getTableLocation(); } @Override @@ -84,11 +85,15 @@ public JdbcTableHandle getTableHandle(ConnectorSession session, SchemaTableName } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { JdbcTableHandle tableHandle = (JdbcTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new JdbcTableLayoutHandle(session.getSqlFunctionProperties(), tableHandle, constraint.getSummary(), Optional.empty())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -104,7 +109,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect ImmutableList.Builder columnMetadata = ImmutableList.builder(); for (JdbcColumnHandle column : jdbcMetadataCache.getColumns(session, handle)) { - columnMetadata.add(column.getColumnMetadata()); + columnMetadata.add(column.getColumnMetadata(session, jdbcClient)); } return new ConnectorTableMetadata(handle.getSchemaTableName(), columnMetadata.build()); } @@ -122,7 +127,7 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (JdbcColumnHandle column : jdbcMetadataCache.getColumns(session, jdbcTableHandle)) { - columnHandles.put(column.getColumnMetadata().getName(), column); + columnHandles.put(column.getColumnMetadata(session, jdbcClient).getName(), column); } return columnHandles.build(); } @@ -189,7 +194,7 @@ public Optional finishCreateTable(ConnectorSession sess JdbcOutputTableHandle handle = (JdbcOutputTableHandle) tableHandle; jdbcClient.commitCreateTable(session, JdbcIdentity.from(session), handle); clearRollback(); - return Optional.empty(); + return Optional.of(new JdbcOutputMetadata(url)); } private void setRollback(Runnable action) @@ -267,4 +272,16 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab List columns = columnHandles.stream().map(JdbcColumnHandle.class::cast).collect(Collectors.toList()); return jdbcClient.getTableStatistics(session, handle, columns, constraint.getSummary()); } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return jdbcClient.normalizeIdentifier(session, identifier); + } + + @Override + public Optional getInfo(ConnectorTableLayoutHandle tableHandle) + { + return Optional.of(new JdbcInputInfo(url)); + } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataCache.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataCache.java index 8ef7d92a0a1e2..f98cf11639f96 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataCache.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataCache.java @@ -20,8 +20,7 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.util.concurrent.UncheckedExecutionException; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Objects; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataConfig.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataConfig.java index d2bd555f2c53b..63db7581ca0a1 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataConfig.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataConfig.java @@ -15,11 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataFactory.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataFactory.java index affa47e1cb186..34031cb093373 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataFactory.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcMetadataFactory.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; @@ -22,18 +22,20 @@ public class JdbcMetadataFactory private final JdbcMetadataCache jdbcMetadataCache; private final JdbcClient jdbcClient; private final boolean allowDropTable; + private final TableLocationProvider tableLocationProvider; @Inject - public JdbcMetadataFactory(JdbcMetadataCache jdbcMetadataCache, JdbcClient jdbcClient, JdbcMetadataConfig config) + public JdbcMetadataFactory(JdbcMetadataCache jdbcMetadataCache, JdbcClient jdbcClient, JdbcMetadataConfig config, TableLocationProvider tableLocationProvider) { this.jdbcMetadataCache = requireNonNull(jdbcMetadataCache, "jdbcMetadataCache is null"); this.jdbcClient = requireNonNull(jdbcClient, "jdbcClient is null"); requireNonNull(config, "config is null"); this.allowDropTable = config.isAllowDropTable(); + this.tableLocationProvider = requireNonNull(tableLocationProvider, "tableLocationProvider is null"); } public JdbcMetadata create() { - return new JdbcMetadata(jdbcMetadataCache, jdbcClient, allowDropTable); + return new JdbcMetadata(jdbcMetadataCache, jdbcClient, allowDropTable, tableLocationProvider); } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java index b3b67ec041cd1..7951a1c6658e1 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcModule.java @@ -54,5 +54,7 @@ public void configure(Binder binder) newOptionalBinder(binder, JdbcSessionPropertiesProvider.class); binder.bind(JdbcConnector.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(JdbcMetadataConfig.class); + configBinder(binder).bindConfig(BaseJdbcConfig.class); + binder.bind(TableLocationProvider.class).to(DefaultTableLocationProvider.class).in(Scopes.SINGLETON); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergWrittenPartitions.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputMetadata.java similarity index 65% rename from presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergWrittenPartitions.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputMetadata.java index 0e1917c739ceb..2c8bc0d276be2 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergWrittenPartitions.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputMetadata.java @@ -11,31 +11,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.iceberg; + +package com.facebook.presto.plugin.jdbc; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; - -import java.util.List; import static java.util.Objects.requireNonNull; -public class IcebergWrittenPartitions +public class JdbcOutputMetadata implements ConnectorOutputMetadata { - private final List partitionNames; + private final String tableLocation; @JsonCreator - public IcebergWrittenPartitions(@JsonProperty("partitionNames") List partitionNames) + public JdbcOutputMetadata(@JsonProperty("tableLocation") String tableLocation) { - this.partitionNames = ImmutableList.copyOf(requireNonNull(partitionNames, "partitionNames is null")); + this.tableLocation = requireNonNull(tableLocation, "tableLocation is null"); } @JsonProperty - public List getInfo() + @Override + public String getInfo() { - return partitionNames; + return tableLocation; } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputTableHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputTableHandle.java index 39b18c1778256..4ce7cbc61693e 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputTableHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcOutputTableHandle.java @@ -19,8 +19,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Objects; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSink.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSink.java index 50d6ed03082fe..79f1872c94d27 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSink.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSink.java @@ -16,51 +16,34 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.UuidType; +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.BooleanWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.DoubleWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.LongWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.ObjectWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceWriteFunction; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Shorts; -import com.google.common.primitives.SignedBytes; import io.airlift.slice.Slice; -import org.joda.time.DateTimeZone; import java.sql.Connection; -import java.sql.Date; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.SQLNonTransientException; -import java.sql.Timestamp; -import java.time.Instant; import java.util.Collection; import java.util.List; import java.util.concurrent.CompletableFuture; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.Chars.isCharType; -import static com.facebook.presto.common.type.DateType.DATE; -import static com.facebook.presto.common.type.Decimals.readBigDecimal; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.common.type.IntegerType.INTEGER; -import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.common.type.SmallintType.SMALLINT; -import static com.facebook.presto.common.type.TinyintType.TINYINT; -import static com.facebook.presto.common.type.UuidType.prestoUuidToJavaUuid; -import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; -import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; import static java.util.concurrent.CompletableFuture.completedFuture; -import static java.util.concurrent.TimeUnit.DAYS; -import static org.joda.time.chrono.ISOChronology.getInstanceUTC; public class JdbcPageSink implements ConnectorPageSink @@ -71,6 +54,7 @@ public class JdbcPageSink private final PreparedStatement statement; private final List columnTypes; + private final List columnWriters; private int batchSize; public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, JdbcClient jdbcClient) @@ -92,6 +76,12 @@ public JdbcPageSink(ConnectorSession session, JdbcOutputTableHandle handle, Jdbc } columnTypes = handle.getColumnTypes(); + columnWriters = columnTypes.stream().map(type -> { + WriteFunction writeFunction = jdbcClient.toWriteMapping(session, type).getWriteFunction(); + verify(type.getJavaType() == writeFunction.getJavaType(), + format("Presto type %s is not compatible with write function %s accepting %s", type, writeFunction, writeFunction.getJavaType())); + return writeFunction; + }).collect(toImmutableList()); } @Override @@ -132,55 +122,27 @@ private void appendColumn(Page page, int position, int channel) } Type type = columnTypes.get(channel); - if (BOOLEAN.equals(type)) { - statement.setBoolean(parameter, type.getBoolean(block, position)); + Class javaType = type.getJavaType(); + WriteFunction writeFunction = columnWriters.get(channel); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameter, type.getBoolean(block, position)); } - else if (BIGINT.equals(type)) { - statement.setLong(parameter, type.getLong(block, position)); + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameter, type.getLong(block, position)); } - else if (INTEGER.equals(type)) { - statement.setInt(parameter, toIntExact(type.getLong(block, position))); + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameter, type.getDouble(block, position)); } - else if (SMALLINT.equals(type)) { - statement.setShort(parameter, Shorts.checkedCast(type.getLong(block, position))); - } - else if (TINYINT.equals(type)) { - statement.setByte(parameter, SignedBytes.checkedCast(type.getLong(block, position))); - } - else if (DOUBLE.equals(type)) { - statement.setDouble(parameter, type.getDouble(block, position)); - } - else if (REAL.equals(type)) { - statement.setFloat(parameter, intBitsToFloat(toIntExact(type.getLong(block, position)))); - } - else if (type instanceof DecimalType) { - statement.setBigDecimal(parameter, readBigDecimal((DecimalType) type, block, position)); - } - else if (isVarcharType(type) || isCharType(type)) { - statement.setString(parameter, type.getSlice(block, position).toStringUtf8()); - } - else if (VARBINARY.equals(type)) { - statement.setBytes(parameter, type.getSlice(block, position).getBytes()); - } - else if (DATE.equals(type)) { - // convert to midnight in default time zone - long utcMillis = DAYS.toMillis(type.getLong(block, position)); - long localMillis = getInstanceUTC().getZone().getMillisKeepLocal(DateTimeZone.getDefault(), utcMillis); - statement.setDate(parameter, new Date(localMillis)); - } - else if (type instanceof TimestampType) { - long timestampValue = type.getLong(block, position); - statement.setTimestamp(parameter, - Timestamp.from(Instant.ofEpochSecond( - ((TimestampType) type).getEpochSecond(timestampValue), - ((TimestampType) type).getNanos(timestampValue)))); - } - else if (UuidType.UUID.equals(type)) { - Slice slice = type.getSlice(block, position); - statement.setObject(parameter, prestoUuidToJavaUuid(slice)); + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameter, type.getSlice(block, position)); } else { - throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); + try { + ((ObjectWriteFunction) writeFunction).set(statement, parameter, type.getObject(block, position)); + } + catch (SQLException e) { + throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); + } } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSinkProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSinkProvider.java index b5f8e964d02ab..c0484f57a8718 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSinkProvider.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcPageSinkProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordCursor.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordCursor.java index 2c8f8b2328d8f..e43e0a515fd28 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordCursor.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordCursor.java @@ -15,6 +15,13 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.type.Type; +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.ReadMapping; +import com.facebook.presto.plugin.jdbc.mapping.functions.BooleanReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.DoubleReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.LongReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.ObjectReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceReadFunction; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; @@ -31,7 +38,6 @@ import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class JdbcRecordCursor @@ -44,6 +50,7 @@ public class JdbcRecordCursor private final DoubleReadFunction[] doubleReadFunctions; private final LongReadFunction[] longReadFunctions; private final SliceReadFunction[] sliceReadFunctions; + private final ObjectReadFunction[] objectReadFunctions; private final JdbcClient jdbcClient; private final Connection connection; @@ -61,6 +68,7 @@ public JdbcRecordCursor(JdbcClient jdbcClient, ConnectorSession session, JdbcSpl doubleReadFunctions = new DoubleReadFunction[columnHandles.size()]; longReadFunctions = new LongReadFunction[columnHandles.size()]; sliceReadFunctions = new SliceReadFunction[columnHandles.size()]; + objectReadFunctions = new ObjectReadFunction[columnHandles.size()]; for (int i = 0; i < this.columnHandles.length; i++) { ReadMapping readMapping = jdbcClient.toPrestoType(session, columnHandles.get(i).getJdbcTypeHandle()) @@ -81,7 +89,12 @@ else if (javaType == Slice.class) { sliceReadFunctions[i] = (SliceReadFunction) readFunction; } else { - throw new IllegalStateException(format("Unsupported java type %s", javaType)); + try { + objectReadFunctions[i] = (ObjectReadFunction) readFunction; + } + catch (NullPointerException e) { + throw new UnsupportedOperationException(); + } } } @@ -180,7 +193,13 @@ public Slice getSlice(int field) @Override public Object getObject(int field) { - throw new UnsupportedOperationException(); + checkState(!closed, "cursor is closed"); + try { + return objectReadFunctions[field].readObject(resultSet, field + 1); + } + catch (SQLException | RuntimeException e) { + throw handleSqlException(e); + } } @Override diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordSetProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordSetProvider.java index f304e3fbb0874..17ccfa149d5cb 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordSetProvider.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcRecordSetProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java index 729717481ce3e..f9377aa4837a9 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplit.java @@ -23,8 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplitManager.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplitManager.java index 25900ec568938..0473a63ca37e4 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplitManager.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcSplitManager.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java index c9ec95841b14f..438c49a6141ff 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcTableHandle.java @@ -18,8 +18,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Joiner; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Objects; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcWarningCode.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcWarningCode.java new file mode 100644 index 0000000000000..32cd545adeed7 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/JdbcWarningCode.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.facebook.presto.spi.WarningCode; +import com.facebook.presto.spi.WarningCodeSupplier; + +public enum JdbcWarningCode + implements WarningCodeSupplier +{ + USE_OF_DEPRECATED_CONFIGURATION_PROPERTY(1), + /**/; + private final WarningCode warningCode; + + public static final int WARNING_CODE_MASK = 0x0300_0000; + + JdbcWarningCode(int code) + { + warningCode = new WarningCode(code + WARNING_CODE_MASK, name()); + } + + @Override + public WarningCode toWarningCode() + { + return warningCode; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java index 65c9c703a3d2c..46eee4d510e0f 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/QueryBuilder.java @@ -31,38 +31,39 @@ import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.BooleanWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.DoubleWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.LongWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.ObjectWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceWriteFunction; import com.facebook.presto.plugin.jdbc.optimization.JdbcExpression; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; -import org.joda.time.DateTimeZone; import java.sql.Connection; -import java.sql.Date; import java.sql.PreparedStatement; import java.sql.SQLException; -import java.sql.Time; -import java.sql.Timestamp; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.getWriteMappingForAccumulators; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.Iterables.getOnlyElement; -import static java.lang.Float.intBitsToFloat; import static java.lang.String.format; import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.DAYS; import static java.util.stream.Collectors.joining; -import static org.joda.time.DateTimeZone.UTC; public class QueryBuilder { @@ -142,59 +143,39 @@ public PreparedStatement buildSql( sql.append(" WHERE ") .append(Joiner.on(" AND ").join(clauses)); } - sql.append(String.format("/* %s : %s */", session.getUser(), session.getQueryId())); + sql.append(format("/* %s : %s */", session.getUser(), session.getQueryId())); PreparedStatement statement = client.getPreparedStatement(session, connection, sql.toString()); for (int i = 0; i < accumulator.size(); i++) { TypeAndValue typeAndValue = accumulator.get(i); - if (typeAndValue.getType().equals(BigintType.BIGINT)) { - statement.setLong(i + 1, (long) typeAndValue.getValue()); + int parameterIndex = i + 1; + Type type = typeAndValue.getType(); + WriteFunction writeFunction = getWriteMappingForAccumulators(type) + .orElseThrow(() -> new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName())) + .getWriteFunction(); + Class javaType = type.getJavaType(); + Object value = typeAndValue.getValue(); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); } - else if (typeAndValue.getType().equals(IntegerType.INTEGER)) { - statement.setInt(i + 1, ((Number) typeAndValue.getValue()).intValue()); + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); } - else if (typeAndValue.getType().equals(SmallintType.SMALLINT)) { - statement.setShort(i + 1, ((Number) typeAndValue.getValue()).shortValue()); + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); } - else if (typeAndValue.getType().equals(TinyintType.TINYINT)) { - statement.setByte(i + 1, ((Number) typeAndValue.getValue()).byteValue()); - } - else if (typeAndValue.getType().equals(DoubleType.DOUBLE)) { - statement.setDouble(i + 1, (double) typeAndValue.getValue()); - } - else if (typeAndValue.getType().equals(RealType.REAL)) { - statement.setFloat(i + 1, intBitsToFloat(((Number) typeAndValue.getValue()).intValue())); - } - else if (typeAndValue.getType().equals(BooleanType.BOOLEAN)) { - statement.setBoolean(i + 1, (boolean) typeAndValue.getValue()); - } - else if (typeAndValue.getType().equals(DateType.DATE)) { - long millis = DAYS.toMillis((long) typeAndValue.getValue()); - statement.setDate(i + 1, new Date(UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millis))); - } - else if (typeAndValue.getType().equals(TimeType.TIME)) { - statement.setTime(i + 1, new Time((long) typeAndValue.getValue())); - } - else if (typeAndValue.getType().equals(TimeWithTimeZoneType.TIME_WITH_TIME_ZONE)) { - statement.setTime(i + 1, new Time(unpackMillisUtc((long) typeAndValue.getValue()))); - } - else if (typeAndValue.getType().equals(TimestampType.TIMESTAMP)) { - statement.setTimestamp(i + 1, new Timestamp((long) typeAndValue.getValue())); - } - else if (typeAndValue.getType().equals(TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE)) { - statement.setTimestamp(i + 1, new Timestamp(unpackMillisUtc((long) typeAndValue.getValue()))); - } - else if (typeAndValue.getType() instanceof VarcharType) { - statement.setString(i + 1, ((Slice) typeAndValue.getValue()).toStringUtf8()); - } - else if (typeAndValue.getType() instanceof CharType) { - statement.setString(i + 1, ((Slice) typeAndValue.getValue()).toStringUtf8()); + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); } else { - throw new UnsupportedOperationException("Can't handle type: " + typeAndValue.getType()); + try { + ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + } + catch (SQLException e) { + throw new PrestoException(NOT_SUPPORTED, "Unsupported column type: " + type.getDisplayName()); + } } } - return statement; } @@ -232,7 +213,7 @@ private String addColumns(List columns, Map co .collect(joining(", ")); } - private static boolean isAcceptedType(Type type) + protected boolean isAcceptedType(Type type) { Type validType = requireNonNull(type, "type is null"); return validType.equals(BigintType.BIGINT) || @@ -259,14 +240,14 @@ private List toConjuncts(List columns, TupleDomain accumulator) + private String toPredicate(String columnName, Domain domain, JdbcColumnHandle columnHandle, List accumulator) { checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); @@ -288,10 +269,10 @@ private String toPredicate(String columnName, Domain domain, Type type, List rangeConjuncts = new ArrayList<>(); if (!range.isLowUnbounded()) { - rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), type, accumulator)); + rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), columnHandle, accumulator)); } if (!range.isHighUnbounded()) { - rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), type, accumulator)); + rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), columnHandle, accumulator)); } // If rangeConjuncts is null, then the range was ALL, which should already have been checked for checkState(!rangeConjuncts.isEmpty()); @@ -301,11 +282,11 @@ private String toPredicate(String columnName, Domain domain, Type type, List 1) { for (Object value : singleValues) { - bindValue(value, type, accumulator); + bindValue(value, columnHandle, accumulator); } String values = Joiner.on(",").join(nCopies(singleValues.size(), "?")); disjuncts.add(quote(columnName) + " IN (" + values + ")"); @@ -320,9 +301,9 @@ else if (singleValues.size() > 1) { return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; } - private String toPredicate(String columnName, String operator, Object value, Type type, List accumulator) + private String toPredicate(String columnName, String operator, Object value, JdbcColumnHandle columnHandle, List accumulator) { - bindValue(value, type, accumulator); + bindValue(value, columnHandle, accumulator); return quote(columnName) + " " + operator + " ?"; } @@ -337,9 +318,9 @@ public static String quote(String identifierQuote, String name) return identifierQuote + name + identifierQuote; } - private static void bindValue(Object value, Type type, List accumulator) + private static void bindValue(Object value, JdbcColumnHandle columnHandle, List accumulator) { - checkArgument(isAcceptedType(type), "Can't handle type: %s", type); + Type type = columnHandle.getColumnType(); accumulator.add(new TypeAndValue(type, value)); } } diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/StandardReadMappings.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/StandardReadMappings.java deleted file mode 100644 index 182997e37036c..0000000000000 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/StandardReadMappings.java +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.plugin.jdbc; - -import com.facebook.presto.common.type.CharType; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.VarcharType; -import com.google.common.base.CharMatcher; -import org.joda.time.chrono.ISOChronology; - -import java.sql.ResultSet; -import java.sql.Time; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.Optional; - -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.CharType.createCharType; -import static com.facebook.presto.common.type.DateType.DATE; -import static com.facebook.presto.common.type.DecimalType.createDecimalType; -import static com.facebook.presto.common.type.Decimals.encodeScaledValue; -import static com.facebook.presto.common.type.Decimals.encodeShortScaledValue; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.common.type.IntegerType.INTEGER; -import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.common.type.SmallintType.SMALLINT; -import static com.facebook.presto.common.type.TimeType.TIME; -import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.common.type.TinyintType.TINYINT; -import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; -import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; -import static com.facebook.presto.plugin.jdbc.ReadMapping.longReadMapping; -import static com.facebook.presto.plugin.jdbc.ReadMapping.sliceReadMapping; -import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.slice.Slices.wrappedBuffer; -import static java.lang.Float.floatToRawIntBits; -import static java.lang.Math.max; -import static java.lang.Math.min; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static org.joda.time.DateTimeZone.UTC; - -public final class StandardReadMappings -{ - private StandardReadMappings() {} - - private static final ISOChronology UTC_CHRONOLOGY = ISOChronology.getInstanceUTC(); - - public static ReadMapping booleanReadMapping() - { - return ReadMapping.booleanReadMapping(BOOLEAN, ResultSet::getBoolean); - } - - public static ReadMapping tinyintReadMapping() - { - return longReadMapping(TINYINT, ResultSet::getByte); - } - - public static ReadMapping smallintReadMapping() - { - return longReadMapping(SMALLINT, ResultSet::getShort); - } - - public static ReadMapping integerReadMapping() - { - return longReadMapping(INTEGER, ResultSet::getInt); - } - - public static ReadMapping bigintReadMapping() - { - return longReadMapping(BIGINT, ResultSet::getLong); - } - - public static ReadMapping realReadMapping() - { - return longReadMapping(REAL, (resultSet, columnIndex) -> floatToRawIntBits(resultSet.getFloat(columnIndex))); - } - - public static ReadMapping doubleReadMapping() - { - return ReadMapping.doubleReadMapping(DOUBLE, ResultSet::getDouble); - } - - public static ReadMapping decimalReadMapping(DecimalType decimalType) - { - // JDBC driver can return BigDecimal with lower scale than column's scale when there are trailing zeroes - int scale = decimalType.getScale(); - if (decimalType.isShort()) { - return longReadMapping(decimalType, (resultSet, columnIndex) -> encodeShortScaledValue(resultSet.getBigDecimal(columnIndex), scale)); - } - return sliceReadMapping(decimalType, (resultSet, columnIndex) -> encodeScaledValue(resultSet.getBigDecimal(columnIndex), scale)); - } - - public static ReadMapping charReadMapping(CharType charType) - { - requireNonNull(charType, "charType is null"); - return sliceReadMapping(charType, (resultSet, columnIndex) -> utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex)))); - } - - public static ReadMapping varcharReadMapping(VarcharType varcharType) - { - return sliceReadMapping(varcharType, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex))); - } - - public static ReadMapping varbinaryReadMapping() - { - return sliceReadMapping(VARBINARY, (resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex))); - } - - public static ReadMapping dateReadMapping() - { - return longReadMapping(DATE, (resultSet, columnIndex) -> { - /* - * JDBC returns a date using a timestamp at midnight in the JVM timezone, or earliest time after that if there was no midnight. - * This works correctly for all dates and zones except when the missing local times 'gap' is 24h. I.e. this fails when JVM time - * zone is Pacific/Apia and date to be returned is 2011-12-30. - * - * `return resultSet.getObject(columnIndex, LocalDate.class).toEpochDay()` avoids these problems but - * is currently known not to work with Redshift (old Postgres connector) and SQL Server. - */ - long localMillis = resultSet.getDate(columnIndex).getTime(); - // Convert it to a ~midnight in UTC. - long utcMillis = ISOChronology.getInstance().getZone().getMillisKeepLocal(UTC, localMillis); - // convert to days - return MILLISECONDS.toDays(utcMillis); - }); - } - - public static ReadMapping timeReadMapping() - { - return longReadMapping(TIME, (resultSet, columnIndex) -> { - /* - * TODO `resultSet.getTime(columnIndex)` returns wrong value if JVM's zone had forward offset change during 1970-01-01 - * and the time value being retrieved was not present in local time (a 'gap'), e.g. time retrieved is 00:10:00 and JVM zone is America/Hermosillo - * The problem can be averted by using `resultSet.getObject(columnIndex, LocalTime.class)` -- but this is not universally supported by JDBC drivers. - */ - Time time = resultSet.getTime(columnIndex); - return UTC_CHRONOLOGY.millisOfDay().get(time.getTime()); - }); - } - - public static ReadMapping timestampReadMapping() - { - return longReadMapping(TIMESTAMP, (resultSet, columnIndex) -> { - /* - * TODO `resultSet.getTimestamp(columnIndex)` returns wrong value if JVM's zone had forward offset change and the local time - * corresponding to timestamp value being retrieved was not present (a 'gap'), this includes regular DST changes (e.g. Europe/Warsaw) - * and one-time policy changes (Asia/Kathmandu's shift by 15 minutes on January 1, 1986, 00:00:00). - * The problem can be averted by using `resultSet.getObject(columnIndex, LocalDateTime.class)` -- but this is not universally supported by JDBC drivers. - */ - Timestamp timestamp = resultSet.getTimestamp(columnIndex); - return timestamp.getTime(); - }); - } - - public static Optional jdbcTypeToPrestoType(JdbcTypeHandle type) - { - int columnSize = type.getColumnSize(); - switch (type.getJdbcType()) { - case Types.BIT: - case Types.BOOLEAN: - return Optional.of(booleanReadMapping()); - - case Types.TINYINT: - return Optional.of(tinyintReadMapping()); - - case Types.SMALLINT: - return Optional.of(smallintReadMapping()); - - case Types.INTEGER: - return Optional.of(integerReadMapping()); - - case Types.BIGINT: - return Optional.of(bigintReadMapping()); - - case Types.REAL: - return Optional.of(realReadMapping()); - - case Types.FLOAT: - case Types.DOUBLE: - return Optional.of(doubleReadMapping()); - - case Types.NUMERIC: - case Types.DECIMAL: - int decimalDigits = type.getDecimalDigits(); - int precision = columnSize + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0). - if (precision > Decimals.MAX_PRECISION) { - return Optional.empty(); - } - return Optional.of(decimalReadMapping(createDecimalType(precision, max(decimalDigits, 0)))); - - case Types.CHAR: - case Types.NCHAR: - // TODO this is wrong, we're going to construct malformed Slice representation if source > charLength - int charLength = min(columnSize, CharType.MAX_LENGTH); - return Optional.of(charReadMapping(createCharType(charLength))); - - case Types.VARCHAR: - case Types.NVARCHAR: - case Types.LONGVARCHAR: - case Types.LONGNVARCHAR: - if (columnSize > VarcharType.MAX_LENGTH) { - return Optional.of(varcharReadMapping(createUnboundedVarcharType())); - } - return Optional.of(varcharReadMapping(createVarcharType(columnSize))); - - case Types.BINARY: - case Types.VARBINARY: - case Types.LONGVARBINARY: - return Optional.of(varbinaryReadMapping()); - - case Types.DATE: - return Optional.of(dateReadMapping()); - - case Types.TIME: - return Optional.of(timeReadMapping()); - - case Types.TIMESTAMP: - return Optional.of(timestampReadMapping()); - } - return Optional.empty(); - } -} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/TableLocationProvider.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/TableLocationProvider.java new file mode 100644 index 0000000000000..718889b1518fd --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/TableLocationProvider.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +/** + * Provides table location information for JDBC connectors. + * Different implementations can provide location information based on + * different sources (e.g., connection URL, service discovery, etc.) + */ +public interface TableLocationProvider +{ + /** + * Returns the location/URL for the table. + * This could be a connection URL, service endpoint, or any other + * location identifier relevant to the specific connector. + */ + String getTableLocation(); +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadFunction.java similarity index 94% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadFunction.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadFunction.java index 80a9c78f765e5..6c576da5c5d7a 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadFunction.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadFunction.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping; public interface ReadFunction { diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadMapping.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadMapping.java similarity index 57% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadMapping.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadMapping.java index 4751a84765c06..d9c6d68d59bea 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/ReadMapping.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/ReadMapping.java @@ -11,32 +11,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping; import com.facebook.presto.common.type.Type; +import com.facebook.presto.plugin.jdbc.mapping.functions.BooleanReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.DoubleReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.LongReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.ObjectReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceReadFunction; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.wrappedBuffer; import static java.util.Objects.requireNonNull; +/* + * JDBC based connectors can control how data should be read from a ResultSet via ReadMapping definitions. + */ public final class ReadMapping { - public static ReadMapping booleanReadMapping(Type prestoType, BooleanReadFunction readFunction) + public static ReadMapping createBooleanReadMapping(Type prestoType, BooleanReadFunction readFunction) { return new ReadMapping(prestoType, readFunction); } - public static ReadMapping longReadMapping(Type prestoType, LongReadFunction readFunction) + public static ReadMapping createLongReadMapping(Type prestoType, LongReadFunction readFunction) { return new ReadMapping(prestoType, readFunction); } - public static ReadMapping doubleReadMapping(Type prestoType, DoubleReadFunction readFunction) + public static ReadMapping createDoubleReadMapping(Type prestoType, DoubleReadFunction readFunction) { return new ReadMapping(prestoType, readFunction); } - public static ReadMapping sliceReadMapping(Type prestoType, SliceReadFunction readFunction) + public static ReadMapping createSliceReadMapping(Type prestoType, SliceReadFunction readFunction) + { + return new ReadMapping(prestoType, readFunction); + } + + public static ReadMapping createObjectReadMapping(Type prestoType, ObjectReadFunction readFunction) { return new ReadMapping(prestoType, readFunction); } @@ -50,7 +65,7 @@ private ReadMapping(Type type, ReadFunction readFunction) this.readFunction = requireNonNull(readFunction, "readFunction is null"); checkArgument( type.getJavaType() == readFunction.getJavaType(), - "Presto type %s is not compatible with read function %s returning %s", + "Presto type %s is not compatible with read function %s using %s", type, readFunction, readFunction.getJavaType()); @@ -66,6 +81,11 @@ public ReadFunction getReadFunction() return readFunction; } + public static ReadMapping varbinaryReadMapping() + { + return createSliceReadMapping(VARBINARY, (resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex))); + } + @Override public String toString() { diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/StandardColumnMappings.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/StandardColumnMappings.java new file mode 100644 index 0000000000000..258f763614e35 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/StandardColumnMappings.java @@ -0,0 +1,488 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping; + +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.UuidType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.google.common.base.CharMatcher; +import com.google.common.primitives.Shorts; +import com.google.common.primitives.SignedBytes; +import org.joda.time.DateTimeZone; +import org.joda.time.chrono.ISOChronology; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import java.sql.Date; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.Calendar; +import java.util.Optional; +import java.util.TimeZone; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.Decimals.decodeUnscaledValue; +import static com.facebook.presto.common.type.Decimals.encodeScaledValue; +import static com.facebook.presto.common.type.Decimals.encodeShortScaledValue; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.common.type.UuidType.UUID; +import static com.facebook.presto.common.type.UuidType.prestoUuidToJavaUuid; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.plugin.jdbc.GeometryUtils.getAsText; +import static com.facebook.presto.plugin.jdbc.GeometryUtils.stGeomFromBinary; +import static com.facebook.presto.plugin.jdbc.mapping.ReadMapping.createBooleanReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.ReadMapping.createDoubleReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.ReadMapping.createLongReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.ReadMapping.createSliceReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.WriteMapping.createBooleanWriteMapping; +import static com.facebook.presto.plugin.jdbc.mapping.WriteMapping.createDoubleWriteMapping; +import static com.facebook.presto.plugin.jdbc.mapping.WriteMapping.createLongWriteMapping; +import static com.facebook.presto.plugin.jdbc.mapping.WriteMapping.createSliceWriteMapping; +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.DAYS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.joda.time.DateTimeZone.UTC; + +public final class StandardColumnMappings +{ + private StandardColumnMappings() {} + + private static final ISOChronology UTC_CHRONOLOGY = ISOChronology.getInstanceUTC(); + private static final Calendar UTC_CALENDAR = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + + public static ReadMapping booleanReadMapping() + { + return createBooleanReadMapping(BOOLEAN, ResultSet::getBoolean); + } + + public static WriteMapping booleanWriteMapping() + { + return createBooleanWriteMapping(PreparedStatement::setBoolean); + } + + public static ReadMapping tinyintReadMapping() + { + return createLongReadMapping(TINYINT, ResultSet::getByte); + } + + public static WriteMapping tinyintWriteMapping() + { + return createLongWriteMapping(((statement, index, value) -> statement.setByte(index, SignedBytes.checkedCast(value)))); + } + + public static ReadMapping smallintReadMapping() + { + return createLongReadMapping(SMALLINT, ResultSet::getShort); + } + + public static WriteMapping smallintWriteMapping() + { + return createLongWriteMapping(((statement, index, value) -> statement.setShort(index, Shorts.checkedCast(value)))); + } + + public static ReadMapping integerReadMapping() + { + return createLongReadMapping(INTEGER, ResultSet::getInt); + } + + public static WriteMapping integerWriteMapping() + { + return createLongWriteMapping((((statement, index, value) -> statement.setInt(index, toIntExact(value))))); + } + + public static ReadMapping bigintReadMapping() + { + return createLongReadMapping(BIGINT, ResultSet::getLong); + } + + public static WriteMapping bigintWriteMapping() + { + return createLongWriteMapping(PreparedStatement::setLong); + } + + public static ReadMapping realReadMapping() + { + return createLongReadMapping(REAL, (resultSet, columnIndex) -> floatToRawIntBits(resultSet.getFloat(columnIndex))); + } + public static WriteMapping realWriteMapping() + { + return createLongWriteMapping((statement, index, value) -> statement.setFloat(index, intBitsToFloat(toIntExact(value)))); + } + + public static ReadMapping doubleReadMapping() + { + return createDoubleReadMapping(DOUBLE, ResultSet::getDouble); + } + + public static WriteMapping doubleWriteMapping() + { + return createDoubleWriteMapping(PreparedStatement::setDouble); + } + + public static ReadMapping decimalReadMapping(DecimalType decimalType) + { + // JDBC driver can return BigDecimal with lower scale than column's scale when there are trailing zeroes + int scale = decimalType.getScale(); + if (decimalType.isShort()) { + return createLongReadMapping(decimalType, (resultSet, columnIndex) -> encodeShortScaledValue(resultSet.getBigDecimal(columnIndex), scale)); + } + return createSliceReadMapping(decimalType, (resultSet, columnIndex) -> encodeScaledValue(resultSet.getBigDecimal(columnIndex), scale)); + } + + public static WriteMapping decimalWriteMapping(DecimalType decimalType) + { + // JDBC driver can return BigDecimal with lower scale than column's scale when there are trailing zeroes + int scale = decimalType.getScale(); + if (decimalType.isShort()) { + return createLongWriteMapping(((statement, index, value) -> { + BigInteger unscaledValue = BigInteger.valueOf(value); + BigDecimal bigDecimal = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + statement.setBigDecimal(index, bigDecimal); + })); + } + return createSliceWriteMapping(((statement, index, value) -> { + BigInteger unscaledValue = decodeUnscaledValue(value); + BigDecimal bigDecimal = new BigDecimal(unscaledValue, decimalType.getScale(), new MathContext(decimalType.getPrecision())); + statement.setBigDecimal(index, bigDecimal); + })); + } + + public static ReadMapping charReadMapping(CharType charType) + { + requireNonNull(charType, "charType is null"); + return createSliceReadMapping(charType, (resultSet, columnIndex) -> utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex)))); + } + + public static WriteMapping charWriteMapping() + { + return createSliceWriteMapping(((statement, index, value) -> statement.setString(index, value.toStringUtf8()))); + } + + public static ReadMapping varcharReadMapping(VarcharType varcharType) + { + return createSliceReadMapping(varcharType, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex))); + } + + public static ReadMapping varbinaryReadMapping() + { + return createSliceReadMapping(VARBINARY, (resultSet, columnIndex) -> wrappedBuffer(resultSet.getBytes(columnIndex))); + } + + public static WriteMapping varbinaryWriteMapping() + { + return createSliceWriteMapping(((statement, index, value) -> statement.setBytes(index, value.getBytes()))); + } + + public static ReadMapping dateReadMapping() + { + return createLongReadMapping(DATE, (resultSet, columnIndex) -> { + /* + * JDBC returns a date using a timestamp at midnight in the JVM timezone, or earliest time after that if there was no midnight. + * This works correctly for all dates and zones except when the missing local times 'gap' is 24h. I.e. this fails when JVM time + * zone is Pacific/Apia and date to be returned is 2011-12-30. + * + * `return resultSet.getObject(columnIndex, LocalDate.class).toEpochDay()` avoids these problems but + * is currently known not to work with Redshift (old Postgres connector) and SQL Server. + */ + long localMillis = resultSet.getDate(columnIndex).getTime(); + // Convert it to a ~midnight in UTC. + long utcMillis = ISOChronology.getInstance().getZone().getMillisKeepLocal(UTC, localMillis); + // convert to days + return MILLISECONDS.toDays(utcMillis); + }); + } + public static WriteMapping dateWriteMapping() + { + return createLongWriteMapping(((statement, index, value) -> statement.setDate(index, new Date(UTC.getMillisKeepLocal + (DateTimeZone.getDefault(), DAYS.toMillis(value)))))); + } + + public static ReadMapping timeReadMapping() + { + return createLongReadMapping(TIME, (resultSet, columnIndex) -> { + /* + * TODO `resultSet.getTime(columnIndex)` returns wrong value if JVM's zone had forward offset change during 1970-01-01 + * and the time value being retrieved was not present in local time (a 'gap'), e.g. time retrieved is 00:10:00 and JVM zone is America/Hermosillo + * The problem can be averted by using `resultSet.getObject(columnIndex, LocalTime.class)` -- but this is not universally supported by JDBC drivers. + */ + Time time = resultSet.getTime(columnIndex); + return UTC_CHRONOLOGY.millisOfDay().get(time.getTime()); + }); + } + + public static WriteMapping timeWriteMapping() + { + return createLongWriteMapping(((statement, index, value) -> statement.setTime(index, new Time(value)))); + } + + public static ReadMapping timestampReadMapping() + { + return createLongReadMapping(TIMESTAMP, (resultSet, columnIndex) -> { + Timestamp timestamp = resultSet.getTimestamp(columnIndex, UTC_CALENDAR); + return timestamp.getTime(); + }); + } + + @Deprecated + public static ReadMapping timestampReadMappingLegacy() + { + return createLongReadMapping(TIMESTAMP, (resultSet, columnIndex) -> { + Timestamp timestamp = resultSet.getTimestamp(columnIndex); + return timestamp.getTime(); + }); + } + + public static WriteMapping timestampWriteMapping(TimestampType timestampType) + { + return createLongWriteMapping((statement, index, value) -> { + statement.setTimestamp(index, Timestamp.from(Instant.ofEpochSecond( + timestampType.getEpochSecond(value), + timestampType.getNanos(value))), UTC_CALENDAR); + }); + } + + @Deprecated + public static WriteMapping timestampWriteMappingLegacy(TimestampType timestampType) + { + return createLongWriteMapping((statement, index, value) -> { + statement.setTimestamp(index, Timestamp.from(Instant.ofEpochSecond( + timestampType.getEpochSecond(value), + timestampType.getNanos(value)))); + }); + } + public static WriteMapping uuidWriteMapping() + { + return createSliceWriteMapping(((statement, index, value) -> statement.setObject(index, prestoUuidToJavaUuid(value)))); + } + + public static WriteMapping timeWithTimeZoneWriteMapping() + { + return createLongWriteMapping((((statement, index, value) -> statement.setTime(index, new Time(unpackMillisUtc(value)))))); + } + + public static WriteMapping timestampWithTimeZoneWriteMapping() + { + return createLongWriteMapping(((statement, index, value) -> statement.setTimestamp(index, new Timestamp(unpackMillisUtc(value))))); + } + + public static Optional jdbcTypeToReadMapping(JdbcTypeHandle type) + { + int columnSize = type.getColumnSize(); + switch (type.getJdbcType()) { + case Types.BIT: + case Types.BOOLEAN: + return Optional.of(booleanReadMapping()); + + case Types.TINYINT: + return Optional.of(tinyintReadMapping()); + + case Types.SMALLINT: + return Optional.of(smallintReadMapping()); + + case Types.INTEGER: + return Optional.of(integerReadMapping()); + + case Types.BIGINT: + return Optional.of(bigintReadMapping()); + + case Types.REAL: + return Optional.of(realReadMapping()); + + case Types.FLOAT: + case Types.DOUBLE: + return Optional.of(doubleReadMapping()); + + case Types.NUMERIC: + case Types.DECIMAL: + int decimalDigits = type.getDecimalDigits(); + int precision = columnSize + max(-decimalDigits, 0); // Map decimal(p, -s) (negative scale) to decimal(p+s, 0). + if (precision > Decimals.MAX_PRECISION) { + return Optional.empty(); + } + return Optional.of(decimalReadMapping(createDecimalType(precision, max(decimalDigits, 0)))); + + case Types.CHAR: + case Types.NCHAR: + // TODO this is wrong, we're going to construct malformed Slice representation if source > charLength + int charLength = min(columnSize, CharType.MAX_LENGTH); + return Optional.of(charReadMapping(createCharType(charLength))); + + case Types.VARCHAR: + case Types.NVARCHAR: + case Types.LONGVARCHAR: + case Types.LONGNVARCHAR: + if (columnSize > VarcharType.MAX_LENGTH) { + return Optional.of(varcharReadMapping(createUnboundedVarcharType())); + } + return Optional.of(varcharReadMapping(createVarcharType(columnSize))); + + case Types.BINARY: + case Types.VARBINARY: + case Types.LONGVARBINARY: + return Optional.of(varbinaryReadMapping()); + + case Types.DATE: + return Optional.of(dateReadMapping()); + case Types.TIME: + return Optional.of(timeReadMapping()); + case Types.TIMESTAMP: + return Optional.of(timestampReadMapping()); + } + return Optional.empty(); + } + + public static Optional prestoTypeToWriteMapping(ConnectorSession session, Type type) + { + if (type.equals(BOOLEAN)) { + return Optional.of(booleanWriteMapping()); + } + else if (type.equals(TINYINT)) { + return Optional.of(tinyintWriteMapping()); + } + else if (type.equals(SMALLINT)) { + return Optional.of(smallintWriteMapping()); + } + else if (type.equals(BIGINT)) { + return Optional.of(bigintWriteMapping()); + } + else if (type.equals(DOUBLE)) { + return Optional.of(doubleWriteMapping()); + } + else if (type.equals(INTEGER)) { + return Optional.of(integerWriteMapping()); + } + else if (type.equals(REAL)) { + return Optional.of(realWriteMapping()); + } + else if (type instanceof DecimalType) { + return Optional.of(decimalWriteMapping((DecimalType) type)); + } + else if (type instanceof CharType || type instanceof VarcharType) { + return Optional.of(charWriteMapping()); + } + else if (type.equals(VARBINARY)) { + return Optional.of(varbinaryWriteMapping()); + } + else if (type instanceof DateType) { + return Optional.of(dateWriteMapping()); + } + else if (type instanceof TimestampType) { + boolean legacyTimestamp = session.getSqlFunctionProperties().isLegacyTimestamp(); + return Optional.of(legacyTimestamp ? timestampWriteMappingLegacy((TimestampType) type) : timestampWriteMapping((TimestampType) type)); + } + else if (type.equals(TIME)) { + return Optional.of(timeWriteMapping()); + } + else if (type.equals(TIME_WITH_TIME_ZONE)) { + return Optional.of(timeWithTimeZoneWriteMapping()); + } + else if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return Optional.of(timestampWithTimeZoneWriteMapping()); + } + else if (type.equals(UUID)) { + return Optional.of(uuidWriteMapping()); + } + return Optional.empty(); + } + + public static Optional getWriteMappingForAccumulators(Type type) + { + if (type.equals(BOOLEAN)) { + return Optional.of(booleanWriteMapping()); + } + else if (type.equals(TINYINT)) { + return Optional.of(tinyintWriteMapping()); + } + else if (type.equals(SMALLINT)) { + return Optional.of(smallintWriteMapping()); + } + else if (type.equals(INTEGER)) { + return Optional.of(integerWriteMapping()); + } + else if (type.equals(BIGINT)) { + return Optional.of(bigintWriteMapping()); + } + else if (type.equals(REAL)) { + return Optional.of(realWriteMapping()); + } + else if (type.equals(DOUBLE)) { + return Optional.of(doubleWriteMapping()); + } + else if (type instanceof CharType || type instanceof VarcharType) { + return Optional.of(charWriteMapping()); + } + else if (type instanceof DecimalType) { + return Optional.of(decimalWriteMapping((DecimalType) type)); + } + else if (type.equals(DateType.DATE)) { + return Optional.of(dateWriteMapping()); + } + else if (type.equals(TIME)) { + return Optional.of(timeWriteMapping()); + } + else if (type.equals(TIMESTAMP)) { + return Optional.of(timestampWriteMapping((TimestampType) type)); + } + else if (type.equals(TIME_WITH_TIME_ZONE)) { + return Optional.of(timeWithTimeZoneWriteMapping()); + } + else if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return Optional.of(timestampWithTimeZoneWriteMapping()); + } + else if (type instanceof UuidType) { + return Optional.of(uuidWriteMapping()); + } + return Optional.empty(); + } + public static ReadMapping geometryReadMapping() + { + return createSliceReadMapping(VARCHAR, + (resultSet, columnIndex) -> getAsText(stGeomFromBinary(wrappedBuffer(resultSet.getBytes(columnIndex))))); + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteFunction.java new file mode 100644 index 0000000000000..48ae87cff0be4 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteFunction.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping; + +public interface WriteFunction +{ + Class getJavaType(); +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteMapping.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteMapping.java new file mode 100644 index 0000000000000..afe56fd45d214 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/WriteMapping.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping; + +import com.facebook.presto.plugin.jdbc.mapping.functions.BooleanWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.DoubleWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.LongWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.ObjectWriteFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceWriteFunction; + +import static java.util.Objects.requireNonNull; + +/* + * JDBC based connectors can control to define how data should be written back to data source by WriteFunctions. + */ +public final class WriteMapping +{ + private final WriteFunction writeFunction; + + private WriteMapping(WriteFunction writeFunction) + { + this.writeFunction = requireNonNull(writeFunction, "writeFunction is null"); + } + + public static WriteMapping createBooleanWriteMapping(BooleanWriteFunction writeFunction) + { + return new WriteMapping(writeFunction); + } + + public static WriteMapping createLongWriteMapping(LongWriteFunction writeFunction) + { + return new WriteMapping(writeFunction); + } + + public static WriteMapping createDoubleWriteMapping(DoubleWriteFunction writeFunction) + { + return new WriteMapping(writeFunction); + } + + public static WriteMapping createSliceWriteMapping(SliceWriteFunction writeFunction) + { + return new WriteMapping(writeFunction); + } + + public static WriteMapping createObjectWriteMapping(Class javaType, ObjectWriteFunction.ObjectWriteFunctionImplementation writeFunctionImplementation) + { + return createObjectWriteMapping(ObjectWriteFunction.of(javaType, writeFunctionImplementation)); + } + + public static WriteMapping createObjectWriteMapping(ObjectWriteFunction writeFunction) + { + return new WriteMapping(writeFunction); + } + + public WriteFunction getWriteFunction() + { + return writeFunction; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BooleanReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanReadFunction.java similarity index 88% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BooleanReadFunction.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanReadFunction.java index 3ef463fcc3bd6..288358aa9f547 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/BooleanReadFunction.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanReadFunction.java @@ -11,7 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; import java.sql.ResultSet; import java.sql.SQLException; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanWriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanWriteFunction.java new file mode 100644 index 0000000000000..161d3e22ba04d --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/BooleanWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface BooleanWriteFunction + extends WriteFunction +{ + default Class getJavaType() + { + return boolean.class; + } + + void set(PreparedStatement statement, int index, boolean value) throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DoubleReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleReadFunction.java similarity index 88% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DoubleReadFunction.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleReadFunction.java index 223cbe2f6ef5d..281d319821e80 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/DoubleReadFunction.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleReadFunction.java @@ -11,7 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; import java.sql.ResultSet; import java.sql.SQLException; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleWriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleWriteFunction.java new file mode 100644 index 0000000000000..11b17da63ce36 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/DoubleWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface DoubleWriteFunction + extends WriteFunction +{ + default Class getJavaType() + { + return double.class; + } + + void set(PreparedStatement statement, int index, double value) throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/LongReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongReadFunction.java similarity index 87% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/LongReadFunction.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongReadFunction.java index f126ce9ce2dcd..042e0443d20e5 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/LongReadFunction.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongReadFunction.java @@ -11,7 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; import java.sql.ResultSet; import java.sql.SQLException; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongWriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongWriteFunction.java new file mode 100644 index 0000000000000..5bc7c1707ba30 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/LongWriteFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface LongWriteFunction + extends WriteFunction +{ + default Class getJavaType() + { + return long.class; + } + + void set(PreparedStatement statement, int index, long value) throws SQLException; +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectReadFunction.java new file mode 100644 index 0000000000000..f7a7149d948b3 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectReadFunction.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; + +import java.sql.ResultSet; +import java.sql.SQLException; + +import static java.util.Objects.requireNonNull; + +public interface ObjectReadFunction + extends ReadFunction +{ + @Override + Class getJavaType(); + + Object readObject(ResultSet resultSet, int columnIndex) throws SQLException; + + static ObjectReadFunction of(Class javaType, ObjectReadFunctionImplementation implementation) + { + requireNonNull(javaType, "javaType is null"); + requireNonNull(implementation, "object read implementation is null"); + return new ObjectReadFunction() { + @Override + public Class getJavaType() + { + return javaType; + } + + @Override + public Object readObject(ResultSet resultSet, int columnIndex) throws SQLException + { + return implementation.read(resultSet, columnIndex); + } + }; + } + + @FunctionalInterface + interface ObjectReadFunctionImplementation + { + T read(ResultSet resultSet, int columnIndex) throws SQLException; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectWriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectWriteFunction.java new file mode 100644 index 0000000000000..071392c9797be --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/ObjectWriteFunction.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +import static java.util.Objects.requireNonNull; + +public interface ObjectWriteFunction + extends WriteFunction +{ + @Override + Class getJavaType(); + + void set(PreparedStatement statement, int index, Object value) throws SQLException; + + static ObjectWriteFunction of(Class javaType, ObjectWriteFunctionImplementation implementation) + { + requireNonNull(javaType, "javaType is null"); + requireNonNull(implementation, "implementation is null"); + + return new ObjectWriteFunction() + { + @Override + public Class getJavaType() + { + return javaType; + } + + @Override + @SuppressWarnings("unchecked") + public void set(PreparedStatement statement, int index, Object value) + throws SQLException + { + implementation.set(statement, index, (T) value); + } + }; + } + + @FunctionalInterface + interface ObjectWriteFunctionImplementation + { + void set(PreparedStatement statement, int index, T value) throws SQLException; + } +} diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/SliceReadFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceReadFunction.java similarity index 88% rename from presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/SliceReadFunction.java rename to presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceReadFunction.java index a71f442ab8972..1e8d7ff107774 100644 --- a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/SliceReadFunction.java +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceReadFunction.java @@ -11,8 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.plugin.jdbc; +package com.facebook.presto.plugin.jdbc.mapping.functions; +import com.facebook.presto.plugin.jdbc.mapping.ReadFunction; import io.airlift.slice.Slice; import java.sql.ResultSet; diff --git a/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceWriteFunction.java b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceWriteFunction.java new file mode 100644 index 0000000000000..088247f261104 --- /dev/null +++ b/presto-base-jdbc/src/main/java/com/facebook/presto/plugin/jdbc/mapping/functions/SliceWriteFunction.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc.mapping.functions; + +import com.facebook.presto.plugin.jdbc.mapping.WriteFunction; +import io.airlift.slice.Slice; + +import java.sql.PreparedStatement; +import java.sql.SQLException; + +public interface SliceWriteFunction + extends WriteFunction +{ + default Class getJavaType() + { + return Slice.class; + } + + void set(PreparedStatement statement, int index, Slice value) throws SQLException; +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestJdbcConfig.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestJdbcConfig.java new file mode 100644 index 0000000000000..7132404ea8280 --- /dev/null +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestJdbcConfig.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.facebook.airlift.configuration.Config; +import jakarta.validation.constraints.NotNull; + +/** + * Simple JDBC configuration that doesn't extend BaseJdbcConfig. + * This demonstrates that JDBC connectors can implement their own + * configuration mechanisms without depending on BaseJdbcConfig. + */ +public class SimpleTestJdbcConfig +{ + private String jdbcUrl; + private String username; + private String password; + + @NotNull + public String getJdbcUrl() + { + return jdbcUrl; + } + + @Config("jdbc-url") + public SimpleTestJdbcConfig setJdbcUrl(String jdbcUrl) + { + this.jdbcUrl = jdbcUrl; + return this; + } + + public String getUsername() + { + return username; + } + + @Config("username") + public SimpleTestJdbcConfig setUsername(String username) + { + this.username = username; + return this; + } + + public String getPassword() + { + return password; + } + + @Config("password") + public SimpleTestJdbcConfig setPassword(String password) + { + this.password = password; + return this; + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestTableLocationProvider.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestTableLocationProvider.java new file mode 100644 index 0000000000000..c9af08e1eb385 --- /dev/null +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/SimpleTestTableLocationProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +/** + * Simple TableLocationProvider implementation that doesn't depend on BaseJdbcConfig. + * This demonstrates that TableLocationProvider implementations can use their own + * configuration mechanisms without depending on BaseJdbcConfig. + */ +public class SimpleTestTableLocationProvider + implements TableLocationProvider +{ + private final String jdbcUrl; + + @Inject + public SimpleTestTableLocationProvider(SimpleTestJdbcConfig config) + { + requireNonNull(config, "config is null"); + this.jdbcUrl = config.getJdbcUrl(); + } + + @Override + public String getTableLocation() + { + return jdbcUrl; + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestBaseJdbcConfig.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestBaseJdbcConfig.java index 2c3d151bb19f6..0f95e7f41c614 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestBaseJdbcConfig.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestBaseJdbcConfig.java @@ -14,14 +14,17 @@ package com.facebook.presto.plugin.jdbc; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; +import com.google.inject.ConfigurationException; import org.testng.annotations.Test; import java.util.Map; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.expectThrows; public class TestBaseJdbcConfig { @@ -36,7 +39,8 @@ public void testDefaults() .setPasswordCredentialName(null) .setCaseInsensitiveNameMatching(false) .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, MINUTES)) - .setlistSchemasIgnoredSchemas("information_schema")); + .setlistSchemasIgnoredSchemas("information_schema") + .setCaseSensitiveNameMatching(false)); } @Test @@ -51,6 +55,7 @@ public void testExplicitPropertyMappings() .put("case-insensitive-name-matching", "true") .put("case-insensitive-name-matching.cache-ttl", "1s") .put("list-schemas-ignored-schemas", "test,test2") + .put("case-sensitive-name-matching", "true") .build(); BaseJdbcConfig expected = new BaseJdbcConfig() @@ -61,8 +66,74 @@ public void testExplicitPropertyMappings() .setPasswordCredentialName("bar") .setCaseInsensitiveNameMatching(true) .setlistSchemasIgnoredSchemas("test,test2") - .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, SECONDS)); + .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, SECONDS)) + .setCaseSensitiveNameMatching(true); ConfigAssertions.assertFullMapping(properties, expected); } + + @Test + public void testValidConfigValidation() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl("jdbc:mysql://localhost:3306/test"); + + // Should not throw any exception + config.validateConfig(); + + assertEquals(config.getConnectionUrl(), "jdbc:mysql://localhost:3306/test"); + } + + @Test + public void testNullConnectionUrlValidation() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + // connectionUrl is null by default + + ConfigurationException exception = expectThrows( + ConfigurationException.class, + config::validateConfig); + assertEquals(exception.getErrorMessages().iterator().next().getMessage(), + "connection-url is required but was not provided"); + } + + @Test + public void testMutuallyExclusiveNameMatchingOptions() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl("jdbc:mysql://localhost:3306/test"); + config.setCaseInsensitiveNameMatching(true); + config.setCaseSensitiveNameMatching(true); + + ConfigurationException exception = expectThrows( + ConfigurationException.class, + config::validateConfig); + assertEquals(exception.getErrorMessages().iterator().next().getMessage(), + "Only one of 'case-insensitive-name-matching=true' or 'case-sensitive-name-matching=true' can be set. " + + "These options are mutually exclusive."); + } + + @Test + public void testCaseInsensitiveNameMatchingOnly() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl("jdbc:mysql://localhost:3306/test"); + config.setCaseInsensitiveNameMatching(true); + config.setCaseSensitiveNameMatching(false); + + // Should not throw any exception + config.validateConfig(); + } + + @Test + public void testCaseSensitiveNameMatchingOnly() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl("jdbc:mysql://localhost:3306/test"); + config.setCaseInsensitiveNameMatching(false); + config.setCaseSensitiveNameMatching(true); + + // Should not throw any exception + config.validateConfig(); + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestDefaultTableLocationProvider.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestDefaultTableLocationProvider.java new file mode 100644 index 0000000000000..85874cc1eede6 --- /dev/null +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestDefaultTableLocationProvider.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.expectThrows; + +public class TestDefaultTableLocationProvider +{ + @Test + public void testValidConnectionUrl() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl("jdbc:mysql://localhost:3306/test"); + + DefaultTableLocationProvider provider = new DefaultTableLocationProvider(config); + assertEquals(provider.getTableLocation(), "jdbc:mysql://localhost:3306/test"); + } + + @Test + public void testNullBaseJdbcConfig() + { + NullPointerException exception = expectThrows( + NullPointerException.class, + () -> new DefaultTableLocationProvider(null)); + assertEquals(exception.getMessage(), "baseJdbcConfig is null"); + } + + @Test + public void testNullConnectionUrl() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + // connectionUrl is null by default + + NullPointerException exception = expectThrows( + NullPointerException.class, + () -> new DefaultTableLocationProvider(config)); + assertEquals(exception.getMessage(), "connection-url is null"); + } + + @Test + public void testEmptyConnectionUrl() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl(""); + + DefaultTableLocationProvider provider = new DefaultTableLocationProvider(config); + assertEquals(provider.getTableLocation(), ""); + } + + @Test + public void testWhitespaceConnectionUrl() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + config.setConnectionUrl(" "); + + DefaultTableLocationProvider provider = new DefaultTableLocationProvider(config); + assertEquals(provider.getTableLocation(), " "); + } + + @Test + public void testComplexConnectionUrl() + { + BaseJdbcConfig config = new BaseJdbcConfig(); + String complexUrl = "jdbc:mysql://user:password@host:3306/database?useSSL=true&serverTimezone=UTC"; + config.setConnectionUrl(complexUrl); + + DefaultTableLocationProvider provider = new DefaultTableLocationProvider(config); + assertEquals(provider.getTableLocation(), complexUrl); + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java index 2d2962742a6e3..8b73ea66dc137 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcDistributedQueries.java @@ -13,11 +13,14 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.Session; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueries; import io.airlift.tpch.TpchTable; +import org.testng.annotations.Test; import static com.facebook.presto.plugin.jdbc.JdbcQueryRunner.createJdbcQueryRunner; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; public class TestJdbcDistributedQueries extends AbstractTestQueries @@ -33,4 +36,15 @@ protected QueryRunner createQueryRunner() public void testLargeIn() { } + + @Test + public void testNativeQueryParameters() + { + Session session = testSessionBuilder() + .addPreparedStatement("my_query_simple", "SELECT * FROM TABLE(system.query(query => ?))") + .addPreparedStatement("my_query", "SELECT * FROM TABLE(system.query(query => format('SELECT %s FROM %s', ?, ?)))") + .build(); + assertQueryFails(session, "EXECUTE my_query_simple USING 'SELECT 1 a'", "line 1:21: Table function system.query not registered"); + assertQueryFails(session, "EXECUTE my_query USING 'a', '(SELECT 2 a) t'", "line 1:21: Table function system.query not registered"); + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java index aa75c5ccc2345..58b3f9c4f3d1d 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadata.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; @@ -66,7 +67,11 @@ public void setUp() database = new TestingDatabase(); ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("test-%s"))); jdbcMetadataCache = new JdbcMetadataCache(executor, database.getJdbcClient(), new JdbcMetadataCacheStats(), OptionalLong.of(0), OptionalLong.of(0), 100); - metadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), false); + + BaseJdbcConfig baseConfig = new BaseJdbcConfig(); + baseConfig.setConnectionUrl("jdbc:h2:mem:test"); + + metadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), false, new DefaultTableLocationProvider(baseConfig)); tableHandle = metadata.getTableHandle(SESSION, new SchemaTableName("example", "numbers")); } @@ -261,7 +266,10 @@ public void testDropTableTable() assertEquals(e.getErrorCode(), PERMISSION_DENIED.toErrorCode()); } - metadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), true); + // Create BaseJdbcConfig with connection URL for drop table test + BaseJdbcConfig dropConfig = new BaseJdbcConfig(); + dropConfig.setConnectionUrl("jdbc:h2:mem:test"); + metadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), true, new DefaultTableLocationProvider(dropConfig)); metadata.dropTable(SESSION, tableHandle); try { @@ -272,4 +280,46 @@ public void testDropTableTable() assertEquals(e.getErrorCode(), NOT_FOUND.toErrorCode()); } } + + @Test + public void testCustomTableLocationProvider() + { + // Create a custom TableLocationProvider that returns a specific location + TableLocationProvider customProvider = new TableLocationProvider() + { + @Override + public String getTableLocation() + { + return "custom://test-location:8080/database"; + } + }; + + // Create JdbcMetadata with the custom provider + JdbcMetadata customMetadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), false, customProvider); + + // Verify that the metadata can be created and basic operations work + JdbcTableHandle customTableHandle = customMetadata.getTableHandle(SESSION, new SchemaTableName("example", "numbers")); + assertEquals(customTableHandle, tableHandle); + + // Verify table metadata can be retrieved + ConnectorTableMetadata tableMetadata = customMetadata.getTableMetadata(SESSION, customTableHandle); + assertEquals(tableMetadata.getTable(), new SchemaTableName("example", "numbers")); + assertEquals(tableMetadata.getColumns().size(), 3); + + // Verify column handles work + Map columnHandles = customMetadata.getColumnHandles(SESSION, customTableHandle); + assertEquals(columnHandles.size(), 3); + assertTrue(columnHandles.containsKey("text")); + assertTrue(columnHandles.containsKey("text_short")); + assertTrue(columnHandles.containsKey("value")); + + // Verify schema listing works + assertTrue(customMetadata.listSchemaNames(SESSION).containsAll(ImmutableSet.of("example", "tpch"))); + + // Verify table listing works + assertEquals(ImmutableSet.copyOf(customMetadata.listTables(SESSION, Optional.of("example"))), ImmutableSet.of( + new SchemaTableName("example", "numbers"), + new SchemaTableName("example", "view_source"), + new SchemaTableName("example", "view"))); + } } diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadataConfig.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadataConfig.java index 6477239814c8a..24e066d92e2ec 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadataConfig.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcMetadataConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcQueryBuilder.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcQueryBuilder.java index 10b22e9d40acf..010b47e8e5a3f 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcQueryBuilder.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestJdbcQueryBuilder.java @@ -37,9 +37,12 @@ import java.sql.Timestamp; import java.time.LocalDate; import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.Calendar; import java.util.List; import java.util.Locale; import java.util.Optional; +import java.util.TimeZone; import static com.facebook.airlift.testing.Assertions.assertContains; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -74,6 +77,7 @@ @Test(singleThreaded = true) public class TestJdbcQueryBuilder { + private static final Calendar UTC_CALENDAR = Calendar.getInstance(TimeZone.getTimeZone("UTC")); private TestingDatabase database; private JdbcClient jdbcClient; private ConnectorSession session; @@ -348,7 +352,7 @@ public void testBuildSqlWithTimestamp() ResultSet resultSet = preparedStatement.executeQuery()) { ImmutableSet.Builder builder = ImmutableSet.builder(); while (resultSet.next()) { - builder.add((Timestamp) resultSet.getObject("col_6")); + builder.add(resultSet.getTimestamp("col_6", UTC_CALENDAR)); } assertEquals(builder.build(), ImmutableSet.of( toTimestamp(2016, 6, 3, 0, 23, 37), @@ -379,7 +383,7 @@ public void testEmptyBuildSql() private static Timestamp toTimestamp(int year, int month, int day, int hour, int minute, int second) { - return Timestamp.valueOf(LocalDateTime.of(year, month, day, hour, minute, second)); + return Timestamp.from(LocalDateTime.of(year, month, day, hour, minute, second).toInstant(ZoneOffset.UTC)); } private static long toDays(int year, int month, int day) diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestSimpleJdbcConnectorCompatibility.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestSimpleJdbcConnectorCompatibility.java new file mode 100644 index 0000000000000..c83648d1609a2 --- /dev/null +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestSimpleJdbcConnectorCompatibility.java @@ -0,0 +1,139 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.jdbc; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.SchemaTableName; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.ListeningExecutorService; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.testing.TestingConnectorSession.SESSION; +import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +/** + * Test to ensure that JDBC connectors can work without depending on BaseJdbcConfig. + * This verifies continued compatibility for connectors that implement their own + * configuration mechanisms. + */ +@Test(singleThreaded = true) +public class TestSimpleJdbcConnectorCompatibility +{ + private TestingDatabase database; + private JdbcMetadata metadata; + private JdbcMetadataCache jdbcMetadataCache; + + @BeforeMethod + public void setUp() + throws Exception + { + database = new TestingDatabase(); + ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("test-%s"))); + jdbcMetadataCache = new JdbcMetadataCache(executor, database.getJdbcClient(), new JdbcMetadataCacheStats(), OptionalLong.of(0), OptionalLong.of(0), 100); + + // Create a simple config that doesn't extend BaseJdbcConfig + SimpleTestJdbcConfig simpleConfig = new SimpleTestJdbcConfig() + .setJdbcUrl("jdbc:h2:mem:test") + .setUsername("test") + .setPassword("test"); + + // Create a TableLocationProvider that uses the simple config + SimpleTestTableLocationProvider locationProvider = new SimpleTestTableLocationProvider(simpleConfig); + + // Create JdbcMetadata with the simple provider (not using BaseJdbcConfig) + metadata = new JdbcMetadata(jdbcMetadataCache, database.getJdbcClient(), false, locationProvider); + } + + @AfterMethod(alwaysRun = true) + public void tearDown() + throws Exception + { + database.close(); + } + + @Test + public void testSimpleConnectorBasicOperations() + { + // Verify that basic metadata operations work with a connector that doesn't use BaseJdbcConfig + + // Test schema listing + assertTrue(metadata.listSchemaNames(SESSION).containsAll(ImmutableSet.of("example", "tpch"))); + + // Test table handle retrieval + JdbcTableHandle tableHandle = metadata.getTableHandle(SESSION, new SchemaTableName("example", "numbers")); + assertNotNull(tableHandle); + + // Test table metadata retrieval + ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(SESSION, tableHandle); + assertEquals(tableMetadata.getTable(), new SchemaTableName("example", "numbers")); + assertEquals(tableMetadata.getColumns().size(), 3); + + // Test column handles + Map columnHandles = metadata.getColumnHandles(SESSION, tableHandle); + assertEquals(columnHandles.size(), 3); + assertTrue(columnHandles.containsKey("text")); + assertTrue(columnHandles.containsKey("text_short")); + assertTrue(columnHandles.containsKey("value")); + + // Test table listing + assertEquals(ImmutableSet.copyOf(metadata.listTables(SESSION, Optional.of("example"))), ImmutableSet.of( + new SchemaTableName("example", "numbers"), + new SchemaTableName("example", "view_source"), + new SchemaTableName("example", "view"))); + } + + @Test + public void testSimpleTableLocationProvider() + { + // Create a simple config + SimpleTestJdbcConfig config = new SimpleTestJdbcConfig() + .setJdbcUrl("jdbc:test://custom-location:9999/testdb"); + + // Create the location provider + SimpleTestTableLocationProvider provider = new SimpleTestTableLocationProvider(config); + + // Verify it returns the expected location + assertEquals(provider.getTableLocation(), "jdbc:test://custom-location:9999/testdb"); + } + + @Test + public void testConfigurationIndependence() + { + // Verify that SimpleTestJdbcConfig works independently of BaseJdbcConfig + SimpleTestJdbcConfig config = new SimpleTestJdbcConfig() + .setJdbcUrl("jdbc:postgresql://localhost:5432/testdb") + .setUsername("testuser") + .setPassword("testpass"); + + assertEquals(config.getJdbcUrl(), "jdbc:postgresql://localhost:5432/testdb"); + assertEquals(config.getUsername(), "testuser"); + assertEquals(config.getPassword(), "testpass"); + + // Verify that this config can be used to create a working TableLocationProvider + SimpleTestTableLocationProvider provider = new SimpleTestTableLocationProvider(config); + assertEquals(provider.getTableLocation(), "jdbc:postgresql://localhost:5432/testdb"); + } +} diff --git a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java index ad0b4163f8556..46fb38385f8c9 100644 --- a/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java +++ b/presto-base-jdbc/src/test/java/com/facebook/presto/plugin/jdbc/TestingDatabase.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.plugin.jdbc; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.SchemaTableName; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.h2.Driver; import java.sql.Connection; @@ -117,7 +117,7 @@ public Map getColumnHandles(String schemaName, String ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (JdbcColumnHandle column : columns) { - columnHandles.put(column.getColumnMetadata().getName(), column); + columnHandles.put(column.getColumnMetadata(session, jdbcClient).getName(), column); } return columnHandles.build(); } diff --git a/presto-benchmark-driver/pom.xml b/presto-benchmark-driver/pom.xml index 964510e12d51d..07254bcb4ed88 100644 --- a/presto-benchmark-driver/pom.xml +++ b/presto-benchmark-driver/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-benchmark-driver @@ -14,6 +14,8 @@ ${project.parent.basedir} com.facebook.presto.benchmark.driver.PrestoBenchmarkDriver + 17 + true @@ -23,7 +25,7 @@ - io.airlift + com.facebook.airlift airline @@ -43,7 +45,7 @@ - io.airlift + com.facebook.airlift units @@ -58,8 +60,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -91,48 +93,58 @@ - - - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - + + + executable-jar + + + !skipExecutableJar + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + package + + shade + + + true + executable + + + + ${main-class} + + + + + + + + + + org.skife.maven + really-executable-jar-maven-plugin - true - executable - - - - ${main-class} - - - + -Xmx1G + executable - - - - - - org.skife.maven - really-executable-jar-maven-plugin - - -Xmx1G - executable - - - - package - - really-executable-jar - - - - - - + + + package + + really-executable-jar + + + + + + + + diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java index 08e3b79b6b3a6..d2fb1da39fe8f 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkDriverOptions.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.benchmark.driver; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ClientSession; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; import io.airlift.airline.Option; -import io.airlift.units.Duration; import java.net.URI; import java.net.URISyntaxException; diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryResult.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryResult.java index ca4b7e818b275..db8e509b5afb9 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryResult.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryResult.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.benchmark.driver; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.Optional; diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java index 2f478e4388280..6d2a69d2dc954 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/BenchmarkQueryRunner.java @@ -20,6 +20,7 @@ import com.facebook.airlift.http.client.JsonResponseHandler; import com.facebook.airlift.http.client.Request; import com.facebook.airlift.http.client.jetty.JettyHttpClient; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ClientSession; import com.facebook.presto.client.QueryData; import com.facebook.presto.client.QueryError; @@ -27,7 +28,6 @@ import com.facebook.presto.client.StatementStats; import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; import okhttp3.OkHttpClient; import java.io.Closeable; diff --git a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/PrestoBenchmarkDriver.java b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/PrestoBenchmarkDriver.java index e69dd85139a16..890fe58bb0b99 100644 --- a/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/PrestoBenchmarkDriver.java +++ b/presto-benchmark-driver/src/main/java/com/facebook/presto/benchmark/driver/PrestoBenchmarkDriver.java @@ -20,8 +20,7 @@ import com.google.common.collect.ImmutableList; import io.airlift.airline.Command; import io.airlift.airline.HelpOption; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.io.IOException; diff --git a/presto-benchmark-runner/pom.xml b/presto-benchmark-runner/pom.xml index 63343a5251a2d..7e4b13c9e5571 100644 --- a/presto-benchmark-runner/pom.xml +++ b/presto-benchmark-runner/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-benchmark-runner @@ -14,6 +14,8 @@ ${project.parent.basedir} com.facebook.presto.benchmark.PrestoBenchmarkRunner + 17 + true @@ -53,7 +55,7 @@ - io.airlift + com.facebook.airlift units @@ -63,8 +65,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -82,16 +84,6 @@ presto-parser - - com.facebook.presto - presto-main - - - - com.facebook.presto - presto-main-base - - com.facebook.presto presto-thrift-connector @@ -128,32 +120,43 @@ - com.google.code.findbugs - jsr305 - true + jakarta.inject + jakarta.inject-api - javax.inject - javax.inject + com.fasterxml.jackson.core + jackson-core - com.fasterxml.jackson.core - jackson-core + jakarta.annotation + jakarta.annotation-api - javax.annotation - javax.annotation-api + com.google.errorprone + error_prone_annotations - io.airlift + com.facebook.airlift airline + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-main-base + test + + com.facebook.presto presto-memory @@ -190,4 +193,21 @@ test + + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.facebook.airlift:log-manager + + + + com.google.errorprone:error_prone_annotations + + + + + diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkPhaseEvent.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkPhaseEvent.java index a372cffd6ba3e..8eedfe200d37f 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkPhaseEvent.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkPhaseEvent.java @@ -15,8 +15,7 @@ import com.facebook.airlift.event.client.EventField; import com.facebook.airlift.event.client.EventType; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Optional; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkQueryEvent.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkQueryEvent.java index c83887256ab5d..e61429c4090b8 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkQueryEvent.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkQueryEvent.java @@ -15,10 +15,9 @@ import com.facebook.airlift.event.client.EventField; import com.facebook.airlift.event.client.EventType; +import com.facebook.airlift.units.Duration; import com.facebook.presto.jdbc.QueryStats; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Optional; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkSuiteEvent.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkSuiteEvent.java index c8c09253d9271..23c9918da42c5 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkSuiteEvent.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/BenchmarkSuiteEvent.java @@ -15,8 +15,7 @@ import com.facebook.airlift.event.client.EventField; import com.facebook.airlift.event.client.EventType; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static com.facebook.presto.benchmark.event.BenchmarkSuiteEvent.Status.COMPLETED_WITH_FAILURES; import static com.facebook.presto.benchmark.event.BenchmarkSuiteEvent.Status.FAILED; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/JsonEventClient.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/JsonEventClient.java index 370387a1fc09f..2dc3c744fd993 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/JsonEventClient.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/event/JsonEventClient.java @@ -18,8 +18,7 @@ import com.facebook.presto.benchmark.framework.BenchmarkRunnerConfig; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.ByteArrayOutputStream; import java.io.FileNotFoundException; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunner.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunner.java index e2c9be2063ecc..f741dad12d94f 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunner.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunner.java @@ -19,8 +19,7 @@ import com.facebook.presto.benchmark.source.BenchmarkSuiteSupplier; import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; - -import javax.annotation.PostConstruct; +import jakarta.annotation.PostConstruct; import java.util.Set; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunnerConfig.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunnerConfig.java index 66f9b1ae4d034..e06c53c41550c 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunnerConfig.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/framework/BenchmarkRunnerConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.util.Optional; import java.util.Set; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/JdbcPrestoAction.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/JdbcPrestoAction.java index 77757a73c94e1..13ebfdb560b04 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/JdbcPrestoAction.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/JdbcPrestoAction.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark.prestoaction; +import com.facebook.airlift.units.Duration; import com.facebook.presto.benchmark.framework.BenchmarkQuery; import com.facebook.presto.benchmark.framework.QueryException; import com.facebook.presto.benchmark.framework.QueryResult; @@ -24,7 +25,6 @@ import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import java.sql.DriverManager; import java.sql.ResultSet; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/PrestoClusterConfig.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/PrestoClusterConfig.java index 31a77ba9a1140..27fc3c691f224 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/PrestoClusterConfig.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/prestoaction/PrestoClusterConfig.java @@ -14,10 +14,9 @@ package com.facebook.presto.benchmark.prestoaction; import com.facebook.airlift.configuration.Config; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.NotNull; import static java.util.concurrent.TimeUnit.MINUTES; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryConfig.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryConfig.java index 477a781fc27e3..173e127ad3c05 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryConfig.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryConfig.java @@ -14,10 +14,9 @@ package com.facebook.presto.benchmark.retry; import com.facebook.airlift.configuration.Config; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.NANOSECONDS; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryDriver.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryDriver.java index 2dade9dc55073..bd7d4633c3dc0 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryDriver.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/retry/RetryDriver.java @@ -14,8 +14,8 @@ package com.facebook.presto.benchmark.retry; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.benchmark.framework.QueryException; -import io.airlift.units.Duration; import java.util.concurrent.ThreadLocalRandom; import java.util.function.Predicate; diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/BenchmarkSuiteConfig.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/BenchmarkSuiteConfig.java index 71502fb785104..59c41cd6aa399 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/BenchmarkSuiteConfig.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/BenchmarkSuiteConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.benchmark.source; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class BenchmarkSuiteConfig { diff --git a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/MySqlBenchmarkSuiteConfig.java b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/MySqlBenchmarkSuiteConfig.java index c1f02f6752efa..0214b5c80ec4d 100644 --- a/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/MySqlBenchmarkSuiteConfig.java +++ b/presto-benchmark-runner/src/main/java/com/facebook/presto/benchmark/source/MySqlBenchmarkSuiteConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.benchmark.source; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class MySqlBenchmarkSuiteConfig { diff --git a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/prestoaction/TestPrestoClusterConfig.java b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/prestoaction/TestPrestoClusterConfig.java index 1fc7079aac156..8edd893be3c8b 100644 --- a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/prestoaction/TestPrestoClusterConfig.java +++ b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/prestoaction/TestPrestoClusterConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.benchmark.prestoaction; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryConfig.java b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryConfig.java index 9c65f19514241..3ce8d38368f4d 100644 --- a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryConfig.java +++ b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.benchmark.retry; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryDriver.java b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryDriver.java index ddd11b71d633c..4c5eb1d4e15ce 100644 --- a/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryDriver.java +++ b/presto-benchmark-runner/src/test/java/com/facebook/presto/benchmark/retry/TestRetryDriver.java @@ -14,8 +14,8 @@ package com.facebook.presto.benchmark.retry; import com.facebook.airlift.log.Logging; +import com.facebook.airlift.units.Duration; import com.facebook.presto.benchmark.framework.QueryException; -import io.airlift.units.Duration; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; diff --git a/presto-benchmark/pom.xml b/presto-benchmark/pom.xml index 8d15ed8f76347..50aedc634e0be 100644 --- a/presto-benchmark/pom.xml +++ b/presto-benchmark/pom.xml @@ -5,7 +5,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-benchmark @@ -13,6 +13,8 @@ ${project.parent.basedir} + 17 + true @@ -42,14 +44,8 @@ - com.facebook.presto - presto-memory - - - - com.google.code.findbugs - jsr305 - true + jakarta.annotation + jakarta.annotation-api @@ -68,7 +64,7 @@ - io.airlift + com.facebook.airlift units @@ -106,8 +102,8 @@ - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api runtime @@ -118,6 +114,12 @@ test + + com.facebook.presto + presto-memory + test + + org.openjdk.jmh jmh-core diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java index f931743a55bd8..e1c6d43d82dcb 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractBenchmark.java @@ -13,19 +13,18 @@ */ package com.facebook.presto.benchmark; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import jakarta.annotation.Nullable; import java.util.Map; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.benchmark.FormatUtils.formatCount; import static com.facebook.presto.benchmark.FormatUtils.formatCountRate; import static com.facebook.presto.benchmark.FormatUtils.formatDataRate; import static com.facebook.presto.benchmark.FormatUtils.formatDataSize; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java index e7462a95b2054..dd2fae3546b79 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/AbstractOperatorBenchmark.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.CpuTimer; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.predicate.TupleDomain; @@ -57,7 +58,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; -import io.airlift.units.DataSize; import java.util.ArrayList; import java.util.Arrays; @@ -68,6 +68,9 @@ import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; import static com.facebook.airlift.stats.CpuTimer.CpuDuration; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SystemSessionProperties.getFilterAndProjectMinOutputPageRowCount; import static com.facebook.presto.SystemSessionProperties.getFilterAndProjectMinOutputPageSize; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -80,9 +83,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/FormatUtils.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/FormatUtils.java index 752bec0a15dc7..5f1035262b5b6 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/FormatUtils.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/FormatUtils.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.benchmark; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import java.math.RoundingMode; import java.text.DecimalFormat; -import static io.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java index 1f9697e301041..22d3e69a02066 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery1.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.benchmark.HandTpchQuery1.TpchQuery1Operator.TpchQuery1OperatorFactory; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; @@ -31,11 +32,11 @@ import com.facebook.presto.util.DateTimeUtils; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DateType.DATE; @@ -44,7 +45,6 @@ import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; public class HandTpchQuery1 diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java index 22d7863794023..2982320011e6b 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HandTpchQuery6.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.function.SqlFunctionProperties; @@ -32,12 +33,12 @@ import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.util.DateTimeUtils; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; import java.util.function.Supplier; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DateType.DATE; @@ -46,7 +47,6 @@ import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.relational.Expressions.field; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; public class HandTpchQuery6 extends AbstractSimpleOperatorBenchmark diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java index a1445e1205946..19f1d43a5d5c8 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/HashAggregationBenchmark.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.operator.HashAggregationOperator.HashAggregationOperatorFactory; @@ -23,16 +24,15 @@ import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class HashAggregationBenchmark extends AbstractSimpleOperatorBenchmark diff --git a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/PredicateFilterBenchmark.java b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/PredicateFilterBenchmark.java index d24f92f3d4aba..a41a43a23df9d 100644 --- a/presto-benchmark/src/main/java/com/facebook/presto/benchmark/PredicateFilterBenchmark.java +++ b/presto-benchmark/src/main/java/com/facebook/presto/benchmark/PredicateFilterBenchmark.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.operator.FilterAndProjectOperator; import com.facebook.presto.operator.OperatorFactory; @@ -23,12 +24,12 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; import java.util.function.Supplier; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.benchmark.BenchmarkQueryRunner.createLocalQueryRunner; import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -37,7 +38,6 @@ import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.field; -import static io.airlift.units.DataSize.Unit.BYTE; public class PredicateFilterBenchmark extends AbstractSimpleOperatorBenchmark diff --git a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java index aa52d3f28d7cf..3bcc983759952 100644 --- a/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java +++ b/presto-benchmark/src/test/java/com/facebook/presto/benchmark/MemoryLocalQueryRunner.java @@ -14,6 +14,7 @@ package com.facebook.presto.benchmark; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.QualifiedObjectName; @@ -36,7 +37,6 @@ import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.intellij.lang.annotations.Language; import java.util.List; @@ -44,8 +44,8 @@ import java.util.Optional; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static org.testng.Assert.assertTrue; public class MemoryLocalQueryRunner diff --git a/presto-benchto-benchmarks/pom.xml b/presto-benchto-benchmarks/pom.xml index 7010115890d13..8eddcee5a4983 100644 --- a/presto-benchto-benchmarks/pom.xml +++ b/presto-benchto-benchmarks/pom.xml @@ -4,7 +4,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-benchto-benchmarks @@ -14,6 +14,7 @@ ${project.parent.basedir} false false + true diff --git a/presto-benchto-benchmarks/src/main/java/com/facebook/presto/benchto/benchmarks/Dummy.java b/presto-benchto-benchmarks/src/main/java/com/facebook/presto/benchto/benchmarks/Dummy.java new file mode 100644 index 0000000000000..e5bc3b155b6b8 --- /dev/null +++ b/presto-benchto-benchmarks/src/main/java/com/facebook/presto/benchto/benchmarks/Dummy.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.benchto.benchmarks; + +/** + * This class exists to force the creation of a jar for the presto-benchto-benchmarks module. This is needed to deploy the presto-benchto-benchmarks module to nexus. + */ +public class Dummy +{ +} diff --git a/presto-bigquery/pom.xml b/presto-bigquery/pom.xml index ba846fdd82244..f7b712b6669d9 100644 --- a/presto-bigquery/pom.xml +++ b/presto-bigquery/pom.xml @@ -5,18 +5,20 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-bigquery + presto-bigquery Presto - Bigquery Connector presto-plugin ${project.parent.basedir} - 1.33.1 + 1.39.1 1.11.0 - 1.46.3 + 2.0.0 + true @@ -24,7 +26,7 @@ com.google.cloud libraries-bom - 16.3.0 + 26.68.0 pom import @@ -32,7 +34,7 @@ org.threeten threetenbp - 1.5.1 + 1.7.2 @@ -44,13 +46,7 @@ com.google.api-client google-api-client - 1.31.1 - - - - com.google.auth - google-auth-library-oauth2-http - ${google.auth.library.version} + 2.8.0 @@ -126,6 +122,12 @@ opencensus-contrib-http-util 0.31.1 + + + javax.annotation + javax.annotation-api + 1.3.2 + @@ -236,9 +238,11 @@ + com.google.protobuf protobuf-java + 3.25.8 @@ -252,13 +256,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -280,7 +284,7 @@ - io.airlift + com.facebook.airlift units provided @@ -332,7 +336,7 @@ org.objenesis objenesis - 2.6 + 3.4 test @@ -389,9 +393,11 @@ com.fasterxml.jackson.core:jackson-core - javax.annotation:javax.annotation-api + jakarta.annotation:jakarta.annotation-api com.fasterxml.jackson.core:jackson-databind com.google.api.grpc:proto-google-common-protos + + com.google.protobuf:protobuf-java @@ -408,6 +414,16 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.facebook.airlift:log-manager + + + diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryClient.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryClient.java index 59215dc81fe79..221ece8ae3b74 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryClient.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryClient.java @@ -50,6 +50,7 @@ public class BigQueryClient private final Optional viewMaterializationProject; private final Optional viewMaterializationDataset; private final String tablePrefix = "_pbc_"; + private final boolean caseSensitiveNameMatching; // presto converts the dataset and table names to lower case, while BigQuery is case sensitive private final ConcurrentMap tableIds = new ConcurrentHashMap<>(); @@ -60,6 +61,7 @@ public class BigQueryClient this.bigQuery = requireNonNull(bigQuery, "bigQuery is null"); this.viewMaterializationProject = requireNonNull(config.getViewMaterializationProject(), "viewMaterializationProject is null"); this.viewMaterializationDataset = requireNonNull(config.getViewMaterializationDataset(), "viewMaterializationDataset is null"); + this.caseSensitiveNameMatching = config.isCaseSensitiveNameMatching(); } public TableInfo getTable(TableId tableId) @@ -108,7 +110,7 @@ private void addTableMappingIfNeeded(DatasetId datasetID, Table table) private Dataset addDataSetMappingIfNeeded(Dataset dataset) { DatasetId bigQueryDatasetId = dataset.getDatasetId(); - DatasetId prestoDatasetId = DatasetId.of(bigQueryDatasetId.getProject(), bigQueryDatasetId.getDataset().toLowerCase(ENGLISH)); + DatasetId prestoDatasetId = DatasetId.of(bigQueryDatasetId.getProject(), bigQueryDatasetId.getDataset()); datasetIds.putIfAbsent(prestoDatasetId, bigQueryDatasetId); return dataset; } @@ -123,7 +125,8 @@ protected TableId createDestinationTable(TableId tableId) private String createTableName() { - return format(tablePrefix + "%s", randomUUID().toString().toLowerCase(ENGLISH).replace("-", "")); + String uuid = randomUUID().toString().replace("-", ""); + return caseSensitiveNameMatching ? format("%s%s", tablePrefix, uuid) : format("%s%s", tablePrefix, uuid).toLowerCase(ENGLISH); } private DatasetId mapIfNeeded(String project, String dataset) diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConfig.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConfig.java index 819bce2a97de1..1247f385e1514 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConfig.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConfig.java @@ -16,9 +16,8 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.google.auth.oauth2.GoogleCredentials; - -import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.Min; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.Min; import java.io.IOException; import java.util.Optional; @@ -36,6 +35,7 @@ public class BigQueryConfig private Optional parentProjectId = Optional.empty(); private OptionalInt parallelism = OptionalInt.empty(); private boolean viewsEnabled; + private boolean caseSensitiveNameMatching; private Optional viewMaterializationProject = Optional.empty(); private Optional viewMaterializationDataset = Optional.empty(); private int maxReadRowsRetries = DEFAULT_MAX_READ_ROWS_RETRIES; @@ -182,6 +182,22 @@ public BigQueryConfig setMaxReadRowsRetries(int maxReadRowsRetries) return this; } + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatching; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription( + "Case sensitivity for schema and table name matching. " + + "true = preserve case and require exact matches; " + + "false (default) = normalize to lower case and match case-insensitively.") + public BigQueryConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatching) + { + this.caseSensitiveNameMatching = caseSensitiveNameMatching; + return this; + } + ReadSessionCreatorConfig createReadSessionCreatorConfig() { return new ReadSessionCreatorConfig( diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConnector.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConnector.java index 76aa273c04626..9fe60eb1d87c2 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConnector.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryConnector.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryMetadata.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryMetadata.java index 8952eb3bbf87d..45dfb474efb81 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryMetadata.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryMetadata.java @@ -39,8 +39,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -48,9 +47,11 @@ import java.util.Set; import static com.facebook.presto.plugin.bigquery.BigQueryErrorCode.BIGQUERY_TABLE_DISAPPEAR_DURING_LIST; +import static com.facebook.presto.plugin.bigquery.Conversions.toColumnMetadata; import static com.google.cloud.bigquery.TableDefinition.Type.TABLE; import static com.google.cloud.bigquery.TableDefinition.Type.VIEW; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; @@ -63,12 +64,14 @@ public class BigQueryMetadata private static final Logger log = Logger.get(BigQueryMetadata.class); private final BigQueryClient bigQueryClient; private final String projectId; + private final boolean caseSensitiveNameMatching; @Inject public BigQueryMetadata(BigQueryClient bigQueryClient, BigQueryConfig config) { this.bigQueryClient = bigQueryClient; this.projectId = config.getProjectId().orElse(bigQueryClient.getProjectId()); + this.caseSensitiveNameMatching = config.isCaseSensitiveNameMatching(); } @Override @@ -124,7 +127,7 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable } @Override - public List getTableLayouts( + public ConnectorTableLayoutResult getTableLayoutForConstraint( ConnectorSession session, ConnectorTableHandle table, Constraint constraint, @@ -136,7 +139,7 @@ public List getTableLayouts( bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(ImmutableList.copyOf(desiredColumns.get())); } BigQueryTableLayoutHandle bigQueryTableLayoutHandle = new BigQueryTableLayoutHandle(bigQueryTableHandle); - return ImmutableList.of(new ConnectorTableLayoutResult(new ConnectorTableLayout(bigQueryTableLayoutHandle), constraint.getSummary())); + return new ConnectorTableLayoutResult(new ConnectorTableLayout(bigQueryTableLayoutHandle), constraint.getSummary()); } @Override @@ -179,7 +182,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect List columns = schema == null ? ImmutableList.of() : schema.getFields().stream() - .map(Conversions::toColumnMetadata) + .map(field -> toColumnMetadata(field, normalizeIdentifier(session, field.getName()))) .collect(toImmutableList()); return new ConnectorTableMetadata(schemaTableName, columns); } @@ -233,4 +236,10 @@ private List listTables(ConnectorSession session, SchemaTablePr ImmutableList.of(tableName) : ImmutableList.of(); // table does not exist } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatching ? identifier : identifier.toLowerCase(ENGLISH); + } } diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryPageSourceProvider.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryPageSourceProvider.java index 65dd0886560b1..4de0ecbb32266 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryPageSourceProvider.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryPageSourceProvider.java @@ -24,8 +24,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryStorageClientFactory.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryStorageClientFactory.java index 3e5b94870d5d1..d937ff1594da9 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryStorageClientFactory.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/BigQueryStorageClientFactory.java @@ -18,8 +18,7 @@ import com.google.auth.Credentials; import com.google.cloud.bigquery.storage.v1beta1.BigQueryStorageClient; import com.google.cloud.bigquery.storage.v1beta1.BigQueryStorageSettings; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/Conversions.java b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/Conversions.java index ed2e88e3d9cba..5fa2f3ae785af 100644 --- a/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/Conversions.java +++ b/presto-bigquery/src/main/java/com/facebook/presto/plugin/bigquery/Conversions.java @@ -47,10 +47,10 @@ static BigQueryColumnHandle toColumnHandle(Field field) field.getDescription()); } - static ColumnMetadata toColumnMetadata(Field field) + static ColumnMetadata toColumnMetadata(Field field, String name) { return ColumnMetadata.builder() - .setName(field.getName()) + .setName(name) .setType(adapt(field).getPrestoType()) .setNullable(getMode(field) == NULLABLE) .setComment(field.getDescription()) diff --git a/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestBigQueryConfig.java b/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestBigQueryConfig.java index 053f83600eb85..8829d86cd6e23 100644 --- a/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestBigQueryConfig.java +++ b/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestBigQueryConfig.java @@ -36,7 +36,8 @@ public void testDefaults() .setParallelism(20) .setViewMaterializationProject("vmproject") .setViewMaterializationDataset("vmdataset") - .setMaxReadRowsRetries(10); + .setMaxReadRowsRetries(10) + .setCaseSensitiveNameMatching(false); assertEquals(config.getCredentialsKey(), Optional.of("ckey")); assertEquals(config.getCredentialsFile(), Optional.of("cfile")); @@ -46,6 +47,7 @@ public void testDefaults() assertEquals(config.getViewMaterializationProject(), Optional.of("vmproject")); assertEquals(config.getViewMaterializationDataset(), Optional.of("vmdataset")); assertEquals(config.getMaxReadRowsRetries(), 10); + assertEquals(config.isCaseSensitiveNameMatching(), false); } @Test @@ -59,6 +61,7 @@ public void testExplicitPropertyMappingsWithCredentialsKey() .put("bigquery.view-materialization-project", "vmproject") .put("bigquery.view-materialization-dataset", "vmdataset") .put("bigquery.max-read-rows-retries", "10") + .put("case-sensitive-name-matching", "true") .build(); ConfigurationFactory configurationFactory = new ConfigurationFactory(properties); @@ -71,6 +74,7 @@ public void testExplicitPropertyMappingsWithCredentialsKey() assertEquals(config.getViewMaterializationProject(), Optional.of("vmproject")); assertEquals(config.getViewMaterializationDataset(), Optional.of("vmdataset")); assertEquals(config.getMaxReadRowsRetries(), 10); + assertEquals(config.isCaseSensitiveNameMatching(), true); } @Test diff --git a/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestTypeConversions.java b/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestTypeConversions.java index 4e6c6b3fe5111..e0e9698e5ac52 100644 --- a/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestTypeConversions.java +++ b/presto-bigquery/src/test/java/com/facebook/presto/plugin/bigquery/TestTypeConversions.java @@ -134,7 +134,7 @@ public void testConvertOneLevelRecordField() RECORD, Field.of("sub_s", STRING), Field.of("sub_i", INTEGER)); - ColumnMetadata metadata = Conversions.toColumnMetadata(field); + ColumnMetadata metadata = Conversions.toColumnMetadata(field, field.getName()); RowType targetType = RowType.from(ImmutableList.of( RowType.field("sub_s", VarcharType.VARCHAR), RowType.field("sub_i", BigintType.BIGINT))); @@ -152,7 +152,7 @@ public void testConvertTwoLevelsRecordField() Field.of("sub_sub_i", INTEGER)), Field.of("sub_s", STRING), Field.of("sub_i", INTEGER)); - ColumnMetadata metadata = Conversions.toColumnMetadata(field); + ColumnMetadata metadata = Conversions.toColumnMetadata(field, field.getName()); RowType targetType = RowType.from(ImmutableList.of( RowType.field("sub_rec", RowType.from(ImmutableList.of( RowType.field("sub_sub_s", VarcharType.VARCHAR), @@ -168,13 +168,13 @@ public void testConvertStringArrayField() Field field = Field.newBuilder("test", STRING) .setMode(REPEATED) .build(); - ColumnMetadata metadata = Conversions.toColumnMetadata(field); + ColumnMetadata metadata = Conversions.toColumnMetadata(field, field.getName()); assertThat(metadata.getType()).isEqualTo(new ArrayType(VarcharType.VARCHAR)); } void assertSimpleFieldTypeConversion(LegacySQLTypeName from, Type to) { - ColumnMetadata metadata = Conversions.toColumnMetadata(createField(from)); + ColumnMetadata metadata = Conversions.toColumnMetadata(createField(from), createField(from).getName()); assertThat(metadata.getType()).isEqualTo(to); } diff --git a/presto-blackhole/pom.xml b/presto-blackhole/pom.xml index d6d30f2278f11..fc2193f88fbb9 100644 --- a/presto-blackhole/pom.xml +++ b/presto-blackhole/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-blackhole + presto-blackhole Presto - Black Hole Connector presto-plugin ${project.parent.basedir} + true @@ -48,7 +50,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -60,7 +62,7 @@ - io.airlift + com.facebook.airlift units provided diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleColumnHandle.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleColumnHandle.java index f861b873a63ed..03b2dc317d847 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleColumnHandle.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleColumnHandle.java @@ -61,6 +61,14 @@ public ColumnMetadata toColumnMetadata() .build(); } + public ColumnMetadata toColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(columnType) + .build(); + } + @Override public int hashCode() { diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleConnector.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleConnector.java index eb9bd54b12548..f5c3b08be741d 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleConnector.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleConnector.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.TypeSignatureParameter; import com.facebook.presto.spi.connector.Connector; @@ -25,7 +26,6 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.List; import java.util.concurrent.ExecutorService; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleInsertTableHandle.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleInsertTableHandle.java index c09634aef0d66..6d241d1e216c8 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleInsertTableHandle.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleInsertTableHandle.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import static java.util.Objects.requireNonNull; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleMetadata.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleMetadata.java index 51f0db43fc997..afe0f1efd8dc4 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleMetadata.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleMetadata.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorInsertTableHandle; @@ -36,7 +37,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import java.util.ArrayList; import java.util.Collection; @@ -54,6 +54,7 @@ import static com.facebook.presto.plugin.blackhole.BlackHoleConnector.SPLIT_COUNT_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; import static com.facebook.presto.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; @@ -97,7 +98,16 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { BlackHoleTableHandle blackHoleTableHandle = (BlackHoleTableHandle) tableHandle; - return blackHoleTableHandle.toTableMetadata(); + return toTableMetadata(blackHoleTableHandle, session); + } + + public ConnectorTableMetadata toTableMetadata(BlackHoleTableHandle blackHoleTableHandle, ConnectorSession session) + { + List columns = blackHoleTableHandle.getColumnHandles().stream() + .map(column -> column.toColumnMetadata(normalizeIdentifier(session, column.getName()))) + .collect(toImmutableList()); + + return new ConnectorTableMetadata(blackHoleTableHandle.toSchemaTableName(), columns); } @Override @@ -129,7 +139,7 @@ public Map> listTableColumns(ConnectorSess { return tables.values().stream() .filter(table -> prefix.matches(table.toSchemaTableName())) - .collect(toMap(BlackHoleTableHandle::toSchemaTableName, handle -> handle.toTableMetadata().getColumns())); + .collect(toMap(BlackHoleTableHandle::toSchemaTableName, handle -> toTableMetadata(handle, session).getColumns())); } @Override @@ -243,7 +253,7 @@ public Optional finishInsert(ConnectorSession session, } @Override - public List getTableLayouts( + public ConnectorTableLayoutResult getTableLayoutForConstraint( ConnectorSession session, ConnectorTableHandle handle, Constraint constraint, @@ -256,7 +266,7 @@ public List getTableLayouts( blackHoleHandle.getRowsPerPage(), blackHoleHandle.getFieldsLength(), blackHoleHandle.getPageProcessingDelay()); - return ImmutableList.of(new ConnectorTableLayoutResult(getTableLayout(session, layoutHandle), constraint.getSummary())); + return new ConnectorTableLayoutResult(getTableLayout(session, layoutHandle), constraint.getSummary()); } @Override diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleOutputTableHandle.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleOutputTableHandle.java index 202bda6cd5a63..041d6537766ac 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleOutputTableHandle.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleOutputTableHandle.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import static java.util.Objects.requireNonNull; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSink.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSink.java index 7052af20ff47a..662c8bca593e0 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSink.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSink.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.spi.ConnectorPageSink; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningScheduledExecutorService; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import java.util.Collection; import java.util.concurrent.CompletableFuture; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSource.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSource.java index 836a15ae8d500..f0c35add5670e 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSource.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHolePageSource.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.spi.ConnectorPageSource; import com.google.common.util.concurrent.ListeningScheduledExecutorService; -import io.airlift.units.Duration; import java.util.concurrent.CompletableFuture; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleSplit.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleSplit.java index 92a126d8b10a8..869f0f69c76e5 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleSplit.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleSplit.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.NodeProvider; @@ -20,7 +21,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.List; import java.util.Objects; diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableHandle.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableHandle.java index 8a83a209940f2..4e58106c30932 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableHandle.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableHandle.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import java.util.List; import java.util.Objects; @@ -126,13 +126,6 @@ public Duration getPageProcessingDelay() return pageProcessingDelay; } - public ConnectorTableMetadata toTableMetadata() - { - return new ConnectorTableMetadata( - toSchemaTableName(), - columnHandles.stream().map(BlackHoleColumnHandle::toColumnMetadata).collect(toList())); - } - public SchemaTableName toSchemaTableName() { return new SchemaTableName(schemaName, tableName); diff --git a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableLayoutHandle.java b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableLayoutHandle.java index 4c6366db788bb..ae944d7b2d413 100644 --- a/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableLayoutHandle.java +++ b/presto-blackhole/src/main/java/com/facebook/presto/plugin/blackhole/BlackHoleTableLayoutHandle.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import java.util.Objects; diff --git a/presto-blackhole/src/test/java/com/facebook/presto/plugin/blackhole/TestBlackHoleMetadata.java b/presto-blackhole/src/test/java/com/facebook/presto/plugin/blackhole/TestBlackHoleMetadata.java index 09529dc14c7a4..2c1439f61c12b 100644 --- a/presto-blackhole/src/test/java/com/facebook/presto/plugin/blackhole/TestBlackHoleMetadata.java +++ b/presto-blackhole/src/test/java/com/facebook/presto/plugin/blackhole/TestBlackHoleMetadata.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.plugin.blackhole; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.List; diff --git a/presto-built-in-worker-function-tools/pom.xml b/presto-built-in-worker-function-tools/pom.xml new file mode 100644 index 0000000000000..9ad5ca7c89b30 --- /dev/null +++ b/presto-built-in-worker-function-tools/pom.xml @@ -0,0 +1,62 @@ + + + + presto-root + com.facebook.presto + 0.297-edge10.1-SNAPSHOT + + 4.0.0 + + presto-built-in-worker-function-tools + presto-built-in-worker-function-tools + + + ${project.parent.basedir} + 17 + true + + + + + com.facebook.presto + presto-spi + + + com.facebook.presto + presto-function-namespace-managers-common + + + com.facebook.airlift + http-client + + + com.google.inject + guice + + + com.google.guava + guava + + + com.facebook.airlift + json + + + com.facebook.presto + presto-common + + + com.facebook.airlift + log + + + com.facebook.airlift + configuration + + + org.testng + testng + test + + + diff --git a/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/ForNativeFunctionRegistryInfo.java b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/ForNativeFunctionRegistryInfo.java new file mode 100644 index 0000000000000..000a6b2e6d0e1 --- /dev/null +++ b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/ForNativeFunctionRegistryInfo.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.builtin.tools; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; + +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@BindingAnnotation +public @interface ForNativeFunctionRegistryInfo +{ +} diff --git a/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarFunctionRegistryTool.java b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarFunctionRegistryTool.java new file mode 100644 index 0000000000000..1990918df633f --- /dev/null +++ b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarFunctionRegistryTool.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.builtin.tools; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.http.client.JsonResponseHandler; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata; +import com.facebook.presto.functionNamespace.UdfFunctionSignatureMap; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.function.SqlFunction; +import com.google.common.collect.ImmutableMap; + +import java.net.URI; +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class NativeSidecarFunctionRegistryTool + implements WorkerFunctionRegistryTool +{ + private final int maxRetries; + private final long retryDelayMs; + private static final Logger log = Logger.get(NativeSidecarFunctionRegistryTool.class); + private final JsonCodec>> nativeFunctionSignatureMapJsonCodec; + private final NodeManager nodeManager; + private final HttpClient httpClient; + private static final String FUNCTION_SIGNATURES_ENDPOINT = "/v1/functions"; + + public NativeSidecarFunctionRegistryTool( + HttpClient httpClient, + JsonCodec>> nativeFunctionSignatureMapJsonCodec, + NodeManager nodeManager, + int nativeSidecarRegistryToolNumRetries, + long nativeSidecarRegistryToolRetryDelayMs) + { + this.nativeFunctionSignatureMapJsonCodec = + requireNonNull(nativeFunctionSignatureMapJsonCodec, "nativeFunctionSignatureMapJsonCodec is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.httpClient = requireNonNull(httpClient, "typeManager is null"); + this.maxRetries = nativeSidecarRegistryToolNumRetries; + this.retryDelayMs = nativeSidecarRegistryToolRetryDelayMs; + } + + @Override + public List getWorkerFunctions() + { + return getNativeFunctionSignatureMap() + .getUDFSignatureMap() + .entrySet() + .stream() + .flatMap(entry -> entry.getValue().stream() + .map(metaInfo -> WorkerFunctionUtil.createSqlInvokedFunction(entry.getKey(), metaInfo, "presto"))) + .collect(toImmutableList()); + } + + private UdfFunctionSignatureMap getNativeFunctionSignatureMap() + { + try { + Request request = Request.Builder.prepareGet().setUri(getSidecarLocationOnStartup(nodeManager, maxRetries, retryDelayMs)).build(); + Map> nativeFunctionSignatureMap = httpClient.execute(request, JsonResponseHandler.createJsonResponseHandler(nativeFunctionSignatureMapJsonCodec)); + return new UdfFunctionSignatureMap(ImmutableMap.copyOf(nativeFunctionSignatureMap)); + } + catch (Exception e) { + throw new PrestoException(StandardErrorCode.INVALID_ARGUMENTS, "Failed to get functions from sidecar.", e); + } + } + + public static URI getSidecarLocationOnStartup(NodeManager nodeManager, int maxRetries, long retryDelayMs) + { + Node sidecarNode = null; + for (int attempt = 1; attempt <= maxRetries; attempt++) { + try { + sidecarNode = nodeManager.getSidecarNode(); + if (sidecarNode != null) { + break; + } + } + catch (Exception e) { + log.error("Error getting sidecar node (attempt " + attempt + "): " + e.getMessage()); + if (attempt == maxRetries) { + throw new RuntimeException("Failed to get sidecar node", e); + } + else { + try { + Thread.sleep(retryDelayMs); + } + catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Retry fetching sidecar function registry interrupted", ie); + } + } + } + } + + return HttpUriBuilder + .uriBuilderFrom(sidecarNode.getHttpUri()) + .appendPath(FUNCTION_SIGNATURES_ENDPOINT) + .build(); + } +} diff --git a/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarRegistryToolConfig.java b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarRegistryToolConfig.java new file mode 100644 index 0000000000000..4ce944572406d --- /dev/null +++ b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/NativeSidecarRegistryToolConfig.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.builtin.tools; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; + +import java.time.Duration; + +public class NativeSidecarRegistryToolConfig +{ + private int nativeSidecarRegistryToolNumRetries = 8; + private long nativeSidecarRegistryToolRetryDelayMs = Duration.ofMinutes(1).toMillis(); + + public int getNativeSidecarRegistryToolNumRetries() + { + return nativeSidecarRegistryToolNumRetries; + } + + @Config("native-sidecar-registry-tool.num-retries") + @ConfigDescription("Max times to retry fetching sidecar node") + public NativeSidecarRegistryToolConfig setNativeSidecarRegistryToolNumRetries(int nativeSidecarRegistryToolNumRetries) + { + this.nativeSidecarRegistryToolNumRetries = nativeSidecarRegistryToolNumRetries; + return this; + } + + public long getNativeSidecarRegistryToolRetryDelayMs() + { + return nativeSidecarRegistryToolRetryDelayMs; + } + + @Config("native-sidecar-registry-tool.retry-delay-ms") + @ConfigDescription("Cooldown period to retry when fetching sidecar node fails") + public NativeSidecarRegistryToolConfig setNativeSidecarRegistryToolRetryDelayMs(long nativeSidecarRegistryToolRetryDelayMs) + { + this.nativeSidecarRegistryToolRetryDelayMs = nativeSidecarRegistryToolRetryDelayMs; + return this; + } +} diff --git a/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionRegistryTool.java b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionRegistryTool.java new file mode 100644 index 0000000000000..99c7cfe13bc78 --- /dev/null +++ b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionRegistryTool.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.builtin.tools; + +import com.facebook.presto.spi.function.SqlFunction; + +import java.util.List; + +public interface WorkerFunctionRegistryTool +{ + List getWorkerFunctions(); +} diff --git a/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionUtil.java b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionUtil.java new file mode 100644 index 0000000000000..3ecd353d21327 --- /dev/null +++ b/presto-built-in-worker-function-tools/src/main/java/com/facebook/presto/builtin/tools/WorkerFunctionUtil.java @@ -0,0 +1,178 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.builtin.tools; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.NamedTypeSignature; +import com.facebook.presto.common.type.RowFieldName; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.common.type.TypeSignatureParameter; +import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.LongVariableConstraint; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.TypeVariableConstraint; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; + +public class WorkerFunctionUtil +{ + private WorkerFunctionUtil() {} + + public static synchronized SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetaData, String catalogName) + { + checkState(jsonBasedUdfFunctionMetaData.getRoutineCharacteristics().getLanguage().equals(RoutineCharacteristics.Language.CPP), "WorkerFunctionUtil only supports CPP UDF"); + QualifiedObjectName qualifiedFunctionName = QualifiedObjectName.valueOf(new CatalogSchemaName(catalogName, jsonBasedUdfFunctionMetaData.getSchema()), functionName); + List parameterNameList = jsonBasedUdfFunctionMetaData.getParamNames(); + List parameterTypeList = convertApplicableTypeToVariable(jsonBasedUdfFunctionMetaData.getParamTypes()); + List typeVariableConstraintsList = jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().isPresent() ? + jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().get() : ImmutableList.of(); + List longVariableConstraintList = jsonBasedUdfFunctionMetaData.getLongVariableConstraints().isPresent() ? + jsonBasedUdfFunctionMetaData.getLongVariableConstraints().get() : ImmutableList.of(); + + TypeSignature outputType = convertApplicableTypeToVariable(jsonBasedUdfFunctionMetaData.getOutputType()); + ImmutableList.Builder parameterBuilder = ImmutableList.builder(); + for (int i = 0; i < parameterNameList.size(); i++) { + parameterBuilder.add(new Parameter(parameterNameList.get(i), parameterTypeList.get(i))); + } + + Optional aggregationFunctionMetadata = + jsonBasedUdfFunctionMetaData.getAggregateMetadata() + .map(metadata -> new AggregationFunctionMetadata( + convertApplicableTypeToVariable(metadata.getIntermediateType()), + metadata.isOrderSensitive())); + + return new SqlInvokedFunction( + qualifiedFunctionName, + parameterBuilder.build(), + typeVariableConstraintsList, + longVariableConstraintList, + outputType, + jsonBasedUdfFunctionMetaData.getDocString(), + jsonBasedUdfFunctionMetaData.getRoutineCharacteristics(), + "", + jsonBasedUdfFunctionMetaData.getVariableArity(), + notVersioned(), + jsonBasedUdfFunctionMetaData.getFunctionKind(), + aggregationFunctionMetadata); + } + + // Todo: Improve the handling of parameter type differentiation in native execution. + // HACK: Currently, we lack support for correctly identifying the parameterKind, specifically between TYPE and VARIABLE, + // in native execution. The following utility functions help bridge this gap by parsing the type signature and verifying whether its base + // and parameters are of a supported type. The valid types list are non - parametric types that Presto supports. + public static List convertApplicableTypeToVariable(List typeSignatures) + { + List newTypeSignaturesList = new ArrayList<>(); + for (TypeSignature typeSignature : typeSignatures) { + if (!typeSignature.getParameters().isEmpty()) { + TypeSignature newTypeSignature = + new TypeSignature( + typeSignature.getBase(), + getTypeSignatureParameters( + typeSignature, + typeSignature.getParameters())); + newTypeSignaturesList.add(newTypeSignature); + } + else { + newTypeSignaturesList.add(typeSignature); + } + } + return newTypeSignaturesList; + } + + public static TypeSignature convertApplicableTypeToVariable(TypeSignature typeSignature) + { + List typeSignaturesList = convertApplicableTypeToVariable(ImmutableList.of(typeSignature)); + checkArgument(!typeSignaturesList.isEmpty(), "Type signature list is empty for : " + typeSignature); + return typeSignaturesList.get(0); + } + + private static List getTypeSignatureParameters( + TypeSignature typeSignature, + List typeSignatureParameterList) + { + List newParameterTypeList = new ArrayList<>(); + for (TypeSignatureParameter parameter : typeSignatureParameterList) { + if (parameter.isLongLiteral()) { + newParameterTypeList.add(parameter); + continue; + } + + boolean isNamedTypeSignature = parameter.isNamedTypeSignature(); + TypeSignature parameterTypeSignature; + // If it's a named type signatures only in the case of row signature types. + if (isNamedTypeSignature) { + parameterTypeSignature = parameter.getNamedTypeSignature().getTypeSignature(); + } + else { + parameterTypeSignature = parameter.getTypeSignature(); + } + + if (parameterTypeSignature.getParameters().isEmpty()) { + boolean changeTypeToVariable = isDecimalTypeBase(typeSignature.getBase()); + if (changeTypeToVariable) { + newParameterTypeList.add( + TypeSignatureParameter.of(parameterTypeSignature.getBase())); + } + else { + if (isNamedTypeSignature) { + newParameterTypeList.add(TypeSignatureParameter.of(parameter.getNamedTypeSignature())); + } + else { + newParameterTypeList.add(TypeSignatureParameter.of(parameterTypeSignature)); + } + } + } + else { + TypeSignature newTypeSignature = + new TypeSignature( + parameterTypeSignature.getBase(), + getTypeSignatureParameters( + parameterTypeSignature.getStandardTypeSignature(), + parameterTypeSignature.getParameters())); + if (isNamedTypeSignature) { + // Preserve the original field name if present, otherwise use Optional.empty() + Optional fieldName = parameter.getNamedTypeSignature().getFieldName(); + newParameterTypeList.add( + TypeSignatureParameter.of( + new NamedTypeSignature( + fieldName, + newTypeSignature))); + } + else { + newParameterTypeList.add(TypeSignatureParameter.of(newTypeSignature)); + } + } + } + return newParameterTypeList; + } + + private static boolean isDecimalTypeBase(String typeBase) + { + return typeBase.equals(StandardTypes.DECIMAL); + } +} diff --git a/presto-built-in-worker-function-tools/src/test/java/com/facebook/presto/builtin/tools/TestNativeSidecarRegistryToolConfig.java b/presto-built-in-worker-function-tools/src/test/java/com/facebook/presto/builtin/tools/TestNativeSidecarRegistryToolConfig.java new file mode 100644 index 0000000000000..34df3f89b88f3 --- /dev/null +++ b/presto-built-in-worker-function-tools/src/test/java/com/facebook/presto/builtin/tools/TestNativeSidecarRegistryToolConfig.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.builtin.tools; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestNativeSidecarRegistryToolConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(NativeSidecarRegistryToolConfig.class) + .setNativeSidecarRegistryToolNumRetries(8) + .setNativeSidecarRegistryToolRetryDelayMs(60_000L)); + } + + @Test + public void testExplicitPropertyMappings() + throws Exception + { + Map properties = new ImmutableMap.Builder() + .put("native-sidecar-registry-tool.num-retries", "15") + .put("native-sidecar-registry-tool.retry-delay-ms", "11115") + .build(); + + NativeSidecarRegistryToolConfig expected = new NativeSidecarRegistryToolConfig() + .setNativeSidecarRegistryToolNumRetries(15) + .setNativeSidecarRegistryToolRetryDelayMs(11_115L); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-bytecode/pom.xml b/presto-bytecode/pom.xml index 9601020ae77c8..6441ba8dc14a1 100644 --- a/presto-bytecode/pom.xml +++ b/presto-bytecode/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-bytecode @@ -13,6 +13,8 @@ ${project.parent.basedir} + 8 + true @@ -39,9 +41,13 @@ - com.google.code.findbugs - jsr305 - true + com.google.errorprone + error_prone_annotations + + + + jakarta.annotation + jakarta.annotation-api diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ArrayOpCode.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ArrayOpCode.java index f0191fa266541..4f0b4e379f946 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ArrayOpCode.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ArrayOpCode.java @@ -14,8 +14,7 @@ package com.facebook.presto.bytecode; import com.google.common.collect.ImmutableMap; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Map; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/BytecodeBlock.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/BytecodeBlock.java index 0542796879bba..af8287dfc0375 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/BytecodeBlock.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/BytecodeBlock.java @@ -23,8 +23,6 @@ import com.google.common.collect.ImmutableList; import org.objectweb.asm.MethodVisitor; -import javax.annotation.concurrent.NotThreadSafe; - import java.lang.invoke.MethodType; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -53,7 +51,6 @@ import static com.google.common.base.Preconditions.checkArgument; @SuppressWarnings("UnusedDeclaration") -@NotThreadSafe public class BytecodeBlock implements BytecodeNode { diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java index 1e9c3f32c2b3a..d996ecdc335d4 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassDefinition.java @@ -18,24 +18,22 @@ import com.google.common.collect.ImmutableSet; import org.objectweb.asm.ClassVisitor; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.ArrayList; import java.util.EnumSet; import java.util.List; import java.util.Set; +import java.util.stream.Stream; import static com.facebook.presto.bytecode.Access.INTERFACE; import static com.facebook.presto.bytecode.Access.STATIC; import static com.facebook.presto.bytecode.Access.a; import static com.facebook.presto.bytecode.Access.toAccessModifier; -import static com.google.common.collect.Iterables.any; -import static com.google.common.collect.Iterables.concat; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; +import static java.util.stream.Stream.concat; import static org.objectweb.asm.Opcodes.ACC_SUPER; import static org.objectweb.asm.Opcodes.V1_8; -@NotThreadSafe public class ClassDefinition { private final EnumSet access; @@ -131,7 +129,7 @@ public void visit(ClassVisitor visitor) { // Generic signature if super class or any interface is generic String signature = null; - if (superClass.isGeneric() || any(interfaces, ParameterizedType::isGeneric)) { + if (superClass.isGeneric() || interfaces.stream().anyMatch(ParameterizedType::isGeneric)) { signature = genericClassSignature(superClass, interfaces); } @@ -293,14 +291,14 @@ public static String genericClassSignature( ParameterizedType... interfaceTypes) { return Joiner.on("").join( - concat(ImmutableList.of(classType), ImmutableList.copyOf(interfaceTypes))); + concat(Stream.of(classType), stream(interfaceTypes)).iterator()); } public static String genericClassSignature( ParameterizedType classType, List interfaceTypes) { - return Joiner.on("").join(concat(ImmutableList.of(classType), interfaceTypes)); + return Joiner.on("").join(concat(Stream.of(classType), interfaceTypes.stream()).iterator()); } @Override diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassGenerator.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassGenerator.java index 3edcd220aa79a..344dd593ba82c 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassGenerator.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassGenerator.java @@ -39,7 +39,7 @@ import static com.facebook.presto.bytecode.ClassInfoLoader.createClassInfoLoader; import static com.facebook.presto.bytecode.ParameterizedType.typeFromJavaClassName; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static com.google.common.io.CharStreams.nullWriter; import static java.nio.file.Files.createDirectories; import static java.util.Objects.requireNonNull; @@ -117,7 +117,7 @@ public ClassGenerator dumpClassFilesTo(Optional dumpClassPath) public Class defineClass(ClassDefinition classDefinition, Class superType) { Map> classes = defineClasses(ImmutableList.of(classDefinition)); - return getOnlyElement(classes.values()).asSubclass(superType); + return classes.values().stream().collect(onlyElement()).asSubclass(superType); } public Map> defineClasses(List classDefinitions) diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassInfo.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassInfo.java index 57f376b364c12..fd4ed87392042 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassInfo.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ClassInfo.java @@ -39,8 +39,8 @@ import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.facebook.presto.bytecode.ParameterizedType.typeFromPathName; -import static com.google.common.collect.Iterables.transform; -import static java.util.Arrays.asList; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; /** @@ -61,8 +61,8 @@ public ClassInfo(ClassInfoLoader loader, ClassNode classNode) typeFromPathName(classNode.name), classNode.access, classNode.superName == null ? null : typeFromPathName(classNode.superName), - transform((List) classNode.interfaces, ParameterizedType::typeFromPathName), - (List) classNode.methods); + classNode.interfaces.stream().map(ParameterizedType::typeFromPathName).collect(toImmutableList()), + classNode.methods); } public ClassInfo(ClassInfoLoader loader, Class aClass) @@ -71,7 +71,7 @@ public ClassInfo(ClassInfoLoader loader, Class aClass) type(aClass), aClass.getModifiers(), aClass.getSuperclass() == null ? null : type(aClass.getSuperclass()), - transform(asList(aClass.getInterfaces()), ParameterizedType::type), + stream(aClass.getInterfaces()).map(ParameterizedType::type).collect(toImmutableList()), null); } diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/FieldDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/FieldDefinition.java index e83c696dffbe3..cd9aa7ff9f200 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/FieldDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/FieldDefinition.java @@ -16,11 +16,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import com.google.errorprone.annotations.Immutable; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.FieldVisitor; -import javax.annotation.concurrent.Immutable; - import java.util.ArrayList; import java.util.EnumSet; import java.util.List; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java index 7da5638f7d8cf..dd7100f857c75 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/MethodDefinition.java @@ -16,13 +16,10 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.tree.InsnNode; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.ArrayList; import java.util.EnumSet; import java.util.List; @@ -32,11 +29,11 @@ import static com.facebook.presto.bytecode.Access.toAccessModifier; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.transform; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Streams.stream; import static org.objectweb.asm.Opcodes.RETURN; @SuppressWarnings("UnusedDeclaration") -@NotThreadSafe public class MethodDefinition { private final Scope scope; @@ -84,8 +81,8 @@ public MethodDefinition( this.returnType = type(void.class); } this.parameters = ImmutableList.copyOf(parameters); - this.parameterTypes = Lists.transform(this.parameters, Parameter::getType); - this.parameterAnnotations = ImmutableList.copyOf(transform(parameters, input -> new ArrayList<>())); + this.parameterTypes = this.parameters.stream().map(Parameter::getType).collect(toImmutableList()); + this.parameterAnnotations = stream(parameters).>map(input -> new ArrayList<>()).collect(toImmutableList()); Optional thisType = Optional.empty(); if (!declaringClass.isInterface() && !access.contains(STATIC)) { thisType = Optional.of(declaringClass.getType()); @@ -264,7 +261,7 @@ public String toSourceString() Joiner.on(' ').appendTo(sb, access).append(' '); sb.append(returnType.getJavaClassName()).append(' '); sb.append(name).append('('); - Joiner.on(", ").appendTo(sb, transform(parameters, Parameter::getSourceString)).append(')'); + Joiner.on(", ").appendTo(sb, parameters.stream().map(Parameter::getSourceString).collect(toImmutableList())).append(')'); return sb.toString(); } @@ -283,7 +280,7 @@ public static String methodDescription(Class returnType, List> param { return methodDescription( type(returnType), - Lists.transform(parameterTypes, ParameterizedType::type)); + parameterTypes.stream().map(ParameterizedType::type).collect(toImmutableList())); } public static String methodDescription( @@ -299,7 +296,7 @@ public static String methodDescription( { StringBuilder sb = new StringBuilder(); sb.append("("); - Joiner.on("").appendTo(sb, transform(parameterTypes, ParameterizedType::getType)); + Joiner.on("").appendTo(sb, parameterTypes.stream().map(ParameterizedType::getType).iterator()); sb.append(")"); sb.append(returnType.getType()); return sb.toString(); diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/Parameter.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/Parameter.java index 9ec15b0f9e62d..070d6ffc219cc 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/Parameter.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/Parameter.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.bytecode; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; @Immutable public class Parameter diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ParameterizedType.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ParameterizedType.java index 8ebc6fa89e818..830adb653dbe3 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ParameterizedType.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/ParameterizedType.java @@ -14,11 +14,10 @@ package com.facebook.presto.bytecode; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import org.objectweb.asm.Type; -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; - import java.util.List; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/control/CaseStatement.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/control/CaseStatement.java index d6ab57183ce83..e14b06ef94178 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/control/CaseStatement.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/control/CaseStatement.java @@ -15,8 +15,7 @@ import com.facebook.presto.bytecode.BytecodeNode; import com.facebook.presto.bytecode.instruction.LabelNode; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpression.java index 12bcf15b403ac..de0e173524d94 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpression.java @@ -20,6 +20,7 @@ import com.facebook.presto.bytecode.MethodGenerationContext; import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import org.objectweb.asm.MethodVisitor; import java.lang.reflect.Field; @@ -30,7 +31,6 @@ import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.transform; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; @@ -162,13 +162,17 @@ public final BytecodeExpression invoke(String methodName, ParameterizedType retu return invoke(methodName, returnType, - ImmutableList.copyOf(transform(parameters, BytecodeExpression::getType)), + Streams.stream(parameters).map(BytecodeExpression::getType).collect(toImmutableList()), parameters); } public final BytecodeExpression invoke(String methodName, Class returnType, Iterable> parameterTypes, BytecodeExpression... parameters) { - return invoke(methodName, type(returnType), transform(parameterTypes, ParameterizedType::type), ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); + return invoke( + methodName, + type(returnType), + Streams.stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList()), + ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); } public final BytecodeExpression invoke(String methodName, ParameterizedType returnType, Iterable parameterTypes, BytecodeExpression... parameters) diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpressions.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpressions.java index c1f0faddd9b2c..5e760b7a25166 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpressions.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/BytecodeExpressions.java @@ -18,6 +18,7 @@ import com.facebook.presto.bytecode.OpCode; import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import java.lang.invoke.MethodType; import java.lang.reflect.Constructor; @@ -36,7 +37,6 @@ import static com.facebook.presto.bytecode.instruction.Constant.loadString; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.transform; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; @@ -259,13 +259,16 @@ public static BytecodeExpression newInstance(ParameterizedType returnType, Itera return newInstance( returnType, - ImmutableList.copyOf(transform(parameters, BytecodeExpression::getType)), + Streams.stream(parameters).map(BytecodeExpression::getType).collect(toImmutableList()), parameters); } public static BytecodeExpression newInstance(Class returnType, Iterable> parameterTypes, BytecodeExpression... parameters) { - return newInstance(type(returnType), transform(parameterTypes, ParameterizedType::type), ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); + return newInstance( + type(returnType), + Streams.stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList()), + ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); } public static BytecodeExpression newInstance(ParameterizedType returnType, Iterable parameterTypes, BytecodeExpression... parameters) @@ -378,7 +381,7 @@ public static BytecodeExpression invokeStatic( methodTargetType, methodName, returnType, - ImmutableList.copyOf(transform(parameters, BytecodeExpression::getType)), + Streams.stream(parameters).map(BytecodeExpression::getType).collect(toImmutableList()), parameters); } @@ -398,7 +401,7 @@ public static BytecodeExpression invokeStatic( type(methodTargetType), methodName, type(returnType), - transform(parameterTypes, ParameterizedType::type), + Streams.stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList()), ImmutableList.copyOf(parameters)); } @@ -457,7 +460,7 @@ public static BytecodeExpression invokeDynamic( bootstrapArgs, methodName, type(returnType), - ImmutableList.copyOf(transform(parameters, BytecodeExpression::getType)), + Streams.stream(parameters).map(BytecodeExpression::getType).collect(toImmutableList()), parameters); } @@ -486,7 +489,7 @@ public static BytecodeExpression invokeDynamic( bootstrapArgs, methodName, returnType, - ImmutableList.copyOf(transform(parameters, BytecodeExpression::getType)), + Streams.stream(parameters).map(BytecodeExpression::getType).collect(toImmutableList()), parameters); } @@ -515,7 +518,7 @@ public static BytecodeExpression invokeDynamic( bootstrapArgs, methodName, type(methodType.returnType()), - transform(methodType.parameterList(), ParameterizedType::type), + methodType.parameterList().stream().map(ParameterizedType::type).collect(toImmutableList()), ImmutableList.copyOf(requireNonNull(parameters, "parameters is null"))); } diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/GetFieldBytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/GetFieldBytecodeExpression.java index d1eec06e19b19..57240b68bd15b 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/GetFieldBytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/GetFieldBytecodeExpression.java @@ -19,8 +19,7 @@ import com.facebook.presto.bytecode.MethodGenerationContext; import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.lang.reflect.Field; import java.lang.reflect.Modifier; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeBytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeBytecodeExpression.java index 829d7caa61c1f..58140113bd3f0 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeBytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeBytecodeExpression.java @@ -19,8 +19,7 @@ import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeDynamicBytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeDynamicBytecodeExpression.java index 93937b44f1c7d..570f2c2eacc4f 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeDynamicBytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/InvokeDynamicBytecodeExpression.java @@ -23,7 +23,6 @@ import java.lang.reflect.Method; import java.util.List; -import static com.google.common.collect.Iterables.transform; import static java.util.Objects.requireNonNull; class InvokeDynamicBytecodeExpression @@ -75,7 +74,7 @@ protected String formatOneLine() StringBuilder builder = new StringBuilder(); builder.append("[").append(bootstrapMethod.getName()); if (!bootstrapArgs.isEmpty()) { - builder.append("(").append(Joiner.on(", ").join(transform(bootstrapArgs, ConstantBytecodeExpression::renderConstant))).append(")"); + builder.append("(").append(Joiner.on(", ").join(bootstrapArgs.stream().map(ConstantBytecodeExpression::renderConstant).iterator())).append(")"); } builder.append("]=>"); diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/NewArrayBytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/NewArrayBytecodeExpression.java index e6cb9030e472a..986576ee5da3d 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/NewArrayBytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/NewArrayBytecodeExpression.java @@ -20,8 +20,7 @@ import com.facebook.presto.bytecode.instruction.TypeInstruction; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/SetFieldBytecodeExpression.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/SetFieldBytecodeExpression.java index 8f62930020f73..7522deef9d02a 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/SetFieldBytecodeExpression.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/expression/SetFieldBytecodeExpression.java @@ -19,8 +19,7 @@ import com.facebook.presto.bytecode.MethodGenerationContext; import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.lang.reflect.Field; import java.lang.reflect.Modifier; diff --git a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/instruction/InvokeInstruction.java b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/instruction/InvokeInstruction.java index b393e567fc629..e5944b4d70ef8 100644 --- a/presto-bytecode/src/main/java/com/facebook/presto/bytecode/instruction/InvokeInstruction.java +++ b/presto-bytecode/src/main/java/com/facebook/presto/bytecode/instruction/InvokeInstruction.java @@ -21,6 +21,7 @@ import com.facebook.presto.bytecode.ParameterizedType; import com.google.common.base.CharMatcher; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Streams; import org.objectweb.asm.Handle; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; @@ -38,7 +39,8 @@ import static com.facebook.presto.bytecode.OpCode.INVOKEVIRTUAL; import static com.facebook.presto.bytecode.ParameterizedType.type; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.transform; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; @SuppressWarnings("UnusedDeclaration") @@ -158,12 +160,12 @@ public static InstructionNode invokeConstructor(Constructor constructor) public static InstructionNode invokeConstructor(Class target, Class... parameterTypes) { - return invokeConstructor(type(target), transform(ImmutableList.copyOf(parameterTypes), ParameterizedType::type)); + return invokeConstructor(type(target), stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList())); } public static InstructionNode invokeConstructor(Class target, Iterable> parameterTypes) { - return invokeConstructor(type(target), transform(parameterTypes, ParameterizedType::type)); + return invokeConstructor(type(target), Streams.stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList())); } public static InstructionNode invokeConstructor(ParameterizedType target, ParameterizedType... parameterTypes) @@ -220,7 +222,7 @@ private static InstructionNode invoke(OpCode invocationType, Method method) type(method.getDeclaringClass()), method.getName(), type(method.getReturnType()), - transform(ImmutableList.copyOf(method.getParameterTypes()), ParameterizedType::type)); + stream(method.getParameterTypes()).map(ParameterizedType::type).collect(toImmutableList())); } private static InstructionNode invoke(OpCode invocationType, MethodDefinition method) @@ -247,7 +249,7 @@ private static InstructionNode invoke(OpCode invocationType, Class target, St type(target), name, type(returnType), - transform(parameterTypes, ParameterizedType::type)); + Streams.stream(parameterTypes).map(ParameterizedType::type).collect(toImmutableList())); } // @@ -287,7 +289,7 @@ public static InstructionNode invokeDynamic(String name, { return new InvokeDynamicInstruction(name, type(methodType.returnType()), - transform(methodType.parameterList(), ParameterizedType::type), + methodType.parameterList().stream().map(ParameterizedType::type).collect(toImmutableList()), bootstrapMethod, ImmutableList.copyOf(bootstrapArguments)); } @@ -299,7 +301,7 @@ public static InstructionNode invokeDynamic(String name, { return new InvokeDynamicInstruction(name, type(methodType.returnType()), - transform(methodType.parameterList(), ParameterizedType::type), + methodType.parameterList().stream().map(ParameterizedType::type).collect(toImmutableList()), bootstrapMethod, ImmutableList.copyOf(bootstrapArguments)); } diff --git a/presto-cache/pom.xml b/presto-cache/pom.xml index 31ccb5b946c68..897e4c0a4eca2 100644 --- a/presto-cache/pom.xml +++ b/presto-cache/pom.xml @@ -5,20 +5,22 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-cache + presto-cache Presto cache library ${project.parent.basedir} + true com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -27,8 +29,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -57,18 +59,18 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true @@ -90,6 +92,18 @@ org.alluxio alluxio-core-common + + + jakarta.servlet + jakarta.servlet-api + + + + + + io.dropwizard.metrics + metrics-core + 4.2.37 @@ -106,7 +120,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -118,7 +132,7 @@ - io.airlift + com.facebook.airlift units provided @@ -157,10 +171,11 @@ org.apache.maven.plugins maven-dependency-plugin - + - io.dropwizard.metrics:metrics-core - + io.dropwizard.metrics:metrics-core + org.alluxio:alluxio-core-client-fs + diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/CacheConfig.java b/presto-cache/src/main/java/com/facebook/presto/cache/CacheConfig.java index 7c223b0ab749e..abbbc81c3d1bd 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/CacheConfig.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/CacheConfig.java @@ -15,10 +15,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.CacheQuotaScope; -import io.airlift.units.DataSize; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.Optional; diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/CacheManager.java b/presto-cache/src/main/java/com/facebook/presto/cache/CacheManager.java index 4b4b3e26ed989..44f24c8cf2715 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/CacheManager.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/CacheManager.java @@ -14,10 +14,9 @@ package com.facebook.presto.cache; import com.facebook.presto.hive.CacheQuota; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; -import javax.annotation.concurrent.ThreadSafe; - @ThreadSafe public interface CacheManager { diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/CacheStats.java b/presto-cache/src/main/java/com/facebook/presto/cache/CacheStats.java index 52ba6a4bb5d8d..d2648a6286a63 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/CacheStats.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/CacheStats.java @@ -13,10 +13,9 @@ */ package com.facebook.presto.cache; +import com.google.errorprone.annotations.ThreadSafe; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.ThreadSafe; - import java.util.concurrent.atomic.AtomicLong; @ThreadSafe diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/CachingModule.java b/presto-cache/src/main/java/com/facebook/presto/cache/CachingModule.java index 75863b1372606..c58cd597e3f78 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/CachingModule.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/CachingModule.java @@ -22,8 +22,7 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/ForCachingFileSystem.java b/presto-cache/src/main/java/com/facebook/presto/cache/ForCachingFileSystem.java index 54781e30b3527..5ce88872721c7 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/ForCachingFileSystem.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/ForCachingFileSystem.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cache; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCacheConfig.java b/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCacheConfig.java index a9b02ffe0e6b7..4a44b7b82e89f 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCacheConfig.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCacheConfig.java @@ -15,10 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; -import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCachingConfigurationProvider.java b/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCachingConfigurationProvider.java index 4006715a001f1..780228d01fc6f 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCachingConfigurationProvider.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/alluxio/AlluxioCachingConfigurationProvider.java @@ -16,10 +16,9 @@ import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.hive.DynamicConfigurationProvider; import com.facebook.presto.hive.HdfsContext; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import java.net.URI; import static com.facebook.presto.cache.CacheType.ALLUXIO; diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheConfig.java b/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheConfig.java index 063590283c62f..f487022e90ab2 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheConfig.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheConfig.java @@ -15,13 +15,12 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; -import javax.validation.constraints.Min; - -import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.TimeUnit.DAYS; public class FileMergeCacheConfig diff --git a/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheManager.java b/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheManager.java index 246715560e843..bb99d84098ee2 100644 --- a/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheManager.java +++ b/presto-cache/src/main/java/com/facebook/presto/cache/filemerge/FileMergeCacheManager.java @@ -15,6 +15,7 @@ import alluxio.collections.ConcurrentHashSet; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.cache.CacheManager; import com.facebook.presto.cache.CacheResult; @@ -31,12 +32,10 @@ import com.google.common.collect.RangeMap; import com.google.common.collect.TreeRangeMap; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.apache.hadoop.fs.Path; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.io.File; import java.io.FileInputStream; import java.io.IOException; @@ -59,10 +58,10 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.stream.Collectors; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterators.getOnlyElement; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.StrictMath.toIntExact; import static java.nio.file.StandardOpenOption.APPEND; import static java.nio.file.StandardOpenOption.CREATE_NEW; diff --git a/presto-cache/src/test/java/com/facebook/presto/cache/TestCacheConfig.java b/presto-cache/src/test/java/com/facebook/presto/cache/TestCacheConfig.java index 11312cf785a67..ae0fd164f52af 100644 --- a/presto-cache/src/test/java/com/facebook/presto/cache/TestCacheConfig.java +++ b/presto-cache/src/test/java/com/facebook/presto/cache/TestCacheConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.cache; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.net.URI; diff --git a/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCacheConfig.java b/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCacheConfig.java index 863320a4e3a96..2b030285e8cc3 100644 --- a/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCacheConfig.java +++ b/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCacheConfig.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.cache.alluxio; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -23,8 +23,8 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MINUTES; diff --git a/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCachingFileSystem.java b/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCachingFileSystem.java index 23230af340be8..cbe2b4b41cc57 100644 --- a/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCachingFileSystem.java +++ b/presto-cache/src/test/java/com/facebook/presto/cache/alluxio/TestAlluxioCachingFileSystem.java @@ -17,11 +17,11 @@ import alluxio.metrics.MetricKey; import alluxio.metrics.MetricsSystem; import alluxio.util.io.FileUtils; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.hive.CacheQuota; import com.facebook.presto.hive.HiveFileContext; import com.facebook.presto.hive.filesystem.ExtendedFileSystem; -import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; @@ -48,14 +48,14 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.cache.CacheType.ALLUXIO; import static com.facebook.presto.cache.TestingCacheUtils.stressTest; import static com.facebook.presto.cache.TestingCacheUtils.validateBuffer; import static com.facebook.presto.hive.CacheQuota.NO_CACHE_CONSTRAINTS; import static com.facebook.presto.hive.CacheQuotaScope.TABLE; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.nio.file.Files.createTempDirectory; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; diff --git a/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheConfig.java b/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheConfig.java index a6d8257783eba..d63432b905e7b 100644 --- a/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheConfig.java +++ b/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheConfig.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.cache.filemerge; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -23,8 +23,8 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheManager.java b/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheManager.java index 29d52b90b0fbd..d3c37cb9df433 100644 --- a/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheManager.java +++ b/presto-cache/src/test/java/com/facebook/presto/cache/filemerge/TestFileMergeCacheManager.java @@ -13,14 +13,14 @@ */ package com.facebook.presto.cache.filemerge; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.cache.CacheManager; import com.facebook.presto.cache.CacheStats; import com.facebook.presto.cache.FileReadRequest; import com.facebook.presto.hive.CacheQuota; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.hadoop.fs.Path; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -38,12 +38,12 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.cache.TestingCacheUtils.stressTest; import static com.facebook.presto.cache.TestingCacheUtils.validateBuffer; import static com.facebook.presto.hive.CacheQuota.NO_CACHE_CONSTRAINTS; import static com.google.common.base.Preconditions.checkState; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.nio.file.Files.createTempDirectory; import static java.nio.file.StandardOpenOption.CREATE_NEW; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-cassandra/pom.xml b/presto-cassandra/pom.xml index 1b5f1cdf95d8b..b8053164da016 100644 --- a/presto-cassandra/pom.xml +++ b/presto-cassandra/pom.xml @@ -4,15 +4,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-cassandra + presto-cassandra Presto - Cassandra Connector presto-plugin ${project.parent.basedir} + true @@ -27,24 +29,24 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true - com.facebook.airlift - bootstrap + jakarta.annotation + jakarta.annotation-api com.facebook.airlift - json + bootstrap com.facebook.airlift - concurrent + json @@ -68,13 +70,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -101,7 +103,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -113,7 +115,7 @@ - io.airlift + com.facebook.airlift units provided @@ -124,6 +126,11 @@ provided + + com.facebook.presto + presto-plugin-toolkit + + com.facebook.presto @@ -143,6 +150,12 @@ test + + com.facebook.airlift + concurrent + test + + com.facebook.presto presto-tpch @@ -179,10 +192,6 @@ test - - com.facebook.airlift - security - @@ -191,7 +200,7 @@ org.slf4j slf4j-api - 1.7.36 + 2.0.16 @@ -232,6 +241,15 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + com.fasterxml.jackson.core:jackson-databind + + + @@ -255,18 +273,14 @@ + - test-cassandra-integration-smoke-test + ci-full-tests org.apache.maven.plugins maven-surefire-plugin - - - **/TestCassandraIntegrationSmokeTest.java - - diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java index 59b095bf212ab..97c08fcc086f1 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientConfig.java @@ -20,15 +20,14 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; import com.facebook.airlift.configuration.DefunctConfig; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MaxDuration; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; -import io.airlift.units.MaxDuration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; -import javax.validation.constraints.Size; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; import java.io.File; import java.util.Arrays; @@ -76,6 +75,7 @@ public class CassandraClientConfig private String truststorePassword; private File keystorePath; private String keystorePassword; + private boolean caseSensitiveNameMatchingEnabled; @NotNull @Size(min = 1) @@ -476,4 +476,16 @@ public CassandraClientConfig setTruststorePassword(String truststorePassword) this.truststorePassword = truststorePassword; return this; } + + public boolean isCaseSensitiveNameMatchingEnabled() + { + return caseSensitiveNameMatchingEnabled; + } + + @Config("case-sensitive-name-matching") + public CassandraClientConfig setCaseSensitiveNameMatchingEnabled(boolean caseSensitiveNameMatchingEnabled) + { + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; + return this; + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java index 364779296f020..c6a3242f0b7ff 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraClientModule.java @@ -25,13 +25,12 @@ import com.datastax.driver.core.policies.TokenAwarePolicy; import com.datastax.driver.core.policies.WhiteListPolicy; import com.facebook.airlift.json.JsonCodec; -import com.facebook.presto.cassandra.util.SslContextProvider; +import com.facebook.presto.plugin.base.security.SslContextProvider; import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import java.net.InetSocketAddress; import java.util.ArrayList; @@ -156,6 +155,6 @@ public static CassandraSession createCassandraSession( contactPoints.forEach(clusterBuilder::addContactPoint); return clusterBuilder.build(); }), - config.getNoHostAvailableRetryTimeout()); + config.getNoHostAvailableRetryTimeout(), config.isCaseSensitiveNameMatchingEnabled()); } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraColumnHandle.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraColumnHandle.java index a5ca33a1a0ff1..5675804b7331b 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraColumnHandle.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraColumnHandle.java @@ -20,8 +20,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.MoreObjects.ToStringHelper; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Objects; @@ -138,6 +137,15 @@ public ColumnMetadata getColumnMetadata() .build(); } + public ColumnMetadata getColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(cassandraType.getNativeType()) + .setHidden(hidden) + .build(); + } + public Type getType() { return cassandraType.getNativeType(); diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java index f787974d4b371..016f2c8022465 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraConnector.java @@ -23,8 +23,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java index c8e07887d091a..da86b2293fdcc 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraMetadata.java @@ -15,7 +15,6 @@ import com.datastax.driver.core.ProtocolVersion; import com.facebook.airlift.json.JsonCodec; -import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; @@ -42,8 +41,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -53,16 +51,15 @@ import java.util.stream.Collectors; import static com.facebook.presto.cassandra.CassandraType.toCassandraType; -import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; -import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validTableName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.cqlNameToSqlName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validColumnName; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.PERMISSION_DENIED; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; public class CassandraMetadata implements ConnectorMetadata @@ -72,6 +69,7 @@ public class CassandraMetadata private final CassandraPartitionManager partitionManager; private final boolean allowDropTable; private final ProtocolVersion protocolVersion; + private boolean caseSensitiveNameMatchingEnabled; private final JsonCodec> extraColumnMetadataCodec; @@ -89,13 +87,14 @@ public CassandraMetadata( this.allowDropTable = requireNonNull(config, "config is null").getAllowDropTable(); this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); this.protocolVersion = requireNonNull(config, "config is null").getProtocolVersion(); + this.caseSensitiveNameMatchingEnabled = requireNonNull(config, "config is null").isCaseSensitiveNameMatchingEnabled(); } @Override public List listSchemaNames(ConnectorSession session) { return cassandraSession.getCaseSensitiveSchemaNames().stream() - .map(name -> name.toLowerCase(ENGLISH)) + .map(name -> normalizeIdentifier(session, name)) .collect(toImmutableList()); } @@ -121,15 +120,15 @@ private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { requireNonNull(tableHandle, "tableHandle is null"); - return getTableMetadata(getTableName(tableHandle)); + return getTableMetadata(session, getTableName(tableHandle)); } - private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName tableName) { CassandraTable table = cassandraSession.getTable(tableName); List columns = table.getColumns().stream() - .map(CassandraColumnHandle::getColumnMetadata) - .collect(toList()); + .map(column -> column.getColumnMetadata(normalizeIdentifier(session, cqlNameToSqlName(column.getName())))) + .collect(toImmutableList()); return new ConnectorTableMetadata(tableName, columns); } @@ -140,7 +139,8 @@ public List listTables(ConnectorSession session, String schemaN for (String schemaName : listSchemas(session, schemaNameOrNull)) { try { for (String tableName : cassandraSession.getCaseSensitiveTableNames(schemaName)) { - tableNames.add(new SchemaTableName(schemaName, tableName.toLowerCase(ENGLISH))); + String normalizedTableName = normalizeIdentifier(session, tableName); + tableNames.add(new SchemaTableName(schemaName, normalizedTableName)); } } catch (SchemaNotFoundException e) { @@ -166,7 +166,9 @@ public Map getColumnHandles(ConnectorSession session, Conn CassandraTable table = cassandraSession.getTable(getTableName(tableHandle)); ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (CassandraColumnHandle columnHandle : table.getColumns()) { - columnHandles.put(CassandraCqlUtils.cqlNameToSqlName(columnHandle.getName()).toLowerCase(ENGLISH), columnHandle); + String columnName = cqlNameToSqlName(columnHandle.getName()); + String normalizedColumnName = normalizeIdentifier(session, columnName); + columnHandles.put(normalizedColumnName, columnHandle); } return columnHandles.build(); } @@ -178,7 +180,7 @@ public Map> listTableColumns(ConnectorSess ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { try { - columns.put(tableName, getTableMetadata(tableName).getColumns()); + columns.put(tableName, getTableMetadata(session, tableName).getColumns()); } catch (NotFoundException e) { // table disappeared during listing operation @@ -202,7 +204,11 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { CassandraTableHandle handle = (CassandraTableHandle) table; CassandraPartitionResult partitionResult = partitionManager.getPartitions(handle, constraint.getSummary()); @@ -225,7 +231,7 @@ public List getTableLayouts(ConnectorSession session handle, partitionResult.getPartitions(), clusteringKeyPredicates)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, unenforcedConstraint)); + return new ConnectorTableLayoutResult(layout, unenforcedConstraint); } @Override @@ -245,7 +251,7 @@ public String toString() @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, boolean ignoreExisting) { - createTable(tableMetadata); + createTable(session, tableMetadata); } @Override @@ -274,10 +280,10 @@ public void renameTable(ConnectorSession session, ConnectorTableHandle tableHand @Override public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout) { - return createTable(tableMetadata); + return createTable(session, tableMetadata); } - private CassandraOutputTableHandle createTable(ConnectorTableMetadata tableMetadata) + private CassandraOutputTableHandle createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { ImmutableList.Builder columnNames = ImmutableList.builder(); ImmutableList.Builder columnTypes = ImmutableList.builder(); @@ -297,12 +303,13 @@ private CassandraOutputTableHandle createTable(ConnectorTableMetadata tableMetad List types = columnTypes.build(); StringBuilder queryBuilder = new StringBuilder(String.format("CREATE TABLE \"%s\".\"%s\"(id uuid primary key", schemaName, tableName)); for (int i = 0; i < columns.size(); i++) { - String name = columns.get(i); + String columnName = columns.get(i); + String finalColumnName = validColumnName(normalizeIdentifier(session, columnName)); Type type = types.get(i); queryBuilder.append(", ") - .append(name) + .append(finalColumnName) .append(" ") - .append(toCassandraType(type, protocolVersion).name().toLowerCase(ENGLISH)); + .append(toCassandraType(type, protocolVersion).name().toLowerCase(ROOT)); } queryBuilder.append(") "); @@ -336,13 +343,13 @@ public ConnectorInsertTableHandle beginInsert(ConnectorSession session, Connecto SchemaTableName schemaTableName = new SchemaTableName(table.getSchemaName(), table.getTableName()); List columns = cassandraSession.getTable(schemaTableName).getColumns(); - List columnNames = columns.stream().map(CassandraColumnHandle::getName).map(CassandraCqlUtils::validColumnName).collect(Collectors.toList()); + List columnNames = columns.stream().map(CassandraColumnHandle::getName).collect(Collectors.toList()); List columnTypes = columns.stream().map(CassandraColumnHandle::getType).collect(Collectors.toList()); return new CassandraInsertTableHandle( connectorId, - validSchemaName(table.getSchemaName()), - validTableName(table.getTableName()), + table.getSchemaName(), + table.getTableName(), columnNames, columnTypes); } @@ -352,4 +359,10 @@ public Optional finishInsert(ConnectorSession session, { return Optional.empty(); } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSink.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSink.java index 60011f48f861a..ee51ec707a8e6 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSink.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSink.java @@ -39,6 +39,9 @@ import static com.datastax.driver.core.querybuilder.QueryBuilder.bindMarker; import static com.datastax.driver.core.querybuilder.QueryBuilder.insertInto; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validColumnName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validSchemaName; +import static com.facebook.presto.cassandra.util.CassandraCqlUtils.validTableName; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; @@ -92,14 +95,14 @@ public CassandraPageSink( toCassandraDate = value -> LocalDate.fromDaysSinceEpoch(toIntExact(value)); } - Insert insert = insertInto(schemaName, tableName); + Insert insert = insertInto(validSchemaName(schemaName), validTableName(tableName)); if (generateUUID) { insert.value("id", bindMarker()); } for (int i = 0; i < columnNames.size(); i++) { String columnName = columnNames.get(i); checkArgument(columnName != null, "columnName is null at position: %d", i); - insert.value(columnName, bindMarker()); + insert.value(validColumnName(columnName), bindMarker()); } this.insert = cassandraSession.prepare(insert); } diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSinkProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSinkProvider.java index 8518647539a7e..37c4d9e0710cc 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSinkProvider.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPageSinkProvider.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java index 7c68969c8f0fb..20034e925a8c1 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraPartitionManager.java @@ -24,8 +24,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Sets; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.List; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSetProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSetProvider.java index f8d450bc76d70..93d4a519e8934 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSetProvider.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraRecordSetProvider.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSessionProperties.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSessionProperties.java index 24e9e123e4b4e..72715ff119c24 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSessionProperties.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSessionProperties.java @@ -16,8 +16,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java index e36ef9d0bb6f2..e94faf013066d 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraSplitManager.java @@ -25,8 +25,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashMap; import java.util.HashSet; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java index 16cbba7071678..27b536a23cde3 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraTokenSplitManager.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.List; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraType.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraType.java index 6b68da673e8ab..6ddee8fd53524 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraType.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/CassandraType.java @@ -17,8 +17,10 @@ import com.datastax.driver.core.LocalDate; import com.datastax.driver.core.ProtocolVersion; import com.datastax.driver.core.Row; +import com.datastax.driver.core.TupleValue; import com.datastax.driver.core.utils.Bytes; import com.facebook.presto.cassandra.util.CassandraCqlUtils; +import com.facebook.presto.common.NotSupportedException; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; @@ -54,6 +56,7 @@ import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public enum CassandraType @@ -81,7 +84,8 @@ public enum CassandraType VARINT(createUnboundedVarcharType(), BigInteger.class), LIST(createUnboundedVarcharType(), null), MAP(createUnboundedVarcharType(), null), - SET(createUnboundedVarcharType(), null); + SET(createUnboundedVarcharType(), null), + TUPLE(createUnboundedVarcharType(), null); private static class Constants { @@ -162,6 +166,8 @@ public static CassandraType getCassandraType(DataType.Name name) return TIMEUUID; case TINYINT: return TINYINT; + case TUPLE: + return TUPLE; case UUID: return UUID; case VARCHAR: @@ -169,7 +175,7 @@ public static CassandraType getCassandraType(DataType.Name name) case VARINT: return VARINT; default: - return null; + throw new NotSupportedException(format("Unsupported Cassandra type: %s", name)); } } @@ -231,6 +237,9 @@ public static NullableValue getColumnValue(Row row, int position, CassandraType case MAP: checkTypeArguments(cassandraType, 2, typeArguments); return NullableValue.of(nativeType, utf8Slice(buildMapValue(row, position, typeArguments.get(0), typeArguments.get(1)))); + case TUPLE: + TupleValue tupleValue = row.getTupleValue(position); + return NullableValue.of(nativeType, utf8Slice(tupleValue.toString())); default: throw new IllegalStateException("Handling of type " + cassandraType + " is not implemented"); diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java index 06ec9e47ef564..44c9b91b1f478 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/NativeCassandraSession.java @@ -28,6 +28,7 @@ import com.datastax.driver.core.Session; import com.datastax.driver.core.Statement; import com.datastax.driver.core.TableMetadata; +import com.datastax.driver.core.TableOptionsMetadata; import com.datastax.driver.core.TokenRange; import com.datastax.driver.core.VersionNumber; import com.datastax.driver.core.exceptions.NoHostAvailableException; @@ -38,6 +39,7 @@ import com.datastax.driver.core.querybuilder.Select; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.cassandra.util.CassandraCqlUtils; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; @@ -56,7 +58,6 @@ import com.google.common.collect.Ordering; import com.google.common.collect.Sets; import com.google.common.util.concurrent.UncheckedExecutionException; -import io.airlift.units.Duration; import java.nio.Buffer; import java.nio.ByteBuffer; @@ -65,6 +66,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Stream; @@ -84,7 +86,7 @@ import static com.google.common.collect.Iterables.transform; import static java.lang.String.format; import static java.util.Comparator.comparing; -import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.stream.Collectors.joining; @@ -102,7 +104,7 @@ public class NativeCassandraSession public KeyspaceMetadata load(String key) throws Exception { - return getKeyspaceByCaseInsensitiveName0(key); + return getKeyspaceByCaseSensitiveName0(key); } }); @@ -116,14 +118,16 @@ public KeyspaceMetadata load(String key) private final Cluster cluster; private final Supplier session; private final Duration noHostAvailableRetryTimeout; + private static boolean caseSensitiveNameMatchingEnabled; - public NativeCassandraSession(String connectorId, JsonCodec> extraColumnMetadataCodec, Cluster cluster, Duration noHostAvailableRetryTimeout) + public NativeCassandraSession(String connectorId, JsonCodec> extraColumnMetadataCodec, Cluster cluster, Duration noHostAvailableRetryTimeout, boolean caseSensitiveNameMatchingEnabled) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null"); this.cluster = requireNonNull(cluster, "cluster is null"); this.noHostAvailableRetryTimeout = requireNonNull(noHostAvailableRetryTimeout, "noHostAvailableRetryTimeout is null"); this.session = memoize(cluster::connect); + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; } @Override @@ -170,9 +174,9 @@ public Set getReplicas(String caseSensitiveSchemaName, ByteBuffer partitio } @Override - public String getCaseSensitiveSchemaName(String caseInsensitiveSchemaName) + public String getCaseSensitiveSchemaName(String caseSensitiveSchemaName) { - return getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName).getName(); + return getKeyspaceByCaseSensitiveName(caseSensitiveSchemaName).getName(); } @Override @@ -187,10 +191,10 @@ public List getCaseSensitiveSchemaNames() } @Override - public List getCaseSensitiveTableNames(String caseInsensitiveSchemaName) + public List getCaseSensitiveTableNames(String caseSensitiveSchemaName) throws SchemaNotFoundException { - KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(caseInsensitiveSchemaName); + KeyspaceMetadata keyspace = getKeyspaceByCaseSensitiveName(caseSensitiveSchemaName); ImmutableList.Builder builder = ImmutableList.builder(); for (TableMetadata table : keyspace.getTables()) { builder.add(table.getName()); @@ -205,21 +209,24 @@ public List getCaseSensitiveTableNames(String caseInsensitiveSchemaName) public CassandraTable getTable(SchemaTableName schemaTableName) throws TableNotFoundException { - KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(schemaTableName.getSchemaName()); + KeyspaceMetadata keyspace = getKeyspaceByCaseSensitiveName(schemaTableName.getSchemaName()); AbstractTableMetadata tableMeta = getTableMetadata(keyspace, schemaTableName.getTableName()); List columnNames = new ArrayList<>(); List columns = tableMeta.getColumns(); - checkColumnNames(columns); + if (!caseSensitiveNameMatchingEnabled) { + checkColumnNames(columns); + } + for (ColumnMetadata columnMetadata : columns) { columnNames.add(columnMetadata.getName()); } // check if there is a comment to establish column ordering - String comment = tableMeta.getOptions().getComment(); + Optional comment = Optional.ofNullable(tableMeta.getOptions()).map(TableOptionsMetadata::getComment); Set hiddenColumns = ImmutableSet.of(); - if (comment != null && comment.startsWith(PRESTO_COMMENT_METADATA)) { - String columnOrderingString = comment.substring(PRESTO_COMMENT_METADATA.length()); + if (comment.isPresent() && comment.get().startsWith(PRESTO_COMMENT_METADATA)) { + String columnOrderingString = comment.get().substring(PRESTO_COMMENT_METADATA.length()); // column ordering List extras = extraColumnMetadataCodec.fromJson(columnOrderingString); @@ -269,11 +276,11 @@ public CassandraTable getTable(SchemaTableName schemaTableName) return new CassandraTable(tableHandle, sortedColumnHandles); } - private KeyspaceMetadata getKeyspaceByCaseInsensitiveName(String caseInsensitiveSchemaName) + private KeyspaceMetadata getKeyspaceByCaseSensitiveName(String caseSensitiveSchemaName) throws SchemaNotFoundException { try { - return keyspaceCache.get(caseInsensitiveSchemaName); + return keyspaceCache.get(caseSensitiveSchemaName); } catch (UncheckedExecutionException | ExecutionException e) { Throwable cause = e.getCause(); @@ -289,7 +296,7 @@ private KeyspaceMetadata getKeyspaceByCaseInsensitiveName(String caseInsensitive } } - private KeyspaceMetadata getKeyspaceByCaseInsensitiveName0(String caseInsensitiveSchemaName) + private KeyspaceMetadata getKeyspaceByCaseSensitiveName0(String caseSensitiveSchemaName) throws SchemaNotFoundException { List keyspaces = executeWithSession(session -> session.getCluster().getMetadata().getKeyspaces()); @@ -297,31 +304,43 @@ private KeyspaceMetadata getKeyspaceByCaseInsensitiveName0(String caseInsensitiv // Ensure that the error message is deterministic List sortedKeyspaces = Ordering.from(comparing(KeyspaceMetadata::getName)).immutableSortedCopy(keyspaces); for (KeyspaceMetadata keyspace : sortedKeyspaces) { - if (keyspace.getName().equalsIgnoreCase(caseInsensitiveSchemaName)) { + if (namesMatch(keyspace.getName(), caseSensitiveSchemaName, caseSensitiveNameMatchingEnabled)) { + if (caseSensitiveNameMatchingEnabled) { + result = keyspace; + break; + } if (result != null) { throw new PrestoException( NOT_SUPPORTED, - format("More than one keyspace has been found for the case insensitive schema name: %s -> (%s, %s)", - caseInsensitiveSchemaName, result.getName(), keyspace.getName())); + format("More than one keyspace has been found for the schema name: %s -> (%s, %s)", + caseSensitiveSchemaName.toLowerCase(ROOT), result.getName(), keyspace.getName())); } result = keyspace; } } + if (result == null) { - throw new SchemaNotFoundException(caseInsensitiveSchemaName); + throw new SchemaNotFoundException(caseSensitiveSchemaName); } return result; } - private static AbstractTableMetadata getTableMetadata(KeyspaceMetadata keyspace, String caseInsensitiveTableName) + private static boolean namesMatch(String actualName, String expectedName, boolean caseSensitive) + { + return caseSensitive + ? actualName.equals(expectedName) + : actualName.equalsIgnoreCase(expectedName); + } + + private static AbstractTableMetadata getTableMetadata(KeyspaceMetadata keyspace, String caseSensitiveTableName) { List tables = Stream.concat( keyspace.getTables().stream(), keyspace.getMaterializedViews().stream()) - .filter(table -> table.getName().equalsIgnoreCase(caseInsensitiveTableName)) + .filter(table -> namesMatch(table.getName(), caseSensitiveTableName, caseSensitiveNameMatchingEnabled)) .collect(toImmutableList()); if (tables.size() == 0) { - throw new TableNotFoundException(new SchemaTableName(keyspace.getName(), caseInsensitiveTableName)); + throw new TableNotFoundException(new SchemaTableName(keyspace.getName(), caseSensitiveTableName)); } else if (tables.size() == 1) { return tables.get(0); @@ -333,12 +352,12 @@ else if (tables.size() == 1) { throw new PrestoException( NOT_SUPPORTED, format("More than one table has been found for the case insensitive table name: %s -> (%s)", - caseInsensitiveTableName, tableNames)); + caseSensitiveTableName.toLowerCase(ROOT), tableNames)); } public boolean isMaterializedView(SchemaTableName schemaTableName) { - KeyspaceMetadata keyspace = getKeyspaceByCaseInsensitiveName(schemaTableName.getSchemaName()); + KeyspaceMetadata keyspace = getKeyspaceByCaseSensitiveName(schemaTableName.getSchemaName()); return keyspace.getMaterializedView(schemaTableName.getTableName()) != null; } @@ -346,14 +365,14 @@ private static void checkColumnNames(List columns) { Map lowercaseNameToColumnMap = new HashMap<>(); for (ColumnMetadata column : columns) { - String lowercaseName = column.getName().toLowerCase(ENGLISH); - if (lowercaseNameToColumnMap.containsKey(lowercaseName)) { + String columnNameKey = column.getName().toLowerCase(ROOT); + if (lowercaseNameToColumnMap.containsKey(columnNameKey)) { throw new PrestoException( NOT_SUPPORTED, format("More than one column has been found for the case insensitive column name: %s -> (%s, %s)", - lowercaseName, lowercaseNameToColumnMap.get(lowercaseName).getName(), column.getName())); + columnNameKey, lowercaseNameToColumnMap.get(columnNameKey).getName(), column.getName())); } - lowercaseNameToColumnMap.put(lowercaseName, column); + lowercaseNameToColumnMap.put(columnNameKey, column); } } @@ -361,7 +380,7 @@ private CassandraColumnHandle buildColumnHandle(AbstractTableMetadata tableMetad { CassandraType cassandraType = CassandraType.getCassandraType(columnMeta.getType().getName()); List typeArguments = null; - if (cassandraType != null && cassandraType.getTypeArgumentSize() > 0) { + if (cassandraType.getTypeArgumentSize() > 0) { List typeArgs = columnMeta.getType().getTypeArguments(); switch (cassandraType.getTypeArgumentSize()) { case 1: diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/RebindSafeMBeanServer.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/RebindSafeMBeanServer.java index 49dc3c6642368..9525145e4e510 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/RebindSafeMBeanServer.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/RebindSafeMBeanServer.java @@ -14,8 +14,8 @@ package com.facebook.presto.cassandra; import com.facebook.airlift.log.Logger; +import com.google.errorprone.annotations.ThreadSafe; -import javax.annotation.concurrent.ThreadSafe; import javax.management.Attribute; import javax.management.AttributeList; import javax.management.AttributeNotFoundException; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java index 0ebdb248ddf70..4518701f67ef5 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/ReopeningCluster.java @@ -17,9 +17,8 @@ import com.datastax.driver.core.Cluster; import com.datastax.driver.core.DelegatingCluster; import com.facebook.airlift.log.Logger; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.function.Supplier; diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/CassandraCqlUtils.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/CassandraCqlUtils.java index b1ec56df43d60..4aa9631d48b27 100644 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/CassandraCqlUtils.java +++ b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/CassandraCqlUtils.java @@ -28,8 +28,6 @@ import java.util.List; import java.util.Set; -import static java.util.Locale.ENGLISH; - public final class CassandraCqlUtils { private CassandraCqlUtils() @@ -53,12 +51,12 @@ private CassandraCqlUtils() public static String validSchemaName(String identifier) { - return validIdentifier(identifier); + return quoteIdentifier(identifier); } public static String validTableName(String identifier) { - return validIdentifier(identifier); + return quoteIdentifier(identifier); } public static String validColumnName(String identifier) @@ -67,19 +65,7 @@ public static String validColumnName(String identifier) return "\"\""; } - return validIdentifier(identifier); - } - - private static String validIdentifier(String identifier) - { - if (!identifier.equals(identifier.toLowerCase(ENGLISH))) { - return quoteIdentifier(identifier); - } - - if (keywords.contains(identifier.toUpperCase(ENGLISH))) { - return quoteIdentifier(identifier); - } - return identifier; + return quoteIdentifier(identifier); } private static String quoteIdentifier(String identifier) diff --git a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/SslContextProvider.java b/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/SslContextProvider.java deleted file mode 100644 index 9dbbcacea1bed..0000000000000 --- a/presto-cassandra/src/main/java/com/facebook/presto/cassandra/util/SslContextProvider.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.cassandra.util; - -import com.facebook.presto.spi.PrestoException; - -import javax.net.ssl.KeyManager; -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.net.ssl.X509TrustManager; -import javax.security.auth.x500.X500Principal; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.security.GeneralSecurityException; -import java.security.KeyStore; -import java.security.cert.Certificate; -import java.security.cert.CertificateExpiredException; -import java.security.cert.CertificateNotYetValidException; -import java.security.cert.X509Certificate; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; - -import static com.facebook.airlift.security.pem.PemReader.loadKeyStore; -import static com.facebook.airlift.security.pem.PemReader.readCertificateChain; -import static com.facebook.presto.cassandra.CassandraErrorCode.CASSANDRA_SSL_INITIALIZATION_FAILURE; -import static java.util.Collections.list; - -public class SslContextProvider -{ - private final Optional keystorePath; - private final Optional keystorePassword; - private final Optional truststorePath; - private final Optional truststorePassword; - - public SslContextProvider( - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional truststorePassword) - { - this.keystorePath = keystorePath; - this.keystorePassword = keystorePassword; - this.truststorePath = truststorePath; - this.truststorePassword = truststorePassword; - } - - public Optional buildSslContext() - { - if (!keystorePath.isPresent() && !truststorePath.isPresent()) { - return Optional.empty(); - } - try { - // load KeyStore if configured and get KeyManagers - KeyStore keystore = null; - KeyManager[] keyManagers = null; - if (keystorePath.isPresent()) { - char[] keyManagerPassword; - try { - // attempt to read the key store as a PEM file - keystore = loadKeyStore(keystorePath.get(), keystorePath.get(), keystorePassword); - // for PEM encoded keys, the password is used to decrypt the specific key (and does not - // protect the keystore itself) - keyManagerPassword = new char[0]; - } - catch (IOException | GeneralSecurityException ignored) { - keyManagerPassword = keystorePassword.map(String::toCharArray).orElse(null); - keystore = KeyStore.getInstance(KeyStore.getDefaultType()); - try (InputStream in = new FileInputStream(keystorePath.get())) { - keystore.load(in, keyManagerPassword); - } - } - validateCertificates(keystore); - KeyManagerFactory keyManagerFactory = - KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); - keyManagerFactory.init(keystore, keyManagerPassword); - keyManagers = keyManagerFactory.getKeyManagers(); - } - // load TrustStore if configured, otherwise use KeyStore - KeyStore truststore = keystore; - if (truststorePath.isPresent()) { - truststore = loadTrustStore(truststorePath.get(), truststorePassword); - } - - // create TrustManagerFactory - TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); - trustManagerFactory.init(truststore); - - // get X509TrustManager - TrustManager[] trustManagers = trustManagerFactory.getTrustManagers(); - if ((trustManagers.length != 1) || !(trustManagers[0] instanceof X509TrustManager)) { - throw new RuntimeException("Unexpected default trust managers:" + Arrays.toString(trustManagers)); - } - - X509TrustManager trustManager = (X509TrustManager) trustManagers[0]; - // create SSLContext - SSLContext result = SSLContext.getInstance("TLS"); - result.init(keyManagers, new TrustManager[] {trustManager}, null); - return Optional.of(result); - } - catch (GeneralSecurityException | IOException e) { - throw new PrestoException(CASSANDRA_SSL_INITIALIZATION_FAILURE, e); - } - } - - public KeyStore loadTrustStore(File trustStorePath, Optional trustStorePassword) - throws IOException, GeneralSecurityException - { - KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType()); - try { - // attempt to read the trust store as a PEM file - List certificateChain = readCertificateChain(trustStorePath); - if (!certificateChain.isEmpty()) { - trustStore.load(null, null); - for (X509Certificate certificate : certificateChain) { - X500Principal principal = certificate.getSubjectX500Principal(); - trustStore.setCertificateEntry(principal.getName(), certificate); - } - return trustStore; - } - } - catch (IOException | GeneralSecurityException ignored) { - } - try (InputStream inputStream = new FileInputStream(trustStorePath)) { - trustStore.load(inputStream, trustStorePassword.map(String::toCharArray).orElse(null)); - } - return trustStore; - } - - public void validateCertificates(KeyStore keyStore) throws GeneralSecurityException - { - for (String alias : list(keyStore.aliases())) { - if (!keyStore.isKeyEntry(alias)) { - continue; - } - Certificate certificate = keyStore.getCertificate(alias); - if (!(certificate instanceof X509Certificate)) { - continue; - } - try { - ((X509Certificate) certificate).checkValidity(); - } - catch (CertificateExpiredException e) { - throw new CertificateExpiredException("KeyStore certificate is expired: " + e.getMessage()); - } - catch (CertificateNotYetValidException e) { - throw new CertificateNotYetValidException("KeyStore certificate is not yet valid: " + e.getMessage()); - } - } - } -} diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java index 202f893dbc000..54c008d646b43 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraQueryRunner.java @@ -19,7 +19,9 @@ import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static com.facebook.presto.cassandra.CassandraTestingUtils.createKeyspace; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -34,7 +36,7 @@ private CassandraQueryRunner() private static boolean tpchLoaded; - public static DistributedQueryRunner createCassandraQueryRunner(CassandraServer server) + public static DistributedQueryRunner createCassandraQueryRunner(CassandraServer server, Map connectorProperties) throws Exception { DistributedQueryRunner queryRunner = new DistributedQueryRunner(createCassandraSession("tpch"), 4); @@ -42,11 +44,13 @@ public static DistributedQueryRunner createCassandraQueryRunner(CassandraServer queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch"); + connectorProperties = new HashMap<>(ImmutableMap.copyOf(connectorProperties)); + connectorProperties.putIfAbsent("cassandra.contact-points", server.getHost()); + connectorProperties.putIfAbsent("cassandra.native-protocol-port", Integer.toString(server.getPort())); + connectorProperties.putIfAbsent("cassandra.allow-drop-table", "true"); + queryRunner.installPlugin(new CassandraPlugin()); - queryRunner.createCatalog("cassandra", "cassandra", ImmutableMap.of( - "cassandra.contact-points", server.getHost(), - "cassandra.native-protocol-port", Integer.toString(server.getPort()), - "cassandra.allow-drop-table", "true")); + queryRunner.createCatalog("cassandra", "cassandra", connectorProperties); createKeyspace(server.getSession(), "tpch"); List> tables = TpchTable.getTables(); diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraServer.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraServer.java index ca04ae6b116ef..c388f6222e84a 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraServer.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraServer.java @@ -14,13 +14,14 @@ package com.facebook.presto.cassandra; import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.Metadata; import com.datastax.driver.core.ResultSet; import com.datastax.driver.core.Row; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableList; import com.google.common.io.Resources; -import io.airlift.units.Duration; import org.testcontainers.containers.GenericContainer; import java.io.Closeable; @@ -58,6 +59,7 @@ public class CassandraServer private final GenericContainer dockerContainer; private final CassandraSession session; + private final Metadata metadata; public CassandraServer() throws Exception @@ -81,7 +83,9 @@ public CassandraServer() "EmbeddedCassandra", JsonCodec.listJsonCodec(ExtraColumnMetadata.class), cluster, - new Duration(1, MINUTES)); + new Duration(1, MINUTES), + false); + this.metadata = cluster.getMetadata(); try { checkConnectivity(session); @@ -117,6 +121,11 @@ public CassandraSession getSession() return requireNonNull(session, "cluster is null"); } + public Metadata getMetadata() + { + return metadata; + } + public String getHost() { return dockerContainer.getContainerIpAddress(); diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java index b8616f7bb03c1..e8dcd09bb63c0 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/CassandraTestingUtils.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.cassandra; +import com.datastax.driver.core.DataType; +import com.datastax.driver.core.Metadata; +import com.datastax.driver.core.TupleType; import com.datastax.driver.core.querybuilder.Insert; import com.datastax.driver.core.querybuilder.QueryBuilder; import com.facebook.presto.spi.SchemaTableName; @@ -42,12 +45,12 @@ public class CassandraTestingUtils private CassandraTestingUtils() {} - public static void createTestTables(CassandraSession cassandraSession, String keyspace, Date date) + public static void createTestTables(CassandraSession cassandraSession, Metadata metadata, String keyspace, Date date) { createKeyspace(cassandraSession, keyspace); - createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date, 9); - createTableAllTypes(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_INSERT), date, 0); - createTableAllTypesPartitionKey(cassandraSession, new SchemaTableName(keyspace, TABLE_ALL_TYPES_PARTITION_KEY), date); + createTableAllTypes(cassandraSession, metadata, new SchemaTableName(keyspace, TABLE_ALL_TYPES), date, 9); + createTableAllTypes(cassandraSession, metadata, new SchemaTableName(keyspace, TABLE_ALL_TYPES_INSERT), date, 0); + createTableAllTypesPartitionKey(cassandraSession, metadata, new SchemaTableName(keyspace, TABLE_ALL_TYPES_PARTITION_KEY), date); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS), 9); createTableClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_CLUSTERING_KEYS_LARGE), 1000); createTableMultiPartitionClusteringKeys(cassandraSession, new SchemaTableName(keyspace, TABLE_MULTI_PARTITION_CLUSTERING_KEYS)); @@ -142,7 +145,7 @@ public static void insertIntoTableClusteringKeysInequality(CassandraSession sess assertEquals(session.execute("SELECT COUNT(*) FROM " + table).all().get(0).getLong(0), rowsCount); } - public static void createTableAllTypes(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + public static void createTableAllTypes(CassandraSession session, Metadata metadata, SchemaTableName table, Date date, int rowsCount) { session.execute("DROP TABLE IF EXISTS " + table); session.execute("CREATE TABLE " + table + " (" + @@ -164,11 +167,12 @@ public static void createTableAllTypes(CassandraSession session, SchemaTableName " typelist list, " + " typemap map, " + " typeset set, " + + " typetuple tuple, " + ")"); - insertTestData(session, table, date, rowsCount); + insertTestData(session, metadata, table, date, rowsCount); } - public static void createTableAllTypesPartitionKey(CassandraSession session, SchemaTableName table, Date date) + public static void createTableAllTypesPartitionKey(CassandraSession session, Metadata metadata, SchemaTableName table, Date date) { session.execute("DROP TABLE IF EXISTS " + table); @@ -191,6 +195,7 @@ public static void createTableAllTypesPartitionKey(CassandraSession session, Sch " typelist frozen >, " + " typemap frozen >, " + " typeset frozen >, " + + " typetuple frozen >, " + " PRIMARY KEY ((" + " key, " + " typeuuid, " + @@ -213,15 +218,18 @@ public static void createTableAllTypesPartitionKey(CassandraSession session, Sch // TODO: NOT YET SUPPORTED AS A PARTITION KEY " typelist, " + " typemap, " + - " typeset" + + " typeset," + + " typetuple" + " ))" + ")"); - insertTestData(session, table, date, 9); + insertTestData(session, metadata, table, date, 9); } - private static void insertTestData(CassandraSession session, SchemaTableName table, Date date, int rowsCount) + private static void insertTestData(CassandraSession session, Metadata metadata, SchemaTableName table, Date date, int rowsCount) { + TupleType tupleType = metadata.newTupleType(DataType.bigint(), DataType.varchar()); + for (int rowNumber = 1; rowNumber <= rowsCount; rowNumber++) { Insert insert = QueryBuilder.insertInto(table.getSchemaName(), table.getTableName()) .value("key", "key " + rowNumber) @@ -241,7 +249,8 @@ private static void insertTestData(CassandraSession session, SchemaTableName tab .value("typetimeuuid", UUID.fromString(String.format("d2177dd0-eaa2-11de-a572-001b779c76e%d", rowNumber))) .value("typelist", ImmutableList.of("list-value-1" + rowNumber, "list-value-2" + rowNumber)) .value("typemap", ImmutableMap.of(rowNumber, rowNumber + 1L, rowNumber + 2, rowNumber + 3L)) - .value("typeset", ImmutableSet.of(false, true)); + .value("typeset", ImmutableSet.of(false, true)) + .value("typetuple", tupleType.newValue((long) rowNumber, "row=" + rowNumber)); session.execute(insert); } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java index 7b6ec6e476abb..3bb4da6fa6f0e 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraClientConfig.java @@ -16,8 +16,8 @@ import com.datastax.driver.core.ConsistencyLevel; import com.datastax.driver.core.SocketOptions; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.File; @@ -65,7 +65,8 @@ public void testDefaults() .setKeystorePassword(null) .setTruststorePath(null) .setTruststorePassword(null) - .setTlsEnabled(false)); + .setTlsEnabled(false) + .setCaseSensitiveNameMatchingEnabled(false)); } @Test @@ -103,6 +104,7 @@ public void testExplicitPropertyMappings() .put("cassandra.tls.keystore-password", "keystore-password") .put("cassandra.tls.truststore-path", "/tmp/truststore") .put("cassandra.tls.truststore-password", "truststore-password") + .put("case-sensitive-name-matching", "true") .build(); CassandraClientConfig expected = new CassandraClientConfig() @@ -136,7 +138,8 @@ public void testExplicitPropertyMappings() .setKeystorePath(new File("/tmp/keystore")) .setKeystorePassword("keystore-password") .setTruststorePath(new File("/tmp/truststore")) - .setTruststorePassword("truststore-password"); + .setTruststorePassword("truststore-password") + .setCaseSensitiveNameMatchingEnabled(true); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java index 3fe4f25c6cca0..49a8a2bdc247d 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraConnector.java @@ -67,8 +67,8 @@ import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING; import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @@ -110,7 +110,7 @@ public void setup() this.server = new CassandraServer(); String keyspace = "test_connector"; - createTestTables(server.getSession(), keyspace, DATE); + createTestTables(server.getSession(), server.getMetadata(), keyspace, DATE); String connectorId = "cassandra-test"; CassandraConnectorFactory connectorFactory = new CassandraConnectorFactory(connectorId); @@ -130,7 +130,7 @@ public void setup() assertInstanceOf(recordSetProvider, CassandraRecordSetProvider.class); database = keyspace; - table = new SchemaTableName(database, TABLE_ALL_TYPES.toLowerCase(ENGLISH)); + table = new SchemaTableName(database, TABLE_ALL_TYPES.toLowerCase(ROOT)); tableUnpartitioned = new SchemaTableName(database, "presto_test_unpartitioned"); invalidTable = new SchemaTableName(database, "totally_invalid_table_name"); } @@ -150,7 +150,7 @@ public void tearDown() public void testGetDatabaseNames() { List databases = metadata.listSchemaNames(SESSION); - assertTrue(databases.contains(database.toLowerCase(ENGLISH))); + assertTrue(databases.contains(database.toLowerCase(ROOT))); } @Test @@ -185,8 +185,8 @@ public void testGetRecords() ConnectorTransactionHandle transaction = CassandraTransactionHandle.INSTANCE; - List layouts = metadata.getTableLayouts(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); - ConnectorTableLayoutHandle layout = getOnlyElement(layouts).getTableLayout().getHandle(); + ConnectorTableLayoutResult layoutResult = metadata.getTableLayoutForConstraint(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); + ConnectorTableLayoutHandle layout = layoutResult.getTableLayout().getHandle(); List splits = getAllSplits(splitManager.getSplits(transaction, SESSION, layout, new SplitSchedulingContext(UNGROUPED_SCHEDULING, false, WarningCollector.NOOP))); long rowNumber = 0; diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java index cded9f30da5d1..c16a78720bb70 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraDistributed.java @@ -16,8 +16,11 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; +import org.testng.annotations.Optional; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; @@ -34,7 +37,7 @@ protected QueryRunner createQueryRunner() throws Exception { this.server = new CassandraServer(); - return CassandraQueryRunner.createCassandraQueryRunner(server); + return CassandraQueryRunner.createCassandraQueryRunner(server, ImmutableMap.of()); } @AfterClass(alwaysRun = true) @@ -114,20 +117,20 @@ public void testUpdate() } @Override - public void testShowColumns() + public void testShowColumns(@Optional("PARQUET") String storageFormat) { MaterializedResult actual = computeActual("SHOW COLUMNS FROM orders"); - MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("orderkey", "bigint", "", "") - .row("custkey", "bigint", "", "") - .row("orderstatus", "varchar", "", "") - .row("totalprice", "double", "", "") - .row("orderdate", "varchar", "", "") - .row("orderpriority", "varchar", "", "") - .row("clerk", "varchar", "", "") - .row("shippriority", "integer", "", "") - .row("comment", "varchar", "", "") + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", Long.valueOf(19), null, null) + .row("custkey", "bigint", "", "", Long.valueOf(19), null, null) + .row("orderstatus", "varchar", "", "", null, null, Long.valueOf(2147483647)) + .row("totalprice", "double", "", "", Long.valueOf(53), null, null) + .row("orderdate", "varchar", "", "", null, null, Long.valueOf(2147483647)) + .row("orderpriority", "varchar", "", "", null, null, Long.valueOf(2147483647)) + .row("clerk", "varchar", "", "", null, null, Long.valueOf(2147483647)) + .row("shippriority", "integer", "", "", Long.valueOf(10), null, null) + .row("comment", "varchar", "", "", null, null, Long.valueOf(2147483647)) .build(); assertEquals(actual, expectedParametrizedVarchar); @@ -150,4 +153,51 @@ public void testWrittenStats() { // TODO Cassandra connector supports CTAS and inserts, but the test would fail } + + @Override + public void testPayloadJoinApplicability() + { + // no op -- test not supported due to lack of support for array types. + } + + @Override + public void testPayloadJoinCorrectness() + { + // no op -- test not supported due to lack of support for array types. + } + @Override + public void testNonAutoCommitTransactionWithCommit() + { + // Connector only supports writes using ctas + } + + @Override + public void testNonAutoCommitTransactionWithRollback() + { + // Connector only supports writes using ctas + } + + @Override + public void testRemoveRedundantCastToVarcharInJoinClause() + { + // no op -- test not supported due to lack of support for array types. + } + + @Override + public void testStringFilters() + { + // no op -- test not supported due to lack of support for char type. + } + + @Override + public void testSubfieldAccessControl() + { + // no op -- test not supported due to lack of support for array types. + } + + @Override + protected String getDateExpression(String storageFormat, String columnExpression) + { + return "cast(" + columnExpression + " as DATE)"; + } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java index 508e8d1cc66de..c7f9a7ecdb510 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntegrationSmokeTest.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cassandra; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.testing.MaterializedResult; @@ -20,7 +21,7 @@ import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -96,8 +97,8 @@ protected QueryRunner createQueryRunner() CassandraServer server = new CassandraServer(); this.server = server; this.session = server.getSession(); - createTestTables(session, KEYSPACE, DATE_TIME_LOCAL); - return createCassandraQueryRunner(server); + createTestTables(session, server.getMetadata(), KEYSPACE, DATE_TIME_LOCAL); + return createCassandraQueryRunner(server, ImmutableMap.of()); } @Test @@ -139,10 +140,14 @@ public void testSelect() @Test public void testCreateTableAs() { - execute("DROP TABLE IF EXISTS table_all_types_copy"); - execute("CREATE TABLE table_all_types_copy AS SELECT * FROM " + TABLE_ALL_TYPES); - assertSelect("table_all_types_copy", true); - execute("DROP TABLE table_all_types_copy"); + try { + execute("DROP TABLE IF EXISTS table_all_types_copy"); + execute("CREATE TABLE table_all_types_copy AS SELECT * FROM " + TABLE_ALL_TYPES); + assertSelect("table_all_types_copy", true); + } + finally { + execute("DROP TABLE IF EXISTS table_all_types_copy"); + } } @Test(enabled = false) @@ -264,145 +269,163 @@ public void testClusteringKeyPushdownInequality() @Test public void testUpperCaseNameUnescapedInCassandra() { - /* - * If an identifier is not escaped with double quotes it is stored as lowercase in the Cassandra metadata - * - * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html - */ - session.execute("CREATE KEYSPACE KEYSPACE_1 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("keyspace_1") - .build(), new Duration(1, MINUTES)); - - session.execute("CREATE TABLE KEYSPACE_1.TABLE_1 (COLUMN_1 bigint PRIMARY KEY)"); - assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_1"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("table_1") - .build(), new Duration(1, MINUTES)); - assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_1.table_1"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType()) - .row("column_1", "bigint", "", "") - .build()); - - execute("INSERT INTO keyspace_1.table_1 (column_1) VALUES (1)"); - - assertEquals(execute("SELECT column_1 FROM cassandra.keyspace_1.table_1").getRowCount(), 1); - assertUpdate("DROP TABLE cassandra.keyspace_1.table_1"); - - // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable - session.execute("DROP KEYSPACE keyspace_1"); + try { + /* + * If an identifier is not escaped with double quotes it is stored as lowercase in the Cassandra metadata + * + * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html + */ + session.execute("CREATE KEYSPACE KEYSPACE_1 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_1") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE KEYSPACE_1.TABLE_1 (COLUMN_1 bigint PRIMARY KEY)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_1"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_1") + .build(), new Duration(1, MINUTES)); + assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_1.table_1"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), BIGINT, BIGINT, BIGINT) + .row("column_1", "bigint", "", "", 19L, null, null) + .build()); + + execute("INSERT INTO keyspace_1.table_1 (column_1) VALUES (1)"); + + assertEquals(execute("SELECT column_1 FROM cassandra.keyspace_1.table_1").getRowCount(), 1); + } + finally { + assertUpdate("DROP TABLE IF EXISTS cassandra.keyspace_1.table_1"); + + // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable + session.execute("DROP KEYSPACE keyspace_1"); + } } @Test public void testUppercaseNameEscaped() { - /* - * If an identifier is escaped with double quotes it is stored verbatim - * - * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html - */ - session.execute("CREATE KEYSPACE \"KEYSPACE_2\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("keyspace_2") - .build(), new Duration(1, MINUTES)); - - session.execute("CREATE TABLE \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\" bigint PRIMARY KEY)"); - assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_2"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("table_2") - .build(), new Duration(1, MINUTES)); - assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_2.table_2"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType()) - .row("column_2", "bigint", "", "") - .build()); - - execute("INSERT INTO \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\") VALUES (1)"); - - assertEquals(execute("SELECT column_2 FROM cassandra.keyspace_2.table_2").getRowCount(), 1); - assertUpdate("DROP TABLE cassandra.keyspace_2.table_2"); - - // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable - session.execute("DROP KEYSPACE \"KEYSPACE_2\""); + try { + /* + * If an identifier is escaped with double quotes it is stored verbatim + * + * http://docs.datastax.com/en/cql/3.1/cql/cql_reference/ucase-lcase_r.html + */ + session.execute("CREATE KEYSPACE \"KEYSPACE_2\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_2") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\" bigint PRIMARY KEY)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_2"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_2") + .build(), new Duration(1, MINUTES)); + assertContains(execute("SHOW COLUMNS FROM cassandra.keyspace_2.table_2"), resultBuilder(getSession(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), createUnboundedVarcharType(), BIGINT, BIGINT, BIGINT) + .row("column_2", "bigint", "", "", 19L, null, null) + .build()); + + execute("INSERT INTO \"KEYSPACE_2\".\"TABLE_2\" (\"COLUMN_2\") VALUES (1)"); + + assertEquals(execute("SELECT column_2 FROM cassandra.keyspace_2.table_2").getRowCount(), 1); + } + finally { + assertUpdate("DROP TABLE IF EXISTS cassandra.keyspace_2.table_2"); + + // when an identifier is unquoted the lowercase and uppercase spelling may be used interchangeable + session.execute("DROP KEYSPACE \"KEYSPACE_2\""); + } } @Test public void testKeyspaceNameAmbiguity() { - // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 keyspaces with names - // that have differences only in letters case. - session.execute("CREATE KEYSPACE \"KeYsPaCe_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - session.execute("CREATE KEYSPACE \"kEySpAcE_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - - // Although in Presto all the schema and table names are always displayed as lowercase - assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("keyspace_3") - .row("keyspace_3") - .build(), new Duration(1, MINUTES)); - - // There is no way to figure out what the exactly keyspace we want to retrieve tables from - assertQueryFailsEventually( - "SHOW TABLES FROM cassandra.keyspace_3", - "More than one keyspace has been found for the case insensitive schema name: keyspace_3 -> \\(KeYsPaCe_3, kEySpAcE_3\\)", - new Duration(1, MINUTES)); - - session.execute("DROP KEYSPACE \"KeYsPaCe_3\""); - session.execute("DROP KEYSPACE \"kEySpAcE_3\""); + try { + // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 keyspaces with names + // that have differences only in letters case. + session.execute("CREATE KEYSPACE \"KeYsPaCe_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + session.execute("CREATE KEYSPACE \"kEySpAcE_3\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + + // Although in Presto all the schema and table names are always displayed as lowercase + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_3") + .row("keyspace_3") + .build(), new Duration(1, MINUTES)); + + // There is no way to figure out what the exactly keyspace we want to retrieve tables from + assertQueryFailsEventually( + "SHOW TABLES FROM cassandra.keyspace_3", + "More than one keyspace has been found for the schema name: keyspace_3 -> \\(KeYsPaCe_3, kEySpAcE_3\\)", + new Duration(1, MINUTES)); + } + finally { + session.execute("DROP KEYSPACE \"KeYsPaCe_3\""); + session.execute("DROP KEYSPACE \"kEySpAcE_3\""); + } } @Test public void testTableNameAmbiguity() throws Exception { - session.execute("CREATE KEYSPACE keyspace_4 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("keyspace_4") - .build(), new Duration(1, MINUTES)); - - // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 tables with names - // that have differences only in letters case. - session.execute("CREATE TABLE keyspace_4.\"TaBlE_4\" (column_4 bigint PRIMARY KEY)"); - session.execute("CREATE TABLE keyspace_4.\"tAbLe_4\" (column_4 bigint PRIMARY KEY)"); - - // This is added for Cassandra to refresh its metadata so that we don't encounter a race condition in the forthcoming steps and achieve eventual consistency. - Thread.sleep(1000); - - // Although in Presto all the schema and table names are always displayed as lowercase - assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_4"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("table_4") - .row("table_4") - .build(), new Duration(1, MINUTES)); - - // There is no way to figure out what the exactly table is being queried - assertQueryFailsEventually( - "SHOW COLUMNS FROM cassandra.keyspace_4.table_4", - "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", - new Duration(1, MINUTES)); - assertQueryFailsEventually( - "SELECT * FROM cassandra.keyspace_4.table_4", - "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", - new Duration(1, MINUTES)); - session.execute("DROP KEYSPACE keyspace_4"); + try { + session.execute("CREATE KEYSPACE keyspace_4 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_4") + .build(), new Duration(1, MINUTES)); + + // Identifiers enclosed in double quotes are stored in Cassandra verbatim. It is possible to create 2 tables with names + // that have differences only in letters case. + session.execute("CREATE TABLE keyspace_4.\"TaBlE_4\" (column_4 bigint PRIMARY KEY)"); + session.execute("CREATE TABLE keyspace_4.\"tAbLe_4\" (column_4 bigint PRIMARY KEY)"); + + // This is added for Cassandra to refresh its metadata so that we don't encounter a race condition in the forthcoming steps and achieve eventual consistency. + Thread.sleep(1000); + + // Although in Presto all the schema and table names are always displayed as lowercase + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_4"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_4") + .row("table_4") + .build(), new Duration(1, MINUTES)); + + // There is no way to figure out what the exactly table is being queried + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_4.table_4", + "More than one table has been found for the case insensitive table name: table_4 -> \\(TaBlE_4, tAbLe_4\\)", + new Duration(1, MINUTES)); + } + finally { + session.execute("DROP KEYSPACE keyspace_4"); + } } @Test public void testColumnNameAmbiguity() { - session.execute("CREATE KEYSPACE keyspace_5 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); - assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("keyspace_5") - .build(), new Duration(1, MINUTES)); - - session.execute("CREATE TABLE keyspace_5.table_5 (\"CoLuMn_5\" bigint PRIMARY KEY, \"cOlUmN_5\" bigint)"); - assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_5"), resultBuilder(getSession(), createUnboundedVarcharType()) - .row("table_5") - .build(), new Duration(1, MINUTES)); - - assertQueryFailsEventually( - "SHOW COLUMNS FROM cassandra.keyspace_5.table_5", - ".*More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", - new Duration(1, MINUTES)); - assertQueryFailsEventually( - "SELECT * FROM cassandra.keyspace_5.table_5", - ".*More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", - new Duration(1, MINUTES)); - - session.execute("DROP KEYSPACE keyspace_5"); + try { + session.execute("CREATE KEYSPACE keyspace_5 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_5") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE keyspace_5.table_5 (\"CoLuMn_5\" bigint PRIMARY KEY, \"cOlUmN_5\" bigint)"); + assertContainsEventually(() -> execute("SHOW TABLES FROM cassandra.keyspace_5"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_5") + .build(), new Duration(1, MINUTES)); + + assertQueryFailsEventually( + "SHOW COLUMNS FROM cassandra.keyspace_5.table_5", + ".*More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); + assertQueryFailsEventually( + "SELECT * FROM cassandra.keyspace_5.table_5", + ".*More than one column has been found for the case insensitive column name: column_5 -> \\(CoLuMn_5, cOlUmN_5\\)", + new Duration(1, MINUTES)); + } + finally { + session.execute("DROP KEYSPACE keyspace_5"); + } } @Test diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntergrationMixedCase.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntergrationMixedCase.java new file mode 100644 index 0000000000000..93456c40d252e --- /dev/null +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraIntergrationMixedCase.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cassandra; + +import com.facebook.airlift.units.Duration; +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.cassandra.CassandraTestingUtils.createKeyspace; +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tests.QueryAssertions.assertContains; +import static com.facebook.presto.tests.QueryAssertions.assertContainsEventually; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test +public class TestCassandraIntergrationMixedCase + extends AbstractTestQueryFramework +{ + private CassandraServer server; + private CassandraSession session; + private static final String KEYSPACE = "test_connector"; + + @Override + protected QueryRunner createQueryRunner() throws Exception + { + this.server = new CassandraServer(); + this.session = server.getSession(); + createKeyspace(session, KEYSPACE); + return CassandraQueryRunner.createCassandraQueryRunner(server, ImmutableMap.of("case-sensitive-name-matching", "true")); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + server.close(); + } + + @Test + public void testCreateTable() + { + Session session = testSessionBuilder() + .setCatalog("cassandra") + .setSchema(KEYSPACE) + .build(); + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATE(name VARCHAR(50), rollNum int)"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATE")); + + getQueryRunner().execute(session, "CREATE TABLE test_create(name VARCHAR(50), rollNum int)"); + assertTrue(getQueryRunner().tableExists(session, "test_create")); + + assertQueryFails(session, "CREATE TABLE TEST_CREATE (name VARCHAR(50), rollNum int)", "line 1:1: Table 'cassandra.test_connector.TEST_CREATE' already exists"); + assertFalse(getQueryRunner().tableExists(session, "Test")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATE"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATE")); + + assertUpdate(session, "DROP TABLE IF EXISTS test_create"); + assertFalse(getQueryRunner().tableExists(session, "test_create")); + } + } + + @Test + public void testCreateTableAs() + { + Session session = testSessionBuilder() + .setCatalog("cassandra") + .setSchema(KEYSPACE) + .build(); + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS")); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS test_createas AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "test_createas")); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.orders o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS_Join")); + + assertQueryFails("CREATE TABLE test_connector.TEST_CREATEAS_FAIL_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.ORDERS o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'", "Table cassandra.tpch.ORDERS does not exist"); //failure scenario since tpch.ORDERS doesn't exist + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS_FAIL_Join")); + + getQueryRunner().execute(session, "CREATE TABLE Test_CreateAs_Mixed_Join AS SELECT Cus.custkey, Ord.orderkey FROM " + + "tpch.customer Cus INNER JOIN tpch.orders Ord ON Cus.custkey = Ord.custkey WHERE Cus.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "Test_CreateAs_Mixed_Join")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATEAS"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS")); + + assertUpdate(session, "DROP TABLE IF EXISTS test_createas"); + assertFalse(getQueryRunner().tableExists(session, "test_createas")); + + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATEAS_Join"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS_Join")); + + assertUpdate(session, "DROP TABLE IF EXISTS Test_CreateAs_Mixed_Join"); + assertFalse(getQueryRunner().tableExists(session, "Test_CreateAs_Mixed_Join")); + } + } + + @Test + public void testDuplicatedColumNameCreateTable() + { + Session session = testSessionBuilder() + .setCatalog("cassandra") + .setSchema(KEYSPACE) + .build(); + try { + getQueryRunner().execute(session, "CREATE TABLE test (a integer, A integer)"); + assertTrue(getQueryRunner().tableExists(session, "test")); + + getQueryRunner().execute(session, "CREATE TABLE TEST (a integer, A integer)"); + assertTrue(getQueryRunner().tableExists(session, "TEST")); + + assertQueryFails("CREATE TABLE Test (a integer, a integer)", "line 1:31: Column name 'a' specified more than once"); + assertFalse(getQueryRunner().tableExists(session, "Test")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS test"); + assertFalse(getQueryRunner().tableExists(session, "test")); + + assertUpdate(session, "DROP TABLE IF EXISTS TEST"); + assertFalse(getQueryRunner().tableExists(session, "TEST")); + } + } + + @Test + public void testSelect() + { + Session session = testSessionBuilder() + .setCatalog("cassandra") + .setSchema(KEYSPACE) + .build(); + try { + getQueryRunner().execute(session, "CREATE TABLE Test_Select AS SELECT * FROM tpch.region where regionkey=3"); + assertTrue(getQueryRunner().tableExists(session, "Test_Select")); + assertQuery("SELECT * from test_connector.Test_Select", "VALUES (3, 'EUROPE', 'ly final courts cajole furiously final excuse')"); + + getQueryRunner().execute(session, "CREATE TABLE test_select AS SELECT * FROM tpch.region LIMIT 2"); + assertQuery("SELECT COUNT(*) FROM test_connector.test_select", "VALUES 2"); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS test_select"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Select"); + } + } + + @Test + public void testShowSchemas() + { + try { + session.execute("CREATE KEYSPACE \"test_keyspace\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + session.execute("CREATE KEYSPACE \"Test_Keyspace\" WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> getQueryRunner().execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("test_keyspace") + .row("Test_Keyspace") + .build(), new Duration(1, MINUTES)); + } + finally { + session.execute("DROP KEYSPACE \"test_keyspace\""); + session.execute("DROP KEYSPACE \"Test_Keyspace\""); + } + } + + @Test + public void testUnicodeColumns() + { + try { + session.execute("CREATE KEYSPACE keyspace_1 WITH REPLICATION = {'class':'SimpleStrategy', 'replication_factor': 1}"); + assertContainsEventually(() -> getQueryRunner().execute("SHOW SCHEMAS FROM cassandra"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("keyspace_1") + .build(), new Duration(1, MINUTES)); + + session.execute("CREATE TABLE keyspace_1.table_1 (COLUMN_1 bigint PRIMARY KEY,\"Test用户表\" bigint)"); + assertContainsEventually(() -> getQueryRunner().execute("SHOW TABLES FROM cassandra.keyspace_1"), resultBuilder(getSession(), createUnboundedVarcharType()) + .row("table_1") + .build(), new Duration(1, MINUTES)); + assertContains( + getQueryRunner().execute("SHOW COLUMNS FROM cassandra.keyspace_1.table_1"), + resultBuilder(getSession(), + createUnboundedVarcharType(), + createUnboundedVarcharType(), + createUnboundedVarcharType(), + createUnboundedVarcharType(), + createUnboundedVarcharType(), + createUnboundedVarcharType(), + createUnboundedVarcharType()) + .row("column_1", "bigint", "", "", Long.valueOf(19), null, null) + .row("Test用户表", "bigint", "", "", Long.valueOf(19), null, null) + .build()); + } + finally { + session.execute("DROP KEYSPACE keyspace_1"); + } + } +} diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java index e18ce7eae042e..8500bab247a52 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/TestCassandraTokenSplitManager.java @@ -76,12 +76,16 @@ public void testEmptyTable() throws Exception { String tableName = "empty_table"; - session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); - server.refreshSizeEstimates(KEYSPACE, tableName); - List splits = splitManager.getSplits(KEYSPACE, tableName, Optional.empty()); - // even for the empty table at least one split must be produced, in case the statistics are inaccurate - assertEquals(splits.size(), 1); - session.execute(format("DROP TABLE %s.%s", KEYSPACE, tableName)); + try { + session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); + server.refreshSizeEstimates(KEYSPACE, tableName); + List splits = splitManager.getSplits(KEYSPACE, tableName, Optional.empty()); + // even for the empty table at least one split must be produced, in case the statistics are inaccurate + assertEquals(splits.size(), 1); + } + finally { + session.execute(format("DROP TABLE IF EXISTS %s.%s", KEYSPACE, tableName)); + } } @Test @@ -89,13 +93,17 @@ public void testNonEmptyTable() throws Exception { String tableName = "non_empty_table"; - session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); - for (int i = 0; i < PARTITION_COUNT; i++) { - session.execute(format("INSERT INTO %s.%s (key) VALUES ('%s')", KEYSPACE, tableName, "value" + i)); + try { + session.execute(format("CREATE TABLE %s.%s (key text PRIMARY KEY)", KEYSPACE, tableName)); + for (int i = 0; i < PARTITION_COUNT; i++) { + session.execute(format("INSERT INTO %s.%s (key) VALUES ('%s')", KEYSPACE, tableName, "value" + i)); + } + server.refreshSizeEstimates(KEYSPACE, tableName); + List splits = splitManager.getSplits(KEYSPACE, tableName, Optional.empty()); + assertEquals(splits.size(), PARTITION_COUNT / SPLIT_SIZE); + } + finally { + session.execute(format("DROP TABLE IF EXISTS %s.%s", KEYSPACE, tableName)); } - server.refreshSizeEstimates(KEYSPACE, tableName); - List splits = splitManager.getSplits(KEYSPACE, tableName, Optional.empty()); - assertEquals(splits.size(), PARTITION_COUNT / SPLIT_SIZE); - session.execute(format("DROP TABLE %s.%s", KEYSPACE, tableName)); } } diff --git a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraCqlUtils.java b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraCqlUtils.java index ea3b0a2ddaf55..503d14f4e1c14 100644 --- a/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraCqlUtils.java +++ b/presto-cassandra/src/test/java/com/facebook/presto/cassandra/util/TestCassandraCqlUtils.java @@ -32,14 +32,14 @@ public class TestCassandraCqlUtils @Test public void testValidSchemaName() { - assertEquals("foo", validSchemaName("foo")); + assertEquals("\"foo\"", validSchemaName("foo")); assertEquals("\"select\"", validSchemaName("select")); } @Test public void testValidTableName() { - assertEquals("foo", validTableName("foo")); + assertEquals("\"foo\"", validTableName("foo")); assertEquals("\"Foo\"", validTableName("Foo")); assertEquals("\"select\"", validTableName("select")); } @@ -47,7 +47,7 @@ public void testValidTableName() @Test public void testValidColumnName() { - assertEquals("foo", validColumnName("foo")); + assertEquals("\"foo\"", validColumnName("foo")); assertEquals("\"\"", validColumnName(CassandraCqlUtils.EMPTY_COLUMN_NAME)); assertEquals("\"\"", validColumnName("")); assertEquals("\"select\"", validColumnName("select")); @@ -80,6 +80,6 @@ public void testAppendSelectColumns() CassandraCqlUtils.appendSelectColumns(sb, columns); String str = sb.toString(); - assertEquals("foo,bar,\"table\"", str); + assertEquals("\"foo\",\"bar\",\"table\"", str); } } diff --git a/presto-cli/pom.xml b/presto-cli/pom.xml index fbce3e3fc30ac..a88fa78a2c41a 100644 --- a/presto-cli/pom.xml +++ b/presto-cli/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-cli @@ -14,6 +14,7 @@ ${project.parent.basedir} com.facebook.presto.cli.Presto + true @@ -38,7 +39,7 @@ - io.airlift + com.facebook.airlift airline @@ -58,19 +59,18 @@ - io.airlift + com.facebook.airlift units - com.google.code.findbugs - jsr305 - true + jakarta.annotation + jakarta.annotation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -108,6 +108,12 @@ jackson-core + + com.google.errorprone + error_prone_annotations + true + + org.testng @@ -124,30 +130,6 @@ - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - true - executable - - - - ${main-class} - - - - - - - - org.basepom.maven duplicate-finder-maven-plugin @@ -160,23 +142,71 @@ - - org.skife.maven - really-executable-jar-maven-plugin + org.apache.maven.plugins + maven-dependency-plugin - -Xmx1G - executable + + com.facebook.airlift:json + - - - package - - really-executable-jar - - - + + + + executable-jar + + + !skipExecutableJar + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + package + + shade + + + true + executable + false + + + + ${main-class} + + + + + + + + + + org.skife.maven + really-executable-jar-maven-plugin + + -Xmx1G + executable + + + + package + + really-executable-jar + + + + + + + + diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java index 6a05483e1e511..f3f15e9324953 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/ClientOptions.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.cli; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ClientSession; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.CharMatcher; @@ -21,7 +22,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; import io.airlift.airline.Option; -import io.airlift.units.Duration; import java.net.URI; import java.net.URISyntaxException; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java index 0deff11d0a1fb..95a0dc76d4078 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/Console.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/Console.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logging; import com.facebook.airlift.log.LoggingConfiguration; +import com.facebook.airlift.units.Duration; import com.facebook.presto.cli.ClientOptions.OutputFormat; import com.facebook.presto.client.ClientSession; import com.facebook.presto.spi.security.SelectedRole; @@ -25,14 +26,12 @@ import com.google.common.io.Files; import io.airlift.airline.Command; import io.airlift.airline.HelpOption; -import io.airlift.units.Duration; +import jakarta.inject.Inject; import jline.console.history.FileHistory; import jline.console.history.History; import jline.console.history.MemoryHistory; import org.fusesource.jansi.AnsiConsole; -import javax.inject.Inject; - import java.io.File; import java.io.IOException; import java.io.PrintStream; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/FormatUtils.java b/presto-cli/src/main/java/com/facebook/presto/cli/FormatUtils.java index 877c9dfa30722..6a06a7b9c6f9f 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/FormatUtils.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/FormatUtils.java @@ -13,16 +13,16 @@ */ package com.facebook.presto.cli; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.math.RoundingMode; import java.text.DecimalFormat; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.repeat; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.String.format; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/KeyReader.java b/presto-cli/src/main/java/com/facebook/presto/cli/KeyReader.java index 7a75fa24603f3..a242a98511c92 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/KeyReader.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/KeyReader.java @@ -18,11 +18,12 @@ import java.io.IOException; import java.io.InputStream; -import static org.fusesource.jansi.internal.CLibrary.STDIN_FILENO; import static org.fusesource.jansi.internal.CLibrary.isatty; public final class KeyReader { + private static final int STDIN_FILENO = 0; + private KeyReader() {} @SuppressWarnings("resource") diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/OutputHandler.java b/presto-cli/src/main/java/com/facebook/presto/cli/OutputHandler.java index 81f41ce32be9f..6985c0741d19c 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/OutputHandler.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/OutputHandler.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.cli; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.StatementClient; -import io.airlift.units.Duration; import java.io.Closeable; import java.io.IOException; @@ -22,7 +22,7 @@ import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; -import static io.airlift.units.Duration.nanosSince; +import static com.facebook.airlift.units.Duration.nanosSince; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/Pager.java b/presto-cli/src/main/java/com/facebook/presto/cli/Pager.java index 28e282a1ac347..15ef6ddbc93bb 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/Pager.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/Pager.java @@ -14,8 +14,7 @@ package com.facebook.presto.cli; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.FilterOutputStream; import java.io.IOException; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java index cf6adcd1028ee..286892cc6e659 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/QueryPreprocessor.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.cli; +import com.facebook.airlift.units.Duration; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.io.CharStreams; -import io.airlift.units.Duration; import sun.misc.Signal; import sun.misc.SignalHandler; diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java index f18ac4088868f..0c49ff4a38e36 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/StatusPrinter.java @@ -14,6 +14,8 @@ package com.facebook.presto.cli; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.QueryStatusInfo; import com.facebook.presto.client.StageStats; import com.facebook.presto.client.StatementClient; @@ -22,8 +24,6 @@ import com.facebook.presto.common.RuntimeUnit; import com.google.common.base.Strings; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.io.PrintStream; import java.util.Comparator; @@ -32,6 +32,10 @@ import java.util.OptionalInt; import java.util.concurrent.atomic.AtomicInteger; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.Duration.nanosSince; +import static com.facebook.airlift.units.Duration.succinctDuration; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.cli.FormatUtils.formatCount; import static com.facebook.presto.cli.FormatUtils.formatCountRate; import static com.facebook.presto.cli.FormatUtils.formatDataRate; @@ -41,10 +45,6 @@ import static com.facebook.presto.cli.FormatUtils.pluralize; import static com.facebook.presto.cli.KeyReader.readKey; import static com.google.common.base.Verify.verify; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.Duration.nanosSince; -import static io.airlift.units.Duration.succinctDuration; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Character.toUpperCase; import static java.lang.Math.max; import static java.lang.Math.min; @@ -179,9 +179,6 @@ public void printFinalInfo(Optional clientStopTimestamp) Duration serverSideWallTime = succinctDuration(stats.getElapsedTimeMillis(), MILLISECONDS); int nodes = stats.getNodes(); - if ((nodes == 0) || (stats.getTotalSplits() == 0)) { - return; - } // blank line out.println(); diff --git a/presto-cli/src/main/java/com/facebook/presto/cli/ThreadInterruptor.java b/presto-cli/src/main/java/com/facebook/presto/cli/ThreadInterruptor.java index 3dd03f252e705..31cb5105aa8e2 100644 --- a/presto-cli/src/main/java/com/facebook/presto/cli/ThreadInterruptor.java +++ b/presto-cli/src/main/java/com/facebook/presto/cli/ThreadInterruptor.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cli; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.Closeable; diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/AbstractCliTest.java b/presto-cli/src/test/java/com/facebook/presto/cli/AbstractCliTest.java index eb2267e1bd797..97e8804910d04 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/AbstractCliTest.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/AbstractCliTest.java @@ -14,6 +14,7 @@ package com.facebook.presto.cli; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ClientSession; import com.facebook.presto.client.Column; import com.facebook.presto.client.QueryResults; @@ -22,7 +23,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import okhttp3.Headers; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestFormatUtils.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestFormatUtils.java index 16b642071781f..2c38bb3d4ca90 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/TestFormatUtils.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestFormatUtils.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.cli; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import org.testng.annotations.Test; import java.util.concurrent.TimeUnit; diff --git a/presto-clickhouse/pom.xml b/presto-clickhouse/pom.xml old mode 100755 new mode 100644 index 36968b6cd7251..98b1757308542 --- a/presto-clickhouse/pom.xml +++ b/presto-clickhouse/pom.xml @@ -4,15 +4,17 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-clickhouse + presto-clickhouse Presto - Clickhouse Connector presto-plugin ${project.parent.basedir} + true @@ -45,8 +47,8 @@ - com.google.code.findbugs - jsr305 + jakarta.annotation + jakarta.annotation-api @@ -60,8 +62,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -78,7 +80,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -90,7 +92,7 @@ - io.airlift + com.facebook.airlift units provided @@ -116,8 +118,8 @@ log - javax.validation - validation-api + jakarta.validation + jakarta.validation-api com.facebook.presto @@ -219,6 +221,10 @@ + + com.facebook.presto + presto-base-jdbc + @@ -259,4 +265,4 @@ - \ No newline at end of file + diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseClient.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseClient.java index 8359b748f5ef1..1592dff1fde01 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseClient.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseClient.java @@ -28,19 +28,18 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.statistics.TableStatistics; -import com.google.common.base.CharMatcher; import com.google.common.base.Joiner; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.annotation.Nullable; -import javax.inject.Inject; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -63,23 +62,18 @@ import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.SmallintType.SMALLINT; -import static com.facebook.presto.common.type.TimeType.TIME; -import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; -import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.TinyintType.TINYINT; -import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.plugin.clickhouse.ClickHouseEngineType.MERGETREE; import static com.facebook.presto.plugin.clickhouse.ClickHouseErrorCode.JDBC_ERROR; import static com.facebook.presto.plugin.clickhouse.ClickhouseDXLKeyWords.ORDER_BY_PROPERTY; import static com.facebook.presto.plugin.clickhouse.StandardReadMappings.jdbcTypeToPrestoType; +import static com.facebook.presto.plugin.jdbc.JdbcWarningCode.USE_OF_DEPRECATED_CONFIGURATION_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; @@ -88,27 +82,13 @@ import static java.sql.ResultSetMetaData.columnNullable; import static java.util.Collections.nCopies; import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; public class ClickHouseClient { private static final Logger log = Logger.get(ClickHouseClient.class); - private static final Map SQL_TYPES = ImmutableMap.builder() - .put(BOOLEAN, "boolean") - .put(BIGINT, "bigint") - .put(INTEGER, "integer") - .put(SMALLINT, "smallint") - .put(TINYINT, "tinyint") - .put(DOUBLE, "double precision") - .put(REAL, "real") - .put(VARBINARY, "varbinary") - .put(DATE, "Date") - .put(TIME, "time") - .put(TIME_WITH_TIME_ZONE, "time with timezone") - .put(TIMESTAMP, "timestamp") - .put(TIMESTAMP_WITH_TIME_ZONE, "timestamp with timezone") - .build(); private static final String tempTableNamePrefix = "tmp_presto_"; protected static final String identifierQuote = "\""; protected final String connectorId; @@ -117,6 +97,7 @@ public class ClickHouseClient protected final int commitBatchSize; protected final Cache> remoteSchemaNames; protected final Cache> remoteTableNames; + protected final boolean caseSensitiveNameMatchingEnabled; private final boolean mapStringAsVarchar; @@ -133,6 +114,7 @@ public ClickHouseClient(ClickHouseConnectorId connectorId, ClickHouseConfig conf .expireAfterWrite(config.getCaseInsensitiveNameMatchingCacheTtl().toMillis(), MILLISECONDS); this.remoteSchemaNames = remoteNamesCacheBuilder.build(); this.remoteTableNames = remoteNamesCacheBuilder.build(); + this.caseSensitiveNameMatchingEnabled = config.isCaseSensitiveNameMatching(); } public int getCommitBatchSize() @@ -140,16 +122,16 @@ public int getCommitBatchSize() return commitBatchSize; } - public List getTableNames(ClickHouseIdentity identity, Optional schema) + public List getTableNames(ConnectorSession session, ClickHouseIdentity identity, Optional schema) { try (Connection connection = connectionFactory.openConnection(identity)) { - Optional remoteSchema = schema.map(schemaName -> toRemoteSchemaName(identity, connection, schemaName)); + Optional remoteSchema = schema.map(schemaName -> toRemoteSchemaName(session, identity, connection, schemaName)); try (ResultSet resultSet = getTables(connection, remoteSchema, Optional.empty())) { ImmutableList.Builder list = ImmutableList.builder(); while (resultSet.next()) { String tableSchema = getTableSchemaName(resultSet); String tableName = resultSet.getString("TABLE_NAME"); - list.add(new SchemaTableName(tableSchema.toLowerCase(ENGLISH), tableName.toLowerCase(ENGLISH))); + list.add(new SchemaTableName(normalizeIdentifier(tableSchema), normalizeIdentifier(tableName))); } return list.build(); } @@ -175,7 +157,7 @@ public final Set getSchemaNames(ClickHouseIdentity identity) { try (Connection connection = connectionFactory.openConnection(identity)) { return listSchemas(connection).stream() - .map(schemaName -> schemaName.toLowerCase(ENGLISH)) + .map(this::normalizeIdentifier) .collect(toImmutableSet()); } catch (SQLException e) { @@ -183,7 +165,7 @@ public final Set getSchemaNames(ClickHouseIdentity identity) } } - public ConnectorSplitSource getSplits(ClickHouseIdentity identity, ClickHouseTableLayoutHandle layoutHandle) + public ConnectorSplitSource getSplits(ClickHouseTableLayoutHandle layoutHandle) { ClickHouseTableHandle tableHandle = layoutHandle.getTable(); ClickHouseSplit clickHouseSplit = new ClickHouseSplit( @@ -211,7 +193,7 @@ public List getColumns(ConnectorSession session, ClickHo resultSet.getInt("DECIMAL_DIGITS"), Optional.empty(), Optional.empty()); - Optional columnMapping = toPrestoType(session, typeHandle); + Optional columnMapping = toPrestoType(typeHandle); // skip unsupported column types if (columnMapping.isPresent()) { String columnName = resultSet.getString("COLUMN_NAME"); @@ -233,7 +215,7 @@ public List getColumns(ConnectorSession session, ClickHo } } - public Optional toPrestoType(ConnectorSession session, ClickHouseTypeHandle typeHandle) + public Optional toPrestoType(ClickHouseTypeHandle typeHandle) { return jdbcTypeToPrestoType(typeHandle, mapStringAsVarchar); } @@ -291,7 +273,7 @@ public String buildInsertSql(ClickHouseOutputTableHandle handle) String columns = Joiner.on(',').join(nCopies(handle.getColumnNames().size(), "?")); return new StringBuilder() .append("INSERT INTO ") - .append(quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTemporaryTableName())) + .append(quoted(handle.getSchemaName(), handle.getTemporaryTableName())) .append(" VALUES (").append(columns).append(")") .toString(); } @@ -314,11 +296,11 @@ protected Collection listSchemas(Connection connection) } } - public ClickHouseTableHandle getTableHandle(ClickHouseIdentity identity, SchemaTableName schemaTableName) + public ClickHouseTableHandle getTableHandle(ConnectorSession session, ClickHouseIdentity identity, SchemaTableName schemaTableName) { try (Connection connection = connectionFactory.openConnection(identity)) { - String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); - String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); + String remoteSchema = toRemoteSchemaName(session, identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(session, identity, connection, remoteSchema, schemaTableName.getTableName()); try (ResultSet resultSet = getTables(connection, Optional.of(remoteSchema), Optional.of(remoteTable))) { List tableHandles = new ArrayList<>(); while (resultSet.next()) { @@ -387,7 +369,7 @@ private static String escapeNamePattern(String name, String escape) return name; } - protected String quoted(@Nullable String catalog, @Nullable String schema, String table) + protected String quoted(@Nullable String schema, String table) { StringBuilder builder = new StringBuilder(); if (!isNullOrEmpty(schema)) { @@ -404,16 +386,10 @@ public void addColumn(ClickHouseIdentity identity, ClickHouseTableHandle handle, String columnName = column.getName(); String sql = format( "ALTER TABLE %s ADD COLUMN %s", - quoted(handle.getCatalogName(), schema, table), + quoted(schema, table), getColumnDefinitionSql(column, columnName)); try (Connection connection = connectionFactory.openConnection(identity)) { - DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { - schema = schema != null ? schema.toUpperCase(ENGLISH) : null; - table = table.toUpperCase(ENGLISH); - columnName = columnName.toUpperCase(ENGLISH); - } execute(connection, sql); } catch (SQLException e) { @@ -448,8 +424,8 @@ public void dropColumn(ClickHouseIdentity identity, ClickHouseTableHandle handle try (Connection connection = connectionFactory.openConnection(identity)) { String sql = format( "ALTER TABLE %s DROP COLUMN %s", - quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName()), - column.getColumnName()); + quoted(handle.getSchemaName(), handle.getTableName()), + quoted(column.getColumnName())); execute(connection, sql); } catch (SQLException e) { @@ -459,8 +435,8 @@ public void dropColumn(ClickHouseIdentity identity, ClickHouseTableHandle handle public void finishInsertTable(ClickHouseIdentity identity, ClickHouseOutputTableHandle handle) { - String temporaryTable = quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTemporaryTableName()); - String targetTable = quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName()); + String temporaryTable = quoted(handle.getSchemaName(), handle.getTemporaryTableName()); + String targetTable = quoted(handle.getSchemaName(), handle.getTableName()); String insertSql = format("INSERT INTO %s SELECT * FROM %s", targetTable, temporaryTable); String cleanupSql = "DROP TABLE " + temporaryTable; @@ -531,15 +507,11 @@ public void renameColumn(ClickHouseIdentity identity, ClickHouseTableHandle hand { String sql = format( "ALTER TABLE %s RENAME COLUMN %s TO %s", - quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName()), - clickHouseColumn.getColumnName(), - newColumnName); + quoted(handle.getSchemaName(), handle.getTableName()), + quoted(clickHouseColumn.getColumnName()), + quoted(newColumnName)); try (Connection connection = connectionFactory.openConnection(identity)) { - DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { - newColumnName = newColumnName.toUpperCase(ENGLISH); - } execute(connection, sql); } catch (SQLException e) { @@ -559,8 +531,8 @@ public ClickHouseOutputTableHandle beginInsertTable(ConnectorTableMetadata table try (Connection connection = connectionFactory.openConnection(identity)) { boolean uppercase = connection.getMetaData().storesUpperCaseIdentifiers(); - String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); - String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); + String remoteSchema = toRemoteSchemaName(session, identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(session, identity, connection, remoteSchema, schemaTableName.getTableName()); if (uppercase) { tableName = tableName.toUpperCase(ENGLISH); } @@ -604,9 +576,9 @@ public ClickHouseOutputTableHandle createTable(ConnectorTableMetadata tableMetad try (Connection connection = connectionFactory.openConnection(identity)) { boolean uppercase = connection.getMetaData().storesUpperCaseIdentifiers(); - String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); - String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); - if (uppercase) { + String remoteSchema = toRemoteSchemaName(session, identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(session, identity, connection, remoteSchema, schemaTableName.getTableName()); + if (uppercase && !caseSensitiveNameMatchingEnabled) { tableName = tableName.toUpperCase(ENGLISH); } String catalog = connection.getCatalog(); @@ -616,7 +588,7 @@ public ClickHouseOutputTableHandle createTable(ConnectorTableMetadata tableMetad ImmutableList.Builder columnList = ImmutableList.builder(); for (ColumnMetadata column : tableMetadata.getColumns()) { String columnName = column.getName(); - if (uppercase) { + if (uppercase && !caseSensitiveNameMatchingEnabled) { columnName = columnName.toUpperCase(ENGLISH); } columnNames.add(columnName); @@ -639,13 +611,15 @@ public ClickHouseOutputTableHandle createTable(ConnectorTableMetadata tableMetad } } - protected String toRemoteTableName(ClickHouseIdentity identity, Connection connection, String remoteSchema, String tableName) + protected String toRemoteTableName(ConnectorSession session, ClickHouseIdentity identity, Connection connection, String remoteSchema, String tableName) { requireNonNull(remoteSchema, "remoteSchema is null"); requireNonNull(tableName, "tableName is null"); - verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(tableName), "Expected table name from internal metadata to be lowercase: %s", tableName); if (caseSensitiveEnabled) { + session.getWarningCollector().add(new PrestoWarning(USE_OF_DEPRECATED_CONFIGURATION_PROPERTY, + "'clickhouse.case-insensitive' is deprecated. Use of this configuration value may lead to query failures. " + + "Please switch to using 'case-sensitive-name-matching' for proper case sensitivity behavior.")); try { com.facebook.presto.plugin.clickhouse.RemoteTableNameCacheKey cacheKey = new com.facebook.presto.plugin.clickhouse.RemoteTableNameCacheKey(identity, remoteSchema); Map mapping = remoteTableNames.getIfPresent(cacheKey); @@ -669,7 +643,7 @@ protected String toRemoteTableName(ClickHouseIdentity identity, Connection conne try { DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { return tableName.toUpperCase(ENGLISH); } return tableName; @@ -708,7 +682,7 @@ public void dropTable(ClickHouseIdentity identity, ClickHouseTableHandle handle) { StringBuilder sql = new StringBuilder() .append("DROP TABLE ") - .append(quoted(handle.getCatalogName(), handle.getSchemaName(), handle.getTableName())); + .append(quoted(handle.getSchemaName(), handle.getTableName())); try (Connection connection = connectionFactory.openConnection(identity)) { execute(connection, sql.toString()); @@ -735,7 +709,7 @@ public void renameTable(ClickHouseIdentity identity, ClickHouseTableHandle handl renameTable(identity, handle.getCatalogName(), handle.getSchemaTableName(), newTable); } - public void createSchema(ClickHouseIdentity identity, String schemaName, Map properties) + public void createSchema(ClickHouseIdentity identity, String schemaName) { try (Connection connection = connectionFactory.openConnection(identity)) { execute(connection, "CREATE DATABASE " + quoted(schemaName)); @@ -761,20 +735,11 @@ protected void renameTable(ClickHouseIdentity identity, String catalogName, Sche String tableName = oldTable.getTableName(); String newSchemaName = newTable.getSchemaName(); String newTableName = newTable.getTableName(); - String sql = format("RENAME TABLE %s.%s TO %s.%s", - quoted(schemaName), - quoted(tableName), - quoted(newTable.getSchemaName()), - quoted(newTable.getTableName())); + String sql = format("RENAME TABLE %s TO %s", + quoted(schemaName, tableName), + quoted(newSchemaName, newTableName)); try (Connection connection = connectionFactory.openConnection(identity)) { - DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { - schemaName = schemaName.toUpperCase(ENGLISH); - tableName = tableName.toUpperCase(ENGLISH); - newSchemaName = newSchemaName.toUpperCase(ENGLISH); - newTableName = newTableName.toUpperCase(ENGLISH); - } execute(connection, sql); } catch (SQLException e) { @@ -892,8 +857,8 @@ protected void copyTableSchema(ClickHouseIdentity identity, String catalogName, String newCreateTableName = newTableName.getTableName(); String sql = format( "CREATE TABLE %s AS %s ", - quoted(null, schemaName, newCreateTableName), - quoted(null, schemaName, oldCreateTableName)); + quoted(schemaName, newCreateTableName), + quoted(schemaName, oldCreateTableName)); try (Connection connection = connectionFactory.openConnection(identity)) { execute(connection, sql); @@ -908,17 +873,18 @@ protected void copyTableSchema(ClickHouseIdentity identity, String catalogName, private String quoted(RemoteTableName remoteTableName) { return quoted( - remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); } - protected String toRemoteSchemaName(ClickHouseIdentity identity, Connection connection, String schemaName) + protected String toRemoteSchemaName(ConnectorSession session, ClickHouseIdentity identity, Connection connection, String schemaName) { requireNonNull(schemaName, "schemaName is null"); - verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(schemaName), "Expected schema name from internal metadata to be lowercase: %s", schemaName); if (caseSensitiveEnabled) { + session.getWarningCollector().add(new PrestoWarning(USE_OF_DEPRECATED_CONFIGURATION_PROPERTY, + "'clickhouse.case-insensitive' is deprecated. Use of this configuration value may lead to query failures. " + + "Please switch to using 'case-sensitive-name-matching' for proper case sensitivity behavior.")); try { Map mapping = remoteSchemaNames.getIfPresent(identity); if (mapping != null && !mapping.containsKey(schemaName)) { @@ -941,7 +907,7 @@ protected String toRemoteSchemaName(ClickHouseIdentity identity, Connection conn try { DatabaseMetaData metadata = connection.getMetaData(); - if (metadata.storesUpperCaseIdentifiers()) { + if (metadata.storesUpperCaseIdentifiers() && !caseSensitiveNameMatchingEnabled) { return schemaName.toUpperCase(ENGLISH); } return schemaName; @@ -956,4 +922,9 @@ protected Map listSchemasByLowerCase(Connection connection) return listSchemas(connection).stream() .collect(toImmutableMap(schemaName -> schemaName.toLowerCase(ENGLISH), schemaName -> schemaName)); } + + public String normalizeIdentifier(String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseColumnHandle.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseColumnHandle.java index 616715b7764dd..5cbfebc09e1b9 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseColumnHandle.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseColumnHandle.java @@ -142,6 +142,17 @@ public ColumnMetadata getColumnMetadata() .build(); } + public ColumnMetadata getColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(columnType) + .setNullable(nullable) + .setHidden(false) + .setProperties(emptyMap()) + .build(); + } + @Override public boolean equals(Object obj) { diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConfig.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConfig.java index d8f867cf9674c..2df52939d6155 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConfig.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConfig.java @@ -16,11 +16,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.annotation.Nullable; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; import static java.util.concurrent.TimeUnit.MINUTES; @@ -36,6 +35,7 @@ public class ClickHouseConfig private boolean mapStringAsVarchar; private boolean allowDropTable; private int commitBatchSize; + private boolean caseSensitiveNameMatchingEnabled; @NotNull public String getConnectionUrl() @@ -103,12 +103,18 @@ public ClickHouseConfig setPasswordCredential(String passwordCredential) return this; } + @Deprecated public boolean isCaseInsensitiveNameMatching() { return caseInsensitiveNameMatching; } + @Deprecated @Config("clickhouse.case-insensitive") + @ConfigDescription("Deprecated: This will be removed in future releases. Use 'case-sensitive-name-matching=true' instead for clickhouse. " + + "This configuration setting converts all schema/table names to lowercase. " + + "If your source database contains names differing only by case (e.g., 'Testdb' and 'testdb'), " + + "this setting can lead to conflicts and query failures.") public ClickHouseConfig setCaseInsensitiveNameMatching(boolean caseInsensitiveNameMatching) { this.caseInsensitiveNameMatching = caseInsensitiveNameMatching; @@ -168,4 +174,18 @@ public ClickHouseConfig setCommitBatchSize(int commitBatchSize) this.commitBatchSize = commitBatchSize; return this; } + + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatchingEnabled; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching of schema, table names across the connector. " + + "When disabled, names are matched case-insensitively using lowercase normalization.") + public ClickHouseConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatchingEnabled) + { + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; + return this; + } } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConnector.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConnector.java index a3ced140b87da..f98389b84eabb 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConnector.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseConnector.java @@ -34,8 +34,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadata.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadata.java index 07b2181812186..0b4649115056d 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadata.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadata.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -48,6 +49,7 @@ import static com.facebook.presto.spi.StandardErrorCode.PERMISSION_DENIED; import static com.google.common.base.Preconditions.checkState; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; public class ClickHouseMetadata @@ -58,11 +60,14 @@ public class ClickHouseMetadata private final boolean allowDropTable; private final AtomicReference rollbackAction = new AtomicReference<>(); + private final ClickHouseConfig clickHouseConfig; - public ClickHouseMetadata(ClickHouseClient clickHouseClient, boolean allowDropTable) + @Inject + public ClickHouseMetadata(ClickHouseClient clickHouseClient, boolean allowDropTable, ClickHouseConfig clickHouseConfig) { this.clickHouseClient = requireNonNull(clickHouseClient, "client is null"); this.allowDropTable = allowDropTable; + this.clickHouseConfig = clickHouseConfig; } @Override @@ -80,15 +85,19 @@ public List listSchemaNames(ConnectorSession session) @Override public ClickHouseTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) { - return clickHouseClient.getTableHandle(ClickHouseIdentity.from(session), tableName); + return clickHouseClient.getTableHandle(session, ClickHouseIdentity.from(session), tableName); } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { ClickHouseTableHandle tableHandle = (ClickHouseTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new ClickHouseTableLayoutHandle(tableHandle, constraint.getSummary(), Optional.empty(), Optional.empty(), Optional.empty())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -96,15 +105,15 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa { return new ConnectorTableLayout(handle); } - @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) + { ClickHouseTableHandle handle = (ClickHouseTableHandle) table; ImmutableList.Builder columnMetadata = ImmutableList.builder(); for (ClickHouseColumnHandle column : clickHouseClient.getColumns(session, handle)) { - columnMetadata.add(column.getColumnMetadata()); + columnMetadata.add(column.getColumnMetadata(normalizeIdentifier(session, column.getColumnName()))); } return new ConnectorTableMetadata(handle.getSchemaTableName(), columnMetadata.build()); } @@ -112,7 +121,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect @Override public List listTables(ConnectorSession session, Optional schemaName) { - return clickHouseClient.getTableNames(ClickHouseIdentity.from(session), schemaName); + return clickHouseClient.getTableNames(session, ClickHouseIdentity.from(session), schemaName); } @Override @@ -122,7 +131,7 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder columnHandles = ImmutableMap.builder(); for (ClickHouseColumnHandle column : clickHouseClient.getColumns(session, clickHouseTableHandle)) { - columnHandles.put(column.getColumnMetadata().getName(), column); + columnHandles.put(normalizeIdentifier(session, column.getColumnMetadata(column.getColumnName()).getName()), column); } return columnHandles.build(); } @@ -140,7 +149,7 @@ public Map> listTableColumns(ConnectorSess } for (SchemaTableName tableName : tables) { try { - ClickHouseTableHandle tableHandle = clickHouseClient.getTableHandle(ClickHouseIdentity.from(session), tableName); + ClickHouseTableHandle tableHandle = clickHouseClient.getTableHandle(session, ClickHouseIdentity.from(session), tableName); if (tableHandle == null) { continue; } @@ -264,7 +273,7 @@ public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTab @Override public void createSchema(ConnectorSession session, String schemaName, Map properties) { - clickHouseClient.createSchema(ClickHouseIdentity.from(session), schemaName, properties); + clickHouseClient.createSchema(ClickHouseIdentity.from(session), schemaName); } @Override @@ -272,4 +281,10 @@ public void dropSchema(ConnectorSession session, String schemaName) { clickHouseClient.dropSchema(ClickHouseIdentity.from(session), schemaName); } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return clickHouseConfig.isCaseSensitiveNameMatching() ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadataFactory.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadataFactory.java index 4c2d715398dc2..c19b41b407861 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadataFactory.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseMetadataFactory.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.plugin.clickhouse; -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; @@ -21,17 +21,19 @@ public class ClickHouseMetadataFactory { private final ClickHouseClient clickHouseClient; private final boolean allowDropTable; + private final ClickHouseConfig clickHouseConfig; @Inject - public ClickHouseMetadataFactory(ClickHouseClient clickHouseClient, ClickHouseConfig config) + public ClickHouseMetadataFactory(ClickHouseClient clickHouseClient, ClickHouseConfig config, ClickHouseConfig clickHouseConfig) { this.clickHouseClient = requireNonNull(clickHouseClient, "clickHouseClient is null"); + this.clickHouseConfig = clickHouseConfig; requireNonNull(config, "config is null"); this.allowDropTable = config.isAllowDropTable(); } public ClickHouseMetadata create() { - return new ClickHouseMetadata(clickHouseClient, allowDropTable); + return new ClickHouseMetadata(clickHouseClient, allowDropTable, clickHouseConfig); } } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseOutputTableHandle.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseOutputTableHandle.java index 805745b1c2b03..ecff7a83eadcc 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseOutputTableHandle.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseOutputTableHandle.java @@ -19,8 +19,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Objects; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHousePageSinkProvider.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHousePageSinkProvider.java index 8c4e9059567a6..4f0edc22687da 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHousePageSinkProvider.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHousePageSinkProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordCursor.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordCursor.java index b194e28190253..b2e8d48f700ad 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordCursor.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordCursor.java @@ -63,7 +63,7 @@ public ClickHouseRecordCursor(ClickHouseClient clickHouseClient, ConnectorSessio sliceReadFunctions = new SliceReadFunction[columnHandles.size()]; for (int i = 0; i < this.columnHandles.length; i++) { - ReadMapping readMapping = clickHouseClient.toPrestoType(session, columnHandles.get(i).getClickHouseTypeHandle()) + ReadMapping readMapping = clickHouseClient.toPrestoType(columnHandles.get(i).getClickHouseTypeHandle()) .orElseThrow(() -> new VerifyException("Unsupported column type")); Class javaType = readMapping.getType().getJavaType(); ReadFunction readFunction = readMapping.getReadFunction(); diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordSetProvider.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordSetProvider.java index 305657fc5eb1d..95db8aa1c7700 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordSetProvider.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseRecordSetProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplit.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplit.java index 2490f86063744..0855b987e28e9 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplit.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplit.java @@ -26,8 +26,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplitManager.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplitManager.java index 9fc449b0353eb..4300e2287d0e9 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplitManager.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseSplitManager.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; @@ -42,6 +41,6 @@ public ConnectorSplitSource getSplits( SplitSchedulingContext splitSchedulingContext) { ClickHouseTableLayoutHandle layoutHandle = (ClickHouseTableLayoutHandle) layout; - return clickHouseClient.getSplits(ClickHouseIdentity.from(session), layoutHandle); + return clickHouseClient.getSplits(layoutHandle); } } diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableHandle.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableHandle.java index eb761b55a64b7..fb987ae9250e9 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableHandle.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableHandle.java @@ -18,8 +18,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Joiner; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Objects; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableProperties.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableProperties.java index ebd873463da12..24a23eb78f459 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableProperties.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ClickHouseTableProperties.java @@ -16,8 +16,7 @@ import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.EnumSet; import java.util.List; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ForBaseJdbc.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ForBaseJdbc.java index cf8259f07fbe0..253929ff657d3 100755 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ForBaseJdbc.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/ForBaseJdbc.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.plugin.clickhouse; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseQueryGenerator.java b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseQueryGenerator.java index aa533b6a3cc0c..a9a20fb3bf7b4 100644 --- a/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseQueryGenerator.java +++ b/presto-clickhouse/src/main/java/com/facebook/presto/plugin/clickhouse/optimization/ClickHouseQueryGenerator.java @@ -39,8 +39,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashSet; import java.util.LinkedHashMap; diff --git a/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseConfig.java b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseConfig.java index 8ed3bc685a04d..934e4947aec03 100755 --- a/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseConfig.java +++ b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseConfig.java @@ -14,8 +14,8 @@ package com.facebook.presto.plugin.clickhouse; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -35,6 +35,7 @@ public class TestClickHouseConfig private static final String mapStringAsVarchar = "clickhouse.map-string-as-varchar"; private static final String allowDropTable = "clickhouse.allow-drop-table"; private static final String commitBatchSize = "clickhouse.commitBatchSize"; + private static final String caseSensitiveNameMatching = "case-sensitive-name-matching"; @Test public void testDefaults() @@ -49,7 +50,8 @@ public void testDefaults() .setAllowDropTable(false) .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, MINUTES)) .setMapStringAsVarchar(false) - .setCommitBatchSize(0)); + .setCommitBatchSize(0) + .setCaseSensitiveNameMatching(false)); } @Test @@ -66,6 +68,7 @@ public void testExplicitPropertyMappings() .put(mapStringAsVarchar, "true") .put(allowDropTable, "true") .put(commitBatchSize, "1000") + .put(caseSensitiveNameMatching, "true") .build(); ClickHouseConfig expected = new ClickHouseConfig() @@ -78,6 +81,7 @@ public void testExplicitPropertyMappings() .setAllowDropTable(true) .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, SECONDS)) .setMapStringAsVarchar(true) + .setCaseSensitiveNameMatching(true) .setCommitBatchSize(1000); ConfigAssertions.assertFullMapping(properties, expected); diff --git a/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseDistributedQueries.java b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseDistributedQueries.java index 56e6d924e5fb6..3a44b568a6a23 100755 --- a/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseDistributedQueries.java +++ b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickHouseDistributedQueries.java @@ -17,7 +17,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; -import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; import org.intellij.lang.annotations.Language; import org.testng.SkipException; @@ -61,8 +60,6 @@ protected QueryRunner createQueryRunner() { this.clickhouseServer = new TestingClickHouseServer(); return createClickHouseQueryRunner(clickhouseServer, - ImmutableMap.of("http-server.http.port", "8080"), - ImmutableMap.of(), TpchTable.getTables()); } @@ -321,8 +318,8 @@ public void testDropColumn() // the columns are referenced by order_by/order_by property can not be dropped assertUpdate("CREATE TABLE " + tableName + "(x int NOT NULL, y int, a int NOT NULL) WITH " + "(engine = 'MergeTree', order_by = ARRAY['x'], partition_by = ARRAY['a'])"); - assertQueryFails("ALTER TABLE " + tableName + " DROP COLUMN x", "ClickHouse exception, code: 47,.*\\n"); - assertQueryFails("ALTER TABLE " + tableName + " DROP COLUMN a", "ClickHouse exception, code: 47,.*\\n"); + assertQueryFails("ALTER TABLE " + tableName + " DROP COLUMN x", "(?s).* Missing columns: 'x' while processing query: 'x', required columns: 'x' 'x'.*\\n"); + assertQueryFails("ALTER TABLE " + tableName + " DROP COLUMN a", "(?s).* Missing columns: 'a' while processing query: 'a', required columns: 'a' 'a'.*\\n"); } @Override @@ -362,6 +359,22 @@ public void testAddColumn() assertFalse(getQueryRunner().tableExists(getSession(), tableName)); } + @Override + public void testRenameTable() + { + String tableName = "test_rename_table_" + randomTableSuffix(); + String newTableName = "test_rename_table_new_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['id'])"); + assertUpdate("INSERT INTO " + tableName + " (id, x) VALUES(1, 'first')", 1); + + assertUpdate("ALTER TABLE " + tableName + " RENAME TO " + newTableName); + assertFalse(getQueryRunner().tableExists(getSession(), tableName)); + assertTrue(getQueryRunner().tableExists(getSession(), newTableName)); + assertUpdate("DROP TABLE " + newTableName); + + assertFalse(getQueryRunner().tableExists(getSession(), newTableName)); + } + @Test public void testShowCreateTable() { @@ -382,16 +395,16 @@ public void testShowCreateTable() @Override public void testDescribeOutput() { - MaterializedResult expectedColumns = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("orderkey", "bigint", "", "") - .row("custkey", "bigint", "", "") - .row("orderstatus", "varchar", "", "") - .row("totalprice", "double", "", "") - .row("orderdate", "date", "", "") - .row("orderpriority", "varchar", "", "") - .row("clerk", "varchar", "", "") - .row("shippriority", "integer", "", "") - .row("comment", "varchar", "", "") + MaterializedResult expectedColumns = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar", "", "", null, null, 2147483647L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar", "", "", null, null, 2147483647L) + .row("clerk", "varchar", "", "", null, null, 2147483647L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar", "", "", null, null, 2147483647L) .build(); MaterializedResult actualColumns = computeActual("DESCRIBE orders"); assertEquals(actualColumns, expectedColumns); @@ -469,7 +482,7 @@ public void testTableProperty() assertUpdate("DROP TABLE " + tableName); // the column refers by order by must be not null - assertQueryFails("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['id', 'x'])", ".* Sorting key cannot contain nullable columns.*\\n"); + assertQueryFails("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['id', 'x'])", ".*Sorting key contains nullable columns, but merge tree setting `allow_nullable_key` is disabled.*\\n.*"); assertUpdate("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR) WITH (engine = 'MergeTree', order_by = ARRAY['id'], primary_key = ARRAY['id'])"); assertTrue(getQueryRunner().tableExists(getSession(), tableName)); @@ -479,7 +492,7 @@ public void testTableProperty() assertTrue(getQueryRunner().tableExists(getSession(), tableName)); assertUpdate("DROP TABLE " + tableName); - assertUpdate("CREATE TABLE " + tableName + " (id int NOT NULL, x VARCHAR NOT NULL, y VARCHAR NOT NULL) WITH (engine = 'MergeTree', order_by = ARRAY['id', 'x'], primary_key = ARRAY['id','x'], sample_by = 'x' )"); + assertUpdate("CREATE TABLE " + tableName + " (id int NOT NULL, x boolean NOT NULL, y VARCHAR NOT NULL) WITH (engine = 'MergeTree', order_by = ARRAY['id', 'x'], primary_key = ARRAY['id','x'], sample_by = 'x' )"); assertTrue(getQueryRunner().tableExists(getSession(), tableName)); assertUpdate("DROP TABLE " + tableName); @@ -496,6 +509,75 @@ public void testTableProperty() "Invalid value for table property 'partition_by': .*"); } + @Override + public void testStringFilters() + { + assertUpdate("CREATE TABLE test_varcharn_filter (shipmode VARCHAR(85))"); + assertTrue(getQueryRunner().tableExists(getSession(), "test_varcharn_filter")); + assertTableColumnNames("test_varcharn_filter", "shipmode"); + assertUpdate("INSERT INTO test_varcharn_filter SELECT shipmode FROM lineitem", 60175); + + assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'AIR'", "VALUES (8491)"); + assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'AIR '", "VALUES (0)"); + assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'AIR '", "VALUES (0)"); + assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'AIR '", "VALUES (0)"); + assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'NONEXIST'", "VALUES (0)"); + } + + @Override + public void testNonAutoCommitTransactionWithCommit() + { + // not supported + } + + @Override + public void testNonAutoCommitTransactionWithRollback() + { + // not supported + } + + @Override + public void testPayloadJoinApplicability() + { + // not supported + } + + @Override + public void testPayloadJoinCorrectness() + { + // test not supported + } + + @Override + public void testRemoveRedundantCastToVarcharInJoinClause() + { + // test not supported + } + + @Override + public void testSubfieldAccessControl() + { + // test not supported + } + + @Override + public void testPreProcessMetastoreCalls() + { + //not supported + } + + @Override + public void testCorrelatedExistsSubqueries() + { + //not supported + } + + @Override + public void testCorrelatedScalarSubqueriesWithScalarAggregation() + { + //not supported + } + private static String randomTableSuffix() { SecureRandom random = new SecureRandom(); diff --git a/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickhouseIntegrationMixedCase.java b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickhouseIntegrationMixedCase.java new file mode 100644 index 0000000000000..0d787e97c2c9d --- /dev/null +++ b/presto-clickhouse/src/test/java/com/facebook/presto/plugin/clickhouse/TestClickhouseIntegrationMixedCase.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.clickhouse; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.plugin.clickhouse.ClickHouseQueryRunner.createClickHouseQueryRunner; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestClickhouseIntegrationMixedCase + extends AbstractTestQueryFramework +{ + private TestingClickHouseServer clickhouseServer; + private Session session; + + @Override + protected QueryRunner createQueryRunner() throws Exception + { + this.clickhouseServer = new TestingClickHouseServer(); + return createClickHouseQueryRunner(clickhouseServer, + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of("case-sensitive-name-matching", "true"), + TpchTable.getTables()); + } + + @AfterClass(alwaysRun = true) + public final void destroy() + { + if (clickhouseServer != null) { + clickhouseServer.close(); + } + } + + @BeforeClass(alwaysRun = true) + public final void setUp() + { + session = testSessionBuilder() + .setCatalog("clickhouse") + .setSchema("default") + .build(); + } + + @Test + public void testCreateTable() + { + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATE(name VARCHAR(50), rollNum int)"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATE")); + + getQueryRunner().execute(session, "CREATE TABLE test_create(name VARCHAR(50), rollNum int)"); + assertTrue(getQueryRunner().tableExists(session, "test_create")); + + assertQueryFails(session, "CREATE TABLE TEST_CREATE (name VARCHAR(50), rollNum int)", "line 1:1: Table 'clickhouse.default.TEST_CREATE' already exists"); + assertFalse(getQueryRunner().tableExists(session, "Test")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATE"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATE")); + + assertUpdate(session, "DROP TABLE IF EXISTS test_create"); + assertFalse(getQueryRunner().tableExists(session, "test_create")); + } + } + + @Test + public void testCreateTableAs() + { + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS")); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS test_createas AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "test_createas")); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.orders o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS_Join")); + + assertQueryFails("CREATE TABLE test_connector.TEST_CREATEAS_FAIL_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.ORDERS1 o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'", "Table clickhouse.tpch.ORDERS1 does not exist"); //failure scenario since tpch.ORDERS1 doesn't exist + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS_FAIL_Join")); + + getQueryRunner().execute(session, "CREATE TABLE Test_CreateAs_Mixed_Join AS SELECT Cus.custkey, Ord.orderkey FROM " + + "tpch.customer Cus INNER JOIN tpch.orders Ord ON Cus.custkey = Ord.custkey WHERE Cus.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "Test_CreateAs_Mixed_Join")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATEAS"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS")); + + assertUpdate(session, "DROP TABLE IF EXISTS test_createas"); + assertFalse(getQueryRunner().tableExists(session, "test_createas")); + + assertUpdate(session, "DROP TABLE IF EXISTS Test_CreateAs_Mixed_Join"); + assertFalse(getQueryRunner().tableExists(session, "Test_CreateAs_Mixed_Join")); + } + } + + @Test + public void testDuplicatedColumNameCreateTable() + { + try { + getQueryRunner().execute(session, "CREATE TABLE test (a integer, A integer)"); + assertTrue(getQueryRunner().tableExists(session, "test")); + + getQueryRunner().execute(session, "CREATE TABLE TEST (a integer, A integer)"); + assertTrue(getQueryRunner().tableExists(session, "TEST")); + + assertQueryFails("CREATE TABLE Test (a integer, a integer)", "line 1:31: Column name 'a' specified more than once"); + assertFalse(getQueryRunner().tableExists(session, "Test")); + } + finally { + assertUpdate(session, "DROP TABLE IF EXISTS test"); + assertFalse(getQueryRunner().tableExists(session, "test")); + + assertUpdate(session, "DROP TABLE IF EXISTS TEST"); + assertFalse(getQueryRunner().tableExists(session, "TEST")); + } + } + + @Test + public void testSelect() + { + try { + getQueryRunner().execute(session, "CREATE TABLE Test_Select AS SELECT * FROM tpch.region where regionkey=3"); + assertTrue(getQueryRunner().tableExists(session, "Test_Select")); + assertQuery("SELECT * from default.Test_Select", "VALUES (3, 'EUROPE', 'ly final courts cajole furiously final excuse')"); + + getQueryRunner().execute(session, "CREATE TABLE test_select AS SELECT * FROM tpch.region LIMIT 2"); + assertQuery("SELECT COUNT(*) FROM default.test_select", "VALUES 2"); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_SELECT"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Select"); + } + } + + @Test + public void testAlterTable() + { + try { + assertUpdate(" CREATE TABLE Test_Alter (x int NOT NULL, y int, a int) WITH (engine = 'MergeTree', order_by = ARRAY['x'])"); + assertTrue(getQueryRunner().tableExists(getSession(), "Test_Alter")); + + assertQueryFails("ALTER TABLE Test_Alter RENAME COLUMN Y to YYY", "line 1:1: Column 'Y' does not exist"); + assertUpdate("ALTER TABLE Test_Alter RENAME COLUMN y to YYY"); + + assertQueryFails("ALTER TABLE IF EXISTS Test_Alter DROP COLUMN notExistColumn", ".*Column 'notExistColumn' does not exist.*"); + assertUpdate("ALTER TABLE Test_Alter DROP COLUMN YYY"); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Alter"); + } + } + + @Test + public void testInsert() + { + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_INSERT(name VARCHAR(50), rollNum int)"); + assertTrue(getQueryRunner().tableExists(session, "TEST_INSERT")); + + assertQueryFails("INSERT INTO Test_Insert VALUES (123, 'test')", ".*Table clickhouse.tpch.Test_Insert does not exist.*"); + getQueryRunner().execute(session, "INSERT INTO TEST_INSERT VALUES ('test', 123)"); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_INSERT"); + } + } + + @Test + public void testUnicodedColumns() + { + try { + getQueryRunner().execute(session, "CREATE TABLE Test_Unicoded(name VARCHAR(50), \"Col用户表\" int)"); + assertTrue(getQueryRunner().tableExists(session, "Test_Unicoded")); + + getQueryRunner().execute(session, "INSERT INTO Test_Unicoded VALUES ('test', 123)"); + + getQueryRunner().execute(session, "CREATE TABLE \"Test_ÆØÅ\" (name VARCHAR(50), \"Col用户表\" int)"); + assertTrue(getQueryRunner().tableExists(session, "Test_ÆØÅ")); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Unicoded"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS \"Test_ÆØÅ\" "); + } + } +} diff --git a/presto-client/pom.xml b/presto-client/pom.xml index 7662acd12632b..24b461b54a146 100644 --- a/presto-client/pom.xml +++ b/presto-client/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-client @@ -13,6 +13,8 @@ ${project.parent.basedir} + 8 + true @@ -20,14 +22,21 @@ com.facebook.presto presto-spi + com.facebook.presto presto-common - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations + true + + + + jakarta.annotation + jakarta.annotation-api true @@ -57,12 +66,12 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api - io.airlift + com.facebook.airlift units @@ -91,19 +100,41 @@ okio-jvm + + com.google.code.findbugs + jsr305 + + org.jetbrains.kotlin kotlin-stdlib-jdk8 + + net.jodah + failsafe + + + + com.facebook.airlift + concurrent + test + + com.facebook.presto presto-testng-services test + + org.assertj + assertj-core + test + + org.testng testng @@ -111,22 +142,33 @@ - com.facebook.drift + javax.inject + javax.inject + + + + com.facebook.airlift.drift drift-protocol test - com.facebook.drift + com.facebook.airlift.drift drift-codec test - com.facebook.drift + com.facebook.airlift.drift drift-codec-utils test + + + com.squareup.okhttp3 + mockwebserver + test + @@ -140,6 +182,9 @@ com.squareup.okio:okio-jvm org.jetbrains.kotlin:kotlin-stdlib-jdk8 + + javax.inject:javax.inject + diff --git a/presto-client/src/main/java/com/facebook/presto/client/ClientSession.java b/presto-client/src/main/java/com/facebook/presto/client/ClientSession.java index 5896a3c6f2d3c..5b217e41182e1 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ClientSession.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ClientSession.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.client; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.spi.security.SelectedRole; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import java.net.URI; import java.nio.charset.CharsetEncoder; diff --git a/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignature.java b/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignature.java index 00193d4ac37db..2b8daad7e90ec 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignature.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignature.java @@ -23,8 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.ArrayList; import java.util.List; diff --git a/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignatureParameter.java b/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignatureParameter.java index ed0ea3ee3dedb..d7ce430fe8ed6 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignatureParameter.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ClientTypeSignatureParameter.java @@ -28,8 +28,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.io.IOException; import java.util.Objects; diff --git a/presto-client/src/main/java/com/facebook/presto/client/Column.java b/presto-client/src/main/java/com/facebook/presto/client/Column.java index b208722bf364b..e77244cbb1027 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/Column.java +++ b/presto-client/src/main/java/com/facebook/presto/client/Column.java @@ -17,8 +17,7 @@ import com.facebook.presto.common.type.TypeSignature; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static java.util.Objects.requireNonNull; diff --git a/presto-client/src/main/java/com/facebook/presto/client/ErrorLocation.java b/presto-client/src/main/java/com/facebook/presto/client/ErrorLocation.java index cf0710106bad8..749c99fbd5dfb 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ErrorLocation.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ErrorLocation.java @@ -18,8 +18,7 @@ import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-client/src/main/java/com/facebook/presto/client/FailureInfo.java b/presto-client/src/main/java/com/facebook/presto/client/FailureInfo.java index 3db250458bd38..1a3e6bd6dd6a0 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/FailureInfo.java +++ b/presto-client/src/main/java/com/facebook/presto/client/FailureInfo.java @@ -16,9 +16,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; diff --git a/presto-client/src/main/java/com/facebook/presto/client/GCSOAuthInterceptor.java b/presto-client/src/main/java/com/facebook/presto/client/GCSOAuthInterceptor.java index eb8572e094e6e..1e2dcd0a0ff37 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/GCSOAuthInterceptor.java +++ b/presto-client/src/main/java/com/facebook/presto/client/GCSOAuthInterceptor.java @@ -20,8 +20,9 @@ import okhttp3.Request; import okhttp3.Response; -import java.io.FileInputStream; import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Paths; import java.util.Collection; import java.util.Optional; import java.util.stream.Collectors; @@ -29,6 +30,7 @@ import static com.facebook.presto.client.GCSOAuthScope.DEVSTORAGE_READ_ONLY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_EXTRA_CREDENTIAL; +import static java.nio.file.Files.newInputStream; import static java.util.Objects.requireNonNull; public class GCSOAuthInterceptor @@ -82,8 +84,8 @@ private synchronized GoogleCredentials getCredentials() private GoogleCredentials createCredentials() { - try { - return GoogleCredentials.fromStream(new FileInputStream(credentialsFilePath)).createScoped(gcsOAuthScopeURLs); + try (InputStream is = newInputStream(Paths.get(credentialsFilePath))) { + return GoogleCredentials.fromStream(is).createScoped(gcsOAuthScopeURLs); } catch (IOException e) { throw new ClientException("Google credential loading error", e); diff --git a/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java index ac8075e6df206..032a12560588c 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java +++ b/presto-client/src/main/java/com/facebook/presto/client/JsonResponse.java @@ -14,6 +14,7 @@ package com.facebook.presto.client; import com.facebook.airlift.json.JsonCodec; +import jakarta.annotation.Nullable; import okhttp3.Headers; import okhttp3.MediaType; import okhttp3.OkHttpClient; @@ -21,8 +22,6 @@ import okhttp3.Response; import okhttp3.ResponseBody; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.UncheckedIOException; diff --git a/presto-client/src/main/java/com/facebook/presto/client/NodeVersion.java b/presto-client/src/main/java/com/facebook/presto/client/NodeVersion.java index 40133333464f9..5fbc9c729fa0f 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/NodeVersion.java +++ b/presto-client/src/main/java/com/facebook/presto/client/NodeVersion.java @@ -18,8 +18,7 @@ import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java index 91e1427249dd9..7de067264a076 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java +++ b/presto-client/src/main/java/com/facebook/presto/client/OkHttpUtil.java @@ -34,12 +34,12 @@ import javax.security.auth.x500.X500Principal; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.net.CookieManager; import java.net.InetSocketAddress; import java.net.Proxy; +import java.nio.file.Paths; import java.security.GeneralSecurityException; import java.security.KeyStore; import java.security.SecureRandom; @@ -57,6 +57,7 @@ import static com.google.common.net.HttpHeaders.USER_AGENT; import static java.net.Proxy.Type.HTTP; import static java.net.Proxy.Type.SOCKS; +import static java.nio.file.Files.newInputStream; import static java.util.Collections.list; import static java.util.Objects.requireNonNull; @@ -197,7 +198,7 @@ public static void setupSsl( char[] keyManagerPassword; try { // attempt to read the key store as a PEM file - keyStore = PemReader.loadKeyStore(new File(keyStorePath.get()), new File(keyStorePath.get()), keyStorePassword); + keyStore = PemReader.loadKeyStore(Paths.get(keyStorePath.get()).toFile(), Paths.get(keyStorePath.get()).toFile(), keyStorePassword); // for PEM encoded keys, the password is used to decrypt the specific key (and does not protect the keystore itself) keyManagerPassword = new char[0]; } @@ -205,7 +206,7 @@ public static void setupSsl( keyManagerPassword = keyStorePassword.map(String::toCharArray).orElse(null); keyStore = KeyStore.getInstance(keystoreType.get()); - try (InputStream in = new FileInputStream(keyStorePath.get())) { + try (InputStream in = newInputStream(Paths.get(keyStorePath.get()))) { keyStore.load(in, keyManagerPassword); } } @@ -219,7 +220,7 @@ public static void setupSsl( KeyStore trustStore = keyStore; if (trustStorePath.isPresent()) { checkArgument(trustStoreType.isPresent(), "truststore type is not present"); - trustStore = loadTrustStore(new File(trustStorePath.get()), trustStorePassword, trustStoreType.get()); + trustStore = loadTrustStore(Paths.get(trustStorePath.get()).toFile(), trustStorePassword, trustStoreType.get()); } // create TrustManagerFactory @@ -288,7 +289,7 @@ private static KeyStore loadTrustStore(File trustStorePath, Optional tru catch (IOException | GeneralSecurityException ignored) { } - try (InputStream in = new FileInputStream(trustStorePath)) { + try (InputStream in = newInputStream(trustStorePath.toPath())) { trustStore.load(in, trustStorePassword.map(String::toCharArray).orElse(null)); } return trustStore; diff --git a/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java b/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java index 3be0b241ee0b2..5ec4c88254853 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java +++ b/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java @@ -42,6 +42,7 @@ public final class PrestoHeaders public static final String PRESTO_SESSION_FUNCTION = "X-Presto-Session-Function"; public static final String PRESTO_ADDED_SESSION_FUNCTION = "X-Presto-Added-Session-Functions"; public static final String PRESTO_REMOVED_SESSION_FUNCTION = "X-Presto-Removed-Session-Function"; + public static final String PRESTO_RETRY_QUERY = "X-Presto-Retry-Query"; public static final String PRESTO_CURRENT_STATE = "X-Presto-Current-State"; public static final String PRESTO_MAX_WAIT = "X-Presto-Max-Wait"; diff --git a/presto-client/src/main/java/com/facebook/presto/client/QueryError.java b/presto-client/src/main/java/com/facebook/presto/client/QueryError.java index 733eeb171ec7f..a2b06482651a0 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/QueryError.java +++ b/presto-client/src/main/java/com/facebook/presto/client/QueryError.java @@ -15,9 +15,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import static com.google.common.base.MoreObjects.toStringHelper; diff --git a/presto-client/src/main/java/com/facebook/presto/client/QueryResults.java b/presto-client/src/main/java/com/facebook/presto/client/QueryResults.java index 58ae6f50dbfdb..d086d2bbf662a 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/QueryResults.java +++ b/presto-client/src/main/java/com/facebook/presto/client/QueryResults.java @@ -17,9 +17,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.List; diff --git a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java index ec199a7c4bda3..6de636a82404e 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java +++ b/presto-client/src/main/java/com/facebook/presto/client/ServerInfo.java @@ -13,14 +13,13 @@ */ package com.facebook.presto.client; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.Optional; diff --git a/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java index 85f144a6dfca3..c3475791a3876 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java +++ b/presto-client/src/main/java/com/facebook/presto/client/SpnegoHandler.java @@ -13,10 +13,11 @@ */ package com.facebook.presto.client; +import com.facebook.airlift.units.Duration; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.sun.security.auth.module.Krb5LoginModule; -import io.airlift.units.Duration; import okhttp3.Authenticator; import okhttp3.Interceptor; import okhttp3.Request; @@ -28,7 +29,6 @@ import org.ietf.jgss.GSSManager; import org.ietf.jgss.Oid; -import javax.annotation.concurrent.GuardedBy; import javax.security.auth.Subject; import javax.security.auth.login.AppConfigurationEntry; import javax.security.auth.login.Configuration; diff --git a/presto-client/src/main/java/com/facebook/presto/client/StageStats.java b/presto-client/src/main/java/com/facebook/presto/client/StageStats.java index 7c8334047efae..b761dd7ed02f2 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StageStats.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StageStats.java @@ -16,8 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; diff --git a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java index 9fce8d205c19f..8ad990a6bd6af 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StatementClient.java @@ -15,10 +15,10 @@ import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.spi.security.SelectedRole; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.Closeable; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -60,6 +60,8 @@ public interface StatementClient Set getDeallocatedPreparedStatements(); + Map> getResponseHeaders(); + @Nullable String getStartedTransactionId(); diff --git a/presto-client/src/main/java/com/facebook/presto/client/StatementClientV1.java b/presto-client/src/main/java/com/facebook/presto/client/StatementClientV1.java index 9268195710af3..4d44006a9e3a4 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StatementClientV1.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StatementClientV1.java @@ -14,14 +14,17 @@ package com.facebook.presto.client; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.OkHttpUtil.NullCallback; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.spi.security.SelectedRole; import com.google.common.base.Joiner; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import okhttp3.Headers; import okhttp3.HttpUrl; import okhttp3.MediaType; @@ -29,9 +32,6 @@ import okhttp3.Request; import okhttp3.RequestBody; -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; - import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URLDecoder; @@ -115,7 +115,7 @@ class StatementClientV1 private final Map addedSessionFunctions = new ConcurrentHashMap<>(); private final Set removedSessionFunctions = newConcurrentHashSet(); private final boolean validateNextUriSource; - + private final Map> responseHeaders; private final AtomicReference state = new AtomicReference<>(State.RUNNING); public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query) @@ -141,6 +141,7 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String } processResponse(response.getHeaders(), response.getValue()); + this.responseHeaders = toHeaderMap(response.getHeaders()); } private Request buildQueryRequest(ClientSession session, String query) @@ -217,6 +218,11 @@ private Request buildQueryRequest(ClientSession session, String query) return builder.build(); } + public Map> getResponseHeaders() + { + return responseHeaders; + } + @Override public String getQuery() { @@ -439,6 +445,15 @@ private void validateNextUriSource(final URI nextUri, final URI infoUri) throw new RuntimeException(format("Next URI host and port %s are different than current %s", nextUri.getHost(), infoUri.getHost())); } + private static Map> toHeaderMap(Headers headers) + { + ImmutableMap.Builder> builder = ImmutableMap.builder(); + for (String name : headers.names()) { + builder.put(name, ImmutableList.copyOf(headers.values(name))); + } + return builder.build(); + } + private void processResponse(Headers headers, QueryResults results) { setCatalog.set(headers.get(PRESTO_SET_CATALOG)); @@ -499,7 +514,7 @@ private RuntimeException requestFailedException(String task, Request request, Js if (!response.hasValue()) { if (response.getStatusCode() == HTTP_UNAUTHORIZED) { return new ClientException("Authentication failed" + - Optional.ofNullable(response.getStatusMessage()) + Optional.ofNullable(response.getResponseBody()) .map(message -> ": " + message) .orElse("")); } diff --git a/presto-client/src/main/java/com/facebook/presto/client/StatementStats.java b/presto-client/src/main/java/com/facebook/presto/client/StatementStats.java index b958ce32d7d92..ba7da1ebde446 100644 --- a/presto-client/src/main/java/com/facebook/presto/client/StatementStats.java +++ b/presto-client/src/main/java/com/facebook/presto/client/StatementStats.java @@ -16,9 +16,8 @@ import com.facebook.presto.common.RuntimeStats; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.OptionalDouble; diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/CompositeRedirectHandler.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/CompositeRedirectHandler.java new file mode 100644 index 0000000000000..bc1ebb0355a3a --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/CompositeRedirectHandler.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.net.URI; +import java.util.List; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class CompositeRedirectHandler + implements RedirectHandler +{ + private final List handlers; + + public CompositeRedirectHandler(List strategies) + { + this.handlers = requireNonNull(strategies, "strategies is null") + .stream() + .map(ExternalRedirectStrategy::getHandler) + .collect(toImmutableList()); + checkState(!handlers.isEmpty(), "Expected at least one external redirect handler"); + } + + @Override + public void redirectTo(URI uri) throws RedirectException + { + RedirectException redirectException = new RedirectException("Could not redirect to " + uri); + for (RedirectHandler handler : handlers) { + try { + handler.redirectTo(uri); + return; + } + catch (RedirectException e) { + redirectException.addSuppressed(e); + } + } + + throw redirectException; + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/DesktopBrowserRedirectHandler.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/DesktopBrowserRedirectHandler.java new file mode 100644 index 0000000000000..6918cba955c4b --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/DesktopBrowserRedirectHandler.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.io.IOException; +import java.net.URI; + +import static java.awt.Desktop.Action.BROWSE; +import static java.awt.Desktop.getDesktop; +import static java.awt.Desktop.isDesktopSupported; + +public final class DesktopBrowserRedirectHandler + implements RedirectHandler +{ + @Override + public void redirectTo(URI uri) + throws RedirectException + { + if (!isDesktopSupported() || !getDesktop().isSupported(BROWSE)) { + throw new RedirectException("Desktop Browser is not available. Make sure your Java process is not in headless mode (-Djava.awt.headless=false)"); + } + + try { + getDesktop().browse(uri); + } + catch (IOException e) { + throw new RedirectException("Failed to redirect", e); + } + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthentication.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthentication.java new file mode 100644 index 0000000000000..afee9bd899b42 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthentication.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.facebook.presto.client.ClientException; +import com.google.common.annotations.VisibleForTesting; + +import java.net.URI; +import java.time.Duration; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +class ExternalAuthentication +{ + private final URI tokenUri; + private final Optional redirectUri; + + public ExternalAuthentication(URI tokenUri, Optional redirectUri) + { + this.tokenUri = requireNonNull(tokenUri, "tokenUri is null"); + this.redirectUri = requireNonNull(redirectUri, "redirectUri is null"); + } + + public Optional obtainToken(Duration timeout, RedirectHandler handler, TokenPoller poller) + { + redirectUri.ifPresent(handler::redirectTo); + + URI currentUri = tokenUri; + + long start = System.nanoTime(); + long timeoutNanos = timeout.toNanos(); + + while (true) { + long remaining = timeoutNanos - (System.nanoTime() - start); + if (remaining < 0) { + return Optional.empty(); + } + + TokenPollResult result = poller.pollForToken(currentUri, Duration.ofNanos(remaining)); + + if (result.isFailed()) { + throw new ClientException(result.getError()); + } + + if (result.isPending()) { + currentUri = result.getNextTokenUri(); + continue; + } + poller.tokenReceived(currentUri); + return Optional.of(result.getToken()); + } + } + + @VisibleForTesting + Optional getRedirectUri() + { + return redirectUri; + } + + @VisibleForTesting + URI getTokenUri() + { + return tokenUri; + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthenticator.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthenticator.java new file mode 100644 index 0000000000000..60f568dc7c173 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalAuthenticator.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.facebook.presto.client.ClientException; +import com.google.common.annotations.VisibleForTesting; +import jakarta.annotation.Nullable; +import okhttp3.Authenticator; +import okhttp3.Challenge; +import okhttp3.Interceptor; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.Route; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ExternalAuthenticator + implements Authenticator, Interceptor +{ + public static final String TOKEN_URI_FIELD = "x_token_server"; + public static final String REDIRECT_URI_FIELD = "x_redirect_server"; + + private final TokenPoller tokenPoller; + private final RedirectHandler redirectHandler; + private final Duration timeout; + private final KnownToken knownToken; + + public ExternalAuthenticator(RedirectHandler redirect, TokenPoller tokenPoller, KnownToken knownToken, Duration timeout) + { + this.tokenPoller = requireNonNull(tokenPoller, "tokenPoller is null"); + this.redirectHandler = requireNonNull(redirect, "redirect is null"); + this.knownToken = requireNonNull(knownToken, "knownToken is null"); + this.timeout = requireNonNull(timeout, "timeout is null"); + } + + @Nullable + @Override + public Request authenticate(Route route, Response response) + { + knownToken.setupToken(() -> { + Optional authentication = toAuthentication(response); + if (!authentication.isPresent()) { + return Optional.empty(); + } + + return authentication.get().obtainToken(timeout, redirectHandler, tokenPoller); + }); + + return knownToken.getToken() + .map(token -> withBearerToken(response.request(), token)) + .orElse(null); + } + + @Override + public Response intercept(Chain chain) + throws IOException + { + Optional token = knownToken.getToken(); + if (token.isPresent()) { + return chain.proceed(withBearerToken(chain.request(), token.get())); + } + + return chain.proceed(chain.request()); + } + + private static Request withBearerToken(Request request, Token token) + { + return request.newBuilder() + .header(AUTHORIZATION, "Bearer " + token.token()) + .build(); + } + + @VisibleForTesting + static Optional toAuthentication(Response response) + { + for (Challenge challenge : response.challenges()) { + if (challenge.scheme().equalsIgnoreCase("Bearer")) { + Optional tokenUri = parseField(challenge.authParams(), TOKEN_URI_FIELD); + Optional redirectUri = parseField(challenge.authParams(), REDIRECT_URI_FIELD); + if (tokenUri.isPresent()) { + return Optional.of(new ExternalAuthentication(tokenUri.get(), redirectUri)); + } + } + } + + return Optional.empty(); + } + + private static Optional parseField(Map fields, String key) + { + return Optional.ofNullable(fields.get(key)).map(value -> { + try { + return new URI(value); + } + catch (URISyntaxException e) { + throw new ClientException(format("Failed to parse URI for field '%s'", key), e); + } + }); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalRedirectStrategy.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalRedirectStrategy.java new file mode 100644 index 0000000000000..07b004f0c15fe --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/ExternalRedirectStrategy.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.google.common.collect.ImmutableList; + +import static java.util.Objects.requireNonNull; + +public enum ExternalRedirectStrategy +{ + DESKTOP_OPEN(new DesktopBrowserRedirectHandler()), + SYSTEM_OPEN(new SystemOpenRedirectHandler()), + PRINT(new SystemOutPrintRedirectHandler()), + OPEN(new CompositeRedirectHandler(ImmutableList.of(SYSTEM_OPEN, DESKTOP_OPEN))), + ALL(new CompositeRedirectHandler(ImmutableList.of(OPEN, PRINT))) + /**/; + + private final RedirectHandler handler; + + ExternalRedirectStrategy(RedirectHandler handler) + { + this.handler = requireNonNull(handler, "handler is null"); + } + + public RedirectHandler getHandler() + { + return handler; + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/HttpTokenPoller.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/HttpTokenPoller.java new file mode 100644 index 0000000000000..b06b224ee80ac --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/HttpTokenPoller.java @@ -0,0 +1,193 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.client.JsonResponse; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.FailsafeException; +import net.jodah.failsafe.RetryPolicy; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.time.Duration; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.client.JsonResponse.execute; +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.net.HttpHeaders.USER_AGENT; +import static java.lang.String.format; +import static java.net.HttpURLConnection.HTTP_INTERNAL_ERROR; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; +import static java.time.temporal.ChronoUnit.MILLIS; +import static java.util.Objects.requireNonNull; + +public class HttpTokenPoller + implements TokenPoller +{ + private static final JsonCodec TOKEN_POLL_CODEC = jsonCodec(TokenPollRepresentation.class); + private static final String USER_AGENT_VALUE = "PrestoTokenPoller/" + + firstNonNull(HttpTokenPoller.class.getPackage().getImplementationVersion(), "unknown"); + + private final Supplier client; + + public HttpTokenPoller(OkHttpClient client) + { + requireNonNull(client, "client is null"); + this.client = () -> client; + } + + public HttpTokenPoller(OkHttpClient client, Consumer refreshableClientConfig) + { + requireNonNull(client, "client is null"); + requireNonNull(refreshableClientConfig, "refreshableClientConfig is null"); + + this.client = () -> { + OkHttpClient.Builder builder = client.newBuilder(); + refreshableClientConfig.accept(builder); + return builder.build(); + }; + } + + @Override + public TokenPollResult pollForToken(URI tokenUri, Duration timeout) + { + try { + return Failsafe.with(new RetryPolicy() + .withMaxAttempts(-1) + .withMaxDuration(timeout) + .withBackoff(100, 500, MILLIS) + .handle(IOException.class)) + .get(() -> executePoll(prepareRequestBuilder(tokenUri).build())); + } + catch (FailsafeException e) { + if (e.getCause() instanceof IOException) { + throw new UncheckedIOException((IOException) e.getCause()); + } + throw e; + } + } + + @Override + public void tokenReceived(URI tokenUri) + { + try { + Failsafe.with(new RetryPolicy() + .withMaxAttempts(-1) + .withMaxDuration(Duration.ofSeconds(4)) + .withBackoff(100, 500, MILLIS) + .handleResultIf(code -> code >= HTTP_INTERNAL_ERROR)) + .get(() -> { + Request request = prepareRequestBuilder(tokenUri) + .delete() + .build(); + try (Response response = client.get().newCall(request) + .execute()) { + return response.code(); + } + }); + } + catch (FailsafeException e) { + if (e.getCause() instanceof IOException) { + throw new UncheckedIOException((IOException) e.getCause()); + } + throw e; + } + } + + private static Request.Builder prepareRequestBuilder(URI tokenUri) + { + HttpUrl url = HttpUrl.get(tokenUri); + if (url == null) { + throw new IllegalArgumentException("Invalid token URI: " + tokenUri); + } + + return new Request.Builder() + .url(url) + .addHeader(USER_AGENT, USER_AGENT_VALUE); + } + + private TokenPollResult executePoll(Request request) + throws IOException + { + JsonResponse response = executeRequest(request); + + if ((response.getStatusCode() == HTTP_OK) && response.hasValue()) { + return response.getValue().toResult(); + } + + Optional responseBody = Optional.ofNullable(response.getResponseBody()); + String message = format("Request to %s failed: %s [Error: %s]", request.url(), response, responseBody.orElse("")); + + if (response.getStatusCode() == HTTP_UNAVAILABLE) { + throw new IOException(message); + } + + return TokenPollResult.failed(message); + } + + private JsonResponse executeRequest(Request request) + throws IOException + { + try { + return execute(TOKEN_POLL_CODEC, client.get(), request); + } + catch (UncheckedIOException e) { + throw e.getCause(); + } + } + + public static class TokenPollRepresentation + { + private final String token; + private final URI nextUri; + private final String error; + + @JsonCreator + public TokenPollRepresentation( + @JsonProperty("token") String token, + @JsonProperty("nextUri") URI nextUri, + @JsonProperty("error") String error) + { + this.token = token; + this.nextUri = nextUri; + this.error = error; + } + + TokenPollResult toResult() + { + if (token != null) { + return TokenPollResult.successful(new Token(token)); + } + if (error != null) { + return TokenPollResult.failed(error); + } + if (nextUri != null) { + return TokenPollResult.pending(nextUri); + } + return TokenPollResult.failed("Failed to poll for token. No fields set in response."); + } + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/KnownToken.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/KnownToken.java new file mode 100644 index 0000000000000..ae57de609c02c --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/KnownToken.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.util.Optional; +import java.util.function.Supplier; + +public interface KnownToken +{ + Optional getToken(); + + void setupToken(Supplier> tokenSource); + + static KnownToken local() + { + return new LocalKnownToken(); + } + + static KnownToken memoryCached() + { + return MemoryCachedKnownToken.INSTANCE; + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/LocalKnownToken.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/LocalKnownToken.java new file mode 100644 index 0000000000000..31ff4bd89d525 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/LocalKnownToken.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.Optional; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +/** + * LocalKnownToken class keeps the token on its field + * and it's designed to use it in fully serialized manner. + */ +@NotThreadSafe +class LocalKnownToken + implements KnownToken +{ + private Optional knownToken = Optional.empty(); + + @Override + public Optional getToken() + { + return knownToken; + } + + @Override + public void setupToken(Supplier> tokenSource) + { + requireNonNull(tokenSource, "tokenSource is null"); + + knownToken = tokenSource.get(); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/MemoryCachedKnownToken.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/MemoryCachedKnownToken.java new file mode 100644 index 0000000000000..0b1253a6deaa6 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/MemoryCachedKnownToken.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import javax.annotation.concurrent.ThreadSafe; + +import java.util.Optional; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; + +/** + * This KnownToken instance forces all Connections to reuse same token. + * Every time an existing token is considered to be invalid each Connection + * will try to obtain a new token, but only the first one will actually do the job, + * where every other connection will be waiting on readLock + * until obtaining new token finishes. + *

+ * In general the game is to reuse same token and obtain it only once, no matter how + * many Connections will be actively using it. It's very important as obtaining the new token + * will take minutes, as it mostly requires user thinking time. + */ +@ThreadSafe +public class MemoryCachedKnownToken + implements KnownToken +{ + public static final MemoryCachedKnownToken INSTANCE = new MemoryCachedKnownToken(); + + private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private final Lock readLock = lock.readLock(); + private final Lock writeLock = lock.writeLock(); + private Optional knownToken = Optional.empty(); + + private MemoryCachedKnownToken() + { + } + + @Override + public Optional getToken() + { + try { + readLock.lockInterruptibly(); + return knownToken; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + finally { + readLock.unlock(); + } + } + + @Override + public void setupToken(Supplier> tokenSource) + { + // Try to lock and generate new token. If some other thread (Connection) has + // already obtained writeLock and is generating new token, then skip this + // to block on getToken() + if (writeLock.tryLock()) { + try { + // Clear knownToken before obtaining new token, as it might fail leaving old invalid token. + knownToken = Optional.empty(); + knownToken = tokenSource.get(); + } + finally { + writeLock.unlock(); + } + } + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectException.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectException.java new file mode 100644 index 0000000000000..0c5a1f5e149fa --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectException.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +public class RedirectException + extends RuntimeException +{ + public RedirectException(String message, Throwable throwable) + { + super(message, throwable); + } + + public RedirectException(String message) + { + super(message); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/CoordinatorLocation.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectHandler.java similarity index 77% rename from presto-main-base/src/main/java/com/facebook/presto/dispatcher/CoordinatorLocation.java rename to presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectHandler.java index 83c606210ce46..710be10238884 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/CoordinatorLocation.java +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/RedirectHandler.java @@ -11,13 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.dispatcher; - -import javax.ws.rs.core.UriInfo; +package com.facebook.presto.client.auth.external; import java.net.URI; -public interface CoordinatorLocation +public interface RedirectHandler { - URI getUri(UriInfo uriInfo, String xForwardedProto); + void redirectTo(URI uri) + throws RedirectException; } diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOpenRedirectHandler.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOpenRedirectHandler.java new file mode 100644 index 0000000000000..d63841b26d830 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOpenRedirectHandler.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.URI; +import java.nio.file.Paths; +import java.util.List; +import java.util.Optional; + +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class SystemOpenRedirectHandler + implements RedirectHandler +{ + private static final List LINUX_BROWSERS = ImmutableList.of( + "xdg-open", + "gnome-open", + "kde-open", + "chromium", + "google", + "google-chrome", + "firefox", + "mozilla", + "opera", + "epiphany", + "konqueror"); + + private static final String MACOS_OPEN_COMMAND = "open"; + private static final String WINDOWS_OPEN_COMMAND = "rundll32 url.dll,FileProtocolHandler"; + + private static final Splitter SPLITTER = Splitter.on(":") + .omitEmptyStrings() + .trimResults(); + + @Override + public void redirectTo(URI uri) + throws RedirectException + { + String operatingSystem = System.getProperty("os.name").toLowerCase(ENGLISH); + + try { + if (operatingSystem.contains("mac")) { + exec(uri, MACOS_OPEN_COMMAND); + } + else if (operatingSystem.contains("windows")) { + exec(uri, WINDOWS_OPEN_COMMAND); + } + else { + String executablePath = findLinuxBrowser() + .orElseThrow(() -> new FileNotFoundException("Could not find any known linux browser in $PATH")); + exec(uri, executablePath); + } + } + catch (IOException e) { + throw new RedirectException(format("Could not open uri %s", uri), e); + } + } + + private static Optional findLinuxBrowser() + { + List paths = SPLITTER.splitToList(System.getenv("PATH")); + for (String path : paths) { + File[] found = Paths.get(path) + .toFile() + .listFiles((dir, name) -> LINUX_BROWSERS.contains(name)); + + if (found == null) { + continue; + } + + if (found.length > 0) { + return Optional.of(found[0].getPath()); + } + } + + return Optional.empty(); + } + + private static void exec(URI uri, String openCommand) + throws IOException + { + Runtime.getRuntime().exec(openCommand + " " + uri.toString()); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOutPrintRedirectHandler.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOutPrintRedirectHandler.java new file mode 100644 index 0000000000000..d0f3c7954dbd7 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/SystemOutPrintRedirectHandler.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.net.URI; + +public class SystemOutPrintRedirectHandler + implements RedirectHandler +{ + @Override + public void redirectTo(URI uri) throws RedirectException + { + System.out.println("External authentication required. Please go to:"); + System.out.println(uri.toString()); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/Token.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/Token.java new file mode 100644 index 0000000000000..82890f3f6b2df --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/Token.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import static java.util.Objects.requireNonNull; + +class Token +{ + private final String token; + + public Token(String token) + { + this.token = requireNonNull(token, "token is null"); + } + + public String token() + { + return token; + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPollResult.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPollResult.java new file mode 100644 index 0000000000000..3cd13b5dd3710 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPollResult.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.net.URI; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class TokenPollResult +{ + private enum State + { + PENDING, SUCCESSFUL, FAILED + } + + private final State state; + private final Optional errorMessage; + private final Optional nextTokenUri; + private final Optional token; + + public static TokenPollResult failed(String error) + { + return new TokenPollResult(State.FAILED, null, error, null); + } + + public static TokenPollResult pending(URI uri) + { + return new TokenPollResult(State.PENDING, null, null, uri); + } + + public static TokenPollResult successful(Token token) + { + return new TokenPollResult(State.SUCCESSFUL, token, null, null); + } + + private TokenPollResult(State state, Token token, String error, URI nextTokenUri) + { + this.state = requireNonNull(state, "state is null"); + this.token = Optional.ofNullable(token); + this.errorMessage = Optional.ofNullable(error); + this.nextTokenUri = Optional.ofNullable(nextTokenUri); + } + + public boolean isFailed() + { + return state == State.FAILED; + } + + public boolean isPending() + { + return state == State.PENDING; + } + + public Token getToken() + { + return token.orElseThrow(() -> new IllegalStateException("state is " + state)); + } + + public String getError() + { + return errorMessage.orElseThrow(() -> new IllegalStateException("state is " + state)); + } + + public URI getNextTokenUri() + { + return nextTokenUri.orElseThrow(() -> new IllegalStateException("state is " + state)); + } +} diff --git a/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPoller.java b/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPoller.java new file mode 100644 index 0000000000000..ea228137ab688 --- /dev/null +++ b/presto-client/src/main/java/com/facebook/presto/client/auth/external/TokenPoller.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.net.URI; +import java.time.Duration; + +public interface TokenPoller +{ + TokenPollResult pollForToken(URI tokenUri, Duration timeout); + + void tokenReceived(URI tokenUri); +} diff --git a/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java index 8573ddd64d3c1..ba9ab0df881e1 100644 --- a/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java +++ b/presto-client/src/test/java/com/facebook/presto/client/TestServerInfo.java @@ -14,6 +14,7 @@ package com.facebook.presto.client; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.codec.ThriftCodecManager; import com.facebook.drift.codec.internal.compiler.CompilerThriftCodecFactory; @@ -26,7 +27,6 @@ import com.facebook.drift.protocol.TMemoryBuffer; import com.facebook.drift.protocol.TProtocol; import com.facebook.drift.protocol.TTransport; -import io.airlift.units.Duration; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; diff --git a/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockRedirectHandler.java b/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockRedirectHandler.java new file mode 100644 index 0000000000000..f4b8c6cd07633 --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockRedirectHandler.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import java.net.URI; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockRedirectHandler + implements RedirectHandler +{ + private URI redirectedTo; + private AtomicInteger redirectionCount = new AtomicInteger(0); + private Duration redirectTime; + + @Override + public void redirectTo(URI uri) + throws RedirectException + { + redirectedTo = uri; + redirectionCount.incrementAndGet(); + try { + if (redirectTime != null) { + Thread.sleep(redirectTime.toMillis()); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + public URI redirectedTo() + { + return redirectedTo; + } + + public int getRedirectionCount() + { + return redirectionCount.get(); + } + + public MockRedirectHandler sleepOnRedirect(Duration redirectTime) + { + this.redirectTime = redirectTime; + return this; + } +} diff --git a/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockTokenPoller.java b/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockTokenPoller.java new file mode 100644 index 0000000000000..6733a76ebb64d --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/auth/external/MockTokenPoller.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.google.common.collect.ImmutableList; + +import java.net.URI; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.Function; + +public final class MockTokenPoller + implements TokenPoller +{ + private final Map> results = new ConcurrentHashMap<>(); + private URI tokenReceivedUri; + + public static TokenPoller onPoll(Function pollingStrategy) + { + return new TokenPoller() + { + @Override + public TokenPollResult pollForToken(URI tokenUri, Duration timeout) + { + return pollingStrategy.apply(tokenUri); + } + + @Override + public void tokenReceived(URI tokenUri) + { + } + }; + } + + public MockTokenPoller withResult(URI tokenUri, TokenPollResult result) + { + results.compute(tokenUri, (uri, queue) -> { + if (queue == null) { + return new LinkedBlockingDeque<>(ImmutableList.of(result)); + } + queue.add(result); + return queue; + }); + return this; + } + + @Override + public TokenPollResult pollForToken(URI tokenUri, Duration ignored) + { + BlockingDeque queue = results.get(tokenUri); + if (queue == null) { + throw new IllegalArgumentException("Unknown token URI: " + tokenUri); + } + return queue.remove(); + } + + @Override + public void tokenReceived(URI tokenUri) + { + this.tokenReceivedUri = tokenUri; + } + + public URI tokenReceivedUri() + { + return tokenReceivedUri; + } +} diff --git a/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthentication.java b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthentication.java new file mode 100644 index 0000000000000..e7434c5ac242e --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthentication.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.facebook.presto.client.ClientException; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; +import static java.net.URI.create; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestExternalAuthentication +{ + private static final String AUTH_TOKEN = "authToken"; + private static final URI REDIRECT_URI = create("https://redirect.uri"); + private static final URI TOKEN_URI = create("https://token.uri"); + private static final Duration TIMEOUT = Duration.ofSeconds(1); + + @Test + public void testObtainTokenWhenTokenAlreadyExists() + { + MockRedirectHandler redirectHandler = new MockRedirectHandler(); + + MockTokenPoller poller = new MockTokenPoller() + .withResult(TOKEN_URI, TokenPollResult.successful(new Token(AUTH_TOKEN))); + + Optional token = new ExternalAuthentication(TOKEN_URI, Optional.of(REDIRECT_URI)) + .obtainToken(TIMEOUT, redirectHandler, poller); + + assertThat(redirectHandler.redirectedTo()).isEqualTo(REDIRECT_URI); + assertThat(token).map(Token::token).hasValue(AUTH_TOKEN); + assertThat(poller.tokenReceivedUri()).isEqualTo(TOKEN_URI); + } + + @Test + public void testObtainTokenWhenTokenIsReadyAtSecondAttempt() + { + RedirectHandler redirectHandler = new MockRedirectHandler(); + + URI nextTokenUri = TOKEN_URI.resolve("/next"); + MockTokenPoller poller = new MockTokenPoller() + .withResult(TOKEN_URI, TokenPollResult.pending(nextTokenUri)) + .withResult(nextTokenUri, TokenPollResult.successful(new Token(AUTH_TOKEN))); + + Optional token = new ExternalAuthentication(TOKEN_URI, Optional.of(REDIRECT_URI)) + .obtainToken(TIMEOUT, redirectHandler, poller); + + assertThat(token).map(Token::token).hasValue(AUTH_TOKEN); + assertThat(poller.tokenReceivedUri()).isEqualTo(nextTokenUri); + } + + @Test + public void testObtainTokenWhenTokenIsNeverAvailable() + { + RedirectHandler redirectHandler = new MockRedirectHandler(); + + TokenPoller poller = MockTokenPoller.onPoll(tokenUri -> { + sleepUninterruptibly(20, TimeUnit.MILLISECONDS); + return TokenPollResult.pending(TOKEN_URI); + }); + + Optional token = new ExternalAuthentication(TOKEN_URI, Optional.of(REDIRECT_URI)) + .obtainToken(TIMEOUT, redirectHandler, poller); + + assertThat(token).isEmpty(); + } + + @Test + public void testObtainTokenWhenPollingFails() + { + RedirectHandler redirectHandler = new MockRedirectHandler(); + + TokenPoller poller = new MockTokenPoller() + .withResult(TOKEN_URI, TokenPollResult.failed("error")); + + assertThatThrownBy(() -> new ExternalAuthentication(TOKEN_URI, Optional.of(REDIRECT_URI)) + .obtainToken(TIMEOUT, redirectHandler, poller)) + .isInstanceOf(ClientException.class) + .hasMessage("error"); + } + + @Test + public void testObtainTokenWhenPollingFailsWithException() + { + RedirectHandler redirectHandler = new MockRedirectHandler(); + + TokenPoller poller = MockTokenPoller.onPoll(uri -> { + throw new UncheckedIOException(new IOException("polling error")); + }); + + assertThatThrownBy(() -> new ExternalAuthentication(TOKEN_URI, Optional.of(REDIRECT_URI)) + .obtainToken(TIMEOUT, redirectHandler, poller)) + .isInstanceOf(UncheckedIOException.class) + .hasRootCauseInstanceOf(IOException.class); + } + + @Test + public void testObtainTokenWhenNoRedirectUriHasBeenProvided() + { + MockRedirectHandler redirectHandler = new MockRedirectHandler(); + + TokenPoller poller = new MockTokenPoller() + .withResult(TOKEN_URI, TokenPollResult.successful(new Token(AUTH_TOKEN))); + + Optional token = new ExternalAuthentication(TOKEN_URI, Optional.empty()) + .obtainToken(TIMEOUT, redirectHandler, poller); + + assertThat(redirectHandler.redirectedTo()).isNull(); + assertThat(token).map(Token::token).hasValue(AUTH_TOKEN); + } +} diff --git a/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthenticator.java b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthenticator.java new file mode 100644 index 0000000000000..86cd89fc6dfa9 --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestExternalAuthenticator.java @@ -0,0 +1,370 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import com.facebook.presto.client.ClientException; +import com.google.common.collect.ImmutableList; +import okhttp3.HttpUrl; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.Response; +import org.assertj.core.api.ListAssert; +import org.assertj.core.api.ThrowableAssert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.net.URI; +import java.net.URISyntaxException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.Stream; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.client.auth.external.ExternalAuthenticator.TOKEN_URI_FIELD; +import static com.facebook.presto.client.auth.external.ExternalAuthenticator.toAuthentication; +import static com.facebook.presto.client.auth.external.MockTokenPoller.onPoll; +import static com.facebook.presto.client.auth.external.TokenPollResult.successful; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; +import static java.lang.String.format; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; +import static java.net.URI.create; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@Test(singleThreaded = true) +public class TestExternalAuthenticator +{ + private static final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed(TestExternalAuthenticator.class.getName() + "-%d")); + + @AfterClass(alwaysRun = true) + public void shutDownThreadPool() + { + executor.shutdownNow(); + } + + @Test + public void testChallengeWithOnlyTokenServerUri() + { + assertThat(buildAuthentication("Bearer x_token_server=\"http://token.uri\"")) + .hasValueSatisfying(authentication -> { + assertThat(authentication.getRedirectUri()).isEmpty(); + assertThat(authentication.getTokenUri()).isEqualTo(create("http://token.uri")); + }); + } + + @Test + public void testChallengeWithBothUri() + { + assertThat(buildAuthentication("Bearer x_redirect_server=\"http://redirect.uri\", x_token_server=\"http://token.uri\"")) + .hasValueSatisfying(authentication -> { + assertThat(authentication.getRedirectUri()).hasValue(create("http://redirect.uri")); + assertThat(authentication.getTokenUri()).isEqualTo(create("http://token.uri")); + }); + } + + @Test + public void testChallengeWithValuesWithoutQuotes() + { + // this is legal according to RFC 7235 + assertThat(buildAuthentication("Bearer x_redirect_server=http://redirect.uri, x_token_server=http://token.uri")) + .hasValueSatisfying(authentication -> { + assertThat(authentication.getRedirectUri()).hasValue(create("http://redirect.uri")); + assertThat(authentication.getTokenUri()).isEqualTo(create("http://token.uri")); + }); + } + + @Test + public void testChallengeWithAdditionalFields() + { + assertThat(buildAuthentication("Bearer type=\"token\", x_redirect_server=\"http://redirect.uri\", x_token_server=\"http://token.uri\", description=\"oauth challenge\"")) + .hasValueSatisfying(authentication -> { + assertThat(authentication.getRedirectUri()).hasValue(create("http://redirect.uri")); + assertThat(authentication.getTokenUri()).isEqualTo(create("http://token.uri")); + }); + } + + @Test + public void testInvalidChallenges() + { + // no authentication parameters + assertThat(buildAuthentication("Bearer")).isEmpty(); + + // no Bearer scheme prefix + assertThat(buildAuthentication("x_redirect_server=\"http://redirect.uri\", x_token_server=\"http://token.uri\"")).isEmpty(); + + // space instead of comma + assertThat(buildAuthentication("Bearer x_redirect_server=\"http://redirect.uri\" x_token_server=\"http://token.uri\"")).isEmpty(); + + // equals sign instead of comma + assertThat(buildAuthentication("Bearer x_redirect_server=\"http://redirect.uri\"=x_token_server=\"http://token.uri\"")).isEmpty(); + } + + @Test + public void testChallengeWithMalformedUri() + { + assertThatThrownBy(() -> buildAuthentication("Bearer x_token_server=\"http://[1.1.1.1]\"")) + .isInstanceOf(ClientException.class) + .hasMessageContaining(format("Failed to parse URI for field '%s'", TOKEN_URI_FIELD)) + .hasRootCauseInstanceOf(URISyntaxException.class); + } + + @Test + public void testAuthentication() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token"))); + ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)); + + Request authenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"")); + + assertThat(authenticated.headers(AUTHORIZATION)) + .containsExactly("Bearer valid-token"); + } + + @Test + public void testReAuthenticationAfterRejectingToken() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("first-token"))) + .withResult(URI.create("http://token.uri"), successful(new Token("second-token"))); + ExternalAuthenticator authenticator = new ExternalAuthenticator(uri -> {}, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)); + + Request request = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"")); + Request reAuthenticated = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\"", request)); + + assertThat(reAuthenticated.headers(AUTHORIZATION)) + .containsExactly("Bearer second-token"); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithLocallyStoredToken() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-1"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-2"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-3"))) + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token-4"))); + MockRedirectHandler redirectHandler = new MockRedirectHandler(); + + List> requests = times( + 4, + () -> new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.local(), Duration.ofSeconds(1)) + .authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .extracting(Request::headers) + .extracting(headers -> headers.get(AUTHORIZATION)) + .contains("Bearer valid-token-1", "Bearer valid-token-2", "Bearer valid-token-3", "Bearer valid-token-4"); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(4); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedToken() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), successful(new Token("valid-token"))); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofSeconds(1)); + + List> requests = times( + 2, + () -> new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1)) + .authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .extracting(Request::headers) + .extracting(headers -> headers.get(AUTHORIZATION)) + .containsOnly("Bearer valid-token"); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateFails() + { + MockTokenPoller tokenPoller = new MockTokenPoller() + .withResult(URI.create("http://token.uri"), TokenPollResult.successful(new Token("first-token"))) + .withResult(URI.create("http://token.uri"), TokenPollResult.failed("external authentication error")); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMillis(500)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1)); + Request firstRequest = authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"")); + + List> requests = times( + 4, + () -> new ExternalAuthenticator(redirectHandler, tokenPoller, KnownToken.memoryCached(), Duration.ofSeconds(1)) + .authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\"", firstRequest))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.firstException().hasMessage("external authentication error") + .isInstanceOf(ClientException.class); + + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(2); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateTimesOut() + { + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofSeconds(1)); + + List> requests = times( + 2, + () -> new ExternalAuthenticator(redirectHandler, onPoll(TokenPollResult::pending), KnownToken.memoryCached(), Duration.ofMillis(1)) + .authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(requests); + assertion.requests() + .containsExactly(null, null); + assertion.assertThatNoExceptionsHasBeenThrown(); + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + @Test(timeOut = 2000) + public void testAuthenticationFromMultipleThreadsWithCachedTokenAfterAuthenticateIsInterrupted() + throws Exception + { + ExecutorService interruptableThreadPool = newCachedThreadPool(daemonThreadsNamed(this.getClass().getName() + "-interruptable-%d")); + MockRedirectHandler redirectHandler = new MockRedirectHandler() + .sleepOnRedirect(Duration.ofMinutes(1)); + + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, onPoll(TokenPollResult::pending), KnownToken.memoryCached(), Duration.ofMillis(1)); + Future interruptedAuthentication = interruptableThreadPool.submit( + () -> authenticator.authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))); + Thread.sleep(100); //It's here to make sure that authentication will start before the other threads. + List> requests = times( + 2, + () -> new ExternalAuthenticator(redirectHandler, onPoll(TokenPollResult::pending), KnownToken.memoryCached(), Duration.ofMillis(1)) + .authenticate(null, getUnauthorizedResponse("Bearer x_token_server=\"http://token.uri\", x_redirect_server=\"http://redirect.uri\""))) + .map(executor::submit) + .collect(toImmutableList()); + + Thread.sleep(100); + interruptableThreadPool.shutdownNow(); + + ConcurrentRequestAssertion assertion = new ConcurrentRequestAssertion(ImmutableList.>builder() + .addAll(requests) + .add(interruptedAuthentication) + .build()); + assertion.requests().containsExactly(null, null); + assertion.firstException().hasRootCauseInstanceOf(InterruptedException.class); + + assertThat(redirectHandler.getRedirectionCount()).isEqualTo(1); + } + + private static Stream> times(int times, Callable request) + { + return Stream.generate(() -> request) + .limit(times); + } + + private static Optional buildAuthentication(String challengeHeader) + { + return toAuthentication(getUnauthorizedResponse(challengeHeader)); + } + + private static Response getUnauthorizedResponse(String challengeHeader) + { + return getUnauthorizedResponse(challengeHeader, + new Request.Builder() + .url(HttpUrl.get("http://example.com")) + .build()); + } + + private static Response getUnauthorizedResponse(String challengeHeader, Request request) + { + return new Response.Builder() + .request(request) + .protocol(Protocol.HTTP_1_1) + .code(HTTP_UNAUTHORIZED) + .message("Unauthorized") + .header(WWW_AUTHENTICATE, challengeHeader) + .build(); + } + + static class ConcurrentRequestAssertion + { + private final List exceptions = new ArrayList<>(); + private final List requests = new ArrayList<>(); + + public ConcurrentRequestAssertion(List> requests) + { + for (Future request : requests) { + try { + this.requests.add(request.get()); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + catch (CancellationException ex) { + exceptions.add(ex); + } + catch (ExecutionException ex) { + checkState(ex.getCause() != null, "Missing cause on ExecutionException " + ex.getMessage()); + + exceptions.add(ex.getCause()); + } + } + } + + ThrowableAssert firstException() + { + return exceptions.stream() + .findFirst() + .map(ThrowableAssert::new) + .orElseGet(() -> new ThrowableAssert(() -> null)); + } + + void assertThatNoExceptionsHasBeenThrown() + { + if (!exceptions.isEmpty()) { + Throwable firstException = exceptions.get(0); + AssertionError assertionError = new AssertionError("Expected no exceptions, but some exceptions has been thrown", firstException); + for (int i = 1; i < exceptions.size(); i++) { + assertionError.addSuppressed(exceptions.get(i)); + } + throw assertionError; + } + } + + ListAssert requests() + { + return assertThat(requests); + } + } +} diff --git a/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestHttpTokenPoller.java b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestHttpTokenPoller.java new file mode 100644 index 0000000000000..69f110d4ec917 --- /dev/null +++ b/presto-client/src/test/java/com/facebook/presto/client/auth/external/TestHttpTokenPoller.java @@ -0,0 +1,224 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.client.auth.external; + +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.time.Duration; + +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.lang.String.format; +import static java.net.HttpURLConnection.HTTP_GONE; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; +import static java.net.URI.create; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@Test(singleThreaded = true) +public class TestHttpTokenPoller +{ + private static final String TOKEN_PATH = "/v1/authentications/sso/test/token"; + private static final Duration ONE_SECOND = Duration.ofSeconds(1); + + private TokenPoller tokenPoller; + private MockWebServer server; + + @BeforeMethod(alwaysRun = true) + public void setup() + throws Exception + { + server = new MockWebServer(); + server.start(); + + tokenPoller = new HttpTokenPoller(new OkHttpClient.Builder() + .callTimeout(Duration.ofMillis(500)) + .build()); + } + + @AfterMethod(alwaysRun = true) + public void teardown() + throws IOException + { + server.close(); + server = null; + } + + @Test + public void testTokenReady() + { + server.enqueue(statusAndBody(HTTP_OK, jsonPair("token", "token"))); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.getToken().token()).isEqualTo("token"); + assertThat(server.getRequestCount()).isEqualTo(1); + } + + @Test + public void testTokenNotReady() + { + server.enqueue(statusAndBody(HTTP_OK, jsonPair("nextUri", tokenUri()))); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.isPending()).isTrue(); + assertThat(server.getRequestCount()).isEqualTo(1); + } + + @Test + public void testErrorResponse() + { + server.enqueue(statusAndBody(HTTP_OK, jsonPair("error", "test failure"))); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.isFailed()).isTrue(); + assertThat(result.getError()).contains("test failure"); + } + + @Test + public void testBadHttpStatus() + { + server.enqueue(new MockResponse().setResponseCode(HTTP_GONE)); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.isFailed()).isTrue(); + assertThat(result.getError()) + .matches("Request to http://.* failed: JsonResponse\\{statusCode=410, .*"); + } + + @Test + public void testInvalidJsonBody() + { + server.enqueue(statusAndBody(HTTP_OK, jsonPair("foo", "bar"))); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.isFailed()).isTrue(); + assertThat(result.getError()) + .isEqualTo("Failed to poll for token. No fields set in response."); + } + + @Test + public void testInvalidNextUri() + { + server.enqueue(statusAndBody(HTTP_OK, jsonPair("nextUri", ":::"))); + + TokenPollResult result = tokenPoller.pollForToken(tokenUri(), ONE_SECOND); + + assertThat(result.isFailed()).isTrue(); + assertThat(result.getError()) + .matches("Request to http://.* failed: JsonResponse\\{statusCode=200, .*, hasValue=false} .*"); + } + + @Test + public void testHttpStatus503() + { + for (int i = 1; i <= 100; i++) { + server.enqueue(statusAndBody(HTTP_UNAVAILABLE, "Server failure #" + i)); + } + + assertThatThrownBy(() -> tokenPoller.pollForToken(tokenUri(), ONE_SECOND)) + .isInstanceOf(UncheckedIOException.class) + .hasRootCauseExactlyInstanceOf(IOException.class); + + assertThat(server.getRequestCount()).isGreaterThan(1); + } + + @Test + public void testHttpTimeout() + { + // force request to timeout by not enqueuing response + + assertThatThrownBy(() -> tokenPoller.pollForToken(tokenUri(), ONE_SECOND)) + .isInstanceOf(UncheckedIOException.class) + .hasMessageEndingWith(": timeout"); + } + + @Test + public void testTokenReceived() + throws InterruptedException + { + server.enqueue(status(HTTP_OK)); + + tokenPoller.tokenReceived(tokenUri()); + + RecordedRequest request = server.takeRequest(1, MILLISECONDS); + assertThat(request.getMethod()).isEqualTo("DELETE"); + assertThat(request.getRequestUrl()).isEqualTo(HttpUrl.get(tokenUri())); + } + + @Test + public void testTokenReceivedRetriesUntilNotErrorReturned() + { + server.enqueue(status(HTTP_UNAVAILABLE)); + server.enqueue(status(HTTP_UNAVAILABLE)); + server.enqueue(status(HTTP_UNAVAILABLE)); + server.enqueue(status(202)); + + tokenPoller.tokenReceived(tokenUri()); + + assertThat(server.getRequestCount()).isEqualTo(4); + } + + @Test + public void testTokenReceivedDoesNotRetriesIndefinitely() + { + for (int i = 1; i <= 100; i++) { + server.enqueue(status(HTTP_UNAVAILABLE)); + } + + tokenPoller.tokenReceived(tokenUri()); + + assertThat(server.getRequestCount()).isLessThan(100); + } + + private URI tokenUri() + { + return create("http://" + server.getHostName() + ":" + server.getPort() + TOKEN_PATH); + } + + private static String jsonPair(String key, Object value) + { + return format("{\"%s\": \"%s\"}", key, value); + } + + private static MockResponse statusAndBody(int status, String body) + { + return new MockResponse() + .setResponseCode(status) + .addHeader(CONTENT_TYPE, JSON_UTF_8) + .setBody(body); + } + + private static MockResponse status(int status) + { + return new MockResponse() + .setResponseCode(status); + } +} diff --git a/presto-clp/pom.xml b/presto-clp/pom.xml index d13b59968aff5..f15165d13ffec 100644 --- a/presto-clp/pom.xml +++ b/presto-clp/pom.xml @@ -6,7 +6,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-clp @@ -88,7 +88,7 @@ - io.airlift + com.facebook.airlift units provided diff --git a/presto-clp/src/main/java/com/facebook/presto/plugin/clp/ClpMetadata.java b/presto-clp/src/main/java/com/facebook/presto/plugin/clp/ClpMetadata.java index 1f9962a3456d9..1dac82c4d29e0 100644 --- a/presto-clp/src/main/java/com/facebook/presto/plugin/clp/ClpMetadata.java +++ b/presto-clp/src/main/java/com/facebook/presto/plugin/clp/ClpMetadata.java @@ -203,4 +203,14 @@ private List listTables(String schemaName) { return tableHandleCache.getUnchecked(schemaName); } + + // Presto 0.297 added normalizeIdentifier() to MetadataManager.getColumnHandles(), which + // lowercases column names by default. CLP schema field names are case-sensitive, so we + // override this to return the identifier unchanged and prevent the case mismatch between + // getColumnHandles() and getTableMetadata().getColumns(). + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return identifier; + } } diff --git a/presto-clp/src/main/java/com/facebook/presto/plugin/clp/split/filter/ClpSplitFilterProvider.java b/presto-clp/src/main/java/com/facebook/presto/plugin/clp/split/filter/ClpSplitFilterProvider.java index 0609843aaf22f..6d6a1e907625c 100644 --- a/presto-clp/src/main/java/com/facebook/presto/plugin/clp/split/filter/ClpSplitFilterProvider.java +++ b/presto-clp/src/main/java/com/facebook/presto/plugin/clp/split/filter/ClpSplitFilterProvider.java @@ -22,8 +22,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import java.io.File; import java.io.IOException; +import java.nio.file.Paths; import java.util.List; import java.util.Map; import java.util.Set; @@ -67,7 +67,7 @@ public ClpSplitFilterProvider(ClpConfig config) mapper.registerModule(module); try { filterMap = mapper.readValue( - new File(config.getSplitFilterConfig()), + Paths.get(config.getSplitFilterConfig()).toFile(), new TypeReference>>() {}); } catch (IOException e) { diff --git a/presto-clp/src/test/java/com/facebook/presto/plugin/clp/ClpMetadataDbSetUp.java b/presto-clp/src/test/java/com/facebook/presto/plugin/clp/ClpMetadataDbSetUp.java index d1d0ee6964c8e..275926af67e4c 100644 --- a/presto-clp/src/test/java/com/facebook/presto/plugin/clp/ClpMetadataDbSetUp.java +++ b/presto-clp/src/test/java/com/facebook/presto/plugin/clp/ClpMetadataDbSetUp.java @@ -23,6 +23,7 @@ import java.io.File; import java.io.IOException; +import java.nio.file.Paths; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; @@ -180,7 +181,7 @@ public static ClpMySqlSplitProvider setupSplit(DbHandle dbHandle, Map presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT 4.0.0 presto-cluster-ttl-providers + presto-cluster-ttl-providers Presto - Cluster Ttl Providers presto-plugin ${project.parent.basedir} + true @@ -44,7 +46,7 @@ - io.airlift + com.facebook.airlift units provided @@ -62,4 +64,4 @@ - \ No newline at end of file + diff --git a/presto-common/pom.xml b/presto-common/pom.xml index 6f1a57adf6535..f4000918fee30 100644 --- a/presto-common/pom.xml +++ b/presto-common/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-common @@ -13,27 +13,47 @@ ${project.parent.basedir} + 11 + true - + com.fasterxml.jackson.core jackson-annotations + + jakarta.inject + jakarta.inject-api + test + + + + javax.inject + javax.inject + test + + io.airlift slice - com.facebook.drift + com.facebook.airlift.drift drift-api - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations + true + + + + jakarta.annotation + jakarta.annotation-api true @@ -104,13 +124,13 @@ - com.facebook.drift + com.facebook.airlift.drift drift-codec test - com.facebook.drift + com.facebook.airlift.drift drift-protocol test diff --git a/presto-common/src/main/java/com/facebook/presto/common/CatalogSchemaName.java b/presto-common/src/main/java/com/facebook/presto/common/CatalogSchemaName.java index ed85a12904a5c..3b63bc89cb385 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/CatalogSchemaName.java +++ b/presto-common/src/main/java/com/facebook/presto/common/CatalogSchemaName.java @@ -34,7 +34,7 @@ public final class CatalogSchemaName public CatalogSchemaName(String catalogName, String schemaName) { this.catalogName = requireNonNull(catalogName, "catalogName is null").toLowerCase(ENGLISH); - this.schemaName = requireNonNull(schemaName, "schemaName is null").toLowerCase(ENGLISH); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); } @ThriftField(1) diff --git a/presto-common/src/main/java/com/facebook/presto/common/ErrorCode.java b/presto-common/src/main/java/com/facebook/presto/common/ErrorCode.java index cb810d5d543ff..f6803c686f0c4 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/ErrorCode.java +++ b/presto-common/src/main/java/com/facebook/presto/common/ErrorCode.java @@ -30,6 +30,7 @@ public final class ErrorCode private final String name; private final ErrorType type; private final boolean retriable; + private final boolean catchableByTry; @JsonCreator @ThriftConstructor @@ -37,7 +38,8 @@ public ErrorCode( @JsonProperty("code") int code, @JsonProperty("name") String name, @JsonProperty("type") ErrorType type, - @JsonProperty("retriable") boolean retriable) + @JsonProperty("retriable") boolean retriable, + @JsonProperty("catchableByTry") boolean catchableByTry) { if (code < 0) { throw new IllegalArgumentException("code is negative"); @@ -46,11 +48,17 @@ public ErrorCode( this.name = requireNonNull(name, "name is null"); this.type = requireNonNull(type, "type is null"); this.retriable = retriable; + this.catchableByTry = catchableByTry; } public ErrorCode(int code, String name, ErrorType type) { - this(code, name, type, false); + this(code, name, type, false, false); + } + + public ErrorCode(int code, String name, ErrorType type, boolean retriable) + { + this(code, name, type, retriable, false); } @JsonProperty @@ -81,6 +89,13 @@ public boolean isRetriable() return retriable; } + @JsonProperty + @ThriftField(5) + public boolean isCatchableByTry() + { + return catchableByTry; + } + @Override public String toString() { diff --git a/presto-common/src/main/java/com/facebook/presto/common/Page.java b/presto-common/src/main/java/com/facebook/presto/common/Page.java index 2e941461dca9f..8eb4fe7550cc5 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/Page.java +++ b/presto-common/src/main/java/com/facebook/presto/common/Page.java @@ -471,7 +471,7 @@ public Page replaceColumn(int channelIndex, Block column) Block[] newBlocks = Arrays.copyOf(blocks, blocks.length); newBlocks[channelIndex] = column; - return Page.wrapBlocksWithoutCopy(newBlocks.length, newBlocks); + return Page.wrapBlocksWithoutCopy(positionCount, newBlocks); } private static class DictionaryBlockIndexes diff --git a/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java b/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java index 70d724822bfa3..a6a7207f94709 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java +++ b/presto-common/src/main/java/com/facebook/presto/common/QualifiedObjectName.java @@ -17,9 +17,9 @@ import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; @@ -58,12 +58,14 @@ public static QualifiedObjectName valueOf(String catalogName, String schemaName, return new QualifiedObjectName(catalogName, schemaName, objectName.toLowerCase(ENGLISH)); } + @JsonCreator @ThriftConstructor - public QualifiedObjectName(String catalogName, String schemaName, String objectName) + public QualifiedObjectName( + @JsonProperty("catalogName") String catalogName, + @JsonProperty("schemaName") String schemaName, + @JsonProperty("objectName") String objectName) { checkLowerCase(catalogName, "catalogName"); - checkLowerCase(schemaName, "schemaName"); - checkLowerCase(objectName, "objectName"); this.catalogName = catalogName; this.schemaName = schemaName; this.objectName = objectName; @@ -75,18 +77,21 @@ public CatalogSchemaName getCatalogSchemaName() } @ThriftField(1) + @JsonProperty("catalogName") public String getCatalogName() { return catalogName; } @ThriftField(2) + @JsonProperty("schemaName") public String getSchemaName() { return schemaName; } @ThriftField(3) + @JsonProperty("objectName") public String getObjectName() { return objectName; diff --git a/presto-common/src/main/java/com/facebook/presto/common/RuntimeMetricName.java b/presto-common/src/main/java/com/facebook/presto/common/RuntimeMetricName.java index 321576bb5eb7b..7e982a98489b5 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/RuntimeMetricName.java +++ b/presto-common/src/main/java/com/facebook/presto/common/RuntimeMetricName.java @@ -55,9 +55,12 @@ private RuntimeMetricName() public static final String CREATE_SCHEDULER_TIME_NANOS = "createSchedulerTimeNanos"; public static final String LOGICAL_PLANNER_TIME_NANOS = "logicalPlannerTimeNanos"; public static final String OPTIMIZER_TIME_NANOS = "optimizerTimeNanos"; + public static final String VALIDATE_FINAL_PLAN_TIME_NANOS = "validateFinalPlanTimeNanos"; + public static final String VALIDATE_INTERMEDIATE_PLAN_TIME_NANOS = "validateIntermediatePlanTimeNanos"; public static final String GET_CANONICAL_INFO_TIME_NANOS = "getCanonicalInfoTimeNanos"; public static final String FRAGMENT_PLAN_TIME_NANOS = "fragmentPlanTimeNanos"; public static final String GET_LAYOUT_TIME_NANOS = "getLayoutTimeNanos"; + public static final String GET_IDENTIFIER_NORMALIZATION_TIME_NANOS = "getIdentifierNormalizationTimeNanos"; public static final String REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED = "rewriteAggregationIfToFilterApplied"; // Time between task creation and start. public static final String TASK_QUEUED_TIME_NANOS = "taskQueuedTimeNanos"; @@ -66,6 +69,9 @@ private RuntimeMetricName() // Blocked time for the operators due to waiting for inputs. public static final String TASK_BLOCKED_TIME_NANOS = "taskBlockedTimeNanos"; public static final String TASK_UPDATE_DELIVERED_WALL_TIME_NANOS = "taskUpdateDeliveredWallTimeNanos"; + public static final String TASK_START_WAIT_FOR_EVENT_LOOP = "taskStartWaitForEventLoop"; + public static final String TASK_UPDATE_DELIVERED_UPDATES = "taskUpdateDeliveredUpdates"; + public static final String TASK_UPDATE_ROUND_TRIP_TIME = "taskUpdateRoundTripTime"; public static final String TASK_UPDATE_SERIALIZED_CPU_TIME_NANOS = "taskUpdateSerializedCpuNanos"; public static final String TASK_PLAN_SERIALIZED_CPU_TIME_NANOS = "taskPlanSerializedCpuNanos"; // Time for event loop to execute a method diff --git a/presto-common/src/main/java/com/facebook/presto/common/SourceColumn.java b/presto-common/src/main/java/com/facebook/presto/common/SourceColumn.java new file mode 100644 index 0000000000000..9125e3e120c73 --- /dev/null +++ b/presto-common/src/main/java/com/facebook/presto/common/SourceColumn.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.common; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class SourceColumn +{ + private final QualifiedObjectName tableName; + private final String columnName; + + @JsonCreator + public SourceColumn( + @JsonProperty("tableName") QualifiedObjectName tableName, + @JsonProperty("columnName") String columnName) + { + this.tableName = requireNonNull(tableName, "tableName is null"); + this.columnName = requireNonNull(columnName, "columnName is null"); + } + + @JsonProperty + public QualifiedObjectName getTableName() + { + return tableName; + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @Override + public int hashCode() + { + return Objects.hash(tableName, columnName); + } + + @Override + public boolean equals(Object obj) + { + if (obj == this) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + SourceColumn entry = (SourceColumn) obj; + return Objects.equals(tableName, entry.tableName) && + Objects.equals(columnName, entry.columnName); + } + + @Override + public String toString() + { + return "SourceColumn{" + + "tableName=" + tableName + + ", columnName='" + columnName + '\'' + + '}'; + } +} diff --git a/presto-common/src/main/java/com/facebook/presto/common/Subfield.java b/presto-common/src/main/java/com/facebook/presto/common/Subfield.java index 5854fff8a8731..4d49c557fcaab 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/Subfield.java +++ b/presto-common/src/main/java/com/facebook/presto/common/Subfield.java @@ -81,6 +81,31 @@ public String toString() } } + public static final class StructureOnly + implements PathElement + { + private static final StructureOnly STRUCTURE_ONLY = new StructureOnly(); + + private StructureOnly() {} + + public static StructureOnly getInstance() + { + return STRUCTURE_ONLY; + } + + @Override + public boolean isSubscript() + { + return true; + } + + @Override + public String toString() + { + return "[$]"; + } + } + public static final class NestedField implements PathElement { @@ -238,6 +263,11 @@ public static PathElement noSubfield() return NoSubfield.getInstance(); } + public static PathElement structureOnly() + { + return StructureOnly.getInstance(); + } + @JsonCreator public Subfield(String path) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/SubfieldTokenizer.java b/presto-common/src/main/java/com/facebook/presto/common/SubfieldTokenizer.java index 5e89f0d6fe288..562a99e1c151a 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/SubfieldTokenizer.java +++ b/presto-common/src/main/java/com/facebook/presto/common/SubfieldTokenizer.java @@ -106,7 +106,7 @@ private Subfield.PathElement computeNext() } if (tryMatch(OPEN_BRACKET)) { - Subfield.PathElement token = tryMatch(QUOTE) ? matchQuotedSubscript() : tryMatch(WILDCARD) ? matchWildcardSubscript() : matchUnquotedSubscript(); + Subfield.PathElement token = tryMatch(QUOTE) ? matchQuotedSubscript() : tryMatch(WILDCARD) ? matchWildcardSubscript() : tryMatch(DOLLAR) ? matchStructureOnlySubscript() : matchUnquotedSubscript(); match(CLOSE_BRACKET); firstSegment = false; @@ -151,9 +151,14 @@ private Subfield.PathElement matchDollarPathElement() return Subfield.noSubfield(); } + private Subfield.PathElement matchStructureOnlySubscript() + { + return Subfield.structureOnly(); + } + private static boolean isUnquotedPathCharacter(char c) { - return c == ':' || c == '$' || c == '-' || c == '/' || c == '@' || c == '|' || c == '#' || c == ' ' || isUnquotedSubscriptCharacter(c); + return c == ':' || c == '$' || c == '-' || c == '/' || c == '@' || c == '|' || c == '#' || c == ' ' || c == '<' || c == '>' || isUnquotedSubscriptCharacter(c); } private Subfield.PathElement matchUnquotedSubscript() diff --git a/presto-common/src/main/java/com/facebook/presto/common/Utils.java b/presto-common/src/main/java/com/facebook/presto/common/Utils.java index dc31b1778252f..374a55d0c9b30 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/Utils.java +++ b/presto-common/src/main/java/com/facebook/presto/common/Utils.java @@ -17,8 +17,7 @@ import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.predicate.Primitives; import com.facebook.presto.common.type.Type; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.function.Supplier; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractArrayBlock.java index ce8993ac09902..cd7c10f1a7f12 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractArrayBlock.java @@ -14,8 +14,7 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.OptionalInt; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractMapBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractMapBlock.java index 804ad94761474..d1342501e9a15 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractMapBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractMapBlock.java @@ -15,10 +15,9 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.lang.invoke.MethodHandle; import java.util.Arrays; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractRowBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractRowBlock.java index 6f1492a1bb84c..9af96953d7135 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/AbstractRowBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/AbstractRowBlock.java @@ -14,8 +14,7 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.OptionalInt; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlock.java index 801c589ad1330..8f5c2a60eaa17 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlock.java @@ -13,10 +13,9 @@ */ package com.facebook.presto.common.block; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlockBuilder.java index d17581459a401..866022de09a97 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ArrayBlockBuilder.java @@ -15,10 +15,9 @@ import com.facebook.presto.common.type.Type; import io.airlift.slice.SliceInput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/BlockUtil.java b/presto-common/src/main/java/com/facebook/presto/common/block/BlockUtil.java index d940f0dad103c..f2d9b3505d4fb 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/BlockUtil.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/BlockUtil.java @@ -16,8 +16,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceTooLargeException; import io.airlift.slice.Slices; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Arrays; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlock.java index 437464b0d0277..3f025870df11e 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlock.java @@ -14,10 +14,9 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlockBuilder.java index a9c658da9547d..ba38e0017bc3d 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ByteArrayBlockBuilder.java @@ -15,10 +15,9 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ClosingBlockLease.java b/presto-common/src/main/java/com/facebook/presto/common/block/ClosingBlockLease.java index ff48d5a5bf78f..4a8f366fc2cf8 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ClosingBlockLease.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ClosingBlockLease.java @@ -13,8 +13,6 @@ */ package com.facebook.presto.common.block; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.List; import static java.util.Arrays.asList; @@ -22,7 +20,6 @@ import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; -@NotThreadSafe public final class ClosingBlockLease implements BlockLease { diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ColumnarMap.java b/presto-common/src/main/java/com/facebook/presto/common/block/ColumnarMap.java index 94f8b2c539147..c6d0c7ad9cbb2 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ColumnarMap.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ColumnarMap.java @@ -13,10 +13,9 @@ */ package com.facebook.presto.common.block; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import static io.airlift.slice.SizeOf.sizeOf; import static java.util.Objects.requireNonNull; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java index 9587f6c8bfc6d..29f28815f2994 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java @@ -508,6 +508,33 @@ public Block getLoadedBlock() return new DictionaryBlock(idsOffset, getPositionCount(), loadedDictionary, ids, false, randomDictionaryId()); } + public Block createProjection(Block newDictionary) + { + if (newDictionary.getPositionCount() != dictionary.getPositionCount()) { + throw new IllegalArgumentException("newDictionary must have the same position count"); + } + + // if the new dictionary is lazy be careful to not materialize it + if (newDictionary instanceof LazyBlock) { + return new LazyBlock(positionCount, (block) -> { + Block newDictionaryBlock = newDictionary.getBlock(0); + Block newBlock = createProjection(newDictionaryBlock); + block.setBlock(newBlock); + }); + } + if (newDictionary instanceof RunLengthEncodedBlock) { + RunLengthEncodedBlock rle = (RunLengthEncodedBlock) newDictionary; + return new RunLengthEncodedBlock(rle.getValue(), positionCount); + } + + // unwrap dictionary in dictionary + int[] newIds = new int[positionCount]; + for (int position = 0; position < positionCount; position++) { + newIds[position] = getIdUnchecked(position); + } + return new DictionaryBlock(0, positionCount, newDictionary, newIds, false, randomDictionaryId()); + } + public Block getDictionary() { return dictionary; @@ -518,7 +545,7 @@ Slice getIds() return Slices.wrappedIntArray(ids, idsOffset, positionCount); } - int[] getRawIds() + public int[] getRawIds() { return ids; } @@ -533,6 +560,11 @@ public int getId(int position) return ids[position + idsOffset]; } + private int getIdUnchecked(int position) + { + return ids[position + idsOffset]; + } + public DictionaryId getDictionarySourceId() { return dictionarySourceId; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlock.java index 8e90829b456dc..d3b4c6a81abe2 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlock.java @@ -16,10 +16,9 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlockBuilder.java index 1684dfd87e455..9daf19e7f1f90 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/Int128ArrayBlockBuilder.java @@ -17,10 +17,9 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlock.java index a5dc040acae80..11a4139425ddb 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlock.java @@ -14,10 +14,9 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlockBuilder.java index e91286d02569c..1c3c2d5c10d3b 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/IntArrayBlockBuilder.java @@ -15,10 +15,9 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlock.java index d36fcfc59ef74..43e3cae5eda64 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlock.java @@ -14,10 +14,9 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlockBuilder.java index 45dd5b8e509fb..e3dd877a72f83 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/LongArrayBlockBuilder.java @@ -15,10 +15,9 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/MapBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/MapBlock.java index fadf644fcce90..fbe0520ab7b7e 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/MapBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/MapBlock.java @@ -14,10 +14,9 @@ package com.facebook.presto.common.block; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.lang.invoke.MethodHandle; import java.util.Arrays; import java.util.Objects; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/MapBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/MapBlockBuilder.java index e8ab8cdd8f9b2..28630e5a78000 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/MapBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/MapBlockBuilder.java @@ -17,10 +17,9 @@ import com.facebook.presto.common.NotSupportedException; import com.facebook.presto.common.type.Type; import io.airlift.slice.SliceInput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.lang.invoke.MethodHandle; import java.util.Arrays; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/RowBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/RowBlock.java index db8856a217310..42ce04160cb20 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/RowBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/RowBlock.java @@ -13,14 +13,15 @@ */ package com.facebook.presto.common.block; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.function.ObjLongConsumer; +import java.util.stream.Collectors; import static com.facebook.presto.common.block.BlockUtil.ensureBlocksAreLoaded; import static io.airlift.slice.SizeOf.sizeOf; @@ -249,6 +250,39 @@ public void retainedBytesForEachPart(ObjLongConsumer consumer) consumer.accept(this, INSTANCE_SIZE); } + /** + * Returns the row fields from the specified block. The block maybe a LazyBlock, RunLengthEncodedBlock, or + * DictionaryBlock, but the underlying block must be a RowBlock. The returned field blocks will be the same + * length as the specified block, which means they are not null suppressed. + */ + public static List getRowFieldsFromBlock(Block block) + { + // if the block is lazy, be careful to not materialize the nested blocks + if (block instanceof LazyBlock) { + LazyBlock lazyBlock = (LazyBlock) block; + block = lazyBlock.getBlock(0); + } + + if (block instanceof RunLengthEncodedBlock) { + RunLengthEncodedBlock runLengthEncodedBlock = (RunLengthEncodedBlock) block; + RowBlock rowBlock = (RowBlock) runLengthEncodedBlock.getValue(); + return Arrays.stream(rowBlock.fieldBlocks) + .map(fieldBlock -> new RunLengthEncodedBlock(fieldBlock, runLengthEncodedBlock.getPositionCount())) + .collect(Collectors.toList()); + } + if (block instanceof DictionaryBlock) { + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + RowBlock rowBlock = (RowBlock) dictionaryBlock.getDictionary(); + return Arrays.stream(rowBlock.fieldBlocks) + .map(dictionaryBlock::createProjection) + .collect(Collectors.toList()); + } + if (block instanceof RowBlock) { + return Arrays.asList(((RowBlock) block).fieldBlocks); + } + throw new IllegalArgumentException("Unexpected block type: " + block.getClass().getSimpleName()); + } + @Override public String toString() { diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/RowBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/RowBlockBuilder.java index 4e8b2a78eaac7..905446bee6890 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/RowBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/RowBlockBuilder.java @@ -16,10 +16,9 @@ import com.facebook.presto.common.type.Type; import io.airlift.slice.SliceInput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.List; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/RunLengthEncodedBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/RunLengthEncodedBlock.java index 9deeaf7faa4cb..28a553ca9184b 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/RunLengthEncodedBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/RunLengthEncodedBlock.java @@ -17,10 +17,9 @@ import com.facebook.presto.common.type.Type; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Objects; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlock.java index 47744726f74c5..d2bc76964e1e6 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlock.java @@ -14,10 +14,9 @@ package com.facebook.presto.common.block; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlockBuilder.java index 3b949233d4968..a60b736677dbd 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/ShortArrayBlockBuilder.java @@ -15,10 +15,9 @@ import io.airlift.slice.SliceInput; import io.airlift.slice.SliceOutput; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.OptionalInt; import java.util.function.ObjLongConsumer; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/SingleMapBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/SingleMapBlock.java index d09e889347aca..b3dae6395630e 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/SingleMapBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/SingleMapBlock.java @@ -17,10 +17,9 @@ import com.facebook.presto.common.GenericInternalException; import com.facebook.presto.common.NotSupportedException; import io.airlift.slice.Slice; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.lang.invoke.MethodHandle; import java.util.Objects; import java.util.OptionalInt; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlock.java index 08978b4707b17..51d339f72bae4 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlock.java @@ -17,10 +17,9 @@ import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import io.airlift.slice.UnsafeSlice; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Arrays; import java.util.Objects; import java.util.Optional; diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlockBuilder.java b/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlockBuilder.java index 29b9a038a082f..6f14f596892b6 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlockBuilder.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/VariableWidthBlockBuilder.java @@ -19,10 +19,9 @@ import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import io.airlift.slice.UnsafeSlice; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; diff --git a/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java b/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java index 3a732de316454..8f08f12b4e925 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/resourceGroups/QueryType.java @@ -27,7 +27,9 @@ public enum QueryType INSERT(6), SELECT(7), CONTROL(8), - UPDATE(9) + UPDATE(9), + MERGE(10), + CALL_DISTRIBUTED_PROCEDURE(11) /**/; private final int value; diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/AbstractType.java b/presto-common/src/main/java/com/facebook/presto/common/type/AbstractType.java index f13fe5552d852..e52c73981fa2a 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/AbstractType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/AbstractType.java @@ -92,6 +92,12 @@ public void writeBoolean(BlockBuilder blockBuilder, boolean value) throw new UnsupportedOperationException(getClass().getName()); } + @Override + public byte getByte(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + @Override public long getLong(Block block, int position) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/DistinctType.java b/presto-common/src/main/java/com/facebook/presto/common/type/DistinctType.java index c79b7f8e0a3f0..1c5a6fd01fe99 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/DistinctType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/DistinctType.java @@ -19,10 +19,9 @@ import com.facebook.presto.common.block.BlockBuilderStatus; import com.facebook.presto.common.block.UncheckedBlock; import com.facebook.presto.common.function.SqlFunctionProperties; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; -import javax.annotation.concurrent.ThreadSafe; - import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/FunctionType.java b/presto-common/src/main/java/com/facebook/presto/common/type/FunctionType.java index 611cbde38db63..7da315cd4b6ca 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/FunctionType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/FunctionType.java @@ -144,6 +144,12 @@ public void writeBoolean(BlockBuilder blockBuilder, boolean value) throw new UnsupportedOperationException(getClass().getName()); } + @Override + public byte getByte(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + @Override public long getLong(Block block, int position) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TimeZoneKey.java b/presto-common/src/main/java/com/facebook/presto/common/type/TimeZoneKey.java index ed9d688ab78e6..80b0fc44e1840 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TimeZoneKey.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TimeZoneKey.java @@ -157,7 +157,6 @@ public static TimeZoneKey getTimeZoneKeyForOffset(long offsetMinutes) private final short key; - @ThriftConstructor public TimeZoneKey(String id, short key) { this.id = requireNonNull(id, "id is null"); @@ -167,13 +166,20 @@ public TimeZoneKey(String id, short key) this.key = key; } - @ThriftField(1) + @ThriftConstructor + public TimeZoneKey(short timeZoneKey) + { + TimeZoneKey parsedTimeZoneKey = getTimeZoneKey(timeZoneKey); + this.id = parsedTimeZoneKey.getId(); + this.key = parsedTimeZoneKey.getKey(); + } + public String getId() { return id; } - @ThriftField(2) + @ThriftField(value = 1, name = "timeZoneKey") @JsonValue public short getKey() { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java b/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java index b151f9291e61f..76103c3372af6 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TinyintType.java @@ -133,6 +133,12 @@ public long getLong(Block block, int position) return (long) block.getByte(position); } + @Override + public byte getByte(Block block, int position) + { + return block.getByte(position); + } + @Override public long getLongUnchecked(UncheckedBlock block, int internalPosition) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/Type.java b/presto-common/src/main/java/com/facebook/presto/common/type/Type.java index 251b3b11c055d..e4e2c24a5ca11 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/Type.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/Type.java @@ -98,6 +98,11 @@ default boolean equalValuesAreIdentical() */ boolean getBooleanUnchecked(UncheckedBlock block, int internalPosition); + /** + * Gets the value at the {@code block} {@code position} as a byte. + */ + byte getByte(Block block, int position); + /** * Gets the value at the {@code block} {@code position} as a long. */ diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TypeManager.java b/presto-common/src/main/java/com/facebook/presto/common/type/TypeManager.java index 277ef86ffb881..2dc2aa18cfaa2 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TypeManager.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TypeManager.java @@ -44,4 +44,9 @@ default Collection getParametricTypes() { throw new UnsupportedOperationException(); } + + /** + * Checks for the existence of this type. + */ + boolean hasType(TypeSignature signature); } diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TypeWithName.java b/presto-common/src/main/java/com/facebook/presto/common/type/TypeWithName.java index 06d718115d20f..903b31d3e8478 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TypeWithName.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TypeWithName.java @@ -145,6 +145,12 @@ public boolean getBooleanUnchecked(UncheckedBlock block, int internalPosition) return type.getBooleanUnchecked(block, internalPosition); } + @Override + public byte getByte(Block block, int position) + { + return type.getByte(block, position); + } + @Override public long getLong(Block block, int position) { diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/UserDefinedType.java b/presto-common/src/main/java/com/facebook/presto/common/type/UserDefinedType.java index 06618c224ae92..22a4759fe1504 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/UserDefinedType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/UserDefinedType.java @@ -18,7 +18,7 @@ import static java.util.Objects.requireNonNull; /** - * UserDefinedType represents an enum type, or a distinct type. + * UserDefinedType represents an enum type, a distinct type, or a primitive type with a name. * Type definition is defined by user and is extracted at runtime. */ public class UserDefinedType @@ -46,4 +46,14 @@ public boolean isDistinctType() { return representation.isDistinctType(); } + + public boolean isEnumType() + { + return representation.isEnum(); + } + + public boolean isPrimitiveType() + { + return !representation.isDistinctType() && representation.getTypeSignatureBase().hasStandardType(); + } } diff --git a/presto-common/src/test/java/com/facebook/presto/common/TestErrorCode.java b/presto-common/src/test/java/com/facebook/presto/common/TestErrorCode.java new file mode 100644 index 0000000000000..9f256c8721e2e --- /dev/null +++ b/presto-common/src/test/java/com/facebook/presto/common/TestErrorCode.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.common; + +import org.testng.annotations.Test; + +import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR; +import static com.facebook.presto.common.ErrorType.USER_ERROR; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestErrorCode +{ + @Test + public void testCatchableByTryFlag() + { + // Error code with catchableByTry = true + ErrorCode catchable = new ErrorCode(1, "CATCHABLE_ERROR", USER_ERROR, false, true); + assertTrue(catchable.isCatchableByTry()); + assertFalse(catchable.isRetriable()); + + // Error code with catchableByTry = false + ErrorCode notCatchable = new ErrorCode(2, "NOT_CATCHABLE_ERROR", USER_ERROR, false, false); + assertFalse(notCatchable.isCatchableByTry()); + } + + @Test + public void testBackwardCompatibleConstructorDefaultsToNotCatchable() + { + // 3-parameter constructor should default catchableByTry to false + ErrorCode threeParam = new ErrorCode(1, "TEST_ERROR", USER_ERROR); + assertFalse(threeParam.isCatchableByTry()); + assertFalse(threeParam.isRetriable()); + + // 4-parameter constructor should default catchableByTry to false + ErrorCode fourParam = new ErrorCode(2, "TEST_ERROR_2", USER_ERROR, true); + assertFalse(fourParam.isCatchableByTry()); + assertTrue(fourParam.isRetriable()); + } + + @Test + public void testErrorCodeProperties() + { + ErrorCode errorCode = new ErrorCode(123, "TEST_ERROR", INTERNAL_ERROR, true, true); + + assertEquals(errorCode.getCode(), 123); + assertEquals(errorCode.getName(), "TEST_ERROR"); + assertEquals(errorCode.getType(), INTERNAL_ERROR); + assertTrue(errorCode.isRetriable()); + assertTrue(errorCode.isCatchableByTry()); + } + + @Test + public void testEquality() + { + // ErrorCode equality is based on code only (existing behavior) + ErrorCode error1 = new ErrorCode(1, "ERROR_A", USER_ERROR, false, true); + ErrorCode error2 = new ErrorCode(1, "ERROR_B", INTERNAL_ERROR, true, false); + + assertEquals(error1, error2); + assertEquals(error1.hashCode(), error2.hashCode()); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testNegativeCodeThrows() + { + new ErrorCode(-1, "NEGATIVE_CODE", USER_ERROR, false, false); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testNullNameThrows() + { + new ErrorCode(1, null, USER_ERROR, false, false); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testNullTypeThrows() + { + new ErrorCode(1, "TEST", null, false, false); + } +} diff --git a/presto-common/src/test/java/com/facebook/presto/common/TestPage.java b/presto-common/src/test/java/com/facebook/presto/common/TestPage.java index 5a5deb8fbc09e..c7b048a157763 100644 --- a/presto-common/src/test/java/com/facebook/presto/common/TestPage.java +++ b/presto-common/src/test/java/com/facebook/presto/common/TestPage.java @@ -182,6 +182,7 @@ public void testReplaceColumn() Page newPage = page.replaceColumn(1, newBlock); assertEquals(newPage.getChannelCount(), 3); + assertEquals(newPage.getPositionCount(), entries); assertEquals(newPage.getBlock(1).getLong(0), 0); assertEquals(newPage.getBlock(1).getLong(1), -1); } diff --git a/presto-common/src/test/java/com/facebook/presto/common/TestSubfieldTokenizer.java b/presto-common/src/test/java/com/facebook/presto/common/TestSubfieldTokenizer.java index 58c1726de7260..83307701800e1 100644 --- a/presto-common/src/test/java/com/facebook/presto/common/TestSubfieldTokenizer.java +++ b/presto-common/src/test/java/com/facebook/presto/common/TestSubfieldTokenizer.java @@ -82,6 +82,19 @@ public void testColumnNames() assertPath(new Subfield("a and b", ImmutableList.of())); } + @Test + public void testAngleBracketsInColumnNames() + { + assertPath(new Subfield("<>col", ImmutableList.of())); + assertPath(new Subfield("colbrackets", ImmutableList.of())); + assertPath(new Subfield("<>col", ImmutableList.of(new NestedField("<>field")))); + assertPath(new Subfield("table", ImmutableList.of(new NestedField("<>field")))); + assertPath(new Subfield("table", ImmutableList.of(new Subfield.StringSubscript("<>value>")))); + assertPath(new Subfield("", ImmutableList.of( + new NestedField(""), + new Subfield.StringSubscript("")))); + } + @Test public void testInvalidPaths() { diff --git a/presto-common/src/test/java/com/facebook/presto/common/predicate/TestSortedRangeSet.java b/presto-common/src/test/java/com/facebook/presto/common/predicate/TestSortedRangeSet.java index 087073c432bd5..a096c0462ff5d 100644 --- a/presto-common/src/test/java/com/facebook/presto/common/predicate/TestSortedRangeSet.java +++ b/presto-common/src/test/java/com/facebook/presto/common/predicate/TestSortedRangeSet.java @@ -64,7 +64,7 @@ public void testEmptySet() assertTrue(rangeSet.isNone()); assertFalse(rangeSet.isAll()); assertFalse(rangeSet.isSingleValue()); - assertTrue(Iterables.isEmpty(rangeSet.getOrderedRanges())); + assertFalse(rangeSet.getOrderedRanges().stream().findAny().isPresent()); assertEquals(rangeSet.getRangeCount(), 0); assertEquals(rangeSet.complement(), SortedRangeSet.all(BIGINT)); assertFalse(rangeSet.includesMarker(Marker.lowerUnbounded(BIGINT))); diff --git a/presto-common/src/test/java/com/facebook/presto/common/type/TestingTypeManager.java b/presto-common/src/test/java/com/facebook/presto/common/type/TestingTypeManager.java index f2e6f847a48a5..e94a0a9465cf0 100644 --- a/presto-common/src/test/java/com/facebook/presto/common/type/TestingTypeManager.java +++ b/presto-common/src/test/java/com/facebook/presto/common/type/TestingTypeManager.java @@ -59,4 +59,10 @@ public List getTypes() { return ImmutableList.of(BOOLEAN, INTEGER, BIGINT, DOUBLE, VARCHAR, VARBINARY, TIMESTAMP, DATE, ID, HYPER_LOG_LOG); } + + @Override + public boolean hasType(TypeSignature signature) + { + return getType(signature) != null; + } } diff --git a/presto-db-session-property-manager/pom.xml b/presto-db-session-property-manager/pom.xml new file mode 100644 index 0000000000000..44172ab07b34a --- /dev/null +++ b/presto-db-session-property-manager/pom.xml @@ -0,0 +1,180 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.297-edge10.1-SNAPSHOT + + + presto-db-session-property-manager + presto-db-session-property-manager + Presto - DB Session Property Manager + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-session-property-managers-common + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + log + + + + com.facebook.airlift + json + + + + com.facebook.airlift + configuration + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + jakarta.annotation + jakarta.annotation-api + + + + jakarta.inject + jakarta.inject-api + + + + com.facebook.airlift + concurrent + + + + jakarta.validation + jakarta.validation-api + + + + org.jdbi + jdbi3-core + + + + org.jdbi + jdbi3-sqlobject + + + + com.mysql + mysql-connector-j + true + runtime + + + + org.mariadb.jdbc + mariadb-java-client + true + runtime + + + + + com.facebook.presto + presto-spi + provided + + + + com.facebook.presto + presto-common + provided + + + + com.facebook.airlift + units + provided + + + + io.airlift + slice + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + org.openjdk.jol + jol-core + provided + + + + + com.facebook.presto + presto-testng-services + test + + + + org.testng + testng + test + + + + com.facebook.airlift + testing + test + + + + com.facebook.presto + testing-mysql-server-8 + test + + + + com.facebook.presto + testing-mysql-server-base + test + + + + com.facebook.presto + presto-session-property-managers-common + test-jar + test + + + diff --git a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerPlugin.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyConfigurationManagerPlugin.java similarity index 82% rename from presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerPlugin.java rename to presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyConfigurationManagerPlugin.java index bde86e9400f31..bbbc653abd2c9 100644 --- a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerPlugin.java +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyConfigurationManagerPlugin.java @@ -11,18 +11,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.db; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; import com.google.common.collect.ImmutableList; -public class FileSessionPropertyManagerPlugin +public class DbSessionPropertyConfigurationManagerPlugin implements Plugin { @Override public Iterable getSessionPropertyConfigurationManagerFactories() { - return ImmutableList.of(new FileSessionPropertyManagerFactory()); + return ImmutableList.of( + new DbSessionPropertyManagerFactory()); } } diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManager.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManager.java new file mode 100644 index 0000000000000..c0b7a059c88e4 --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManager.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.presto.session.AbstractSessionPropertyManager; +import com.facebook.presto.session.SessionMatchSpec; +import com.facebook.presto.spi.session.SessionConfigurationContext; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManager; +import jakarta.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +/** + * A {@link SessionPropertyConfigurationManager} implementation that connects to a database for fetching information + * about session property overrides given {@link SessionConfigurationContext}. + */ +public class DbSessionPropertyManager + extends AbstractSessionPropertyManager +{ + private final DbSpecsProvider specsProvider; + + @Inject + public DbSessionPropertyManager(DbSpecsProvider specsProvider) + { + this.specsProvider = requireNonNull(specsProvider, "specsProvider is null"); + } + + @Override + protected List getSessionMatchSpecs() + { + return this.specsProvider.get(); + } +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerConfig.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerConfig.java new file mode 100644 index 0000000000000..55759dc278db8 --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerConfig.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.NotNull; + +import static java.util.concurrent.TimeUnit.SECONDS; + +public class DbSessionPropertyManagerConfig +{ + private String configDbUrl; + private String jdbcDriverName = "com.mysql.jdbc.Driver"; + private Duration specsRefreshPeriod = new Duration(10, SECONDS); + + @NotNull + public String getConfigDbUrl() + { + return configDbUrl; + } + + @Config("session-property-manager.db.url") + public DbSessionPropertyManagerConfig setConfigDbUrl(String configDbUrl) + { + this.configDbUrl = configDbUrl; + return this; + } + + @NotNull + public String getJdbcDriverName() + { + return jdbcDriverName; + } + + @Config("session-property-manager.db.driver-name") + public DbSessionPropertyManagerConfig setJdbcDriverName(String jdbcDriverName) + { + this.jdbcDriverName = jdbcDriverName; + return this; + } + + @NotNull + @MinDuration("1ms") + public Duration getSpecsRefreshPeriod() + { + return specsRefreshPeriod; + } + + @Config("session-property-manager.db.refresh-period") + public DbSessionPropertyManagerConfig setSpecsRefreshPeriod(Duration specsRefreshPeriod) + { + this.specsRefreshPeriod = specsRefreshPeriod; + return this; + } +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerFactory.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerFactory.java new file mode 100644 index 0000000000000..7a6eba7821ef0 --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerFactory.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.resourceGroups.SessionPropertyConfigurationManagerContext; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManager; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; +import com.google.inject.Injector; + +import java.util.Map; + +import static com.google.common.base.Throwables.throwIfUnchecked; + +public class DbSessionPropertyManagerFactory + implements SessionPropertyConfigurationManagerFactory +{ + @Override + public String getName() + { + return "db"; + } + + @Override + public SessionPropertyConfigurationManager create(Map config, SessionPropertyConfigurationManagerContext context) + { + try { + Bootstrap app = new Bootstrap(new JsonModule(), new DbSessionPropertyManagerModule()); + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(DbSessionPropertyManager.class); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerModule.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerModule.java new file mode 100644 index 0000000000000..48f480b7603aa --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSessionPropertyManagerModule.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; + +public class DbSessionPropertyManagerModule + implements Module +{ + @Override + public void configure(Binder binder) + { + configBinder(binder).bindConfig(DbSessionPropertyManagerConfig.class); + binder.bind(DbSessionPropertyManager.class).in(Scopes.SINGLETON); + binder.bind(SessionPropertiesDao.class).toProvider(SessionPropertiesDaoProvider.class).in(Scopes.SINGLETON); + binder.bind(DbSpecsProvider.class).to(RefreshingDbSpecsProvider.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSpecsProvider.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSpecsProvider.java new file mode 100644 index 0000000000000..278b52daadb67 --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/DbSpecsProvider.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.presto.session.SessionMatchSpec; + +import java.util.List; +import java.util.function.Supplier; + +/** + * This interface was created to separate the scheduling logic for {@link SessionMatchSpec} loading. This also helps + * us test the core logic of {@link DbSessionPropertyManager} in a modular fashion by letting us use a test + * implementation of this interface. + */ +public interface DbSpecsProvider + extends Supplier> +{ +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/RefreshingDbSpecsProvider.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/RefreshingDbSpecsProvider.java new file mode 100644 index 0000000000000..7c44a813cc13e --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/RefreshingDbSpecsProvider.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.session.SessionMatchSpec; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; + +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; + +/** + * Periodically schedules the loading of specs from the database during initialization. Returns the most recent successfully + * loaded specs on every get() invocation. + */ +public class RefreshingDbSpecsProvider + implements DbSpecsProvider +{ + private static final Logger log = Logger.get(RefreshingDbSpecsProvider.class); + + private final AtomicReference> sessionMatchSpecs = new AtomicReference<>(ImmutableList.of()); + private final SessionPropertiesDao dao; + + private final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("RefreshingDbSpecsProvider")); + private final AtomicBoolean started = new AtomicBoolean(); + private final long refreshPeriodMillis; + + @Inject + public RefreshingDbSpecsProvider(DbSessionPropertyManagerConfig config, SessionPropertiesDao dao) + { + requireNonNull(config, "config is null"); + this.dao = requireNonNull(dao, "dao is null"); + this.refreshPeriodMillis = config.getSpecsRefreshPeriod().toMillis(); + dao.createSessionSpecsTable(); + dao.createSessionClientTagsTable(); + dao.createSessionPropertiesTable(); + } + + @PostConstruct + public void initialize() + { + if (!started.getAndSet(true)) { + executor.scheduleWithFixedDelay(this::refresh, 0, refreshPeriodMillis, TimeUnit.MILLISECONDS); + } + } + + @VisibleForTesting + void refresh() + { + try { + sessionMatchSpecs.set(ImmutableList.copyOf(dao.getSessionMatchSpecs())); + } + catch (Throwable e) { + // Catch all exceptions here since throwing an exception from executor#scheduleWithFixedDelay method + // suppresses all future scheduled invocations + log.error(e, "Error loading configuration from database"); + } + } + + @PreDestroy + public void destroy() + { + executor.shutdownNow(); + } + + @Override + public List get() + { + return sessionMatchSpecs.get(); + } +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDao.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDao.java new file mode 100644 index 0000000000000..71a93f0838c4a --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDao.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.presto.session.SessionMatchSpec; +import com.google.common.annotations.VisibleForTesting; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; +import org.jdbi.v3.sqlobject.statement.UseRowMapper; + +import java.util.List; + +/** + * Dao should guarantee that the list of SessionMatchSpecs is returned in increasing order of priority. i.e. if two + * rows in the ResultSet specify different values for the same property, the row coming in later will override the + * value set by the row coming in earlier. + */ +public interface SessionPropertiesDao +{ + String SESSION_SPECS_TABLE = "session_specs"; + String CLIENT_TAGS_TABLE = "session_client_tags"; + String PROPERTIES_TABLE = "session_property_values"; + public static String EMPTY_CATALOG = "__NULL__"; + + @SqlUpdate("CREATE TABLE IF NOT EXISTS " + SESSION_SPECS_TABLE + "(\n" + + "spec_id BIGINT NOT NULL AUTO_INCREMENT,\n" + + "user_regex VARCHAR(512),\n" + + "source_regex VARCHAR(512),\n" + + "query_type VARCHAR(512),\n" + + "group_regex VARCHAR(512),\n" + + "client_info_regex VARCHAR(512),\n" + + "override_session_properties TINYINT(1),\n" + + "priority INT NOT NULL,\n" + + "PRIMARY KEY (spec_id)\n" + + ")") + void createSessionSpecsTable(); + + @SqlUpdate("CREATE TABLE IF NOT EXISTS " + CLIENT_TAGS_TABLE + "(\n" + + "tag_spec_id BIGINT NOT NULL,\n" + + "client_tag VARCHAR(512) NOT NULL,\n" + + "PRIMARY KEY (tag_spec_id, client_tag),\n" + + "FOREIGN KEY (tag_spec_id) REFERENCES session_specs (spec_id)\n" + + ")") + void createSessionClientTagsTable(); + + @SqlUpdate("CREATE TABLE IF NOT EXISTS " + PROPERTIES_TABLE + "(\n" + + "property_spec_id BIGINT NOT NULL,\n" + + "session_property_name VARCHAR(512),\n" + + "session_property_value VARCHAR(512),\n" + + "catalog VARCHAR(512),\n" + + "PRIMARY KEY (property_spec_id, session_property_name),\n" + + "FOREIGN KEY (property_spec_id) REFERENCES session_specs (spec_id)\n" + + ")") + void createSessionPropertiesTable(); + + @SqlUpdate("DROP TABLE IF EXISTS " + SESSION_SPECS_TABLE) + void dropSessionSpecsTable(); + + @SqlUpdate("DROP TABLE IF EXISTS " + CLIENT_TAGS_TABLE) + void dropSessionClientTagsTable(); + + @SqlUpdate("DROP TABLE IF EXISTS " + PROPERTIES_TABLE) + void dropSessionPropertiesTable(); + + @SqlQuery("SELECT " + + "S.spec_id,\n" + + "S.user_regex,\n" + + "S.source_regex,\n" + + "S.query_type,\n" + + "S.group_regex,\n" + + "S.client_info_regex,\n" + + "S.override_session_properties,\n" + + "S.client_tags,\n" + + "GROUP_CONCAT(P.session_property_name ORDER BY P.session_property_name) session_property_names,\n" + + "GROUP_CONCAT(P.session_property_value ORDER BY P.session_property_name) session_property_values,\n" + + "GROUP_CONCAT(COALESCE(P.catalog, '" + EMPTY_CATALOG + "') ORDER BY P.session_property_name) session_property_catalogs\n" + + "FROM\n" + + "(SELECT\n" + + "A.spec_id, A.user_regex, A.source_regex, A.query_type, A.group_regex, A.client_info_regex, A.override_session_properties, A.priority,\n" + + "GROUP_CONCAT(DISTINCT B.client_tag) client_tags\n" + + "FROM " + SESSION_SPECS_TABLE + " A\n" + + "LEFT JOIN " + CLIENT_TAGS_TABLE + " B\n" + + "ON A.spec_id = B.tag_spec_id\n" + + "GROUP BY A.spec_id, A.user_regex, A.source_regex, A.query_type, A.group_regex, A.client_info_regex, A.override_session_properties, A.priority)\n" + + " S JOIN\n" + + PROPERTIES_TABLE + " P\n" + + "ON S.spec_id = P.property_spec_id\n" + + "GROUP BY S.spec_id, S.user_regex, S.source_regex, S.query_type, S.group_regex, S.client_info_regex, S.override_session_properties, S.priority, S.client_tags\n" + + "ORDER BY S.priority asc") + @UseRowMapper(SessionMatchSpec.Mapper.class) + List getSessionMatchSpecs(); + + @VisibleForTesting + @SqlUpdate("INSERT INTO " + SESSION_SPECS_TABLE + " (spec_id, user_regex, source_regex, query_type, group_regex, client_info_regex, override_session_properties, priority)\n" + + "VALUES (:spec_id, :user_regex, :source_regex, :query_type, :group_regex, :client_info_regex, :override_session_properties, :priority)") + void insertSpecRow( + @Bind("spec_id") long specId, + @Bind("user_regex") String userRegex, + @Bind("source_regex") String sourceRegex, + @Bind("query_type") String queryType, + @Bind("group_regex") String groupRegex, + @Bind("client_info_regex") String clientInfoRegex, + @Bind("override_session_properties") Integer overrideSessionProperties, + @Bind("priority") int priority); + + @VisibleForTesting + @SqlUpdate("INSERT INTO " + CLIENT_TAGS_TABLE + " (tag_spec_id, client_tag) VALUES (:spec_id, :client_tag)") + void insertClientTag(@Bind("spec_id") long specId, @Bind("client_tag") String clientTag); + + @VisibleForTesting + @SqlUpdate("INSERT INTO " + PROPERTIES_TABLE + " (property_spec_id, session_property_name, session_property_value, catalog)\n" + + "VALUES (:property_spec_id, :session_property_name, :session_property_value, :catalog)") + void insertSessionProperty( + @Bind("property_spec_id") long propertySpecId, + @Bind("session_property_name") String sessionPropertyName, + @Bind("session_property_value") String sessionPropertyValue, + @Bind("catalog") String catalog); +} diff --git a/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDaoProvider.java b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDaoProvider.java new file mode 100644 index 0000000000000..a567cf5338014 --- /dev/null +++ b/presto-db-session-property-manager/src/main/java/com/facebook/presto/session/db/SessionPropertiesDaoProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import jakarta.inject.Inject; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.sqlobject.SqlObjectPlugin; + +import javax.inject.Provider; + +import java.sql.DriverManager; + +import static java.util.Objects.requireNonNull; + +public class SessionPropertiesDaoProvider + implements Provider +{ + private final SessionPropertiesDao dao; + + @Inject + public SessionPropertiesDaoProvider(DbSessionPropertyManagerConfig config) + { + requireNonNull(config, "config is null"); + requireNonNull(config.getConfigDbUrl(), "db url is null"); + + try { + Class.forName(config.getJdbcDriverName()); + } + catch (ClassNotFoundException e) { + throw new RuntimeException("JDBC driver class not found: " + config.getJdbcDriverName(), e); + } + + this.dao = Jdbi.create(() -> DriverManager.getConnection(config.getConfigDbUrl())) + .installPlugin(new SqlObjectPlugin()) + .onDemand(SessionPropertiesDao.class); + } + + @Override + public SessionPropertiesDao get() + { + return dao; + } +} diff --git a/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManager.java b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManager.java new file mode 100644 index 0000000000000..865e8f840edf4 --- /dev/null +++ b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManager.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.presto.session.AbstractTestSessionPropertyManager; +import com.facebook.presto.session.SessionMatchSpec; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManager; +import com.facebook.presto.testing.mysql.MySqlOptions; +import com.facebook.presto.testing.mysql.TestingMySqlServer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; + +import java.util.Map; +import java.util.regex.Pattern; + +import static org.testng.Assert.assertEquals; + +public abstract class TestDbSessionPropertyManager + extends AbstractTestSessionPropertyManager +{ + private static final MySqlOptions MY_SQL_OPTIONS = MySqlOptions.builder() + .build(); + + private final String driver; + + private final TestingMySqlServer mysqlServer; + + public TestDbSessionPropertyManager(String driver) + throws Exception + { + this.driver = driver; + this.mysqlServer = new TestingMySqlServer("testuser", "testpass", ImmutableList.of(), MY_SQL_OPTIONS); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + throws Exception + { + mysqlServer.close(); + } + + @Override + protected void assertProperties(Map defaultProperties, SessionMatchSpec... spec) + { + assertProperties(defaultProperties, ImmutableMap.of(), ImmutableMap.of(), spec); + } + + @Override + protected void assertProperties(Map defaultProperties, Map overrideProperties, SessionMatchSpec... specs) + { + assertProperties(defaultProperties, overrideProperties, ImmutableMap.of(), specs); + } + + @Override + protected void assertProperties(Map defaultProperties, Map overrideProperties, Map> catalogProperties, SessionMatchSpec... specs) + { + DbSessionPropertyManagerConfig config = new DbSessionPropertyManagerConfig() + .setConfigDbUrl(overrideJdbcUrl(mysqlServer.getJdbcUrl("session") + "&createDatabaseIfNotExist=true")); + + SessionPropertiesDaoProvider sessionPropertiesDaoProvider = new SessionPropertiesDaoProvider(config); + SessionPropertiesDao dao = sessionPropertiesDaoProvider.get(); + RefreshingDbSpecsProvider dbSpecsProvider = new RefreshingDbSpecsProvider(config, sessionPropertiesDaoProvider.get()); + SessionPropertyConfigurationManager manager = new DbSessionPropertyManager(dbSpecsProvider); + int id = 1; + try { + for (SessionMatchSpec spec : specs) { + int finalId = id; + dao.insertSpecRow( + finalId, + spec.getUserRegex().map(Pattern::pattern).orElse(null), + spec.getSourceRegex().map(Pattern::pattern).orElse(null), + spec.getQueryType().orElse(null), + spec.getResourceGroupRegex().map(Pattern::pattern).orElse(null), + spec.getClientInfoRegex().map(Pattern::pattern).orElse(null), + spec.getOverrideSessionProperties().map(val -> val ? 1 : 0).orElse(null), + finalId); + spec.getClientTags().forEach(tag -> dao.insertClientTag(finalId, tag)); + spec.getSessionProperties().forEach((key, value) -> dao.insertSessionProperty(finalId, key, value, null)); + spec.getCatalogSessionProperties().forEach((catalog, property) -> property.forEach((key, value) -> dao.insertSessionProperty(finalId, key, value, catalog))); + id++; + } + dbSpecsProvider.refresh(); + SessionPropertyConfigurationManager.SystemSessionPropertyConfiguration propertyConfiguration = manager.getSystemSessionProperties(CONTEXT); + assertEquals(propertyConfiguration.systemPropertyDefaults, defaultProperties); + assertEquals(propertyConfiguration.systemPropertyOverrides, overrideProperties); + assertEquals(manager.getCatalogSessionProperties(CONTEXT), catalogProperties); + } + finally { + dao.dropSessionPropertiesTable(); + dao.dropSessionClientTagsTable(); + dao.dropSessionSpecsTable(); + dbSpecsProvider.destroy(); + } + } + + protected String overrideJdbcUrl(String url) + { + return url; + } +} diff --git a/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerConfig.java b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerConfig.java new file mode 100644 index 0000000000000..f22548acf1fc7 --- /dev/null +++ b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerConfig.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +import com.facebook.airlift.units.Duration; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestDbSessionPropertyManagerConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(DbSessionPropertyManagerConfig.class) + .setConfigDbUrl(null) + .setJdbcDriverName("com.mysql.jdbc.Driver") + .setSpecsRefreshPeriod(new Duration(10, SECONDS))); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("session-property-manager.db.url", "foo") + .put("session-property-manager.db.driver-name", "org.mariadb.jdbc.Driver") + .put("session-property-manager.db.refresh-period", "50s") + .build(); + + DbSessionPropertyManagerConfig expected = new DbSessionPropertyManagerConfig() + .setConfigDbUrl("foo") + .setJdbcDriverName("org.mariadb.jdbc.Driver") + .setSpecsRefreshPeriod(new Duration(50, TimeUnit.SECONDS)); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMariadb.java b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMariadb.java new file mode 100644 index 0000000000000..3e440b95d3070 --- /dev/null +++ b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMariadb.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +public class TestDbSessionPropertyManagerMariadb + extends TestDbSessionPropertyManager +{ + public TestDbSessionPropertyManagerMariadb() throws Exception + { + super("mariadb"); + } + + @Override + public String overrideJdbcUrl(String url) + { + return url.replaceFirst("jdbc:mysql:", "jdbc:mariadb:"); + } +} diff --git a/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMysql.java b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMysql.java new file mode 100644 index 0000000000000..806f5c8da9cac --- /dev/null +++ b/presto-db-session-property-manager/src/test/java/com/facebook/presto/session/db/TestDbSessionPropertyManagerMysql.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.db; + +public class TestDbSessionPropertyManagerMysql + extends TestDbSessionPropertyManager +{ + public TestDbSessionPropertyManagerMysql() throws Exception + { + super("mysql"); + } +} diff --git a/presto-delta/pom.xml b/presto-delta/pom.xml index 68f5e3c2dd5d9..470ba00fc823f 100644 --- a/presto-delta/pom.xml +++ b/presto-delta/pom.xml @@ -4,19 +4,32 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-delta + presto-delta Presto - Delta Table Connector presto-plugin ${project.parent.basedir} true - 3.2.0 + 3.3.2 + 17 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + io.delta @@ -76,7 +89,6 @@ org.scala-lang scala-library - 2.12.11 @@ -115,13 +127,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -142,7 +154,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -154,7 +166,7 @@ - io.airlift + com.facebook.airlift units provided @@ -197,7 +209,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -309,6 +321,12 @@ test + + com.facebook.airlift + log-manager + runtime + + com.facebook.airlift node @@ -326,10 +344,15 @@ - javax.servlet - javax.servlet-api + jakarta.servlet + jakarta.servlet-api test + + + com.facebook.airlift.drift + drift-codec + @@ -341,6 +364,7 @@ org.scala-lang:scala-library:jar commons-io:commons-io:jar + com.facebook.airlift.drift:drift-codec:jar diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaClient.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaClient.java index bf8ff349c0d3e..919b59b953fd0 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaClient.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaClient.java @@ -30,11 +30,10 @@ import io.delta.kernel.internal.InternalScanFileUtils; import io.delta.kernel.internal.SnapshotImpl; import io.delta.kernel.utils.CloseableIterator; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.io.IOException; import java.io.UncheckedIOException; import java.time.Instant; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConfig.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConfig.java index 2a87649c69ad5..74f91269ba77a 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConfig.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.delta; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class DeltaConfig { diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConnector.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConnector.java index 111bf2087ba7f..48c6d1e443c9c 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConnector.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaConnector.java @@ -28,8 +28,7 @@ import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorSplitManager; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.List; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaMetadata.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaMetadata.java index d76eed8e6cc40..8cdfde27b4f2f 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaMetadata.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaMetadata.java @@ -43,10 +43,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -232,7 +232,7 @@ public DeltaTableHandle getTableHandle(ConnectorSession session, SchemaTableName } @Override - public List getTableLayouts( + public ConnectorTableLayoutResult getTableLayoutForConstraint( ConnectorSession session, ConnectorTableHandle table, Constraint constraint, @@ -260,7 +260,7 @@ public List getTableLayouts( ImmutableList.of(), Optional.empty()); - return ImmutableList.of(new ConnectorTableLayoutResult(newLayout, unenforcedPredicate)); + return new ConnectorTableLayoutResult(newLayout, unenforcedPredicate); } @Override @@ -332,11 +332,19 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema return null; } - List columnMetadata = tableHandle.getDeltaTable().getColumns().stream() - .map(this::getColumnMetadata) + DeltaTable deltaTable = tableHandle.getDeltaTable(); + + // External location property + Map properties = new HashMap<>(1); + if (deltaTable.getTableLocation() != null) { + properties.put(DeltaTableProperties.EXTERNAL_LOCATION_PROPERTY, deltaTable.getTableLocation()); + } + + List columnMetadata = deltaTable.getColumns().stream() + .map(column -> getColumnMetadata(session, column)) .collect(Collectors.toList()); - return new ConnectorTableMetadata(tableName, columnMetadata); + return new ConnectorTableMetadata(tableName, columnMetadata, properties); } @Override @@ -362,10 +370,10 @@ private List listTables(ConnectorSession session, SchemaTablePr return ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } - private ColumnMetadata getColumnMetadata(DeltaColumn deltaColumn) + private ColumnMetadata getColumnMetadata(ConnectorSession session, DeltaColumn deltaColumn) { return ColumnMetadata.builder() - .setName(deltaColumn.getName()) + .setName(normalizeIdentifier(session, deltaColumn.getName())) .setType(typeManager.getType(deltaColumn.getType())) .build(); } diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaModule.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaModule.java index 8898de615c37d..cbbe73ef3ec92 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaModule.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaModule.java @@ -42,6 +42,7 @@ import com.facebook.presto.hive.metastore.HivePartitionMutator; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; import com.facebook.presto.hive.metastore.InvalidateMetastoreCacheProcedure; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.MetastoreCacheStats; import com.facebook.presto.hive.metastore.MetastoreConfig; import com.facebook.presto.hive.metastore.thrift.ThriftHiveMetastoreConfig; @@ -52,9 +53,8 @@ import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.multibindings.Multibinder; - -import javax.inject.Inject; -import javax.inject.Singleton; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; import java.util.concurrent.ExecutorService; @@ -106,6 +106,7 @@ protected void setup(Binder binder) configBinder(binder).bindConfig(HiveClientConfig.class); configBinder(binder).bindConfig(MetastoreClientConfig.class); configBinder(binder).bindConfig(ThriftHiveMetastoreConfig.class); + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(MetastoreCacheStats.class).to(HiveMetastoreCacheStats.class).in(Scopes.SINGLETON); newExporter(binder).export(MetastoreCacheStats.class).as(generatedNameOf(MetastoreCacheStats.class, connectorId)); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaPageSourceProvider.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaPageSourceProvider.java index 9420438b6fd96..a47b6de549a88 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaPageSourceProvider.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaPageSourceProvider.java @@ -48,6 +48,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileStatus; @@ -64,8 +65,6 @@ import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; -import javax.inject.Inject; - import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSessionProperties.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSessionProperties.java index f7eef23f202e0..73aa2fd9f09ce 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSessionProperties.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSessionProperties.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSplitManager.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSplitManager.java index 86e20b35a7a4a..150e1b1d03457 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSplitManager.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaSplitManager.java @@ -26,8 +26,7 @@ import io.delta.kernel.internal.InternalScanFileUtils; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.FileStatus; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTableProperties.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTableProperties.java index 77029fd359a9b..fad3a62c5e01f 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTableProperties.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTableProperties.java @@ -18,8 +18,7 @@ import com.facebook.presto.hive.HiveStorageFormat; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTypeUtils.java b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTypeUtils.java index b8d63cbfcf57f..be08af998ede3 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTypeUtils.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/DeltaTypeUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.delta; +import com.facebook.presto.common.type.DateTimeEncoding; import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.NamedTypeSignature; import com.facebook.presto.common.type.RowFieldName; @@ -41,13 +42,13 @@ import io.delta.kernel.types.ShortType; import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructType; +import io.delta.kernel.types.TimestampNTZType; import io.delta.kernel.types.TimestampType; import java.math.BigDecimal; import java.math.BigInteger; import java.sql.Timestamp; import java.time.LocalDate; -import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.Locale; @@ -65,7 +66,9 @@ import static com.facebook.presto.common.type.SmallintType.SMALLINT; import static com.facebook.presto.common.type.StandardTypes.ARRAY; import static com.facebook.presto.common.type.StandardTypes.MAP; +import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.TinyintType.TINYINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; @@ -78,6 +81,7 @@ import static java.lang.Float.parseFloat; import static java.lang.Long.parseLong; import static java.lang.String.format; +import static java.time.ZoneOffset.UTC; /** * Contains utility methods to convert Delta data types (and data values) to Presto data types (and data values) @@ -181,7 +185,10 @@ public static Object convertPartitionValue( } if (type.equals(TIMESTAMP)) { // Delta partition serialized value contains up to the second precision - return Timestamp.valueOf(valueString).toLocalDateTime().toEpochSecond(ZoneOffset.UTC) * 1_000; + return Timestamp.valueOf(valueString).toLocalDateTime().toEpochSecond(UTC) * 1_000; + } + if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return DateTimeEncoding.packDateTimeWithZone((Timestamp.valueOf(valueString).toLocalDateTime().toEpochSecond(UTC) * 1_000), UTC_KEY); } throw new PrestoException(DELTA_UNSUPPORTED_COLUMN_TYPE, format("Unsupported data type '%s' for partition column %s", type, columnName)); @@ -231,6 +238,9 @@ else if (deltaType instanceof StringType) { return createUnboundedVarcharType(); } else if (deltaType instanceof TimestampType) { + return TIMESTAMP_WITH_TIME_ZONE; + } + else if (deltaType instanceof TimestampNTZType) { return TIMESTAMP; } diff --git a/presto-delta/src/main/java/com/facebook/presto/delta/rule/DeltaPlanOptimizerProvider.java b/presto-delta/src/main/java/com/facebook/presto/delta/rule/DeltaPlanOptimizerProvider.java index 5367945535b1e..0f1a00aca9143 100644 --- a/presto-delta/src/main/java/com/facebook/presto/delta/rule/DeltaPlanOptimizerProvider.java +++ b/presto-delta/src/main/java/com/facebook/presto/delta/rule/DeltaPlanOptimizerProvider.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.relation.RowExpressionService; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Set; diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/AbstractDeltaDistributedQueryTestBase.java b/presto-delta/src/test/java/com/facebook/presto/delta/AbstractDeltaDistributedQueryTestBase.java index 367f7a09db711..11e8126af8632 100644 --- a/presto-delta/src/test/java/com/facebook/presto/delta/AbstractDeltaDistributedQueryTestBase.java +++ b/presto-delta/src/test/java/com/facebook/presto/delta/AbstractDeltaDistributedQueryTestBase.java @@ -13,25 +13,16 @@ */ package com.facebook.presto.delta; -import com.facebook.presto.Session; -import com.facebook.presto.hive.HivePlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; -import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tpch.TpchPlugin; import com.google.common.collect.ImmutableMap; import org.testng.ITest; import org.testng.annotations.AfterClass; import org.testng.annotations.DataProvider; import java.nio.file.FileSystems; -import java.nio.file.Path; -import java.util.Map; -import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; -import static java.util.Locale.US; public abstract class AbstractDeltaDistributedQueryTestBase extends AbstractTestQueryFramework implements ITest @@ -61,13 +52,14 @@ public abstract class AbstractDeltaDistributedQueryTestBase "test-lowercase", "test-partitions-lowercase", "test-uppercase", - "test-partitions-uppercase" + "test-partitions-uppercase", + "test-typing" }; /** * List of tables present in the test resources directory. Each table is replicated in reader version 1 and 3 */ - private static final String[] DELTA_TEST_TABLE_LIST = + public static final String[] DELTA_TEST_TABLE_LIST = new String[DELTA_VERSIONS.length * DELTA_TEST_TABLE_NAMES_LIST.length]; static { for (int i = 0; i < DELTA_VERSIONS.length; i++) { @@ -101,9 +93,9 @@ protected static String getVersionPrefix(String version) protected QueryRunner createQueryRunner() throws Exception { - QueryRunner queryRunner = createDeltaQueryRunner(ImmutableMap.of( + QueryRunner queryRunner = DeltaQueryRunner.builder().setExtraProperties(ImmutableMap.of( "experimental.pushdown-subfields-enabled", "true", - "experimental.pushdown-dereference-enabled", "true")); + "experimental.pushdown-dereference-enabled", "true")).build().getQueryRunner(); // Create the test Delta tables in HMS for (String deltaTestTable : DELTA_TEST_TABLE_LIST) { @@ -135,51 +127,6 @@ protected static String goldenTablePathWithPrefix(String prefix, String tableNam return goldenTablePath(prefix + FileSystems.getDefault().getSeparator() + tableName); } - private static DistributedQueryRunner createDeltaQueryRunner(Map extraProperties) - throws Exception - { - Session session = testSessionBuilder() - .setCatalog(DELTA_CATALOG) - .setSchema(DELTA_SCHEMA.toLowerCase(US)) - .setTimeZoneKey(UTC_KEY) - .build(); - - DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) - .setExtraProperties(extraProperties) - .build(); - - // Install the TPCH plugin for test data (not in Delta format) - queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); - - Path dataDirectory = queryRunner.getCoordinator().getDataDirectory().resolve("delta_metadata"); - Path catalogDirectory = dataDirectory.getParent().resolve("catalog"); - - // Install a Delta connector catalog - queryRunner.installPlugin(new DeltaPlugin()); - Map deltaProperties = ImmutableMap.builder() - .put("hive.metastore", "file") - .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) - .put("delta.case-sensitive-partitions-enabled", "false") - .build(); - queryRunner.createCatalog(DELTA_CATALOG, "delta", deltaProperties); - - // Install a Hive connector catalog that uses the same metastore as Delta - // This catalog will be used to create tables in metastore as the Delta connector doesn't - // support creating tables yet. - queryRunner.installPlugin(new HivePlugin("hive")); - Map hiveProperties = ImmutableMap.builder() - .put("hive.metastore", "file") - .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) - .put("hive.allow-drop-table", "true") - .put("hive.security", "legacy") - .build(); - queryRunner.createCatalog(HIVE_CATALOG, "hive", hiveProperties); - queryRunner.execute(format("CREATE SCHEMA %s.%s", HIVE_CATALOG, DELTA_SCHEMA)); - - return queryRunner; - } - /** * Register the given deltaTableName as hiveTableName in HMS using the Delta catalog. * Hive and Delta catalogs share the same HMS in this test. diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/DeltaQueryRunner.java b/presto-delta/src/test/java/com/facebook/presto/delta/DeltaQueryRunner.java new file mode 100644 index 0000000000000..7411d9c9b1b73 --- /dev/null +++ b/presto-delta/src/test/java/com/facebook/presto/delta/DeltaQueryRunner.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.delta; + +import com.facebook.airlift.log.Logging; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.hive.HivePlugin; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableMap; + +import java.net.URI; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.BiFunction; + +import static com.facebook.airlift.log.Level.ERROR; +import static com.facebook.airlift.log.Level.WARN; +import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; +import static java.util.Locale.US; +import static java.util.Objects.requireNonNull; + +public class DeltaQueryRunner +{ + public static final String DELTA_CATALOG = "delta"; + public static final String HIVE_CATALOG = "hive"; + public static final String DELTA_SCHEMA = "deltaTables"; // Schema in Hive which has test Delta tables + + private DistributedQueryRunner queryRunner; + + private DeltaQueryRunner(DistributedQueryRunner queryRunner) + { + this.queryRunner = requireNonNull(queryRunner, "queryRunner is null"); + } + + public DistributedQueryRunner getQueryRunner() + { + return queryRunner; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private Builder() {} + + private Map extraProperties = new HashMap<>(); + // If externalWorkerLauncher is not provided, Java workers are used by default. + private Optional> externalWorkerLauncher = Optional.empty(); + private TimeZoneKey timeZoneKey = UTC_KEY; + private boolean caseSensitivePartitions; + private OptionalInt nodeCount = OptionalInt.of(4); + + public Builder setExternalWorkerLauncher(Optional> externalWorkerLauncher) + { + this.externalWorkerLauncher = requireNonNull(externalWorkerLauncher); + return this; + } + + public Builder setExtraProperties(Map extraProperties) + { + this.extraProperties = ImmutableMap.copyOf(extraProperties); + return this; + } + + public Builder setTimeZoneKey(TimeZoneKey timeZoneKey) + { + this.timeZoneKey = timeZoneKey; + return this; + } + + public Builder caseSensitivePartitions() + { + caseSensitivePartitions = true; + return this; + } + + public Builder setNodeCount(OptionalInt nodeCount) + { + this.nodeCount = nodeCount; + return this; + } + + public DeltaQueryRunner build() + throws Exception + { + setupLogging(); + Session session = testSessionBuilder() + .setCatalog(DELTA_CATALOG) + .setSchema(DELTA_SCHEMA.toLowerCase(US)) + .setTimeZoneKey(timeZoneKey) + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) + .setExtraProperties(extraProperties) + .setNodeCount(nodeCount.orElse(4)) + .setExternalWorkerLauncher(externalWorkerLauncher) + .build(); + + // Install the TPCH plugin for test data (not in Delta format) + queryRunner.installPlugin(new TpchPlugin()); + queryRunner.createCatalog("tpch", "tpch"); + + Path dataDirectory = queryRunner.getCoordinator().getDataDirectory().resolve("delta_metadata"); + Path catalogDirectory = dataDirectory.getParent().resolve("catalog"); + + // Install a Delta connector catalog + queryRunner.installPlugin(new DeltaPlugin()); + Map deltaProperties = new HashMap<>(); + deltaProperties.put("hive.metastore", "file"); + deltaProperties.put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()); + deltaProperties.put("delta.case-sensitive-partitions-enabled", Boolean.toString(caseSensitivePartitions)); + queryRunner.createCatalog(DELTA_CATALOG, "delta", deltaProperties); + + // Install a Hive connector catalog that uses the same metastore as Delta + // This catalog will be used to create tables in metastore as the Delta connector doesn't + // support creating tables yet. + queryRunner.installPlugin(new HivePlugin("hive")); + Map hiveProperties = ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) + .put("hive.allow-drop-table", "true") + .put("hive.security", "legacy") + .build(); + queryRunner.createCatalog(HIVE_CATALOG, "hive", hiveProperties); + queryRunner.execute(format("CREATE SCHEMA %s.%s", HIVE_CATALOG, DELTA_SCHEMA)); + + return new DeltaQueryRunner(queryRunner); + } + } + + private static void setupLogging() + { + Logging logging = Logging.initialize(); + logging.setLevel("com.facebook.presto.event", WARN); + logging.setLevel("com.facebook.presto.security.AccessControlManager", WARN); + logging.setLevel("com.facebook.presto.server.PluginManager", WARN); + logging.setLevel("com.facebook.airlift.bootstrap.LifeCycleManager", WARN); + logging.setLevel("org.apache.parquet.hadoop", WARN); + logging.setLevel("org.eclipse.jetty.server.handler.ContextHandler", WARN); + logging.setLevel("org.eclipse.jetty.server.AbstractConnector", WARN); + logging.setLevel("org.glassfish.jersey.internal.inject.Providers", ERROR); + logging.setLevel("parquet.hadoop", WARN); + logging.setLevel("org.apache.iceberg", WARN); + logging.setLevel("com.facebook.airlift.bootstrap", WARN); + logging.setLevel("Bootstrap", WARN); + logging.setLevel("org.apache.hadoop.io.compress", WARN); + } +} diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaIntegration.java b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaIntegration.java index 156b6fd06ee61..4c016e734126d 100644 --- a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaIntegration.java +++ b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaIntegration.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.delta; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.MaterializedResult; import com.google.common.base.Joiner; import org.testng.annotations.Test; @@ -25,6 +28,7 @@ import java.util.concurrent.TimeUnit; import static java.lang.String.format; +import static org.testng.Assert.assertEquals; /** * Integration tests for reading Delta tables. @@ -214,7 +218,7 @@ public void readPartitionedTableAllDataTypes(String version) " cast(0.0 as double), " + " '0', " + " DATE '2021-09-08', " + - " TIMESTAMP '2021-09-08 11:11:11', " + + " TIMESTAMP WITH TIME ZONE '2021-09-08 11:11:11 UTC', " + " cast(0 as decimal)," + " '0'" + // regular column "), " + @@ -227,7 +231,7 @@ public void readPartitionedTableAllDataTypes(String version) " cast(1.0 as double), " + " '1', " + " DATE '2021-09-08', " + - " TIMESTAMP '2021-09-08 11:11:11', " + + " TIMESTAMP WITH TIME ZONE '2021-09-08 11:11:11 UTC', " + " cast(1 as decimal), " + " '1'" + // regular column "), " + @@ -247,6 +251,62 @@ public void readPartitionedTableAllDataTypes(String version) assertQuery(testQuery, expResultsQuery); } + @Test(dataProvider = "deltaReaderVersions") + public void testDeltaTimezoneTypeSupportINT96(String version) + { + /* + https://docs.delta.io/3.2.1/api/java/kernel/index.html?io/delta/kernel/types/TimestampNTZType.html + According to delta's type specifications, the expected behaviour for TimestampNTZ + The timestamp without time zone type represents a local time in microsecond precision, which is independent of time zone. + So TimestampNTZ is independent of local timezones and should return the same value regardless of the timezone. + If legacy_timestamp is true, Presto TimestampNTZ (Timestamp) is adjusted to the timezone. + If legacy_timestamp is false, TimestampNTZ is not adjusted. + This test sets the timezone to UTC+3, and the original data file the timestamp is 12 AM. + The proper delta implementation would return 12 AM regardless of the timezone, but with + legacy_timestamp true we get 3 AM. legacy_timestamp set to false matches the specifications. + */ + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("UTC+3")) + .setSystemProperty("legacy_timestamp", "false") + .build(); + String testQuery = format("SELECT tpep_dropoff_datetime, tpep_pickup_datetime FROM \"%s\".\"%s\"", + PATH_SCHEMA, goldenTablePathWithPrefix(version, "test-typing")); + + MaterializedResult actual = computeActual(session, testQuery); + + String timestamptzField = actual.getMaterializedRows().get(0).getField(0).toString(); + + assertEquals(timestamptzField, "2021-12-31T16:53:29Z[UTC]", "Delta Timestamp type not being read correctly."); + if (version.equals("delta_v3")) { + String timestamptzntz = actual.getMaterializedRows().get(0).getField(1).toString(); + assertEquals(timestamptzntz, "2022-01-01T00:35:40", "Delta TimestampNTZ type not being read correctly."); + } + } + + @Test(dataProvider = "deltaReaderVersions") + public void testDeltaTimezoneTypeSupportINT64(String version) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("UTC+3")) + .setSystemProperty("legacy_timestamp", "false") + .build(); + String testQuery = format("SELECT created_at_tz FROM \"%s\".\"%s\"", + PATH_SCHEMA, goldenTablePathWithPrefix(version, "timestamp_64")); + + MaterializedResult actual = computeActual(session, testQuery); + + String timestamptzField = actual.getMaterializedRows().get(0).getField(0).toString(); + + assertEquals(timestamptzField, "2025-05-22T09:24:11.321Z[UTC]", "Delta Timestamp type not being read correctly."); + if (version.equals("delta_v3")) { + String ntzTestQuery = format("SELECT created_at_ntz, created_at_ntz FROM \"%s\".\"%s\"", + PATH_SCHEMA, goldenTablePathWithPrefix(version, "timestamp_64")); + + actual = computeActual(session, ntzTestQuery); + String timestamptzntz = actual.getMaterializedRows().get(0).getField(0).toString(); + assertEquals(timestamptzntz, "2025-05-22T12:25:16.544", "Delta TimestampNTZ type not being read correctly."); + } + } /** * Expected results for table "data-reader-primitives" */ @@ -278,4 +338,33 @@ private static void setCommitFileModificationTime(String tableLocation, long com Paths.get(URI.create(tableLocation)).resolve("_delta_log/").resolve(format("%020d.json", commitId)), FileTime.from(commitTimeMillis, TimeUnit.MILLISECONDS)); } + + @Test(dataProvider = "deltaReaderVersions") + public void testShowCreateTable(String deltaVersion) + { + String tableName = deltaVersion + "/data-reader-primitives"; + String fullTableName = format("%s.%s.\"%s\"", DELTA_CATALOG, DELTA_SCHEMA.toLowerCase(), tableName); + + String createTableQueryTemplate = "CREATE TABLE %s (\n" + + " \"as_int\" integer,\n" + + " \"as_long\" bigint,\n" + + " \"as_byte\" tinyint,\n" + + " \"as_short\" smallint,\n" + + " \"as_boolean\" boolean,\n" + + " \"as_float\" real,\n" + + " \"as_double\" double,\n" + + " \"as_string\" varchar,\n" + + " \"as_binary\" varbinary,\n" + + " \"as_big_decimal\" decimal(1,0)\n" + + ")\n" + + "WITH (\n" + + " external_location = '%s'\n" + + ")"; + + String expectedSqlCommand = format(createTableQueryTemplate, fullTableName, goldenTablePath(tableName)); + + String showCreateTableCommandResult = (String) computeActual("SHOW CREATE TABLE " + fullTableName).getOnlyValue(); + + assertEquals(showCreateTableCommandResult, expectedSqlCommand); + } } diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java index ae6ae1b4bec21..0a0d387331950 100644 --- a/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java +++ b/presto-delta/src/test/java/com/facebook/presto/delta/TestDeltaTableHandle.java @@ -16,6 +16,7 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockEncoding; @@ -24,6 +25,7 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; @@ -92,6 +94,8 @@ private JsonCodec getJsonCodec() Module module = binder -> { binder.install(new JsonModule()); binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); configBinder(binder).bindConfig(FeaturesConfig.class); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); diff --git a/presto-delta/src/test/java/com/facebook/presto/delta/TestUppercasePartitionColumns.java b/presto-delta/src/test/java/com/facebook/presto/delta/TestUppercasePartitionColumns.java index 97c312d5877f3..6e7f304da7d34 100644 --- a/presto-delta/src/test/java/com/facebook/presto/delta/TestUppercasePartitionColumns.java +++ b/presto-delta/src/test/java/com/facebook/presto/delta/TestUppercasePartitionColumns.java @@ -13,22 +13,15 @@ */ package com.facebook.presto.delta; -import com.facebook.presto.Session; import com.facebook.presto.common.type.TimeZoneKey; -import com.facebook.presto.hive.HivePlugin; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; -import com.facebook.presto.tests.DistributedQueryRunner; -import com.facebook.presto.tpch.TpchPlugin; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; -import java.nio.file.Path; import java.util.Map; -import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; -import static java.util.Locale.US; import static org.testng.Assert.assertEquals; public class TestUppercasePartitionColumns @@ -38,54 +31,16 @@ public class TestUppercasePartitionColumns protected QueryRunner createQueryRunner() throws Exception { - return createDeltaQueryRunner(ImmutableMap.of( + Map extraProperties = ImmutableMap.of( "experimental.pushdown-subfields-enabled", "true", - "experimental.pushdown-dereference-enabled", "true")); - } + "experimental.pushdown-dereference-enabled", "true"); - private static DistributedQueryRunner createDeltaQueryRunner(Map extraProperties) - throws Exception - { - Session session = testSessionBuilder() - .setCatalog(DELTA_CATALOG) - .setSchema(DELTA_SCHEMA.toLowerCase(US)) + return DeltaQueryRunner.builder() .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("Europe/Madrid")) - .build(); - - DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) .setExtraProperties(extraProperties) - .build(); - - // Install the TPCH plugin for test data (not in Delta format) - queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); - - Path dataDirectory = queryRunner.getCoordinator().getDataDirectory().resolve("delta_metadata"); - Path catalogDirectory = dataDirectory.getParent().resolve("catalog"); - - // Install a Delta connector catalog - queryRunner.installPlugin(new DeltaPlugin()); - Map deltaProperties = ImmutableMap.builder() - .put("hive.metastore", "file") - .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) - .put("delta.case-sensitive-partitions-enabled", "true") - .build(); - queryRunner.createCatalog(DELTA_CATALOG, "delta", deltaProperties); - - // Install a Hive connector catalog that uses the same metastore as Delta - // This catalog will be used to create tables in metastore as the Delta connector doesn't - // support creating tables yet. - queryRunner.installPlugin(new HivePlugin("hive")); - Map hiveProperties = ImmutableMap.builder() - .put("hive.metastore", "file") - .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) - .put("hive.allow-drop-table", "true") - .put("hive.security", "legacy") - .build(); - queryRunner.createCatalog(HIVE_CATALOG, "hive", hiveProperties); - queryRunner.execute(format("CREATE SCHEMA %s.%s", HIVE_CATALOG, DELTA_SCHEMA)); - - return queryRunner; + .caseSensitivePartitions() + .build() + .getQueryRunner(); } @Test(dataProvider = "deltaReaderVersions") diff --git a/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=0/part-00005-2d4e572d-5bdd-43f4-9d13-7c5354deb1f6.c000.snappy.parquet b/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=0/part-00005-2d4e572d-5bdd-43f4-9d13-7c5354deb1f6.c000.snappy.parquet new file mode 100644 index 0000000000000..0ff0067f430fc Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=0/part-00005-2d4e572d-5bdd-43f4-9d13-7c5354deb1f6.c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=1/part-00007-00fd7022-5e9c-4cba-b8c8-2297ef36e5ff.c000.snappy.parquet b/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=1/part-00007-00fd7022-5e9c-4cba-b8c8-2297ef36e5ff.c000.snappy.parquet new file mode 100644 index 0000000000000..dee0f78633291 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 11%3A11%3A11/as_big_decimal=1/part-00007-00fd7022-5e9c-4cba-b8c8-2297ef36e5ff.c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/.part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet.crc b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/.part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet.crc new file mode 100644 index 0000000000000..83724aa690668 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/.part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.crc.crc b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.crc.crc new file mode 100644 index 0000000000000..c29a2e2a1cf34 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.crc.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.json.crc b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 0000000000000..1c308c0155d14 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/.00000000000000000000.json.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.crc b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.crc new file mode 100644 index 0000000000000..4ae3f7a3665fc --- /dev/null +++ b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.crc @@ -0,0 +1 @@ +{"txnId":"1dc8f519-a27a-4e10-8cba-b676268f8757","tableSizeBytes":5526,"numFiles":1,"numMetadata":1,"numProtocol":1,"setTransactions":[],"domainMetadata":[],"metadata":{"id":"3b76f6bb-def8-40b9-acfc-f0c1fd74ff91","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738058626356},"protocol":{"minReaderVersion":1,"minWriterVersion":2},"allFiles":[{"path":"part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet","partitionValues":{},"size":5526,"modificationTime":1738058627336,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2022-01-01T00:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2022-01-01T00:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}]} diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.json new file mode 100644 index 0000000000000..f9aa0a0ef36fd --- /dev/null +++ b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"commitInfo":{"timestamp":1738058627380,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"5526"},"engineInfo":"Apache-Spark/3.5.4 Delta-Lake/3.3.0","txnId":"1dc8f519-a27a-4e10-8cba-b676268f8757"}} +{"metaData":{"id":"3b76f6bb-def8-40b9-acfc-f0c1fd74ff91","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738058626356}} +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"add":{"path":"part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet","partitionValues":{},"size":5526,"modificationTime":1738058627336,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2022-01-01T00:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2022-01-01T00:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}} diff --git a/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet new file mode 100644 index 0000000000000..b9df9a77b19b4 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/snapshot-data3/test-typing1/part-00000-16ab4bfc-37b2-4e87-961c-e9e18a355eeb-c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/.part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet.crc b/presto-delta/src/test/resources/delta_v1/test-typing/.part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet.crc new file mode 100644 index 0000000000000..2b3b1cafa5263 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/test-typing/.part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.crc.crc b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.crc.crc new file mode 100644 index 0000000000000..e8cf87488b4a0 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.crc.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.json.crc b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 0000000000000..3edb1af2ddc9a Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/.00000000000000000000.json.crc differ diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.crc b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.crc new file mode 100644 index 0000000000000..741f73a9d85cb --- /dev/null +++ b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.crc @@ -0,0 +1 @@ +{"txnId":"f768aae4-300a-4cab-819f-949d69bf126f","tableSizeBytes":5526,"numFiles":1,"numMetadata":1,"numProtocol":1,"setTransactions":[],"domainMetadata":[],"metadata":{"id":"bf378016-c2bd-406a-b90c-59d260f8ff52","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738136311793},"protocol":{"minReaderVersion":1,"minWriterVersion":2},"allFiles":[{"path":"part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet","partitionValues":{},"size":5526,"modificationTime":1738136312747,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2021-12-31T19:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2021-12-31T19:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}]} diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.json new file mode 100644 index 0000000000000..179c604aaa68f --- /dev/null +++ b/presto-delta/src/test/resources/delta_v1/test-typing/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"commitInfo":{"timestamp":1738136312822,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"5526"},"engineInfo":"Apache-Spark/3.5.4 Delta-Lake/3.3.0","txnId":"f768aae4-300a-4cab-819f-949d69bf126f"}} +{"metaData":{"id":"bf378016-c2bd-406a-b90c-59d260f8ff52","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738136311793}} +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"add":{"path":"part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet","partitionValues":{},"size":5526,"modificationTime":1738136312747,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2021-12-31T19:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2021-12-31T19:35:40.000+03:00\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}} diff --git a/presto-delta/src/test/resources/delta_v1/test-typing/part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet b/presto-delta/src/test/resources/delta_v1/test-typing/part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet new file mode 100644 index 0000000000000..5b9c94981a9f1 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/test-typing/part-00000-4705e2d5-430d-42da-a737-4921ed7fc950-c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v1/timestamp_64/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v1/timestamp_64/_delta_log/00000000000000000000.json new file mode 100644 index 0000000000000..f1bb9087020e3 --- /dev/null +++ b/presto-delta/src/test/resources/delta_v1/timestamp_64/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"2c5549b9-37d0-42ae-98fb-eb42570a1a3a","name":null,"description":null,"format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"product_id\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"product\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"sales_price\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"qt\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"available\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"created_at_tz\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"createdTime":1747905891462,"configuration":{}}} +{"add":{"path":"part-00001-b2260725-1b2b-4dbe-9a2a-6b598cea10da-c000.snappy.parquet","partitionValues":{},"size":2063,"modificationTime":1747905891465,"dataChange":true,"stats":"{\"numRecords\":4,\"minValues\":{\"available\":true,\"sales_price\":1.0,\"created_at_tz\":\"2025-05-22T09:24:11.321346Z\",\"qt\":1,\"product\":\"broculi\",\"product_id\":\"1\"},\"maxValues\":{\"sales_price\":2.0,\"product\":\"water melon\",\"qt\":1,\"product_id\":\"4\",\"available\":true,\"created_at_tz\":\"2025-05-22T09:24:11.321346Z\"},\"nullCount\":{\"sales_price\":0,\"product_id\":0,\"qt\":0,\"created_at_tz\":0,\"available\":0,\"product\":0}}","tags":null,"deletionVector":null,"baseRowId":null,"defaultRowCommitVersion":null,"clusteringProvider":null}} +{"commitInfo":{"timestamp":1747905891465,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists"},"operationMetrics":{"execution_time_ms":2,"num_added_files":1,"num_added_rows":4,"num_partitions":0,"num_removed_files":0},"clientVersion":"delta-rs.py-0.25.5"}} \ No newline at end of file diff --git a/presto-delta/src/test/resources/delta_v1/timestamp_64/part-00001-b2260725-1b2b-4dbe-9a2a-6b598cea10da-c000.snappy.parquet b/presto-delta/src/test/resources/delta_v1/timestamp_64/part-00001-b2260725-1b2b-4dbe-9a2a-6b598cea10da-c000.snappy.parquet new file mode 100644 index 0000000000000..da277a9dd5be9 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v1/timestamp_64/part-00001-b2260725-1b2b-4dbe-9a2a-6b598cea10da-c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.crc b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.crc new file mode 100644 index 0000000000000..a41700080650f --- /dev/null +++ b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.crc @@ -0,0 +1 @@ +{"txnId":"43073c59-3636-4134-9608-1f30c3f32ec2","tableSizeBytes":0,"numFiles":0,"numDeletedRecordsOpt":0,"numDeletionVectorsOpt":0,"numMetadata":1,"numProtocol":1,"setTransactions":[],"domainMetadata":[],"metadata":{"id":"379c3206-6fc5-4c37-887b-8128df98f1f5","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"as_int\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_long\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_byte\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_short\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_boolean\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_float\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_double\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_string\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_timestamp\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_big_decimal\",\"type\":\"decimal(38,18)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["as_int","as_long","as_byte","as_short","as_boolean","as_float","as_double","as_string","as_date","as_timestamp","as_big_decimal"],"configuration":{"delta.checkpoint.writeStatsAsJson":"false","delta.checkpoint.writeStatsAsStruct":"true","delta.enableDeletionVectors":"true"},"createdTime":1761674769368},"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["deletionVectors","timestampNtz"],"writerFeatures":["deletionVectors","timestampNtz","appendOnly","invariants"]},"histogramOpt":{"sortedBinBoundaries":[0,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,4194304,8388608,12582912,16777216,20971520,25165824,29360128,33554432,37748736,41943040,50331648,58720256,67108864,75497472,83886080,92274688,100663296,109051904,117440512,125829120,130023424,134217728,138412032,142606336,146800640,150994944,167772160,184549376,201326592,218103808,234881024,251658240,268435456,285212672,301989888,318767104,335544320,352321536,369098752,385875968,402653184,419430400,436207616,452984832,469762048,486539264,503316480,520093696,536870912,553648128,570425344,587202560,603979776,671088640,738197504,805306368,872415232,939524096,1006632960,1073741824,1140850688,1207959552,1275068416,1342177280,1409286144,1476395008,1610612736,1744830464,1879048192,2013265920,2147483648,2415919104,2684354560,2952790016,3221225472,3489660928,3758096384,4026531840,4294967296,8589934592,17179869184,34359738368,68719476736,137438953472,274877906944],"fileCounts":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"totalBytes":[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]},"deletedRecordCountsHistogramOpt":{"deletedRecordCounts":[0,0,0,0,0,0,0,0,0,0]},"allFiles":[]} diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.json index 22679deca8247..e93403dc63c52 100644 --- a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.json +++ b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000000.json @@ -1,6 +1,3 @@ -{"commitInfo":{"timestamp":1713955885069,"operation":"WRITE","operationParameters":{"mode":"Overwrite","partitionBy":"[\"as_int\",\"as_long\",\"as_byte\",\"as_short\",\"as_boolean\",\"as_float\",\"as_double\",\"as_string\",\"as_date\",\"as_timestamp\",\"as_big_decimal\"]"},"isolationLevel":"Serializable","isBlindAppend":false,"operationMetrics":{"numFiles":"3","numOutputRows":"3","numOutputBytes":"1347"},"engineInfo":"Apache-Spark/3.5.1 Delta-Lake/3.1.0","txnId":"9b250066-b90d-4fe5-abf7-04990dc85713"}} -{"metaData":{"id":"13d0a1fb-53aa-49f6-b5f0-c5347c805d6f","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"as_int\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_long\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_byte\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_short\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_boolean\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_float\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_double\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_string\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_timestamp\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_big_decimal\",\"type\":\"decimal(1,0)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["as_int","as_long","as_byte","as_short","as_boolean","as_float","as_double","as_string","as_date","as_timestamp","as_big_decimal"],"configuration":{},"createdTime":1713955883329}} -{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} -{"add":{"path":"as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08%252011%253A11%253A11/as_big_decimal=1/part-00000-889df1d9-6d79-4ea3-8b69-971cec9bf618.c000.snappy.parquet","partitionValues":{"as_big_decimal":"1","as_int":"1","as_byte":"1","as_long":"1","as_date":"2021-09-08","as_string":"1","as_timestamp":"2021-09-08 11:11:11","as_float":"1.0","as_short":"1","as_boolean":"false","as_double":"1.0"},"size":449,"modificationTime":1713955884743,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"1\"},\"maxValues\":{\"value\":\"1\"},\"nullCount\":{\"value\":0}}"}} -{"add":{"path":"as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00001-6c239ff5-7f3e-4bbd-b06a-4ea89364c08a.c000.snappy.parquet","partitionValues":{"as_big_decimal":null,"as_int":null,"as_byte":null,"as_long":null,"as_date":null,"as_string":null,"as_timestamp":null,"as_float":null,"as_short":null,"as_boolean":null,"as_double":null},"size":449,"modificationTime":1713955884743,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"2\"},\"maxValues\":{\"value\":\"2\"},\"nullCount\":{\"value\":0}}"}} -{"add":{"path":"as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08%252011%253A11%253A11/as_big_decimal=0/part-00002-8e5e2719-f9d0-4e23-8c27-1ac72563c6ab.c000.snappy.parquet","partitionValues":{"as_big_decimal":"0","as_int":"0","as_byte":"0","as_long":"0","as_date":"2021-09-08","as_string":"0","as_timestamp":"2021-09-08 11:11:11","as_float":"0.0","as_short":"0","as_boolean":"true","as_double":"0.0"},"size":449,"modificationTime":1713955884743,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"0\"},\"maxValues\":{\"value\":\"0\"},\"nullCount\":{\"value\":0}}"}} +{"commitInfo":{"timestamp":1761674769606,"userId":"6777253476263814","userName":"llozano@denodo.com","operation":"CREATE TABLE","operationParameters":{"partitionBy":"[\"as_int\",\"as_long\",\"as_byte\",\"as_short\",\"as_boolean\",\"as_float\",\"as_double\",\"as_string\",\"as_date\",\"as_timestamp\",\"as_big_decimal\"]","clusterBy":"[]","description":null,"isManaged":"false","properties":"{\"delta.checkpoint.writeStatsAsJson\":\"false\",\"delta.checkpoint.writeStatsAsStruct\":\"true\",\"delta.enableDeletionVectors\":\"true\"}","statsOnLoad":false},"job":{"jobId":"","runId":""},"isolationLevel":"WriteSerializable","isBlindAppend":true,"operationMetrics":{},"tags":{"restoresDeletedRows":"false"},"engineInfo":"Databricks-Runtime/17.2.x-photon-scala2.13","txnId":"43073c59-3636-4134-9608-1f30c3f32ec2"}} +{"metaData":{"id":"379c3206-6fc5-4c37-887b-8128df98f1f5","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"as_int\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_long\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_byte\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_short\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_boolean\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_float\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_double\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_string\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_timestamp\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_big_decimal\",\"type\":\"decimal(38,18)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["as_int","as_long","as_byte","as_short","as_boolean","as_float","as_double","as_string","as_date","as_timestamp","as_big_decimal"],"configuration":{"delta.checkpoint.writeStatsAsJson":"false","delta.checkpoint.writeStatsAsStruct":"true","delta.enableDeletionVectors":"true"},"createdTime":1761674769368}} +{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["deletionVectors","timestampNtz"],"writerFeatures":["deletionVectors","timestampNtz","appendOnly","invariants"]}} diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.crc b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.crc new file mode 100644 index 0000000000000..f7aa514eb85bb --- /dev/null +++ b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.crc @@ -0,0 +1 @@ +{"txnId":"65f81a4f-7285-45dd-8ce7-8a20f8aa8ef8","tableSizeBytes":1935,"numFiles":3,"numDeletedRecordsOpt":0,"numDeletionVectorsOpt":0,"numMetadata":1,"numProtocol":1,"setTransactions":[],"domainMetadata":[],"metadata":{"id":"379c3206-6fc5-4c37-887b-8128df98f1f5","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"as_int\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_long\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_byte\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_short\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_boolean\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_float\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_double\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_string\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_timestamp\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_big_decimal\",\"type\":\"decimal(38,18)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["as_int","as_long","as_byte","as_short","as_boolean","as_float","as_double","as_string","as_date","as_timestamp","as_big_decimal"],"configuration":{"delta.checkpoint.writeStatsAsJson":"false","delta.checkpoint.writeStatsAsStruct":"true","delta.enableDeletionVectors":"true"},"createdTime":1761674769368},"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["deletionVectors","timestampNtz"],"writerFeatures":["deletionVectors","timestampNtz","appendOnly","invariants"]},"histogramOpt":{"sortedBinBoundaries":[0,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,4194304,8388608,12582912,16777216,20971520,25165824,29360128,33554432,37748736,41943040,50331648,58720256,67108864,75497472,83886080,92274688,100663296,109051904,117440512,125829120,130023424,134217728,138412032,142606336,146800640,150994944,167772160,184549376,201326592,218103808,234881024,251658240,268435456,285212672,301989888,318767104,335544320,352321536,369098752,385875968,402653184,419430400,436207616,452984832,469762048,486539264,503316480,520093696,536870912,553648128,570425344,587202560,603979776,671088640,738197504,805306368,872415232,939524096,1006632960,1073741824,1140850688,1207959552,1275068416,1342177280,1409286144,1476395008,1610612736,1744830464,1879048192,2013265920,2147483648,2415919104,2684354560,2952790016,3221225472,3489660928,3758096384,4026531840,4294967296,8589934592,17179869184,34359738368,68719476736,137438953472,274877906944],"fileCounts":[3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"totalBytes":[1935,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]},"deletedRecordCountsHistogramOpt":{"deletedRecordCounts":[3,0,0,0,0,0,0,0,0,0]},"allFiles":[{"path":"as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00000-16f19de3-66b6-4c24-8558-f9d84ff07d49.c000.snappy.parquet","partitionValues":{"as_big_decimal":null,"as_int":null,"as_byte":null,"as_string":null,"as_timestamp":null,"as_float":null,"as_short":null,"as_boolean":null,"as_double":null,"as_long":null,"as_date":null},"size":645,"modificationTime":1761674772000,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"2\"},\"maxValues\":{\"value\":\"2\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000000","MIN_INSERTION_TIME":"1761674772000000","MAX_INSERTION_TIME":"1761674772000000","OPTIMIZE_TARGET_SIZE":"268435456"}},{"path":"as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08%2006%253A11%253A11/as_big_decimal=1.000000000000000000/part-00002-6c7a17a8-afab-413a-96af-894d78ab408e.c000.snappy.parquet","partitionValues":{"as_big_decimal":"1.000000000000000000","as_int":"1","as_byte":"1","as_string":"1","as_timestamp":"2021-09-08 06:11:11","as_float":"1.0","as_short":"1","as_boolean":"false","as_double":"1.0","as_long":"1","as_date":"2021-09-08"},"size":645,"modificationTime":1761674772000,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"1\"},\"maxValues\":{\"value\":\"1\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000002","MIN_INSERTION_TIME":"1761674772000002","MAX_INSERTION_TIME":"1761674772000002","OPTIMIZE_TARGET_SIZE":"268435456"}},{"path":"as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08%2006%253A11%253A11/as_big_decimal=0.000000000000000000/part-00001-60f500b5-1817-43e3-863c-92012bcb0486.c000.snappy.parquet","partitionValues":{"as_big_decimal":"0.000000000000000000","as_int":"0","as_byte":"0","as_string":"0","as_timestamp":"2021-09-08 06:11:11","as_float":"0.0","as_short":"0","as_boolean":"true","as_double":"0.0","as_long":"0","as_date":"2021-09-08"},"size":645,"modificationTime":1761674772000,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"0\"},\"maxValues\":{\"value\":\"0\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000001","MIN_INSERTION_TIME":"1761674772000001","MAX_INSERTION_TIME":"1761674772000001","OPTIMIZE_TARGET_SIZE":"268435456"}}]} diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.json b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.json index 4ab72ecf8ed73..3c61b2455d561 100644 --- a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.json +++ b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/_delta_log/00000000000000000001.json @@ -1,3 +1,4 @@ -{"commitInfo":{"timestamp":1713955888815,"operation":"SET TBLPROPERTIES","operationParameters":{"properties":"{\"delta.minReaderVersion\":\"3\",\"delta.minWriterVersion\":\"7\"}"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Apache-Spark/3.5.1 Delta-Lake/3.1.0","txnId":"5e3d0fee-e8f7-4471-b48a-21537d4b27cc"}} -{"metaData":{"id":"13d0a1fb-53aa-49f6-b5f0-c5347c805d6f","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"as_int\",\"type\":\"integer\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_long\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_byte\",\"type\":\"byte\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_short\",\"type\":\"short\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_boolean\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_float\",\"type\":\"float\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_double\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_string\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_date\",\"type\":\"date\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_timestamp\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"as_big_decimal\",\"type\":\"decimal(1,0)\",\"nullable\":true,\"metadata\":{}},{\"name\":\"value\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["as_int","as_long","as_byte","as_short","as_boolean","as_float","as_double","as_string","as_date","as_timestamp","as_big_decimal"],"configuration":{},"createdTime":1713955883329}} -{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":[],"writerFeatures":["appendOnly","invariants"]}} +{"commitInfo":{"timestamp":1761674771298,"userId":"6777253476263814","userName":"llozano@denodo.com","operation":"WRITE","operationParameters":{"mode":"Append","statsOnLoad":false,"partitionBy":"[]"},"job":{"jobId":"","runId":""},"readVersion":0,"isolationLevel":"WriteSerializable","isBlindAppend":true,"operationMetrics":{"numFiles":"3","numOutputRows":"3","numOutputBytes":"1935"},"tags":{"noRowsCopied":"true","restoresDeletedRows":"false"},"engineInfo":"Databricks-Runtime/17.2.x-photon-scala2.13","txnId":"65f81a4f-7285-45dd-8ce7-8a20f8aa8ef8"}} +{"add":{"path":"as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00000-16f19de3-66b6-4c24-8558-f9d84ff07d49.c000.snappy.parquet","partitionValues":{"as_big_decimal":null,"as_int":null,"as_byte":null,"as_string":null,"as_timestamp":null,"as_float":null,"as_short":null,"as_boolean":null,"as_double":null,"as_long":null,"as_date":null},"size":645,"modificationTime":1761674772000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"2\"},\"maxValues\":{\"value\":\"2\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000000","MIN_INSERTION_TIME":"1761674772000000","MAX_INSERTION_TIME":"1761674772000000","OPTIMIZE_TARGET_SIZE":"268435456"}}} +{"add":{"path":"as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08%2006%253A11%253A11/as_big_decimal=0.000000000000000000/part-00001-60f500b5-1817-43e3-863c-92012bcb0486.c000.snappy.parquet","partitionValues":{"as_big_decimal":"0.000000000000000000","as_int":"0","as_byte":"0","as_string":"0","as_timestamp":"2021-09-08 06:11:11","as_float":"0.0","as_short":"0","as_boolean":"true","as_double":"0.0","as_long":"0","as_date":"2021-09-08"},"size":645,"modificationTime":1761674772000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"0\"},\"maxValues\":{\"value\":\"0\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000001","MIN_INSERTION_TIME":"1761674772000001","MAX_INSERTION_TIME":"1761674772000001","OPTIMIZE_TARGET_SIZE":"268435456"}}} +{"add":{"path":"as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08%2006%253A11%253A11/as_big_decimal=1.000000000000000000/part-00002-6c7a17a8-afab-413a-96af-894d78ab408e.c000.snappy.parquet","partitionValues":{"as_big_decimal":"1.000000000000000000","as_int":"1","as_byte":"1","as_string":"1","as_timestamp":"2021-09-08 06:11:11","as_float":"1.0","as_short":"1","as_boolean":"false","as_double":"1.0","as_long":"1","as_date":"2021-09-08"},"size":645,"modificationTime":1761674772000,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"value\":\"1\"},\"maxValues\":{\"value\":\"1\"},\"nullCount\":{\"value\":0},\"tightBounds\":true}","tags":{"INSERTION_TIME":"1761674772000002","MIN_INSERTION_TIME":"1761674772000002","MAX_INSERTION_TIME":"1761674772000002","OPTIMIZE_TARGET_SIZE":"268435456"}}} diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=0.000000000000000000/part-00001-60f500b5-1817-43e3-863c-92012bcb0486.c000.snappy.parquet b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=0.000000000000000000/part-00001-60f500b5-1817-43e3-863c-92012bcb0486.c000.snappy.parquet new file mode 100644 index 0000000000000..83b6fe21547fb Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=0/as_long=0/as_byte=0/as_short=0/as_boolean=true/as_float=0.0/as_double=0.0/as_string=0/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=0.000000000000000000/part-00001-60f500b5-1817-43e3-863c-92012bcb0486.c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=1.000000000000000000/part-00002-6c7a17a8-afab-413a-96af-894d78ab408e.c000.snappy.parquet b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=1.000000000000000000/part-00002-6c7a17a8-afab-413a-96af-894d78ab408e.c000.snappy.parquet new file mode 100644 index 0000000000000..e36d00d402eea Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=1/as_long=1/as_byte=1/as_short=1/as_boolean=false/as_float=1.0/as_double=1.0/as_string=1/as_date=2021-09-08/as_timestamp=2021-09-08 06%3A11%3A11/as_big_decimal=1.000000000000000000/part-00002-6c7a17a8-afab-413a-96af-894d78ab408e.c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00000-16f19de3-66b6-4c24-8558-f9d84ff07d49.c000.snappy.parquet b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00000-16f19de3-66b6-4c24-8558-f9d84ff07d49.c000.snappy.parquet new file mode 100644 index 0000000000000..aba6075eb8f18 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/data-reader-partition-values/as_int=__HIVE_DEFAULT_PARTITION__/as_long=__HIVE_DEFAULT_PARTITION__/as_byte=__HIVE_DEFAULT_PARTITION__/as_short=__HIVE_DEFAULT_PARTITION__/as_boolean=__HIVE_DEFAULT_PARTITION__/as_float=__HIVE_DEFAULT_PARTITION__/as_double=__HIVE_DEFAULT_PARTITION__/as_string=__HIVE_DEFAULT_PARTITION__/as_date=__HIVE_DEFAULT_PARTITION__/as_timestamp=__HIVE_DEFAULT_PARTITION__/as_big_decimal=__HIVE_DEFAULT_PARTITION__/part-00000-16f19de3-66b6-4c24-8558-f9d84ff07d49.c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/.part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet.crc b/presto-delta/src/test/resources/delta_v3/test-typing/.part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet.crc new file mode 100644 index 0000000000000..21ef729a48a5f Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/test-typing/.part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet.crc differ diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.crc.crc b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.crc.crc new file mode 100644 index 0000000000000..c6a4be078e0ec Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.crc.crc differ diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.json.crc b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.json.crc new file mode 100644 index 0000000000000..6d33b4e46148a Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/.00000000000000000000.json.crc differ diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.crc b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.crc new file mode 100644 index 0000000000000..6d834d1b50143 --- /dev/null +++ b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.crc @@ -0,0 +1 @@ +{"txnId":"c7efdcd9-93db-4ff1-b39f-13f429c718e5","tableSizeBytes":5521,"numFiles":1,"numMetadata":1,"numProtocol":1,"setTransactions":[],"domainMetadata":[],"metadata":{"id":"19eeeb6e-0857-4b50-97bc-41ef9d836887","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738136249198},"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["timestampNtz"],"writerFeatures":["timestampNtz","appendOnly","invariants"]},"allFiles":[{"path":"part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet","partitionValues":{},"size":5521,"modificationTime":1738136249854,"dataChange":false,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}]} diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.json new file mode 100644 index 0000000000000..70bc1fcdc1505 --- /dev/null +++ b/presto-delta/src/test/resources/delta_v3/test-typing/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"commitInfo":{"timestamp":1738136249889,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists","partitionBy":"[]"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"5521"},"engineInfo":"Apache-Spark/3.5.4 Delta-Lake/3.3.0","txnId":"c7efdcd9-93db-4ff1-b39f-13f429c718e5"}} +{"metaData":{"id":"19eeeb6e-0857-4b50-97bc-41ef9d836887","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"VendorID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_pickup_datetime\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tpep_dropoff_datetime\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"passenger_count\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"trip_distance\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"RatecodeID\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"store_and_fwd_flag\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"PULocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"DOLocationID\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"payment_type\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"fare_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"extra\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"mta_tax\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tip_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"tolls_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"improvement_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"total_amount\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"congestion_surcharge\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"airport_fee\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"configuration":{},"createdTime":1738136249198}} +{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["timestampNtz"],"writerFeatures":["timestampNtz","appendOnly","invariants"]}} +{"add":{"path":"part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet","partitionValues":{},"size":5521,"modificationTime":1738136249854,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"maxValues\":{\"VendorID\":1,\"tpep_pickup_datetime\":\"2022-01-01T00:35:40.000\",\"tpep_dropoff_datetime\":\"2021-12-31T19:53:29.000+03:00\",\"passenger_count\":2.0,\"trip_distance\":3.8,\"RatecodeID\":1.0,\"store_and_fwd_flag\":\"N\",\"PULocationID\":142,\"DOLocationID\":236,\"payment_type\":1,\"fare_amount\":14.5,\"extra\":3.0,\"mta_tax\":0.5,\"tip_amount\":3.65,\"tolls_amount\":0.0,\"improvement_surcharge\":0.3,\"total_amount\":21.95,\"congestion_surcharge\":2.5,\"airport_fee\":0.0},\"nullCount\":{\"VendorID\":0,\"tpep_pickup_datetime\":0,\"tpep_dropoff_datetime\":0,\"passenger_count\":0,\"trip_distance\":0,\"RatecodeID\":0,\"store_and_fwd_flag\":0,\"PULocationID\":0,\"DOLocationID\":0,\"payment_type\":0,\"fare_amount\":0,\"extra\":0,\"mta_tax\":0,\"tip_amount\":0,\"tolls_amount\":0,\"improvement_surcharge\":0,\"total_amount\":0,\"congestion_surcharge\":0,\"airport_fee\":0}}"}} diff --git a/presto-delta/src/test/resources/delta_v3/test-typing/part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet b/presto-delta/src/test/resources/delta_v3/test-typing/part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet new file mode 100644 index 0000000000000..7becceba42d25 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/test-typing/part-00000-994a8669-28a4-4c94-a8bb-a437508996d7-c000.snappy.parquet differ diff --git a/presto-delta/src/test/resources/delta_v3/timestamp_64/_delta_log/00000000000000000000.json b/presto-delta/src/test/resources/delta_v3/timestamp_64/_delta_log/00000000000000000000.json new file mode 100644 index 0000000000000..9082ec2179ea9 --- /dev/null +++ b/presto-delta/src/test/resources/delta_v3/timestamp_64/_delta_log/00000000000000000000.json @@ -0,0 +1,4 @@ +{"protocol":{"minReaderVersion":3,"minWriterVersion":7,"readerFeatures":["timestampNtz"],"writerFeatures":["timestampNtz"]}} +{"metaData":{"id":"357c1852-a8c1-42cb-91f0-42b0d6d7d5df","name":null,"description":null,"format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"product_id\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"product\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"sales_price\",\"type\":\"double\",\"nullable\":true,\"metadata\":{}},{\"name\":\"qt\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}},{\"name\":\"available\",\"type\":\"boolean\",\"nullable\":true,\"metadata\":{}},{\"name\":\"created_at_tz\",\"type\":\"timestamp\",\"nullable\":true,\"metadata\":{}},{\"name\":\"created_at_ntz\",\"type\":\"timestamp_ntz\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":[],"createdTime":1747905934921,"configuration":{}}} +{"add":{"path":"part-00001-7316599e-eeb8-4a8c-be26-5442eaca8b3f-c000.snappy.parquet","partitionValues":{},"size":2412,"modificationTime":1747905934924,"dataChange":true,"stats":"{\"numRecords\":4,\"minValues\":{\"sales_price\":1.0,\"product_id\":\"1\",\"qt\":1,\"available\":true,\"product\":\"broculi\",\"created_at_tz\":\"2025-05-22T09:24:11.321346Z\",\"created_at_ntz\":\"2025-05-22T12:25:16.544646Z\"},\"maxValues\":{\"qt\":1,\"product\":\"water melon\",\"available\":true,\"created_at_tz\":\"2025-05-22T09:24:11.321346Z\",\"created_at_ntz\":\"2025-05-22T12:25:16.544646Z\",\"product_id\":\"4\",\"sales_price\":2.0},\"nullCount\":{\"created_at_ntz\":0,\"product_id\":0,\"available\":0,\"sales_price\":0,\"qt\":0,\"product\":0,\"created_at_tz\":0}}","tags":null,"deletionVector":null,"baseRowId":null,"defaultRowCommitVersion":null,"clusteringProvider":null}} +{"commitInfo":{"timestamp":1747905934924,"operation":"WRITE","operationParameters":{"mode":"ErrorIfExists"},"operationMetrics":{"execution_time_ms":3,"num_added_files":1,"num_added_rows":4,"num_partitions":0,"num_removed_files":0},"clientVersion":"delta-rs.py-0.25.5"}} \ No newline at end of file diff --git a/presto-delta/src/test/resources/delta_v3/timestamp_64/part-00001-7316599e-eeb8-4a8c-be26-5442eaca8b3f-c000.snappy.parquet b/presto-delta/src/test/resources/delta_v3/timestamp_64/part-00001-7316599e-eeb8-4a8c-be26-5442eaca8b3f-c000.snappy.parquet new file mode 100644 index 0000000000000..6dc7859cd5851 Binary files /dev/null and b/presto-delta/src/test/resources/delta_v3/timestamp_64/part-00001-7316599e-eeb8-4a8c-be26-5442eaca8b3f-c000.snappy.parquet differ diff --git a/presto-docs/pom.xml b/presto-docs/pom.xml index e425ed23193c6..322ce3389a79c 100644 --- a/presto-docs/pom.xml +++ b/presto-docs/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-docs @@ -14,6 +14,7 @@ ${project.parent.basedir} + true @@ -36,7 +37,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-maven-plugin diff --git a/presto-docs/requirements.txt b/presto-docs/requirements.txt index f5ee358d6d1e4..3203e335054ca 100644 --- a/presto-docs/requirements.txt +++ b/presto-docs/requirements.txt @@ -1,3 +1,2 @@ sphinx==8.2.1 sphinx-immaterial==0.13.0 -sphinx-copybutton==0.5.2 diff --git a/presto-docs/src/main/sphinx/admin.rst b/presto-docs/src/main/sphinx/admin.rst index da9997630ed5d..9c3bc6f547b5c 100644 --- a/presto-docs/src/main/sphinx/admin.rst +++ b/presto-docs/src/main/sphinx/admin.rst @@ -13,6 +13,7 @@ Administration admin/spill admin/exchange-materialization admin/cte-materialization + admin/materialized-views admin/resource-groups admin/session-property-managers admin/function-namespace-managers @@ -20,3 +21,4 @@ Administration admin/spark admin/verifier admin/grafana-cloud + admin/version-support diff --git a/presto-docs/src/main/sphinx/admin/benchmark-driver.rst b/presto-docs/src/main/sphinx/admin/benchmark-driver.rst index e2dcc0308daa5..cfbabe2741453 100644 --- a/presto-docs/src/main/sphinx/admin/benchmark-driver.rst +++ b/presto-docs/src/main/sphinx/admin/benchmark-driver.rst @@ -8,7 +8,7 @@ It is used to continuously evaluate the performance of trunk. Installation ------------ -Download :maven_download:`benchmark-driver`. +Download :github_download:`benchmark-driver`. Rename the JAR file to ``presto-benchmark-driver`` with the following command (replace ``*`` with the version number of the downloaded jar file): diff --git a/presto-docs/src/main/sphinx/admin/function-namespace-managers.rst b/presto-docs/src/main/sphinx/admin/function-namespace-managers.rst index 39d9bdcc9b686..dd1782bfef7d9 100644 --- a/presto-docs/src/main/sphinx/admin/function-namespace-managers.rst +++ b/presto-docs/src/main/sphinx/admin/function-namespace-managers.rst @@ -55,6 +55,12 @@ following contents:: function-namespaces-table-name=example_function_namespaces functions-table-name=example_sql_functions +To use the MariaDB Java driver instead of the MySQL Connector Java +driver, use the following properties for the ``database-`` fields:: + + database-driver-name=org.mariadb.jdbc.Driver + database-url=jdbc:mariadb://example.net:3306/database?user=root&password=password + When Presto first starts with the above MySQL function namespace manager configuration, two MySQL tables will be created if they do not exist. @@ -86,7 +92,11 @@ The following table lists all configuration properties supported by the MySQL fu =========================================== ================================================================================================== Name Description =========================================== ================================================================================================== -``database-url`` The URL of the MySQL database used by the MySQL function namespace manager. +``database-url`` The JDBC URL of the MySQL database used by the MySQL function namespace manager. If using the MariaDB Java driver, ensure the URL uses the MariaDB connection string format where the string starts with ``jdbc:mariadb:`` +``database-driver-name`` (optional) The name of the JDBC driver class to use for connecting to the MySQL database. For the MariaDb Java client use ``org.mariadb.jdbc.Driver``. Defaults to ``com.mysql.jdbc.Driver``. +``database-connection-timeout`` (optional) The timeout in milliseconds for establishing a connection to the database. Defaults to 30 seconds. +``database-connection-max-lifetime`` (optional) The maximum lifetime of a connection in milliseconds. Defaults to 30 minutes. +``function-namespace-manager.name`` The name of the function namespace manager to instantiate. Currently, only ``mysql`` is supported. ``function-namespaces-table-name`` The name of the table that stores all the function namespaces managed by this manager. ``functions-table-name`` The name of the table that stores all the functions managed by this manager. =========================================== ================================================================================================== diff --git a/presto-docs/src/main/sphinx/admin/materialized-views.rst b/presto-docs/src/main/sphinx/admin/materialized-views.rst new file mode 100644 index 0000000000000..9e330c5509511 --- /dev/null +++ b/presto-docs/src/main/sphinx/admin/materialized-views.rst @@ -0,0 +1,98 @@ +================== +Materialized Views +================== + +Introduction +------------ + +A materialized view stores the results of a query physically, unlike regular views which are virtual. +Queries can read pre-computed results instead of re-executing the underlying query against base tables. + +Materialized views improve performance for expensive queries by calculating results once during +refresh rather than on every query. Common use cases include aggregations, joins, and filtered +subsets of large tables. + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + Use at your own risk in production environments. + + To enable materialized views, set :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\`` = ``false`` + in your configuration properties. + + Alternatively, you can use ``SET SESSION legacy_materialized_views = false`` to enable them for a session, + but only if :ref:`admin/properties:\`\`experimental.allow-legacy-materialized-views-toggle\`\`` = ``true`` + is set in the server configuration. The toggle option should only be used in non-production environments + for testing and migration purposes. + +Security Modes +-------------- + +When ``legacy_materialized_views=false``, materialized views support SQL standard security modes +that control whose permissions are used when querying the view: + +**SECURITY DEFINER** + The materialized view executes with the permissions of the user who created it. This is the + default mode and matches the behavior of most SQL systems. When using DEFINER mode, column + masks and row filters on base tables are permitted. + +**SECURITY INVOKER** + The materialized view executes with the permissions of the user querying it. Each user must + have appropriate permissions on the underlying base tables. When using INVOKER mode and there + are column masks or row filters on the base tables, the materialized view is treated as stale, + since the data would vary by user. + +The default security mode can be configured using the ``default_view_security_mode`` session +property. When the ``SECURITY`` clause is not specified in ``CREATE MATERIALIZED VIEW``, this +default is used. + +.. note:: + + The ``REFRESH`` operation always uses DEFINER rights regardless of the view's security mode. + +Stale Data Handling +------------------- + +Connectors report the freshness state of materialized views to the engine. When a materialized +view is stale (base tables have been modified since the data was last known to be fresh), the +engine determines how to handle the query based on configuration. + +Connectors can configure staleness handling per materialized view, including a behavior setting +and staleness tolerance window. See connector-specific documentation for details (for example, +:ref:`Iceberg `). + +When no per-view configuration is specified, the default behavior is ``USE_VIEW_QUERY`` (Presto +falls back to executing the underlying view query against the base tables). This can be changed +using the ``materialized_view_stale_read_behavior`` session property or the +``materialized-view-stale-read-behavior`` configuration property. Setting it to ``FAIL`` causes +the query to fail with an error when the materialized view is stale. + +Required Permissions +-------------------- + +The following permissions are required for materialized view operations when +``legacy_materialized_views=false``: + +**CREATE MATERIALIZED VIEW** + * ``CREATE TABLE`` permission + * ``CREATE VIEW`` permission + +**REFRESH MATERIALIZED VIEW** + * ``INSERT`` permission (to write new data) + * ``DELETE`` permission (to remove old data) + +**DROP MATERIALIZED VIEW** + * ``DROP TABLE`` permission + * ``DROP VIEW`` permission + +**Querying a materialized view** + * For DEFINER mode: User needs ``SELECT`` permission on the view itself. Additionally, the + view owner must have ``CREATE_VIEW_WITH_SELECT_COLUMNS`` permission on base tables when + non-owners query the view to prevent privilege escalation. + * For INVOKER mode: User needs ``SELECT`` permission on all underlying base tables + +See Also +-------- + +:doc:`/sql/create-materialized-view`, :doc:`/sql/drop-materialized-view`, +:doc:`/sql/refresh-materialized-view`, :doc:`/sql/show-create-materialized-view` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/admin/properties-session.rst b/presto-docs/src/main/sphinx/admin/properties-session.rst index ab0efbc74a363..afe069be78a15 100644 --- a/presto-docs/src/main/sphinx/admin/properties-session.rst +++ b/presto-docs/src/main/sphinx/admin/properties-session.rst @@ -2,12 +2,12 @@ Presto Session Properties ========================= -This section describes session properties that may be used to tune +This section describes session properties that may be used to tune Presto or alter its behavior when required. -The following is not a complete list of all session properties -available in Presto, and does not include any connector-specific -catalog properties. +The following is not a complete list of all session properties +available in Presto, and does not include any connector-specific +catalog properties. For information on catalog properties, see the :doc:`connector documentation `. @@ -41,9 +41,9 @@ only need to fit in distributed memory across all nodes. When set to ``AUTOMATIC Presto will make a cost based decision as to which distribution type is optimal. It will also consider switching the left and right inputs to the join. In ``AUTOMATIC`` mode, Presto will default to hash distributed joins if no cost could be computed, such as if -the tables do not have statistics. +the tables do not have statistics. -The corresponding configuration property is :ref:`admin/properties:\`\`join-distribution-type\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`join-distribution-type\`\``. ``redistribute_writes`` @@ -58,7 +58,7 @@ across nodes in the cluster. It can be disabled when it is known that the output data set is not skewed in order to avoid the overhead of hashing and redistributing all the data across the network. -The corresponding configuration property is :ref:`admin/properties:\`\`redistribute-writes\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`redistribute-writes\`\``. ``task_writer_count`` ^^^^^^^^^^^^^^^^^^^^^ @@ -69,7 +69,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`redistrib Default number of local parallel table writer threads per worker. It is required to be a power of two for a Java query engine. -The corresponding configuration property is :ref:`admin/properties:\`\`task.writer-count\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`task.writer-count\`\``. ``task_partitioned_writer_count`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -93,6 +93,61 @@ be executed within a single node. The corresponding configuration property is :ref:`admin/properties:\`\`single-node-execution-enabled\`\``. +``offset_clause_enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +To enable the ``OFFSET`` clause in SQL query expressions, set this property to ``true``. + +The corresponding configuration property is :ref:`admin/properties:\`\`offset-clause-enabled\`\``. + +``check_access_control_on_utilized_columns_only`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Apply access control rules on only those columns that are required to produce the query output. + +Note: Setting this property to true with the following kinds of queries: + +* queries that have ``USING`` in a join condition +* queries that have duplicate named common table expressions (CTE) + +causes the query to be evaluated as if the property is set to false and checks the access control for all columns. + +To avoid these problems: + +* replace all ``USING`` join conditions in a query with ``ON`` join conditions +* set unique names for all CTEs in a query + +The corresponding configuration property is :ref:`admin/properties:\`\`check-access-control-on-utilized-columns-only\`\``. + +``max_serializable_object_size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``long`` +* **Default value:** ``1000`` + +Maximum object size in bytes that can be considered serializable in a function call by the coordinator. + +The corresponding configuration property is :ref:`admin/properties:\`\`max-serializable-object-size\`\``. + +``max_prefixes_count`` +^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Minimum value:** ``1`` +* **Default value:** ``100`` + +Maximum number of prefixes (catalog/schema/table scopes used to narrow metadata lookups) that Presto generates when querying information_schema. +If the number of computed prefixes exceeds this limit, Presto falls back to a single broader prefix (catalog only). +If it’s below the limit, the generated prefixes are used. + +The corresponding configuration property is :ref:`admin/properties:\`\`max-prefixes-count\`\``. + Spilling Properties ------------------- @@ -105,13 +160,12 @@ Spilling Properties Try spilling memory to disk to avoid exceeding memory limits for the query. Spilling works by offloading memory to disk. This process can allow a query with a large memory -footprint to pass at the cost of slower execution times. Currently, spilling is supported only for -aggregations and joins (inner and outer), so this property will not reduce memory usage required for -window functions, sorting and other join types. +footprint to pass at the cost of slower execution times. See :ref:`spill-operations` +for a list of operations that support spilling. Be aware that this is an experimental feature and should be used with care. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.spill-enabled\`\``. ``join_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^ @@ -122,7 +176,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for joins to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.join-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.join-spill-enabled\`\``. ``aggregation_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -133,7 +187,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for aggregations to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.aggregation-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.aggregation-spill-enabled\`\``. ``distinct_aggregation_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -144,7 +198,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``aggregation_spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for distinct aggregations to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.distinct-aggregation-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.distinct-aggregation-spill-enabled\`\``. ``order_by_aggregation_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -155,7 +209,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``aggregation_spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for order by aggregations to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.order-by-aggregation-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.order-by-aggregation-spill-enabled\`\``. ``window_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -166,7 +220,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for window functions to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.window-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.window-spill-enabled\`\``. ``order_by_spill_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -177,7 +231,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for order by to avoid exceeding memory limits for the query. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.order-by-spill-enabled\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.order-by-spill-enabled\`\``. ``aggregation_operator_unspill_memory_limit`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -187,7 +241,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`experimen Limit for memory used for unspilling a single aggregation operator instance. -The corresponding configuration property is :ref:`admin/properties:\`\`experimental.aggregation-operator-unspill-memory-limit\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.aggregation-operator-unspill-memory-limit\`\``. Task Properties --------------- @@ -205,9 +259,9 @@ resource utilization. Lower values are better for clusters that run many queries concurrently because the cluster will already be utilized by all the running queries, so adding more concurrency will result in slow downs due to context switching and other overhead. Higher values are better for clusters that only run -one or a few queries at a time. +one or a few queries at a time. -The corresponding configuration property is :ref:`admin/properties:\`\`task.concurrency\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`task.concurrency\`\``. ``task_writer_count`` ^^^^^^^^^^^^^^^^^^^^^ @@ -220,9 +274,9 @@ The number of concurrent writer threads per worker per query. Increasing this va increase write speed, especially when a query is not I/O bound and can take advantage of additional CPU for parallel writes (some connectors can be bottlenecked on CPU when writing due to compression or other factors). Setting this too high may cause the cluster -to become overloaded due to excessive resource utilization. +to become overloaded due to excessive resource utilization. -The corresponding configuration property is :ref:`admin/properties:\`\`task.writer-count\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`task.writer-count\`\``. Optimizer Properties -------------------- @@ -233,9 +287,9 @@ Optimizer Properties * **Type:** ``boolean`` * **Default value:** ``false`` -Enables optimization for aggregations on dictionaries. +Enables optimization for aggregations on dictionaries. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.dictionary-aggregation\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.dictionary-aggregation\`\``. ``optimize_hash_generation`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -247,12 +301,12 @@ Compute hash codes for distribution, joins, and aggregations early during execut allowing result to be shared between operations later in the query. This can reduce CPU usage by avoiding computing the same hash multiple times, but at the cost of additional network transfer for the hashes. In most cases it will decrease overall -query processing time. +query processing time. It is often helpful to disable this property when using :doc:`/sql/explain` in order to make the query plan easier to read. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.optimize-hash-generation\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.optimize-hash-generation\`\``. ``push_aggregation_through_join`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -272,9 +326,9 @@ over an outer join. For example:: Enabling this optimization can substantially speed up queries by reducing the amount of data that needs to be processed by the join. However, it may slow down some -queries that have very selective joins. +queries that have very selective joins. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-aggregation-through-join\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-aggregation-through-join\`\``. ``push_table_write_through_union`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -286,9 +340,9 @@ Parallelize writes when using ``UNION ALL`` in queries that write data. This imp speed of writing output tables in ``UNION ALL`` queries because these writes do not require additional synchronization when collecting results. Enabling this optimization can improve ``UNION ALL`` speed when write speed is not yet saturated. However, it may slow down queries -in an already heavily loaded system. +in an already heavily loaded system. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-table-write-through-union\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-table-write-through-union\`\``. ``join_reordering_strategy`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -302,9 +356,9 @@ query. ``ELIMINATE_CROSS_JOINS`` reorders joins to eliminate cross joins where otherwise maintains the original query order. When reordering joins it also strives to maintain the original table order as much as possible. ``AUTOMATIC`` enumerates possible orders and uses statistics-based cost estimation to determine the least cost order. If stats are not available or if -for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. +for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.join-reordering-strategy\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.join-reordering-strategy\`\``. ``confidence_based_broadcast`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -316,7 +370,7 @@ Enable broadcasting based on the confidence of the statistics that are being use broadcasting the side of a joinNode which has the highest (``HIGH`` or ``FACT``) confidence statistics. If both sides have the same confidence statistics, then the original behavior will be followed. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.confidence-based-broadcast\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.confidence-based-broadcast\`\``. ``treat-low-confidence-zero-estimation-as-unknown`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -324,9 +378,9 @@ The corresponding configuration property is :ref:`admin/properties:\`\`optimizer * **Type:** ``boolean`` * **Default value:** ``false`` -Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. +Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. -The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.treat-low-confidence-zero-estimation-as-unknown\`\``. +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.treat-low-confidence-zero-estimation-as-unknown\`\``. ``retry-query-with-history-based-optimization`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -334,7 +388,7 @@ The corresponding configuration property is :ref:`admin/properties:\`\`optimizer * **Type:** ``boolean`` * **Default value:** ``false`` -Enable retry for failed queries who can potentially be helped by HBO. +Enable retry for failed queries who can potentially be helped by HBO. The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.retry-query-with-history-based-optimization\`\``. @@ -358,6 +412,128 @@ The corresponding configuration property is :ref:`admin/properties:\`\`optimizer Enable push down inner join inequality predicates to database. For this configuration to be enabled, :ref:`admin/properties-session:\`\`optimizer_inner_join_pushdown_enabled\`\`` should be set to ``true``. The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.inequality-join-pushdown-enabled\`\``. +``verbose_optimizer_info_enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Use this and ``optimizers_to_enable_verbose_runtime_stats`` in development to collect valuable debugging information about the optimizer. + +Set to ``true`` to use as shown in this example: + +``SET SESSION verbose_optimizer_info_enabled=true;`` + +``optimizers_to_enable_verbose_runtime_stats`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Allowed values:** ``ALL``, an optimizer rule name, or multiple comma-separated optimization rule names +* **Default value:** ``none`` + +Use this and ``verbose_optimizer_info_enabled`` in development to collect valuable debugging information about the optimizer. + +Run the following command to use ``optimizers_to_enable_verbose_runtime_stats``: + +``SET SESSION optimizers_to_enable_verbose_runtime_stats=ALL;`` + +``pushdown_subfields_for_map_functions`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Use this to optimize the ``map_filter()`` and ``map_subset()`` function. + +It controls if subfields access is executed at the data source or not. + +``pushdown_subfields_for_cardinality`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enable subfield pruning for the ``cardinality()`` function to skip reading keys and values. + +When enabled, the query optimizer can push down subfield pruning for cardinality operations, +allowing the data source to skip reading the actual keys and values when only the cardinality +(count of elements) is needed. + +``schedule_splits_based_on_task_load`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* **Type:** ``boolean`` +* **Default value:** ``false`` + +If true then splits are scheduled to the tasks based on task load, rather than on the node load. +This is particularly useful for the native worker as it runs splits for tasks differently than the java worker. +The corresponding configuration property is :ref:`admin/properties:\`\`node-scheduler.max-splits-per-task\`\``. + +Set to ``true`` to use as shown in this example: + +``SET SESSION schedule_splits_based_on_task_load=true;`` + +``table_scan_shuffle_parallelism_threshold`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``double`` +* **Default value:** ``0.1`` + +Parallelism threshold for adding a shuffle above table scan. When the table's parallelism factor +is below this threshold (0.0-1.0) and ``table_scan_shuffle_strategy`` is ``COST_BASED``, +a round-robin shuffle exchange is added above the table scan to redistribute data. + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.table-scan-shuffle-parallelism-threshold\`\``. + +``table_scan_shuffle_strategy`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Allowed values:** ``DISABLED``, ``ALWAYS_ENABLED``, ``COST_BASED`` +* **Default value:** ``DISABLED`` + +Strategy for adding shuffle above table scan to redistribute data. When set to ``DISABLED``, +no shuffle is added. When set to ``ALWAYS_ENABLED``, a round-robin shuffle exchange is always +added above table scans. When set to ``COST_BASED``, a shuffle is added only when the table's +parallelism factor is below the ``table_scan_shuffle_parallelism_threshold``. + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.table-scan-shuffle-strategy\`\``. + +``remote_function_names_for_fixed_parallelism`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** ``""`` (empty string, disabled) + +A regular expression pattern to match fully qualified remote function names, such as ``catalog.schema.function_name``, +that should use fixed parallelism. When a remote function matches this pattern, the optimizer inserts +round-robin shuffle exchanges before and after the projection containing the remote function call. +This ensures that the remote function executes with a fixed degree of parallelism, which can be useful +for controlling resource usage when calling external services. + +This property only applies to external/remote functions (functions where ``isExternalExecution()`` returns ``true``, +such as functions using THRIFT, GRPC, or REST implementation types). + +Example patterns: + +* ``myschema.myfunction`` - matches an exact function name +* ``catalog.schema.remote_.*`` - matches all functions starting with ``remote_`` in the specified catalog and schema +* ``.*remote.*`` - matches any function containing ``remote`` in its fully qualified name + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.remote-function-names-for-fixed-parallelism\`\``. + +``remote_function_fixed_parallelism_task_count`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``null`` (uses the default hash partition count) + +The number of tasks to use for remote functions matching the ``remote_function_names_for_fixed_parallelism`` pattern. +When set, this value determines the degree of parallelism for the round-robin shuffle exchanges inserted +around matching remote function projections. If not set, the default hash partition count will be used. + +This property is only effective when ``remote_function_names_for_fixed_parallelism`` is set to a non-empty pattern. + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.remote-function-fixed-parallelism-task-count\`\``. + + JDBC Properties --------------- @@ -418,4 +594,83 @@ Query Manager Properties This property can be used to configure how long a query runs without contact from the client application, such as the CLI, before it's abandoned. -The corresponding configuration property is :ref:`admin/properties:\`\`query.client.timeout\`\``. \ No newline at end of file +The corresponding configuration property is :ref:`admin/properties:\`\`query.client.timeout\`\``. + +``query_priority`` +^^^^^^^^^^^^^^^^^^ + +* **Type:** ``int`` +* **Default value:** ``1`` + +This property defines the priority of queries for execution and plays an important role in query admission. +Queries with higher priority are scheduled first than the ones with lower priority. Higher number indicates higher priority. + +``query_max_queued_time`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``Duration`` +* **Default value:** ``100d`` + +Use to configure how long a query can be queued before it is terminated. + +The corresponding configuration property is :ref:`admin/properties:\`\`query.max-queued-time\`\``. + +View and Materialized View Properties +-------------------------------------- + +``default_view_security_mode`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Allowed values:** ``DEFINER``, ``INVOKER`` +* **Default value:** ``DEFINER`` + +Sets the default security mode for views and materialized views when the ``SECURITY`` +clause is not explicitly specified in ``CREATE VIEW`` or ``CREATE MATERIALIZED VIEW`` +statements. + +* ``DEFINER``: Views execute with the permissions of the user who created them +* ``INVOKER``: Views execute with the permissions of the user querying them + +The corresponding configuration property is :ref:`admin/properties:\`\`default-view-security-mode\`\``. + +``legacy_materialized_views`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Use legacy materialized views implementation. Set to ``false`` to enable the new materialized +views implementation with security modes (DEFINER and INVOKER), automatic query rewriting, and +freshness tracking. + +By default, this session property is locked to the server configuration value and cannot be +changed. To allow runtime toggling of this property (for testing/migration purposes only), +set :ref:`admin/properties:\`\`experimental.allow-legacy-materialized-views-toggle\`\`` = ``true`` +in the server configuration. + +The corresponding configuration property is :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\``. + +``materialized_view_stale_read_behavior`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** ``USE_VIEW_QUERY`` + +Controls behavior when a materialized view is stale and no per-view staleness config is set. +Valid values are ``FAIL`` (throw an error) or ``USE_VIEW_QUERY`` (query base tables instead). + +The corresponding configuration property is :ref:`admin/properties:\`\`materialized-view-stale-read-behavior\`\``. + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + +``optimizer.optimize_multiple_approx_distinct_on_same_type`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enable optimization to combine multiple :func:`!approx_distinct` function calls on expressions +of the same type into a single aggregation using ``set_agg`` with array operations (``array_constructor``, ``array_transpose``). diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 997515dae4620..2720d8fffa362 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -2,12 +2,12 @@ Presto Configuration Properties =============================== -This section describes configuration properties that may be used to tune +This section describes configuration properties that may be used to tune Presto or alter its behavior when required. -The following is not a complete list of all configuration properties +The following is not a complete list of all configuration properties available in Presto, and does not include any connector-specific -catalog configuration properties. +catalog configuration properties. For information on catalog configuration properties, see the :doc:`connector documentation `. @@ -40,9 +40,9 @@ only need to fit in distributed memory across all nodes. When set to ``AUTOMATIC Presto will make a cost based decision as to which distribution type is optimal. It will also consider switching the left and right inputs to the join. In ``AUTOMATIC`` mode, Presto will default to hash distributed joins if no cost could be computed, such as if -the tables do not have statistics. +the tables do not have statistics. -The corresponding session property is :ref:`admin/properties-session:\`\`join_distribution_type\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`join_distribution_type\`\``. ``redistribute-writes`` @@ -55,9 +55,31 @@ This property enables redistribution of data before writing. This can eliminate the performance impact of data skew when writing by hashing it across nodes in the cluster. It can be disabled when it is known that the output data set is not skewed in order to avoid the overhead of hashing and -redistributing all the data across the network. +redistributing all the data across the network. -The corresponding session property is :ref:`admin/properties-session:\`\`redistribute_writes\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`redistribute_writes\`\``. + +``check-access-control-on-utilized-columns-only`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Apply access control rules on only those columns that are required to produce the query output. + +Note: Setting this property to true with the following kinds of queries: + +* queries that have ``USING`` in a join condition +* queries that have duplicate named common table expressions (CTE) + +causes the query to be evaluated as if the property is set to false and checks the access control for all columns. + +To avoid these problems: + +* replace all ``USING`` join conditions in a query with ``ON`` join conditions +* set unique names for all CTEs in a query + +The corresponding session property is :ref:`admin/properties-session:\`\`check_access_control_on_utilized_columns_only\`\``. ``eager-plan-validation-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -102,8 +124,8 @@ session properties are included. * **Minimum value:** ``0`` * **Default value:** ``0`` -The number of times that a query is automatically retried in the case of a transient query or communications failure. -The default value ``0`` means that retries are disabled. +The number of times that a query is automatically retried in the case of a transient query or communications failure. +The default value ``0`` means that retries are disabled. ``http-server.max-request-header-size`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -111,12 +133,54 @@ The default value ``0`` means that retries are disabled. * **Type:** ``data size`` * **Default value:** ``8 kB`` -The maximum size of the request header from the HTTP server. +The maximum size of the request header from the HTTP server. -Note: The default value can cause errors when large session properties -or other large session information is involved. +Note: The default value can cause errors when large session properties +or other large session information is involved. See :ref:`troubleshoot/query:\`\`Request Header Fields Too Large\`\``. +``offset-clause-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +To enable the ``OFFSET`` clause in SQL query expressions, set this property to ``true``. + +The corresponding session property is :ref:`admin/properties-session:\`\`offset_clause_enabled\`\``. + +``max-serializable-object-size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``long`` +* **Default value:** ``1000`` + +Maximum object size in bytes that can be considered serializable in a function call by the coordinator. + +The corresponding session property is :ref:`admin/properties-session:\`\`max_serializable_object_size\`\``. + +``max-prefixes-count`` +^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Minimum value:** ``1`` +* **Default value:** ``100`` + +Maximum number of prefixes (catalog/schema/table scopes used to narrow metadata lookups) that Presto generates when querying information_schema. +If the number of computed prefixes exceeds this limit, Presto falls back to a single broader prefix (catalog only). +If it’s below the limit, the generated prefixes are used. + +The corresponding session property is :ref:`admin/properties-session:\`\`max_prefixes_count\`\``. + +``cluster-tag`` +^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** (none) + +An optional identifier for the cluster. When set, this tag is included in the response from the +``/v1/cluster`` REST API endpoint, allowing clients to identify which cluster provided the response. + Memory Management Properties ---------------------------- @@ -210,13 +274,12 @@ Spilling Properties Try spilling memory to disk to avoid exceeding memory limits for the query. Spilling works by offloading memory to disk. This process can allow a query with a large memory -footprint to pass at the cost of slower execution times. Currently, spilling is supported only for -aggregations and joins (inner and outer), so this property will not reduce memory usage required for -window functions, sorting and other join types. +footprint to pass at the cost of slower execution times. See :ref:`spill-operations` +for a list of operations that support spilling. Be aware that this is an experimental feature and should be used with care. -The corresponding session property is :ref:`admin/properties-session:\`\`spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`spill_enabled\`\``. ``experimental.join-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -227,7 +290,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`spill_e When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for joins to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`join_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`join_spill_enabled\`\``. ``experimental.aggregation-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -238,7 +301,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`join_sp When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for aggregations to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`aggregation_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`aggregation_spill_enabled\`\``. ``experimental.distinct-aggregation-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -249,7 +312,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`aggrega When ``aggregation_spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for distinct aggregations to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`distinct_aggregation_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`distinct_aggregation_spill_enabled\`\``. ``experimental.order-by-aggregation-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -260,7 +323,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`distinc When ``aggregation_spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for order by aggregations to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`order_by_aggregation_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`order_by_aggregation_spill_enabled\`\``. ``experimental.window-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -271,7 +334,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`order_b When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for window functions to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`window_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`window_spill_enabled\`\``. ``experimental.order-by-spill-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -282,7 +345,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`window_ When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for order by to avoid exceeding memory limits for the query. -The corresponding session property is :ref:`admin/properties-session:\`\`order_by_spill_enabled\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`order_by_spill_enabled\`\``. ``experimental.spiller.task-spilling-strategy`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -408,7 +471,7 @@ Max spill space to be used by a single query on a single node. Limit for memory used for unspilling a single aggregation operator instance. -The corresponding session property is :ref:`admin/properties-session:\`\`aggregation_operator_unspill_memory_limit\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`aggregation_operator_unspill_memory_limit\`\``. ``experimental.spill-compression-codec`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -529,6 +592,24 @@ shared across all of the partitioned consumers. Increasing this value may improve network throughput for data transferred between stages if the network has high latency or if there are many nodes in the cluster. +``use-connector-provided-serialization-codecs`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enables the use of custom connector-provided serialization codecs for handles. +This feature allows connectors to use their own serialization format for +handle objects (such as table handles, column handles, and splits) instead +of standard JSON serialization. + +When enabled, connectors that provide a ``ConnectorCodecProvider`` with +appropriate codecs will have their handles serialized using custom binary +formats, which are then Base64-encoded for transport. Connectors without +codec support automatically fall back to standard JSON serialization. +Internal Presto handles (prefixed with ``$``) always use JSON serialization +regardless of this setting. + .. _task-properties: Task Properties @@ -547,9 +628,9 @@ resource utilization. Lower values are better for clusters that run many queries concurrently because the cluster will already be utilized by all the running queries, so adding more concurrency will result in slow downs due to context switching and other overhead. Higher values are better for clusters that only run -one or a few queries at a time. +one or a few queries at a time. -The corresponding session property is :ref:`admin/properties-session:\`\`task_concurrency\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`task_concurrency\`\``. ``task.http-response-threads`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -640,9 +721,9 @@ The number of concurrent writer threads per worker per query. Increasing this va increase write speed, especially when a query is not I/O bound and can take advantage of additional CPU for parallel writes (some connectors can be bottlenecked on CPU when writing due to compression or other factors). Setting this too high may cause the cluster -to become overloaded due to excessive resource utilization. +to become overloaded due to excessive resource utilization. -The corresponding session property is :ref:`admin/properties-session:\`\`task_writer_count\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`task_writer_count\`\``. ``task.interrupt-runaway-splits-timeout`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -683,6 +764,30 @@ due to splits not being balanced across workers. Ideally, it should be set such that there is always at least one split waiting to be processed, but not higher. +``node-scheduler.max-splits-per-task`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``10`` + +The target value for the number of splits that can be running for +each task, assuming all splits have the standard split weight. + +Using a higher value is recommended if tasks parallelism is higher than 10. +Increasing this value may improve query latency by ensuring that the workers +have enough splits to keep them fully utilized. + +When connectors do support weight based split scheduling, the number of splits +assigned will depend on the weight of the individual splits. If splits are +small, more of them are allowed to be assigned to each worker to compensate. + +Setting this too high will waste memory and may result in lower performance +due to splits not being balanced across workers. Ideally, it should be set +such that there is always at least one split waiting to be processed, but +not higher. + +The corresponding session property is :ref:`admin/properties-session:\`\`schedule_splits_based_on_task_load\`\``. + ``node-scheduler.max-pending-splits-per-task`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -734,9 +839,9 @@ Optimizer Properties * **Type:** ``boolean`` * **Default value:** ``false`` -Enables optimization for aggregations on dictionaries. +Enables optimization for aggregations on dictionaries. -The corresponding session property is :ref:`admin/properties-session:\`\`dictionary_aggregation\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`dictionary_aggregation\`\``. ``optimizer.optimize-hash-generation`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -748,12 +853,12 @@ Compute hash codes for distribution, joins, and aggregations early during execut allowing result to be shared between operations later in the query. This can reduce CPU usage by avoiding computing the same hash multiple times, but at the cost of additional network transfer for the hashes. In most cases it will decrease overall -query processing time. +query processing time. It is often helpful to disable this property when using :doc:`/sql/explain` in order to make the query plan easier to read. -The corresponding session property is :ref:`admin/properties-session:\`\`optimize_hash_generation\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`optimize_hash_generation\`\``. ``optimizer.optimize-metadata-queries`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -799,9 +904,9 @@ over an outer join. For example:: Enabling this optimization can substantially speed up queries by reducing the amount of data that needs to be processed by the join. However, it may slow down some -queries that have very selective joins. +queries that have very selective joins. -The corresponding session property is :ref:`admin/properties-session:\`\`push_aggregation_through_join\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`push_aggregation_through_join\`\``. ``optimizer.push-table-write-through-union`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -813,9 +918,9 @@ Parallelize writes when using ``UNION ALL`` in queries that write data. This imp speed of writing output tables in ``UNION ALL`` queries because these writes do not require additional synchronization when collecting results. Enabling this optimization can improve ``UNION ALL`` speed when write speed is not yet saturated. However, it may slow down queries -in an already heavily loaded system. +in an already heavily loaded system. -The corresponding session property is :ref:`admin/properties-session:\`\`push_table_write_through_union\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`push_table_write_through_union\`\``. ``optimizer.join-reordering-strategy`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -829,9 +934,9 @@ query. ``ELIMINATE_CROSS_JOINS`` reorders joins to eliminate cross joins where otherwise maintains the original query order. When reordering joins it also strives to maintain the original table order as much as possible. ``AUTOMATIC`` enumerates possible orders and uses statistics-based cost estimation to determine the least cost order. If stats are not available or if -for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. +for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. -The corresponding session property is :ref:`admin/properties-session:\`\`join_reordering_strategy\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`join_reordering_strategy\`\``. ``optimizer.max-reordered-joins`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -914,7 +1019,7 @@ Enable broadcasting based on the confidence of the statistics that are being use broadcasting the side of a joinNode which has the highest (``HIGH`` or ``FACT``) confidence statistics. If both sides have the same confidence statistics, then the original behavior will be followed. -The corresponding session property is :ref:`admin/properties-session:\`\`confidence_based_broadcast\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`confidence_based_broadcast\`\``. ``optimizer.treat-low-confidence-zero-estimation-as-unknown`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -922,9 +1027,9 @@ The corresponding session property is :ref:`admin/properties-session:\`\`confide * **Type:** ``boolean`` * **Default value:** ``false`` -Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. +Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. -The corresponding session property is :ref:`admin/properties-session:\`\`treat-low-confidence-zero-estimation-as-unknown\`\``. +The corresponding session property is :ref:`admin/properties-session:\`\`treat-low-confidence-zero-estimation-as-unknown\`\``. ``optimizer.retry-query-with-history-based-optimization`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -932,7 +1037,7 @@ The corresponding session property is :ref:`admin/properties-session:\`\`treat-l * **Type:** ``boolean`` * **Default value:** ``false`` -Enable retry for failed queries who can potentially be helped by HBO. +Enable retry for failed queries who can potentially be helped by HBO. The corresponding session property is :ref:`admin/properties-session:\`\`retry-query-with-history-based-optimization\`\``. @@ -963,11 +1068,74 @@ The corresponding session property is :ref:`admin/properties-session:\`\`optimiz * **Default Value:** ``false`` Enables the optimizer to use histograms when available to perform cost estimate calculations -during query optimization. When set to ``false``, this parameter does not prevent histograms -from being collected by ``ANALYZE``, but prevents them from being used during query -optimization. This behavior can be controlled on a per-query basis using the +during query optimization. When set to ``false``, this parameter prevents histograms from +being collected by ``ANALYZE``, and also prevents the existing histograms from being used +during query optimization. This behavior can be controlled on a per-query basis using the ``optimizer_use_histograms`` session property. +``optimizer.table-scan-shuffle-parallelism-threshold`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``double`` +* **Default value:** ``0.1`` + +Parallelism threshold for adding a shuffle above table scan. When the table's parallelism factor +is below this threshold (0.0-1.0) and ``optimizer.table-scan-shuffle-strategy`` is ``COST_BASED``, +a round-robin shuffle exchange is added above the table scan to redistribute data. + +The corresponding session property is :ref:`admin/properties-session:\`\`table_scan_shuffle_parallelism_threshold\`\``. + +``optimizer.table-scan-shuffle-strategy`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Allowed values:** ``DISABLED``, ``ALWAYS_ENABLED``, ``COST_BASED`` +* **Default value:** ``DISABLED`` + +Strategy for adding shuffle above table scan to redistribute data. When set to ``DISABLED``, +no shuffle is added. When set to ``ALWAYS_ENABLED``, a round-robin shuffle exchange is always +added above table scans. When set to ``COST_BASED``, a shuffle is added only when the table's +parallelism factor is below the ``optimizer.table-scan-shuffle-parallelism-threshold``. + +The corresponding session property is :ref:`admin/properties-session:\`\`table_scan_shuffle_strategy\`\``. + +``optimizer.remote-function-names-for-fixed-parallelism`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** ``""`` (empty string, disabled) + +A regular expression pattern to match fully qualified remote function names, such as ``catalog.schema.function_name``, +that should use fixed parallelism. When a remote function matches this pattern, the optimizer inserts +round-robin shuffle exchanges before and after the projection containing the remote function call. +This ensures that the remote function executes with a fixed degree of parallelism, which can be useful +for controlling resource usage when calling external services. + +This property only applies to external/remote functions (functions where ``isExternalExecution()`` returns ``true``, +such as functions using THRIFT, GRPC, or REST implementation types). + +Example patterns: + +* ``myschema.myfunction`` - matches an exact function name +* ``catalog.schema.remote_.*`` - matches all functions starting with ``remote_`` in the specified catalog and schema +* ``.*remote.*`` - matches any function containing ``remote`` in its fully qualified name + +The corresponding session property is :ref:`admin/properties-session:\`\`remote_function_names_for_fixed_parallelism\`\``. + +``optimizer.remote-function-fixed-parallelism-task-count`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``null`` (uses the default hash partition count) + +The number of tasks to use for remote functions matching the ``optimizer.remote-function-names-for-fixed-parallelism`` pattern. +When set, this value determines the degree of parallelism for the round-robin shuffle exchanges inserted +around matching remote function projections. If not set, the default hash partition count will be used. + +This property is only effective when ``optimizer.remote-function-names-for-fixed-parallelism`` is set to a non-empty pattern. + +The corresponding session property is :ref:`admin/properties-session:\`\`remote_function_fixed_parallelism_task_count\`\``. + Planner Properties ------------------ @@ -1102,4 +1270,156 @@ Query Manager Properties * **Default value:** ``5m`` This property can be used to configure how long a query runs without contact -from the client application, such as the CLI, before it's abandoned. \ No newline at end of file +from the client application, such as the CLI, before it's abandoned. + +The corresponding session property is :ref:`admin/properties-session:\`\`query_client_timeout\`\``. + +``query.max-queued-time`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``Duration`` +* **Default value:** ``100d`` + +Use to configure how long a query can be queued before it is terminated. + +The corresponding session property is :ref:`admin/properties-session:\`\`query_max_queued_time\`\``. + +``query-manager.query-pacing.max-queries-per-second`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Minimum value:** ``1`` +* **Default value:** ``2147483647`` (unlimited) + +Maximum number of queries that can be admitted per second globally across +all resource groups. This property enables query admission pacing to prevent +worker overload when many queries start simultaneously. Pacing only activates +when the number of running queries exceeds the threshold configured by +``query-manager.query-pacing.min-running-queries``. + +Set to a lower value such as ``10`` to limit query admission rate during +periods of high cluster load. The default value effectively disables pacing. + +``query-manager.query-pacing.min-running-queries`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Minimum value:** ``0`` +* **Default value:** ``30`` + +Minimum number of running queries required before query admission pacing +is applied. When the total number of running queries is below this threshold, +queries are admitted immediately without rate limiting, regardless of the +``query-manager.query-pacing.max-queries-per-second`` setting. + +This allows the cluster to quickly ramp up when idle while still providing +protection against overload when the cluster is busy. Set to ``0`` to always +apply pacing when ``max-queries-per-second`` is configured. + +Query Retry Properties +---------------------- + +``retry.enabled`` +^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Enable cross-cluster retry functionality. When enabled, queries that fail with +specific error codes can be automatically retried on a backup cluster if a +retry URL is provided. + +``retry.allowed-domains`` +^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** (empty, signifying current second-level domain allowed only) + +Comma-separated list of allowed domains for retry URLs. Supports wildcards +like ``*.example.com``. For example: ``cluster1.example.com,*.backup.example.net``. +When empty (default), only retry URLs from the same domain as the current server +are allowed. + +``retry.require-https`` +^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Require HTTPS for retry URLs. When enabled, only HTTPS URLs will be accepted +for cross-cluster retry operations. + +``retry.cross-cluster-error-codes`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** ``REMOTE_TASK_ERROR`` + +Comma-separated list of error codes that allow cross-cluster retry. When a query +fails with one of these error codes, it can be automatically retried on a backup +cluster if a retry URL is provided. Available error codes include standard Presto +error codes such as ``REMOTE_TASK_ERROR``, ``CLUSTER_OUT_OF_MEMORY``, etc. + +View and Materialized View Properties +------------------------------------- + +``default-view-security-mode`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Allowed values:** ``DEFINER``, ``INVOKER`` +* **Default value:** ``DEFINER`` + +Sets the default security mode for views and materialized views when the ``SECURITY`` +clause is not explicitly specified in ``CREATE VIEW`` or ``CREATE MATERIALIZED VIEW`` +statements. + +* ``DEFINER``: Views execute with the permissions of the user who created them +* ``INVOKER``: Views execute with the permissions of the user querying them + +The corresponding session property is :ref:`admin/properties-session:\`\`default_view_security_mode\`\``. + +``experimental.legacy-materialized-views`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +Use legacy materialized views implementation. Set to ``false`` to enable materialized +views with security modes (DEFINER and INVOKER), automatic query rewriting, and +freshness tracking. + +The corresponding session property is :ref:`admin/properties-session:\`\`legacy_materialized_views\`\``. + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + +``experimental.allow-legacy-materialized-views-toggle`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Allow the ``legacy_materialized_views`` session property to be changed at runtime. +By default, the session property value is locked to the server configuration value +and cannot be changed per-session. + +Set this to ``true`` to allow users to toggle between legacy and new materialized +views implementations using the session property. This is intended for testing and +migration purposes only. + +.. warning:: + + This should only be enabled in non-production environments. + +``materialized-view-stale-read-behavior`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``string`` +* **Default value:** ``USE_VIEW_QUERY`` + +Controls behavior when a materialized view is stale and no per-view staleness config is set. +Valid values are ``FAIL`` (throw an error) or ``USE_VIEW_QUERY`` (query base tables instead). + +The corresponding session property is :ref:`admin/properties-session:\`\`materialized_view_stale_read_behavior\`\``. diff --git a/presto-docs/src/main/sphinx/admin/resource-groups.rst b/presto-docs/src/main/sphinx/admin/resource-groups.rst index b2693bc224daa..54b48dc14601a 100644 --- a/presto-docs/src/main/sphinx/admin/resource-groups.rst +++ b/presto-docs/src/main/sphinx/admin/resource-groups.rst @@ -154,7 +154,7 @@ Database Resource Group Manager Properties * - ``resource-groups.exact-match-selector-enabled`` - Setting this flag enables usage of an additional ``exact_match_source_selectors`` table to configure resource group - selection rules defined exact name based matches for source, environment + selection rules which define exact name based matches for source, environment and query type. By default, the rules are only loaded from the ``selectors`` table, with a regex-based filter for ``source``, among other filters. @@ -199,7 +199,7 @@ Here are the key properties that can be set for a Resource Group: sub-group is computed based on the weights for all currently eligible sub-groups. The sub-group with the least concurrency relative to its share is selected to start the next query. - ``weighted``: queued queries are selected stochastically in proportion to their priority, - specified via the ``query_priority`` {doc} ``session property ``. Sub groups are selected + specified by the :ref:`admin/properties-session:\`\`query_priority\`\``. Sub groups are selected to start new queries in proportion to their ``schedulingWeight``. - ``query_priority``: all sub-groups must also be configured with ``query_priority``. Queued queries are selected strictly according to their priority. @@ -277,6 +277,8 @@ Here are the key components of selector rules in PrestoDB: - ``ANALYZE``: ``ANALYZE`` queries. - ``DATA_DEFINITION``: Queries that alter/create/drop the metadata of schemas/tables/views, and that manage prepared statements, privileges, sessions, and transactions. + - ``CONTROL``: Transaction control queries like ``COMMIT``, ``ROLLBACK`` and session control queries like + ``USE``, ``SET SESSION``. * ``clientTags`` (optional): List of tags. To match, every tag in this list must be in the list of client-provided tags associated with the query. @@ -352,13 +354,13 @@ There are four selectors, that define which queries run in which resource group: dynamically-created per-user pipeline group under the ``global.pipeline`` group. * The fourth selector matches queries that come from BI tools which have a source matching the regular - expression ``jdbc#(?.*)``, and have client provided tags that are a superset of ``hi-pri``. - These are placed in a dynamically-created sub-group under the ``global.pipeline.tools`` group. The dynamic - sub-group is created based on the named variable ``toolname``, which is extracted from the - regular expression for source. + expression ``jdbc#(?.*)``, and have client provided tags that are a superset of ``hipri``. + These are placed in a dynamically created sub-group under the ``global.adhoc`` group. The dynamic sub-groups + are created based on the values of named variables ``toolname`` and ``USER``. The values are derived from the + source regular expression and the query user respectively. Consider a query with a source ``jdbc#powerfulbi``, user ``kayla``, and - client tags ``hipri`` and ``fast``. This query is routed to the ``global.pipeline.bi-powerfulbi.kayla`` + client tags ``hipri`` and ``fast``. This query is routed to the ``global.adhoc.bi-powerfulbi.kayla`` resource group. * The last selector is a catch-all, which places all queries that have not yet been matched into a per-user @@ -373,9 +375,9 @@ For the remaining users: * No more than 100 total queries may run concurrently. -* Up to 5 concurrent DDL queries with a source ``pipeline`` can run. Queries are run in FIFO order. +* Up to 5 concurrent DDL queries with a source that includes ``pipeline`` in its name can run. Queries are run in FIFO order. -* Non-DDL queries will run under the ``global.pipeline`` group, with a total concurrency of 45, and a per-user +* Non-DDL queries with a source that includes ``pipeline`` in its name run under the ``global.pipeline`` group, with a total concurrency of 45, and a per-user concurrency of 5. Queries are run in FIFO order. * For BI tools, each tool can run up to 10 concurrent queries, and each user can run up to 3. If the total demand diff --git a/presto-docs/src/main/sphinx/admin/session-property-managers.rst b/presto-docs/src/main/sphinx/admin/session-property-managers.rst index 23b494f6b3852..4035e8dbb474f 100644 --- a/presto-docs/src/main/sphinx/admin/session-property-managers.rst +++ b/presto-docs/src/main/sphinx/admin/session-property-managers.rst @@ -5,10 +5,12 @@ Session Property Managers Administrators can add session properties to control the behavior for subsets of their workload. These properties are defaults and can be overridden by users (if authorized to do so). Session properties can be used to control resource usage, enable or disable features, and change query -characteristics. Session property managers are pluggable. +characteristics. Session property managers are pluggable. A session property manager can either +be database-based or file-based. For production environments, the database-based manager is +recommended as the properties can be updated without requiring a cluster restart. -Add an ``etc/session-property-config.properties`` file with the following contents to enable -the built-in manager that reads a JSON config file: +To enable a built-in manager that reads a JSON configuration file, add an +``etc/session-property-config.properties`` file with the following contents: .. code-block:: none @@ -24,6 +26,66 @@ by default. All matching rules contribute to constructing a list of session prop are applied in the order they are specified. Rules specified later in the file override values for properties that have been previously encountered. + +For the database-based built-in manager, add an +``etc/session-property-config.properties`` file with the following contents: + +.. code-block:: text + + session-property-config.configuration-manager=db + session-property-manager.db.url=jdbc:mysql://localhost:3306/session_properties?user=user&password=pass&createDatabaseIfNotExist=true + session-property-manager.db.refresh-period=50s + +Change the value of ``session-property-manager.db.url`` to the JDBC URL of a database. + +``session-property-manager.db.refresh-period`` should be set to how often Presto refreshes +to fetch the latest session properties from the database. + +This database consists of three tables: ``session_specs``, ``session_client_tags`` and ``session_property_values``. +Presto will create the database on startup if you set ``createDatabaseIfNotExist`` to ``true`` in your JDBC URL. +If the tables do not exist, Presto will create them on startup. + +.. code-block:: text + + mysql> DESCRIBE session_specs; + +-----------------------------+--------------+------+-----+---------+----------------+ + | Field | Type | Null | Key | Default | Extra | + +-----------------------------+--------------+------+-----+---------+----------------+ + | spec_id | bigint | NO | PRI | NULL | auto_increment | + | user_regex | varchar(512) | YES | | NULL | | + | source_regex | varchar(512) | YES | | NULL | | + | query_type | varchar(512) | YES | | NULL | | + | group_regex | varchar(512) | YES | | NULL | | + | client_info_regex | varchar(512) | YES | | NULL | | + | override_session_properties | tinyint(1) | YES | | NULL | | + | priority | int | NO | | NULL | | + +-----------------------------+--------------+------+-----+---------+----------------+ + 8 rows in set (0.016 sec) + +.. code-block:: text + + mysql> DESCRIBE session_client_tags; + +-------------+--------------+------+-----+---------+-------+ + | Field | Type | Null | Key | Default | Extra | + +-------------+--------------+------+-----+---------+-------+ + | tag_spec_id | bigint | NO | PRI | NULL | | + | client_tag | varchar(512) | NO | PRI | NULL | | + +-------------+--------------+------+-----+---------+-------+ + 2 rows in set (0.062 sec) + +.. code-block:: text + + mysql> DESCRIBE session_property_values; + +--------------------------+--------------+------+-----+---------+-------+ + | Field | Type | Null | Key | Default | Extra | + +--------------------------+--------------+------+-----+---------+-------+ + | property_spec_id | bigint | NO | PRI | NULL | | + | session_property_name | varchar(512) | NO | PRI | NULL | | + | session_property_value | varchar(512) | YES | | NULL | | + | session_property_catalog | varchar(512) | YES | | NULL | | + +--------------------------+--------------+------+-----+---------+-------+ + 3 rows in set (0.009 sec) + Match Rules ----------- @@ -52,9 +114,15 @@ Match Rules Note that once a session property has been overridden by ANY rule it remains overridden even if later higher precedence rules change the value, but don't specify override. -* ``sessionProperties``: map with string keys and values. Each entry is a system or catalog property name and +* ``sessionProperties``: map with string keys and values. Each entry is a system property name and corresponding value. Values must be specified as strings, no matter the actual data type. +* ``catalogSessionProperties``: map with string keys corresponding to the catalog name, and a map with string keys + and values as the value. Each entry is a catalog name and corresponding map of session property values. + +* For the database session property manager, catalog & system session properties are located in the same table. + ``session_property_catalog`` should be null for system session properties. + Example ------- @@ -71,6 +139,8 @@ Consider the following set of requirements: * All high memory ETL queries (tagged with 'high_mem_etl') are routed to subgroups under the ``global.pipeline`` group, and must be configured to enable :doc:`/admin/exchange-materialization`. +* All iceberg catalog queries should override the ``delete-as-join-rewrite-enabled`` property + These requirements can be expressed with the following rules: .. code-block:: json @@ -79,7 +149,7 @@ These requirements can be expressed with the following rules: { "group": "global.*", "sessionProperties": { - "query_max_execution_time": "8h", + "query_max_execution_time": "8h" } }, { @@ -104,5 +174,12 @@ These requirements can be expressed with the following rules: "partitioning_provider_catalog": "hive", "hash_partition_count": 4096 } + }, + { + "catalogSessionProperties": { + "iceberg": { + "delete_as_join_rewrite_enabled": "true" + } + } } ] diff --git a/presto-docs/src/main/sphinx/admin/spill.rst b/presto-docs/src/main/sphinx/admin/spill.rst index 3dd6f1dec22de..095da4fd72abe 100644 --- a/presto-docs/src/main/sphinx/admin/spill.rst +++ b/presto-docs/src/main/sphinx/admin/spill.rst @@ -30,7 +30,7 @@ of memory to queries and prevents deadlock caused by memory allocation. It is efficient when there are a lot of small queries in the cluster, but leads to killing large queries that don't stay within the limits. -To overcome this inefficiency, the concept of revocable memory was introduced. A +To overcome this limitation, the concept of revocable memory was introduced. A query can request memory that does not count toward the limits, but this memory can be revoked by the memory manager at any time. When memory is revoked, the query runner spills intermediate data from memory to disk and continues to @@ -107,10 +107,7 @@ When spill encryption is enabled (``spill-encryption-enabled`` property in (per spill file) secret key. Enabling this will decrease the performance of spilling to disk but can protect spilled data from being recovered from the files written to disk. -**Note**: Some distributions of Java ship with policy files that limit the strength -of the cryptographic keys that can be used. Spill encryption uses -256-bit AES keys and may require Unlimited Strength :abbr:`JCE (Java Cryptography Extension)` -policy files to work correctly. +.. _spill-operations: Supported Operations -------------------- diff --git a/presto-docs/src/main/sphinx/admin/version-support.rst b/presto-docs/src/main/sphinx/admin/version-support.rst new file mode 100644 index 0000000000000..883efecd129de --- /dev/null +++ b/presto-docs/src/main/sphinx/admin/version-support.rst @@ -0,0 +1,230 @@ +=============== +Version Support +=============== + +Overview +-------- + +Presto is maintained by volunteers. This document describes which versions receive support and what level of support to expect. + +Support Philosophy +------------------ + +* Data correctness issues are taken extremely seriously and typically fixed quickly +* Runtime bugs and security vulnerabilities are prioritized and addressed promptly +* Support depends on volunteer availability - no formal SLAs +* Users are encouraged to contribute fixes for issues affecting them + +.. _current-version-support: + +Current Version Support +----------------------- + +**Latest Release** + * Primary focus for bug fixes + * Recommended for new deployments after testing + +**Past 4 Releases (N-1 through N-4)** + * Critical fixes only when: + + - Data correctness issues are found + - Volunteers are available to backport + + * Patch releases for severe issues only + * Support decreases with age + +**Older Releases (N-5 and earlier)** + * Not supported + * Exceptions only when: + + - Volunteer provides the backport + - Fix applies cleanly + - Testing is available + + * Upgrade required + +**Trunk/Master Branch** + * Development branch + * **Never use in production** + * Contains experimental features and bugs + * For testing upcoming changes only + +**Edge Releases** + * Weekly builds from master + * **Never use in production** + * **Not supported** - no fixes provided + * For testing upcoming features + +Support Lifecycle +----------------- + +Timeframes are approximate and depend on volunteer availability. + +A typical release follows this lifecycle: + +1. **Release Candidates** (2-4 weeks) + + - One RC version per release + - Active bug fixing with fixes verified in the existing RC + - High community engagement + +2. **Current Release** (approximately 2 months) + + - Primary focus for bug fixes + - Active monitoring for issues + - Most community attention + +3. **Supported Releases** (N-1 through N-4, approximately 8 months) + + - Critical fixes only + - Progressively reduced community focus + - Patch releases for severe issues become less likely with age + +4. **Archived** (N-5 and older) + + - No active support + - Users strongly encouraged to upgrade + - See :ref:`current-version-support` for details + +Types of Support +---------------- + +**Bug Fixes** + Highest priority (typically fixed very quickly): + + * Data correctness issues - taken extremely seriously + + High priority: + + * Runtime bugs and crashes + * Severe performance regressions + + Lower priority: + + * Minor performance issues + * UI/cosmetic problems + * Feature enhancements + +**Security Vulnerabilities** + * Upgrade to latest release (default recommendation) + * Patches for N-1 through N-4 available upon request + * Backport availability depends on volunteers and severity + * Plan to upgrade rather than rely on backports + +**Documentation** + * Release notes and full documentation for all versions remain available + * Migration guides for major changes + * Community-contributed upgrade experiences + +Getting Support +--------------- + +**Community Channels** + +* `Presto Slack `_ - Real-time community discussion +* `GitHub Issues `_ - Bug reports and feature requests +* `Mailing List `_ - Development discussions + +**Self-Support Resources** + +* Release notes and documentation +* Community Slack search history +* GitHub issues and pull requests +* Stack Overflow questions tagged 'presto' + +Recommendations for Production Use +---------------------------------- + +**Version Selection** + +1. **For new deployments**: Use the latest stable release after thorough testing +2. **For existing deployments**: Stay within 4 versions of the latest release +3. **For conservative environments**: Wait for at least one patch release (if any) before upgrading +4. **Never use trunk/master or edge** in production + +**Upgrade Strategy** + +* Plan regular upgrades (every 2-4 months) +* Test thoroughly in staging environments +* Monitor community channels for known issues +* Maintain ability to rollback if needed +* Consider skipping releases if stable (but don't fall too far behind) + +**Risk Mitigation** + +* Maintain test environments matching production +* Participate in release candidate testing +* Monitor community discussions for your version +* Contribute test cases for critical workflows + +Contributing to Support +----------------------- + +Ways to contribute: + +**Report Issues** + * File detailed bug reports with reproduction steps on `GitHub Issues `_ + * Test fixes and provide feedback + * Share workarounds with the community + +**Contribute Fixes** + * Submit `pull requests `_ for bugs affecting you + * Help review and test others' fixes + * Backport critical fixes to versions you use + +**Share Knowledge** + * Document upgrade experiences + * Answer questions in `Presto Slack `_ + * Write blog posts about solutions + * Contribute to `documentation `_ + +**Sponsor Development** + * Allocate engineering resources to the project + * Fund specific feature development + * Support maintainers and release shepherds + +Special Considerations +---------------------- + +**Long-Term Support (LTS)** + * Not available + * Volunteer model incompatible with LTS commitments + +**End-of-Life Announcements** + * No formal EOL process + * Versions become unsupported as community moves forward + * Check release announcements for migration guidance + +**Compatibility** + * Breaking changes documented in release notes + * Migration guides provided for major changes + * Test when upgrading across multiple versions + +Support Expectations +-------------------- + +**Available:** + +* Typically quick response to data correctness and runtime bugs +* Priority focus on critical issues +* Active community troubleshooting help +* Transparency about known issues +* Documentation for old versions + +**Not Available:** + +* Guaranteed response times +* Fixes for all issues +* Support for old versions +* Feature backports +* 24/7 support + +Summary +------- + +Running Presto in production requires: + +* Regular upgrades (every 2-4 months) +* Thorough testing before deploying +* Understanding that support is volunteer-based +* Contributing fixes for issues you encounter \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/cache/service.rst b/presto-docs/src/main/sphinx/cache/service.rst index f0ce914a123f4..4b11cee30e703 100644 --- a/presto-docs/src/main/sphinx/cache/service.rst +++ b/presto-docs/src/main/sphinx/cache/service.rst @@ -78,7 +78,7 @@ to learn how to configure Alluxio file system in Presto. Here is a simple exampl Start a Prest CLI connecting to the server started in the previous step. -Download :maven_download:`cli`, rename it to ``presto``, +Download :github_download:`cli`, rename it to ``presto``, make it executable with ``chmod +x``, then run it: .. code-block:: none diff --git a/presto-docs/src/main/sphinx/clients/presto-cli.rst b/presto-docs/src/main/sphinx/clients/presto-cli.rst index 8aad3dda889c3..ffeb9733d22d8 100644 --- a/presto-docs/src/main/sphinx/clients/presto-cli.rst +++ b/presto-docs/src/main/sphinx/clients/presto-cli.rst @@ -18,7 +18,7 @@ over HTTP using :doc:`Presto Client REST API `. Installation ============ -Download :maven_download:`cli`. +Download :github_download:`cli`. Rename the JAR file to ``presto`` with the following command (replace ``*`` with the version number of the downloaded jar file): diff --git a/presto-docs/src/main/sphinx/conf.py b/presto-docs/src/main/sphinx/conf.py index 4f1e7342c5a16..0ba35583db294 100644 --- a/presto-docs/src/main/sphinx/conf.py +++ b/presto-docs/src/main/sphinx/conf.py @@ -64,7 +64,7 @@ def get_version(): needs_sphinx = '8.2.1' extensions = [ - 'sphinx_immaterial', 'sphinx_copybutton', 'download', 'issue', 'pr', 'sphinx.ext.autosectionlabel' + 'sphinx_immaterial', 'download', 'issue', 'pr', 'sphinx.ext.autosectionlabel' ] copyright = 'The Presto Foundation. All rights reserved. Presto is a registered trademark of LF Projects, LLC' @@ -106,13 +106,8 @@ def get_version(): html_logo = 'images/logo.png' html_favicon = 'images/favicon.ico' -# doesn't seem to do anything -# html_baseurl = 'overview.html' - html_static_path = ['.'] -templates_path = ['_templates'] - # Set the primary domain to js because if left as the default python # the theme errors when functions aren't available in a python module primary_domain = 'js' @@ -135,6 +130,7 @@ def get_version(): 'features': [ 'toc.follow', 'toc.sticky', + 'content.code.copy', ], 'palette': [ { diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index a45ab8f9bafa4..d337fe4ed12d1 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -14,7 +14,6 @@ from different data sources. connector/blackhole connector/cassandra connector/clickhouse - connector/clp connector/deltalake connector/druid connector/elasticsearch diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst index ba5d4bce9ebfd..a9e0cd73a53c7 100644 --- a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -52,11 +52,44 @@ Property Name Description ========================================== ============================================================== ``arrow-flight.server`` Endpoint of the Flight server ``arrow-flight.server.port`` Flight server port -``arrow-flight.server-ssl-certificate`` Pass ssl certificate +``arrow-flight.server-ssl-certificate`` Path to SSL certificate of Flight server +``arrow-flight.client-ssl-certificate`` Path to SSL certificate that Flight clients will use for mTLS authentication with the Flight server +``arrow-flight.client-ssl-key`` Path to SSL key that Flight clients will use for mTLS authentication with the Flight server ``arrow-flight.server.verify`` To verify server ``arrow-flight.server-ssl-enabled`` Port is ssl enabled +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema, table, and column names for the connector. When disabled, names are matched case-insensitively using lowercase normalization. Defaults to ``false``. ========================================== ============================================================== +Mutual TLS (mTLS) Support +------------------------- + + +To connect the Presto client to an Arrow Flight server with mutual TLS (mTLS) enabled, you must configure the client to present a valid certificate and key that the server can validate. This enhances security by ensuring both the client and server authenticate each other. + +To enable mTLS, the following properties must be configured: + +- ``arrow-flight.server-ssl-enabled=true``: Explicitly enables TLS for the connection. +- ``arrow-flight.server-ssl-certificate``: Path to the server's SSL certificate. +- ``arrow-flight.client-ssl-certificate``: Path to the client's SSL certificate. +- ``arrow-flight.client-ssl-key``: Path to the client's SSL private key. + +These properties must be used alongside the existing SSL configurations for the server, such as ``arrow-flight.server-ssl-certificate`` and ``arrow-flight.server-ssl-enabled=true``. Make sure the server is configured to trust the client certificates (typically via a shared CA). + +Below is an example code snippet to configure the Arrow Flight server with mTLS: + +.. code-block:: java + + File certChainFile = new File("src/test/resources/certs/server.crt"); + File privateKeyFile = new File("src/test/resources/certs/server.key"); + File caCertFile = new File("src/test/resources/certs/ca.crt"); + + server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator)) + .useTls(certChainFile, privateKeyFile) + .useMTlsClientVerification(caCertFile) + .build(); + + server.start(); + Querying Arrow-Flight --------------------- diff --git a/presto-docs/src/main/sphinx/connector/bigquery.rst b/presto-docs/src/main/sphinx/connector/bigquery.rst index d71c600b758c8..4e90e92555da0 100644 --- a/presto-docs/src/main/sphinx/connector/bigquery.rst +++ b/presto-docs/src/main/sphinx/connector/bigquery.rst @@ -137,6 +137,9 @@ Property Description ``bigquery.max-read-rows-retries`` The number of retries in case of retryable server issues ``3`` ``bigquery.credentials-key`` credentials key (base64 encoded) None. See `authentication <#authentication>`_ ``bigquery.credentials-file`` JSON credentials file path None. See `authentication <#authentication>`_ +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ========================================= ============================================================== ============================================== Data Types diff --git a/presto-docs/src/main/sphinx/connector/cassandra.rst b/presto-docs/src/main/sphinx/connector/cassandra.rst index 44c6fc96a5bed..7be7e244d2cdf 100644 --- a/presto-docs/src/main/sphinx/connector/cassandra.rst +++ b/presto-docs/src/main/sphinx/connector/cassandra.rst @@ -74,6 +74,10 @@ Property Name Description ``cassandra.protocol-version`` It is possible to override the protocol version for older Cassandra clusters. This property defaults to ``V3``. Possible values include ``V2``, ``V3`` and ``V4``. + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema, table, and column names for the connector. + When disabled, names are matched case-insensitively using lowercase normalization. + Defaults to ``false``. ================================================== ====================================================================== .. note:: @@ -220,6 +224,7 @@ SET VARCHAR TEXT VARCHAR TIMESTAMP TIMESTAMP TIMEUUID VARCHAR +TUPLE VARCHAR VARCHAR VARCHAR VARINT VARCHAR SMALLINT INTEGER @@ -230,7 +235,7 @@ DATE DATE Any collection (LIST/MAP/SET) can be designated as FROZEN, and the value is mapped to VARCHAR. Additionally, blobs have the limitation that they cannot be empty. -Types not mentioned in the table above are not supported (e.g. tuple or UDT). +Data types not listed in the table above, such as UDT, are not supported. Partition keys can only be of the following types: diff --git a/presto-docs/src/main/sphinx/connector/clickhouse.rst b/presto-docs/src/main/sphinx/connector/clickhouse.rst index f8510cc8424a4..f2f2f41e4d7d8 100644 --- a/presto-docs/src/main/sphinx/connector/clickhouse.rst +++ b/presto-docs/src/main/sphinx/connector/clickhouse.rst @@ -58,6 +58,10 @@ Property Name Default Value Description ``clickhouse.allow-drop-table`` false Allow delete table operation. +``case-sensitive-name-matching`` false Enable case sensitive identifier support for schema, table, and column names for the connector. + When disabled, names are matched case-insensitively using lowercase normalization. + Defaults to ``false``. + ========================================= ================ ============================================================================================================== diff --git a/presto-docs/src/main/sphinx/connector/deltalake.rst b/presto-docs/src/main/sphinx/connector/deltalake.rst index 96c7899e56972..05393e6102d74 100644 --- a/presto-docs/src/main/sphinx/connector/deltalake.rst +++ b/presto-docs/src/main/sphinx/connector/deltalake.rst @@ -55,8 +55,8 @@ Property Name Description =============================================== ========================================================= ============ Delta Lake connector reuses many of the modules existing in Hive connector. -Modules for connectivity and security such as S3, Azure Data Lake, Glue metastore etc. -So the configurations for these modules is same those available in Hive connector documentation. +Modules for connectivity and security such as S3, Azure Data Lake, and Glue metastore. +Configuration options for these modules are identical to those described in the :doc:`/connector/hive`. Querying Delta Lake Tables -------------------------- @@ -66,21 +66,21 @@ Example query SELECT * FROM sales.apac.sales_data LIMIT 200; -In the above query +In the above query, * ``sales`` refers to the Delta Lake catalog. * ``apac`` refers to the database in Hive metastore. * ``sales_data`` refers to the Delta Lake table registered in the ``apac`` database. -If the table is not registered in Hive metastore, it can be registered using the following DDL +If the table is not registered in the Hive metastore, it can be registered using the following DDL command. .. note:: - To register a table in Hive metastore, full schema of the table is not required in DDL + To register a table in Hive metastore, the full schema of the table is not required in DDL as the Delta Lake connector gets the schema from the metadata located at the Delta Lake - table location. To get around no columns error in Hive metastore, provide a dummy column - as schema of the Delta table being registered. + table location. To avoid a ``no columns`` error in Hive metastore, provide a dummy column + as the schema of the Delta table being registered. Examples -------- @@ -105,7 +105,7 @@ Another option is querying the table directly using the table location as table In the above query the schema ``$path$`` indicates the table name is a path. Table name given as `s3://db-sa-datasets/presto/sales_date` is a path where the -Delta Lake table is located. The path based option allows users to query a +Delta Lake table is located. The path-based option allows users to query a Delta table without registering it in the Hive metastore. To query a specific snapshot of the Delta Lake table use the snapshot identifier @@ -133,3 +133,47 @@ in the table ``sales.apac.sales_data``. Above query drops the external table ``sales.apac.sales_data_new``. This only drops the metadata for the table. The referenced data directory is not deleted. + +Delta Lake to PrestoDB type mapping +----------------------------------- + +Map of Delta Lake types to the relevant PrestoDB types: + +.. list-table:: Delta Lake to PrestoDB type mapping + :widths: 50, 50 + :header-rows: 1 + + * - Delta Lake type + - PrestoDB type + * - ``BOOLEAN`` + - ``BOOLEAN`` + * - ``SMALLINT`` + - ``SMALLINT`` + * - ``TINYINT`` + - ``TINYINT`` + * - ``INT`` + - ``INTEGER`` + * - ``LONG`` + - ``BIGINT`` + * - ``FLOAT`` + - ``REAL`` + * - ``DOUBLE`` + - ``DOUBLE`` + * - ``DECIMAL`` + - ``DECIMAL`` + * - ``STRING`` + - ``VARCHAR`` + * - ``BINARY`` + - ``VARBINARY`` + * - ``DATE`` + - ``DATE`` + * - ``TIMESTAMP_NTZ`` + - ``TIMESTAMP`` + * - ``TIMESTAMP`` + - ``TIMESTAMP WITH TIME ZONE`` + * - ``ARRAY`` + - ``ARRAY`` + * - ``MAP`` + - ``MAP`` + * - ``STRUCT`` + - ``ROW`` diff --git a/presto-docs/src/main/sphinx/connector/druid.rst b/presto-docs/src/main/sphinx/connector/druid.rst index 4954ee124471b..79d527e1edf4d 100644 --- a/presto-docs/src/main/sphinx/connector/druid.rst +++ b/presto-docs/src/main/sphinx/connector/druid.rst @@ -32,17 +32,21 @@ Configuration Properties The following configuration properties are available: -=================================================== ============================================================ -Property Name Description -=================================================== ============================================================ -``druid.coordinator-url`` Druid coordinator url. -``druid.broker-url`` Druid broker url. -``druid.schema-name`` Druid schema name. -``druid.compute-pushdown-enabled`` Whether to pushdown all query processing to Druid. -``druid.case-insensitive-name-matching`` Match dataset and table names case-insensitively. -``druid.case-insensitive-name-matching.cache-ttl`` Duration for which remote dataset and table names will be - cached. Set to ``0ms`` to disable the cache -=================================================== ============================================================ +================================== ============================================================ +Property Name Description +================================== ============================================================ +``druid.coordinator-url`` Druid coordinator url. +``druid.broker-url`` Druid broker url. +``druid.schema-name`` Druid schema name. +``druid.compute-pushdown-enabled`` Whether to pushdown all query processing to Druid. +``case-sensitive-name-matching`` Enable case-sensitive identifier support for schema, + table, and column names for the connector. When disabled, + names are matched case-insensitively using lowercase + normalization. Default is ``false``. +``druid.tls.enabled`` Enable TLS when connecting to Druid. +``druid.tls.truststore-path`` Path to the trust certificate file. +``druid.tls.truststore-password`` Password for the trust certificate file. +================================== ============================================================ ``druid.coordinator-url`` ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -68,19 +72,22 @@ Whether to pushdown all query processing to Druid. the default is ``false``. -``druid.case-insensitive-name-matching`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``druid.tls.enabled`` +^^^^^^^^^^^^^^^^^^^^^ -Match dataset and table names case-insensitively. +Enable TLS when connecting to Druid. The default is ``false``. -``druid.case-insensitive-name-matching.cache-ttl`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``druid.tls.truststore-path`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Path to the trust certificate file. -Duration for which remote dataset and table names will be cached. Set to ``0ms`` to disable the cache. +``druid.tls.truststore-password`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The default is ``1m``. +Password for the trust certificate file. Data Types ---------- diff --git a/presto-docs/src/main/sphinx/connector/elasticsearch.rst b/presto-docs/src/main/sphinx/connector/elasticsearch.rst index ca427b6fd0797..345079df29d9f 100644 --- a/presto-docs/src/main/sphinx/connector/elasticsearch.rst +++ b/presto-docs/src/main/sphinx/connector/elasticsearch.rst @@ -52,6 +52,9 @@ Property Name Description ``elasticsearch.max-http-connections`` Maximum number of persistent HTTP connections to Elasticsearch. ``elasticsearch.http-thread-count`` Number of threads handling HTTP connections to Elasticsearch. ``elasticsearch.ignore-publish-address`` Whether to ignore the published address and use the configured address. +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and column names for the connector. + When disabled, names are matched case-insensitively using lowercase normalization. + Default is ``false``. ============================================= ============================================================================== ``elasticsearch.host`` diff --git a/presto-docs/src/main/sphinx/connector/hana.rst b/presto-docs/src/main/sphinx/connector/hana.rst index 7f18e3e156a26..7c53be61a2c45 100644 --- a/presto-docs/src/main/sphinx/connector/hana.rst +++ b/presto-docs/src/main/sphinx/connector/hana.rst @@ -91,6 +91,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying HANA diff --git a/presto-docs/src/main/sphinx/connector/hive-security.rst b/presto-docs/src/main/sphinx/connector/hive-security.rst index cbbcf0b4694aa..e7e43b3ea2b14 100644 --- a/presto-docs/src/main/sphinx/connector/hive-security.rst +++ b/presto-docs/src/main/sphinx/connector/hive-security.rst @@ -349,12 +349,6 @@ Keytab files must be distributed to every node in the cluster that runs Presto. :ref:`Additional Information About Keytab Files.` -Impersonation Accessing the Hive Metastore ------------------------------------------- - -Presto does not currently support impersonating the end user when accessing the -Hive metastore. - .. _configuring-hadoop-impersonation: Impersonation in Hadoop @@ -450,6 +444,18 @@ limitations and differences: * ``SET ROLE ALL`` enables all of a user's roles except ``admin``. * The ``admin`` role must be enabled explicitly by executing ``SET ROLE admin``. +Configuration properties +^^^^^^^^^^^^^^^^^^^^^^^^ + +================================================== ================================================================ ============ +Property Name Description Default +================================================== ================================================================ ============ +``hive.restrict-procedure-call`` A configuration property that controls whether procedure true + calls are restricted. It defaults to ``true``, meaning procedure + calls are not allowed. Set it to ``false`` to allow procedure + calls. +================================================== ================================================================ ============ + .. _hive-file-based-authorization: File Based Authorization @@ -496,6 +502,19 @@ These rules govern who may set session properties. * ``allowed`` (required): boolean indicating whether this session property may be set. +Procedure Rules +^^^^^^^^^^^^^^^ + +These rules govern the privileges granted on specific procedures. + +* ``user`` (optional): regex to match against user name. + +* ``schema`` (optional): regex to match against schema name. + +* ``procedure`` (optional): regex to match against procedure name. + +* ``privileges`` (required): a list that is empty or contains ``EXECUTE``. + See below for an example. .. code-block:: json @@ -541,6 +560,24 @@ See below for an example. "property": "max_split_size", "allow": true } + ], + "procedures": [ + { + "user": "admin", + "schema": ".*", + "privileges": ["EXECUTE"] + }, + { + "user": "alice", + "schema": "alice_schema", + "privileges": ["EXECUTE"] + }, + { + "user": "guest", + "schema": "alice_schema", + "procedure": "test_procedure", + "privileges": ["EXECUTE"] + } ] } @@ -608,13 +645,13 @@ properties are enabled by default. Configuration properties ^^^^^^^^^^^^^^^^^^^^^^^^ -================================================== ============================================================ ============ -Property Name Description Default -================================================== ============================================================ ============ +================================================== ================================================================ ============ +Property Name Description Default +================================================== ================================================================ ============ ``hive.ranger.rest-endpoint`` URL address of the Ranger REST service. Kerberos authentication is not supported yet. -``hive.ranger.refresh-policy-period`` Interval at which cached policies are refreshed 60s +``hive.ranger.refresh-policy-period`` Interval at which cached policies are refreshed 60s ``hive.ranger.policy.hive-servicename`` Ranger Hive plugin service name @@ -634,4 +671,8 @@ Property Name Description ``ranger.http-client.trust-store-password`` Ranger SSL configuration - client trust-store password -================================================== ============================================================ ============ +``hive.restrict-procedure-call`` A configuration property that controls whether procedure true + calls are restricted. It defaults to ``true``, meaning procedure + calls are not allowed. Set it to ``false`` to allow procedure + calls. +================================================== ================================================================ ============ diff --git a/presto-docs/src/main/sphinx/connector/hive.rst b/presto-docs/src/main/sphinx/connector/hive.rst index 113c5520dd12c..510557d3b38d2 100644 --- a/presto-docs/src/main/sphinx/connector/hive.rst +++ b/presto-docs/src/main/sphinx/connector/hive.rst @@ -37,6 +37,37 @@ The following file types are supported for the Hive connector: * JSON * Text + +Hive Metastore +-------------- + +The Hive Metastore is a central metadata repository that the Hive connector uses to access table definitions, partition information, +and other structural details about your Hive tables. + +The Hive Metastore: + +* Stores metadata about tables, columns, partitions, and storage locations +* Enables schema-on-read functionality +* Supports multiple metastore backends (Apache Hive Metastore Service, AWS Glue) + +See the `Metastore `_ design documentation for more details. + + +Additional Resources for Metastore Configuration +------------------------------------------------ + +* `Metastore Configuration Properties`_ +* `How to invalidate metastore cache?`_ +* :ref:`installation/deployment:File-Based Metastore` +* :doc:`/connector/hive-security` +* `AWS Glue Catalog Configuration Properties`_ + +Additional authentication-related configuration properties are covered in +:ref:`connector/hive-security:Hive Metastore Thrift Service Authentication` and +:ref:`connector/hive-security:HDFS Authentication`. + + + Configuration ------------- @@ -103,10 +134,8 @@ Accessing Hadoop clusters protected with Kerberos authentication Kerberos authentication is supported for both HDFS and the Hive metastore. However, Kerberos authentication by ticket cache is not yet supported. -The properties that apply to Hive connector security are listed in the -`Hive Configuration Properties`_ table. Please see the -:doc:`/connector/hive-security` section for a more detailed discussion of the -security options in the Hive connector. +For authentication-related configuration of the Hive Metastore Thrift service and HDFS, +see :doc:`/connector/hive-security`. File-Based Metastore ^^^^^^^^^^^^^^^^^^^^ @@ -117,98 +146,138 @@ filesystem directory as a Hive Metastore. See :ref:`installation/deployment:File Hive Configuration Properties ----------------------------- -================================================== ============================================================ ============ -Property Name Description Default -================================================== ============================================================ ============ -``hive.metastore.uri`` The URI(s) of the Hive metastore to connect to using the - Thrift protocol. If multiple URIs are provided, the first - URI is used by default and the rest of the URIs are - fallback metastores. This property is required. - Example: ``thrift://192.0.2.3:9083`` or - ``thrift://192.0.2.3:9083,thrift://192.0.2.4:9083`` +======================================================== ============================================================ ============ +Property Name Description Default +======================================================== ============================================================ ============ +``hive.metastore.uri`` The URI(s) of the Hive metastore to connect to using the + Thrift protocol. If multiple URIs are provided, the first + URI is used by default and the rest of the URIs are + fallback metastores. This property is required. + Example: ``thrift://192.0.2.3:9083`` or + ``thrift://192.0.2.3:9083,thrift://192.0.2.4:9083`` + +``hive.metastore.username`` The username Presto will use to access the Hive metastore. + +``hive.config.resources`` An optional comma-separated list of HDFS + configuration files. These files must exist on the + machines running Presto. Only specify this if + absolutely necessary to access HDFS. + Example: ``/etc/hdfs-site.xml`` + +``hive.storage-format`` The default file format used when creating new tables. ``ORC`` + +``hive.compression-codec`` The compression codec to use when writing files. ``GZIP`` + +``hive.force-local-scheduling`` Force splits to be scheduled on the same node as the Hadoop ``false`` + DataNode process serving the split data. This is useful for + installations where Presto is collocated with every + DataNode. + +``hive.order-based-execution-enabled`` Enable order-based execution. When enabled, Hive files ``false`` + become non-splittable and the table ordering properties + would be exposed to plan optimizer. + +``hive.respect-table-format`` Should new partitions be written using the existing table ``true`` + format or the default Presto format? -``hive.metastore.username`` The username Presto will use to access the Hive metastore. +``hive.immutable-partitions`` Can new data be inserted into existing partitions? ``false`` -``hive.config.resources`` An optional comma-separated list of HDFS - configuration files. These files must exist on the - machines running Presto. Only specify this if - absolutely necessary to access HDFS. - Example: ``/etc/hdfs-site.xml`` +``hive.create-empty-bucket-files`` Should empty files be created for buckets that have no data? ``true`` -``hive.storage-format`` The default file format used when creating new tables. ``ORC`` +``hive.max-partitions-per-writers`` Maximum number of partitions per writer. 100 -``hive.compression-codec`` The compression codec to use when writing files. ``GZIP`` +``hive.max-partitions-per-scan`` Maximum number of partitions for a single table scan. 100,000 -``hive.force-local-scheduling`` Force splits to be scheduled on the same node as the Hadoop ``false`` - DataNode process serving the split data. This is useful for - installations where Presto is collocated with every - DataNode. +``hive.dynamic-split-sizes-enabled`` Enable dynamic sizing of splits based on data scanned by ``false`` + the query. -``hive.order-based-execution-enabled`` Enable order-based execution. When it's enabled, hive files ``false`` - become non-splittable and the table ordering properties - would be exposed to plan optimizer +``hive.non-managed-table-writes-enabled`` Enable writes to non-managed (external) Hive tables. ``false`` -``hive.respect-table-format`` Should new partitions be written using the existing table ``true`` - format or the default Presto format? +``hive.non-managed-table-creates-enabled`` Enable creating non-managed (external) Hive tables. ``true`` -``hive.immutable-partitions`` Can new data be inserted into existing partitions? ``false`` +``hive.collect-column-statistics-on-write`` Enables automatic column level statistics collection ``false`` + on write. See `Table Statistics <#table-statistics>`__ for + details. -``hive.create-empty-bucket-files`` Should empty files be created for buckets that have no data? ``true`` +``hive.s3select-pushdown.enabled`` Enable query pushdown to AWS S3 Select service. ``false`` -``hive.max-partitions-per-writers`` Maximum number of partitions per writer. 100 +``hive.s3select-pushdown.max-connections`` Maximum number of simultaneously open connections to S3 for 500 + S3SelectPushdown. -``hive.max-partitions-per-scan`` Maximum number of partitions for a single table scan. 100,000 +``hive.metastore.load-balancing-enabled`` Enable load balancing between multiple Metastore instances ``false`` -``hive.dynamic-split-sizes-enabled`` Enable dynamic sizing of splits based on data scanned by ``false`` - the query. +``hive.skip-empty-files`` Enable skipping empty files. Otherwise, it will produce an ``false`` + error iterating through empty files. -``hive.metastore.authentication.type`` Hive metastore authentication type. ``NONE`` - Possible values are ``NONE`` or ``KERBEROS``. +``hive.file-status-cache.max-retained-size`` Maximum size in bytes of the directory listing cache ``0KB`` -``hive.metastore.service.principal`` The Kerberos principal of the Hive metastore service. +``hive.metastore.catalog.name`` Specifies the catalog name to be passed to the metastore. -``hive.metastore.client.principal`` The Kerberos principal that Presto will use when connecting - to the Hive metastore service. +``hive.experimental.symlink.optimized-reader.enabled`` Experimental: Enable optimized SymlinkTextInputFormat reader ``true`` -``hive.metastore.client.keytab`` Hive metastore client keytab location. +``hive.copy-on-first-write-configuration-enabled`` Optimize the number of configuration copies by enabling ``false`` + copy-on-write technique. -``hive.hdfs.authentication.type`` HDFS authentication type. ``NONE`` - Possible values are ``NONE`` or ``KERBEROS``. + CopyOnFirstWriteConfiguration acts as a wrapper around the + standard Hadoop Configuration object, extending its + behaviour by introducing an additional layer of + indirection. However, many third-party libraries that + integrate with Presto rely directly on the Configuration + copy `constructor`_. Since this constructor does not + recognise or account for the wrapped nature of + CopyOnFirstWriteConfiguration, it can result in silent + failures where critical configuration properties are not + correctly propagated. + + ``hive.orc.use-column-names`` Enable accessing ORC columns by name in the ORC file ``false`` + metadata, instead of their ordinal position. Also toggleable + through the ``hive.orc_use_column_names`` session property. +======================================================== ============================================================ ============ -``hive.hdfs.impersonation.enabled`` Enable HDFS end user impersonation. ``false`` +.. _constructor: https://github.com/apache/hadoop/blob/02a9190af5f8264e25966a80c8f9ea9bb6677899/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/conf/Configuration.java#L844-L875 -``hive.hdfs.presto.principal`` The Kerberos principal that Presto will use when connecting - to HDFS. +Avro Configuration Properties +----------------------------- -``hive.hdfs.presto.keytab`` HDFS client keytab location. +When querying or creating Avro-formatted tables with the Hive connector, you may need to supply or override the Avro schema. In addition, Hive Metastore, especially Hive 3.x, must be configured to read storage schemas for Avro tables. -``hive.security`` See :doc:`hive-security`. +Table Properties +^^^^^^^^^^^^^^^^ -``security.config-file`` Path of config file to use when ``hive.security=file``. - See :ref:`hive-file-based-authorization` for details. +These properties can be used when creating or querying Avro tables in Presto: -``hive.non-managed-table-writes-enabled`` Enable writes to non-managed (external) Hive tables. ``false`` +======================================================== ============================================================================== ====================================================================================== +Property Name Description Default +======================================================== ============================================================================== ====================================================================================== +``avro_schema_url`` URL or path (HDFS, S3, HTTP, or others) to the Avro schema file for None (must be specified if Metastore does not provide or you need to + reading an Avro-formatted table. If specified, Presto will fetch override schema) + and use this schema instead of relying on any schema in the + Metastore. -``hive.non-managed-table-creates-enabled`` Enable creating non-managed (external) Hive tables. ``true`` +``skip_header_line_count`` Number of header lines to skip when reading CSV or TEXTFILE tables. None (ignored if not set). Must be non-negative. Only valid for + When set to ``1``, a header row will be written when creating new CSV and TEXTFILE formats. Values greater than ``1`` are not + CSV or TEXTFILE tables. supported for ``CREATE TABLE AS`` or ``INSERT`` operations. -``hive.collect-column-statistics-on-write`` Enables automatic column level statistics collection ``false`` - on write. See `Table Statistics <#table-statistics>`__ for - details. +``skip_footer_line_count`` Number of footer lines to skip when reading CSV or TEXTFILE tables. None (ignored if not set). Must be non-negative. Only valid for + Cannot be used when inserting into a table. CSV and TEXTFILE formats. This property is not + supported for ``CREATE TABLE AS`` or ``INSERT`` operations. +======================================================== ============================================================================== ====================================================================================== -``hive.s3select-pushdown.enabled`` Enable query pushdown to AWS S3 Select service. ``false`` +Hive Metastore Configuration for Avro +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``hive.s3select-pushdown.max-connections`` Maximum number of simultaneously open connections to S3 for 500 - S3SelectPushdown. +To support Avro tables with schema properties when using Hive 3.x, you must configure the Hive Metastore service: -``hive.metastore.load-balancing-enabled`` Enable load balancing between multiple Metastore instances +Add the ``metastore.storage.schema.reader.impl`` property to ``hive-site.xml`` where the metastore service is running: -``hive.skip-empty-files`` Enable skipping empty files. Otherwise, it will produce an ``false`` - error iterating through empty files. +.. code-block:: xml - ``hive.file-status-cache.max-retained-size`` Maximum size in bytes of the directory listing cache ``0KB`` + + metastore.storage.schema.reader.impl + org.apache.hadoop.hive.metastore.SerDeStorageSchemaReader + - ``hive.metastore.catalog.name`` Specifies the catalog name to be passed to the metastore. -================================================== ============================================================ ============ +You must restart the metastore service for this configuration to take effect. This setting allows the metastore to read storage schemas for Avro tables and avoids ``Storage schema reading not supported`` errors. Metastore Configuration Properties ---------------------------------- @@ -220,15 +289,29 @@ Property Name Descriptio ======================================================== ============================================================= ============ ``hive.metastore-timeout`` Timeout for Hive metastore requests. ``10s`` -``hive.metastore-cache-ttl`` Duration how long cached metastore data should be considered ``0s`` +``hive.metastore.cache.enabled-caches`` Comma-separated list of metastore cache types to enable. NONE + The value should be a valid . + +``hive.metastore.cache.disabled-caches`` Comma-separated list of metastore cache types to disable. NONE + The value should be a valid . + +``hive.metastore.cache.ttl.default`` Duration how long cached metastore data should be considered ``0s`` valid. +``hive.metastore.cache.ttl-by-type`` Per-cache time-to-live (TTL) overrides for Hive metastore NONE + caches. The value is a comma-separated list of + : pairs. + ``hive.metastore-cache-maximum-size`` Hive metastore cache maximum size. 10000 -``hive.metastore-refresh-interval`` Asynchronously refresh cached metastore data after access ``0s`` +``hive.metastore.cache.refresh-interval.default`` Asynchronously refresh cached metastore data after access ``0s`` if it is older than this but is not yet expired, allowing subsequent accesses to see fresh data. +``hive.metastore.cache.refresh-interval-by-type`` Per-cache refresh interval overrides for Hive metastore NONE + caches. The value is a comma-separated list of + : pairs. + ``hive.metastore-refresh-max-threads`` Maximum threads used to refresh cached metastore data. 100 ``hive.invalidate-metastore-cache-procedure-enabled`` When enabled, users will be able to invalidate metastore false @@ -246,6 +329,26 @@ Property Name Descriptio ======================================================== ============================================================= ============ +.. note:: + + The supported values for ``CACHE_TYPE`` when enabling Hive Metastore Cache are: + + * ``ALL``: Represents all supported Hive metastore cache types. + * ``DATABASE``: Caches metadata for individual Hive databases. + * ``DATABASE_NAMES``: Caches the list of all database names in the metastore. + * ``TABLE``: Caches metadata for individual Hive tables. + * ``TABLE_NAMES``: Caches the list of table names within a database. + * ``TABLE_STATISTICS``: Caches column-level statistics for Hive tables. + * ``TABLE_CONSTRAINTS``: Caches table constraint metadata such as primary and unique keys. + * ``PARTITION``: Caches metadata for individual Hive partitions. + * ``PARTITION_STATISTICS``: Caches column-level statistics for individual partitions. + * ``PARTITION_FILTER``: Caches partition name lookups based on partition filter predicates. + * ``PARTITION_NAMES``: Caches the list of partition names for a table. + * ``VIEW_NAMES``: Caches the list of view names within a database. + * ``TABLE_PRIVILEGES``: Caches table-level privilege information for users and roles. + * ``ROLES``: Caches the list of available Hive roles. + * ``ROLE_GRANTS``: Caches role grant mappings for principals. + AWS Glue Catalog Configuration Properties ----------------------------------------- @@ -953,12 +1056,16 @@ Hive catalog is called ``web``:: CALL web.system.example_procedure() -The following procedures are available: +Create Empty Partition +^^^^^^^^^^^^^^^^^^^^^^ * ``system.create_empty_partition(schema_name, table_name, partition_columns, partition_values)`` Create an empty partition in the specified table. +Sync Partition Metadata +^^^^^^^^^^^^^^^^^^^^^^^ + * ``system.sync_partition_metadata(schema_name, table_name, mode, case_sensitive)`` Check and update partitions list in metastore. There are three modes available: @@ -972,13 +1079,22 @@ The following procedures are available: file system paths to use lowercase (e.g. ``col_x=SomeValue``). Partitions on the file system not conforming to this convention are ignored, unless the argument is set to ``false``. -* ``system.invalidate_directory_list_cache()`` +Invalidate Directory List Cache +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Flush full directory list cache. +Invalidating directory list cache is useful when the files are added or deleted in the cache directory path and you want to make the changes visible to Presto immediately. +There are a couple of ways for invalidating this cache: -* ``system.invalidate_directory_list_cache(directory_path)`` +* The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.CachingDirectoryLister#flushCache``) to invalidate the directory list cache. You can call this procedure to invalidate the directory list cache by connecting via jconsole or jmxterm. This procedure flushes all the cache entries. - Invalidate directory list cache for specified directory_path. +* The Hive connector exposes ``system.invalidate_directory_list_cache`` procedure which gives the flexibility to invalidate the list cache completely or partially as per the requirement and can be invoked in various ways. + + * ``system.invalidate_directory_list_cache()`` : Flush full directory list cache. + + * ``system.invalidate_directory_list_cache(directory_path)`` : Invalidate directory list cache for specified directory_path. + +Invalidate Metastore Cache +^^^^^^^^^^^^^^^^^^^^^^^^^^ * ``system.invalidate_metastore_cache()`` @@ -996,8 +1112,10 @@ The following procedures are available: Invalidate all metastore cache entries linked to a specific partition. -Note: To enable ``system.invalidate_metastore_cache`` procedure, please refer to the properties that -apply to Hive Metastore and are listed in the `Metastore Configuration Properties`_ table. + .. note:: + + To enable ``system.invalidate_metastore_cache`` procedure, ``hive.invalidate-metastore-cache-procedure-enabled`` must be set to ``true``. + See the properties in `Metastore Configuration Properties`_ table for more information. Extra Hidden Columns -------------------- @@ -1012,22 +1130,12 @@ columns as a part of the query like any other columns of the table. How to invalidate metastore cache? ---------------------------------- -The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.metastore.CachingHiveMetastore#flushCache``) to invalidate the metastore cache. -You can call this procedure to invalidate the metastore cache by connecting via jconsole or jmxterm. - -This is useful when the Hive metastore is updated outside of Presto and you want to make the changes visible to Presto immediately. - -Currently, this procedure flushes the cache for all the tables in all the schemas. This is a known limitation and will be enhanced in the future. - -How to invalidate directory list cache? ---------------------------------------- - -The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.HiveDirectoryLister#flushCache``) to invalidate the directory list cache. -You can call this procedure to invalidate the directory list cache by connecting via jconsole or jmxterm. +Invalidating metastore cache is useful when the Hive metastore is updated outside of Presto and you want to make the changes visible to Presto immediately. +There are a couple of ways for invalidating this cache and are listed below - -This is useful when the files are added or deleted in the cache directory path and you want to make the changes visible to Presto immediately. +* The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore#invalidateAll``) to invalidate the metastore cache. You can call this procedure to invalidate the metastore cache by connecting via jconsole or jmxterm. However, this procedure flushes the cache for all the tables in all the schemas. -Currently, this procedure flushes all the cache entries. This is a known limitation and will be enhanced in the future. +* The Hive connector exposes ``system.invalidate_metastore_cache`` procedure which enables users to invalidate the metastore cache completely or partially as per the requirement and can be invoked with various arguments. See `Invalidate Metastore Cache`_ for more information. Examples -------- diff --git a/presto-docs/src/main/sphinx/connector/iceberg.rst b/presto-docs/src/main/sphinx/connector/iceberg.rst index ef5692f41408d..6f04490f91e0a 100644 --- a/presto-docs/src/main/sphinx/connector/iceberg.rst +++ b/presto-docs/src/main/sphinx/connector/iceberg.rst @@ -93,6 +93,10 @@ Property Name Description ``iceberg.hive.table-refresh.backoff-scale-factor`` The multiple used to scale subsequent wait time between 4.0 retries. + +``iceberg.engine.hive.lock-enabled`` Whether to use locks to ensure atomicity of commits. true + This will turn off locks but is overridden at a table level + with the table configuration ``engine.hive.lock-enabled``. ======================================================== ============================================================= ============ Nessie catalog @@ -341,9 +345,13 @@ Property Name Description ``iceberg.file-format`` The storage file format for Iceberg tables. The available ``PARQUET`` Yes No, write is not supported yet values are ``PARQUET`` and ``ORC``. -``iceberg.compression-codec`` The compression codec to use when writing files. The ``GZIP`` Yes No, write is not supported yet +``iceberg.compression-codec`` The compression codec to use when writing files. The ``ZSTD`` Yes No, write is not supported yet available values are ``NONE``, ``SNAPPY``, ``GZIP``, ``LZ4``, and ``ZSTD``. + + Note: ``LZ4`` is only available when + ``iceberg.file-format=ORC``. + ``iceberg.max-partitions-per-writer`` The maximum number of partitions handled per writer. ``100`` Yes No, write is not supported yet @@ -358,6 +366,23 @@ Property Name Description ``iceberg.delete-as-join-rewrite-enabled`` When enabled, equality delete row filtering is applied ``true`` Yes No, Equality delete read is not supported as a join with the data of the equality delete files. + Deprecated: This property is deprecated and will be removed + in a future release. Use the + ``iceberg.delete-as-join-rewrite-max-delete-columns`` + configuration property instead. + +``iceberg.delete-as-join-rewrite-max-delete-columns`` When set to a number greater than 0, this property enables ``400`` Yes No, Equality delete read is not supported + equality delete row filtering as a join with the data of the + equality delete files. The value of this property is the + maximum number of columns that can be used in the equality + delete files. If the number of columns in the equality delete + files exceeds this value, then the optimization is not + applied and the equality delete files are applied directly to + each row in the data files. + + This property is only applicable when + ``iceberg.delete-as-join-rewrite-enabled`` is set to + ``true``. ``iceberg.enable-parquet-dereference-pushdown`` Enable parquet dereference pushdown. ``true`` Yes No @@ -449,9 +474,13 @@ Property Name Description ``write.metadata.metrics.max-inferred-column-defaults`` Optionally specifies the maximum number of columns for which ``100`` Yes No, write is not supported yet metrics are collected. -``write.update.mode`` Optionally specifies the write delete mode of the Iceberg ``merge-on-read`` Yes No, write is not supported yet +``write.update.mode`` Optionally specifies the write update mode of the Iceberg ``merge-on-read`` Yes No, write is not supported yet specification to use for new tables, either ``copy-on-write`` or ``merge-on-read``. + +``engine.hive.lock-enabled`` Whether to use Hive metastore locks when committing to Yes No + a Hive metastore + ======================================================== =============================================================== ===================== =================== ============================================= The table definition below specifies format ``ORC``, partitioning by columns ``c1`` and ``c2``, @@ -499,6 +528,12 @@ Property Name Description ===================================================== ======================================================================= =================== ============================================= ``iceberg.delete_as_join_rewrite_enabled`` Overrides the behavior of the connector property Yes No, Equality delete read is not supported ``iceberg.delete-as-join-rewrite-enabled`` in the current session. + + Deprecated: This property is deprecated and will be removed. Use + ``iceberg.delete_as_join_rewrite_max_delete_columns`` instead. +``iceberg.delete_as_join_rewrite_max_delete_columns`` Overrides the behavior of the connector property Yes No, Equality delete read is not supported + ``iceberg.delete-as-join-rewrite-max-delete-columns`` in the + current session. ``iceberg.hive_statistics_merge_strategy`` Overrides the behavior of the connector property Yes Yes ``iceberg.hive-statistics-merge-strategy`` in the current session. ``iceberg.rows_for_metadata_optimization_threshold`` Overrides the behavior of the connector property Yes Yes @@ -514,9 +549,14 @@ Property Name Description assign a split to. Splits which read data from the same file within the same chunk will hash to the same node. A smaller chunk size will result in a higher probability splits being distributed evenly across - the cluster, but reduce locality. + the cluster, but reduce locality. + See :ref:`develop/connectors:Node Selection Strategy`. ``iceberg.parquet_dereference_pushdown_enabled`` Overrides the behavior of the connector property Yes No ``iceberg.enable-parquet-dereference-pushdown`` in the current session. +``materialized_view_storage_table_name_prefix`` Prefix for automatically generated materialized view storage table Yes No + names. Default: ``__mv_storage__`` +``materialized_view_missing_base_table_behavior`` Behavior when a base table referenced by a materialized view is Yes No + missing. Valid values: ``FAIL``, ``IGNORE``. Default: ``FAIL`` ===================================================== ======================================================================= =================== ============================================= Caching Support @@ -639,7 +679,21 @@ File and stripe footer cache is not applicable for Presto C++. Metastore Cache ^^^^^^^^^^^^^^^ -Iceberg Connector does not support Metastore Caching. +Iceberg Connector supports Metastore Caching with some exceptions. Iceberg Connector does not allow enabling TABLE cache. +Metastore Caching is only supported when ``iceberg.catalog.type`` is ``HIVE``. + +The Iceberg connector supports the same configuration properties for +`Hive Metastore Caching `_ +as a Hive connector. + +The following configuration properties are the minimum set of configurations required to be added in the Iceberg catalog file ``catalog/iceberg.properties``: + +.. code-block:: none + + # Hive Metastore Cache + hive.metastore.cache.disabled-caches=TABLE + hive.metastore.cache.ttl.default=10m + hive.metastore.cache.refresh-interval.default=5m Extra Hidden Metadata Columns ----------------------------- @@ -675,6 +729,42 @@ The Iceberg data sequence number in which this row was added. ----------------------------------+------------ 2 | 3 +``$deleted`` column +^^^^^^^^^^^^^^^^^^^ +Whether this row is a deleted row. When this column is used, deleted rows +from delete files will be marked as ``true`` instead of being filtered out of the results. + +.. code-block:: sql + + DELETE FROM "ctas_nation" WHERE regionkey = 0; + + SELECT "$deleted", regionkey FROM "ctas_nation"; + +.. code-block:: text + + $deleted | regionkey + ----------+----------- + true | 0 + false | 1 + +``$delete_file_path`` column +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The path of the delete file corresponding to a deleted row, or NULL if the row was not deleted. +When this column is used, deleted rows will not be filtered out of the results. + +.. code-block:: sql + + DELETE FROM "ctas_nation" WHERE regionkey = 0; + + SELECT "$delete_file_path", regionkey FROM "ctas_nation"; + +.. code-block:: text + + $delete_file_path | regionkey + -----------------------------------------------------------------------------------+----------- + file:/path/to/table/data/delete_file_d8510b3e-510a-4fc2-b2b2-e59ead7fd386.parquet | 0 + NULL | 1 + Presto C++ Support ^^^^^^^^^^^^^^^^^^ @@ -696,9 +786,9 @@ General properties of the given table. .. code-block:: text - key | value - ----------------------+--------- - write.format.default | PARQUET + key | value | is_supported_by_presto + ----------------------+----------+------------------------ + write.format.default | PARQUET | true ``$history`` Table ^^^^^^^^^^^^^^^^^^ @@ -860,6 +950,22 @@ Details about Iceberg references including branches and tags. For more informati testBranch | BRANCH | 3374797416068698476 | NULL | NULL | NULL testTag | TAG | 4686954189838128572 | 10 | NULL | NULL +``$metadata_log_entries`` Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Provides metadata log entries for the table. + +.. code-block:: sql + + SELECT * FROM "region$metadata_log_entries"; + +.. code-block:: text + + timestamp | file | latest_snapshot_id | latest_schema_id | latest_sequence_number + -------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------+------------------+------------------------ + 2024-12-28 23:41:30.451 Asia/Kolkata | hdfs://localhost:9000/user/hive/warehouse/iceberg_schema.db/region1/metadata/00000-395385ba-3b69-47a7-9c5b-61d056de55c6.metadata.json | 5983271822201743253 | 0 | 1 + 2024-12-28 23:42:42.207 Asia/Kolkata | hdfs://localhost:9000/user/hive/warehouse/iceberg_schema.db/region1/metadata/00001-61151efc-0e01-4a47-a5e6-7b72749cc4a8.metadata.json | 5841566266546816471 | 0 | 2 + 2024-12-28 23:42:47.591 Asia/Kolkata | hdfs://localhost:9000/user/hive/warehouse/iceberg_schema.db/region1/metadata/00002-d4a9c326-5053-4a26-9082-d9fbf1d6cd14.metadata.json | 6894018661156805064 | 0 | 3 + Presto C++ Support ^^^^^^^^^^^^^^^^^^ @@ -876,7 +982,7 @@ Register Table Iceberg tables for which table data and metadata already exist in the file system can be registered with the catalog. Use the ``register_table`` procedure on the catalog's ``system`` schema to register a table which -already exists but does not known by the catalog. +already exists but is not known by the catalog. The following arguments are available: @@ -1161,6 +1267,93 @@ Examples: CALL iceberg.system.set_table_property('schema_name', 'table_name', 'commit.retry.num-retries', '10'); +Rewrite Data Files +^^^^^^^^^^^^^^^^^^ + +Iceberg tracks all data files under different partition specs in a table. More data files require +more metadata to be stored in manifest files, and small data files can cause an unnecessary amount of metadata and +less efficient queries due to file open costs. Also, data files under different partition specs can +prevent metadata level deletion or thorough predicate push down for Presto. + +Use ``rewrite_data_files`` to rewrite the data files of a specified table so that they are +merged into fewer but larger files under the newest partition spec. If the table is partitioned, the data +files compaction can act separately on the selected partitions to improve read performance by reducing +metadata overhead and runtime file open cost. + +The following arguments are available: + +===================== ========== =============== ======================================================================= +Argument Name required type Description +===================== ========== =============== ======================================================================= +``schema`` Yes string Schema of the table to update. + +``table_name`` Yes string Name of the table to update. + +``filter`` string Predicate as a string used for filtering the files. Currently + only rewrite of whole partitions is supported. Filter on partition + columns. The default value is ``true``. + +``sorted_by`` array of Specify an array of one or more columns to use for sorting. When + strings performing a rewrite, the specified sorting definition must be + compatible with the table's own sorting property, if one exists. + +``options`` map Options to be used for data files rewrite. (to be expanded) +===================== ========== =============== ======================================================================= + +Examples: + +* Rewrite all the data files in table ``db.sample`` to the newest partition spec and combine small files to larger ones:: + + CALL iceberg.system.rewrite_data_files('db', 'sample'); + CALL iceberg.system.rewrite_data_files(schema => 'db', table_name => 'sample'); + +* Rewrite the data files in partitions specified by a filter in table ``db.sample`` to the newest partition spec:: + + CALL iceberg.system.rewrite_data_files('db', 'sample', 'partition_key = 1'); + CALL iceberg.system.rewrite_data_files(schema => 'db', table_name => 'sample', filter => 'partition_key = 1'); + +* Rewrite the data files in partitions specified by a filter in table ``db.sample`` to the newest partition spec and a sorting definition:: + + CALL iceberg.system.rewrite_data_files('db', 'sample', 'partition_key = 1', ARRAY['join_date DESC NULLS FIRST', 'emp_id ASC NULLS LAST']); + CALL iceberg.system.rewrite_data_files(schema => 'db', table_name => 'sample', filter => 'partition_key = 1', sorted_by => ARRAY['join_date']); + +Rewrite Manifests +^^^^^^^^^^^^^^^^^ + +This procedure rewrites the manifest files of an Iceberg table to optimize table metadata. +The procedure is a metadata-only operation and commits a new snapshot with `operation = replace`. + +The following arguments are available: + +===================== ========== =============== ======================================================================== +Argument Name required type Description +===================== ========== =============== ======================================================================== +``schema`` Yes string Schema of the table to update + +``table_name`` Yes string Name of the table to update + +``spec_id`` No integer Partition spec ID to rewrite manifests for. + If not specified, manifests for the curren partition spec are rewritten. +===================== ========== =============== ======================================================================== + +``rewrite_manifests`` does not modify data files and does not change query results. +The procedure may be a logical no-op if the existing manifests are already optimal. + +Delete-only manifests are retained as long as snapshots that reference them are valid. +To allow cleanup of such manifests, old snapshots must first be expired using ``CALL system.expire_snapshots``. + +The procedure always commits a snapshot with `operation = replace`, even when no physical rewrite is required. + +Examples: + +* Rewrite manifests for a table using positional arguments: :: + + CALL iceberg.system.rewrite_manifests('schema_name', 'table_name'); + +* Rewrite manifests for a specific partition spec: :: + + CALL iceberg.system.rewrite_manifests('schema_name', 'table_name', 0); + Presto C++ Support ^^^^^^^^^^^^^^^^^^ @@ -1205,6 +1398,8 @@ SQL Operation Presto Java Presto C++ Comments ``DESCRIBE`` Yes Yes ``UPDATE`` Yes No + +``MERGE`` Yes No ============================== ============= ============ ============================================================================ The Iceberg connector supports querying and manipulating Iceberg tables and schemas @@ -1404,6 +1599,10 @@ Alter table operations are supported in the Iceberg connector:: ALTER TABLE iceberg.web.page_views DROP COLUMN location; + ALTER TABLE iceberg.web.page_views DROP BRANCH 'branch1'; + + ALTER TABLE iceberg.web.page_views DROP TAG 'tag1'; + To add a new column as a partition column, identify the transform functions for the column. The table is partitioned by the transformed value of the column:: @@ -1421,10 +1620,27 @@ The table is partitioned by the transformed value of the column:: ALTER TABLE iceberg.web.page_views ADD COLUMN ts timestamp WITH (partitioning = 'hour'); -Table properties can be modified for an Iceberg table using an ALTER TABLE SET PROPERTIES statement. Only `commit_retries` can be modified at present. -For example, to set `commit_retries` to 6 for the table `iceberg.web.page_views_v2`, use:: +Use ``ARRAY[...]`` instead of a string to specify multiple partition transforms when adding a column. For example:: + + ALTER TABLE iceberg.web.page_views ADD COLUMN location VARCHAR WITH (partitioning = ARRAY['truncate(2)', 'bucket(8)', 'identity']); + + ALTER TABLE iceberg.web.page_views ADD COLUMN dt date WITH (partitioning = ARRAY['year', 'bucket(16)', 'identity']); - ALTER TABLE iceberg.web.page_views_v2 SET PROPERTIES (commit_retries = 6); +Some Iceberg table properties can be modified using an ``ALTER TABLE SET PROPERTIES`` statement. The modifiable table properties are: + +* ``commit.retry.num-retries`` +* ``read.split.target-size`` +* ``write.metadata.delete-after-commit.enabled`` +* ``engine.hive.lock-enabled`` +* ``write.metadata.previous-versions-max`` + +For example, to set ``commit.retry.num-retries`` to 6 for the table ``iceberg.web.page_views_v2``, use:: + + ALTER TABLE iceberg.web.page_views_v2 SET PROPERTIES ("commit.retry.num-retries" = 6); + +To set ``write.metadata.delete-after-commit.enabled`` to true and set ``write.metadata.previous-versions-max`` to 5, use:: + + ALTER TABLE iceberg.web.page_views_v2 SET PROPERTIES ("write.metadata.delete-after-commit.enabled" = true, "write.metadata.previous-versions-max" = 5); ALTER VIEW ^^^^^^^^^^ @@ -1638,11 +1854,11 @@ For example, ``DESCRIBE`` from the partitioned Iceberg table ``customer``: comment | varchar | | (3 rows) -UPDATE -^^^^^^ +UPDATE and MERGE +^^^^^^^^^^^^^^^^ -The Iceberg connector supports :doc:`../sql/update` operations on Iceberg -tables. Only some tables support updates. These tables must be at minimum format +The Iceberg connector supports :doc:`../sql/update` and :doc:`../sql/merge` operations on Iceberg +tables. Only some tables support them. These tables must be at minimum format version 2, and the ``write.update.mode`` must be set to `merge-on-read`. .. code-block:: sql @@ -1662,6 +1878,16 @@ updates. Query 20250204_010445_00022_ymwi5 failed: Iceberg table updates require at least format version 2 and update mode must be merge-on-read +Iceberg tables do not support running multiple :doc:`../sql/merge` statements on the same table in parallel. If two or more ``MERGE`` operations are executed concurrently on the same Iceberg table: + +* The first operation to complete will succeed. +* Subsequent operations will fail due to conflicting writes and will return the following error: + +.. code-block:: text + + Failed to commit Iceberg update to table:
+ Found conflicting files that can contain records matching true + Schema Evolution ---------------- @@ -2149,7 +2375,7 @@ Sorting can be combined with partitioning on the same column. For example:: sorted_by = ARRAY['join_date'] ) -The Iceberg connector does not support sort order transforms. The following sort order transformations are not supported: +Sort order does not support transforms. The following transforms are not supported: .. code-block:: text @@ -2172,4 +2398,243 @@ For example:: ) If a user creates a table externally with non-identity sort columns and then inserts data, the following warning message will be shown. -``Iceberg table sort order has sort fields of , , ... which are not currently supported by Presto`` \ No newline at end of file +``Iceberg table sort order has sort fields of , , ... which are not currently supported by Presto`` + +Materialized Views +------------------ + +The Iceberg connector supports materialized views. See :doc:`/admin/materialized-views` for general information and :doc:`/sql/create-materialized-view` for SQL syntax. + +Storage +^^^^^^^ + +Materialized views use a dedicated Iceberg storage table to persist the pre-computed results. By default, the storage table is created with the prefix ``__mv_storage__`` followed by the materialized view name in the same schema as the view. + +Table Properties +^^^^^^^^^^^^^^^^ + +The following table properties can be specified when creating a materialized view: + +========================================================== ============================================================================ +Property Name Description +========================================================== ============================================================================ +``storage_schema`` Schema name for the storage table. Defaults to the materialized view's + schema. + +``storage_table`` Custom name for the storage table. Defaults to the prefix plus the + materialized view name. + +``stale_read_behavior`` Behavior when reading from a materialized view that is stale beyond the + staleness window. Valid values: ``FAIL`` (throw an error), + ``USE_VIEW_QUERY`` (query base tables instead). + +``staleness_window`` Duration window for staleness tolerance (e.g., ``1h``, ``30m``, ``0s``). + Defaults to ``0s`` if only ``stale_read_behavior`` is set. When set to + ``0s``, any staleness triggers the configured behavior. + +``refresh_type`` Refresh strategy for the materialized view. Currently only ``FULL`` is + supported. Default: ``FULL`` +========================================================== ============================================================================ + +The storage table inherits standard Iceberg table properties for partitioning, sorting, and file format. + +Freshness and Refresh +^^^^^^^^^^^^^^^^^^^^^ + +Materialized views track the snapshot IDs of their base tables to determine staleness. When base tables are modified, the materialized view becomes stale and returns results by querying the base tables directly. After running ``REFRESH MATERIALIZED VIEW``, queries read from the pre-computed storage table. + +The refresh operation uses a full refresh strategy, replacing all data in the storage table with the current query results. + +.. _iceberg-stale-data-handling: + +Stale Data Handling +^^^^^^^^^^^^^^^^^^^ + +By default, when no staleness properties are configured, queries against a stale materialized +view will fall back to executing the underlying view query against the base tables. You can +change this default using the ``materialized_view_stale_read_behavior`` session property. + +To configure staleness handling per view, set both of these properties together: + +- ``stale_read_behavior``: What to do when reading stale data (``FAIL`` or ``USE_VIEW_QUERY``) +- ``staleness_window``: How much staleness to tolerate (e.g., ``1h``, ``30m``, ``0s``) + +The Iceberg connector automatically detects staleness based on base table modifications. +A materialized view is considered stale if base tables have changed AND the time since +the last base table modification exceeds the staleness window. + +Example with staleness handling: + +.. code-block:: sql + + CREATE MATERIALIZED VIEW hourly_sales + WITH ( + stale_read_behavior = 'FAIL', + staleness_window = '1h' + ) + AS SELECT date_trunc('hour', sale_time) as hour, SUM(amount) as total + FROM sales GROUP BY 1; + +Limitations +^^^^^^^^^^^ + +- All refreshes recompute the entire result set +- REFRESH does not provide snapshot isolation across multiple base tables +- Querying materialized views at specific snapshots or timestamps is not supported + +Example +^^^^^^^ + +Create a materialized view with custom storage configuration: + +.. code-block:: sql + + CREATE MATERIALIZED VIEW regional_sales + WITH ( + storage_schema = 'analytics', + storage_table = 'sales_summary' + ) + AS SELECT region, SUM(amount) as total FROM orders GROUP BY region; + +Authorization +------------- + +Enable authorization checks for the :doc:`/connector/iceberg` by setting +the ``iceberg.security`` property in the Iceberg catalog properties file. This +property must be one of the following values: + +================================================== ============================================================ +Property Value Description +================================================== ============================================================ +``allow-all`` (default value) No authorization checks are enforced, thus allowing all + operations. + +``file`` Authorization checks are enforced using a config file specified + by the Iceberg configuration property ``security.config-file``. + See :ref:`iceberg-file-based-authorization` for details. +================================================== ============================================================ + +.. _iceberg-file-based-authorization: + +File Based Authorization +^^^^^^^^^^^^^^^^^^^^^^^^ + +The config file is specified using JSON and is composed of three sections, +each of which is a list of rules that are matched in the order specified +in the config file. The user is granted the privileges from the first +matching rule. All regexes default to ``.*`` if not specified. + +Schema Rules +~~~~~~~~~~~~ + +These rules govern who is considered an owner of a schema. + +* ``user`` (optional): regex to match against user name. + +* ``schema`` (optional): regex to match against schema name. + +* ``owner`` (required): boolean indicating ownership. + +Table Rules +~~~~~~~~~~~ + +These rules govern the privileges granted on specific tables. + +* ``user`` (optional): regex to match against user name. + +* ``schema`` (optional): regex to match against schema name. + +* ``table`` (optional): regex to match against table name. + +* ``privileges`` (required): zero or more of ``SELECT``, ``INSERT``, + ``DELETE``, ``OWNERSHIP``, ``GRANT_SELECT``. + +Session Property Rules +~~~~~~~~~~~~~~~~~~~~~~ + +These rules govern who may set session properties. + +* ``user`` (optional): regex to match against user name. + +* ``property`` (optional): regex to match against session property name. + +* ``allowed`` (required): boolean indicating whether this session property may be set. + +Procedure Rules +~~~~~~~~~~~~~~~ + +These rules govern the privileges granted on specific procedures. + +* ``user`` (optional): regex to match against user name. + +* ``schema`` (optional): regex to match against schema name. + +* ``procedure`` (optional): regex to match against procedure name. + +* ``privileges`` (required): a list that is empty or contains ``EXECUTE``. + +See below for an example. + +.. code-block:: json + + { + "schemas": [ + { + "user": "admin", + "schema": ".*", + "owner": true + }, + { + "user": "guest", + "owner": false + }, + { + "schema": "default", + "owner": true + } + ], + "tables": [ + { + "user": "admin", + "privileges": ["SELECT", "INSERT", "DELETE", "OWNERSHIP"] + }, + { + "user": "banned_user", + "privileges": [] + }, + { + "schema": "default", + "table": ".*", + "privileges": ["SELECT"] + } + ], + "sessionProperties": [ + { + "property": "force_local_scheduling", + "allow": true + }, + { + "user": "admin", + "property": "max_split_size", + "allow": true + } + ], + "procedures": [ + { + "user": "admin", + "schema": ".*", + "privileges": ["EXECUTE"] + }, + { + "user": "alice", + "schema": "alice_schema", + "privileges": ["EXECUTE"] + }, + { + "user": "guest", + "schema": "alice_schema", + "procedure": "test_procedure", + "privileges": ["EXECUTE"] + } + ] + } diff --git a/presto-docs/src/main/sphinx/connector/kafka.rst b/presto-docs/src/main/sphinx/connector/kafka.rst index c210ab08128ef..39a9f4ed0664b 100644 --- a/presto-docs/src/main/sphinx/connector/kafka.rst +++ b/presto-docs/src/main/sphinx/connector/kafka.rst @@ -73,6 +73,10 @@ Property Name Description ``kafka.truststore.password`` Password for the truststore file ``kafka.truststore.type`` File format of the truststore file, defaults to ``JKS`` ``kafka.config.resources`` A comma-separated list of Kafka client configuration files. If a specialized authentication method is required, it can be specified in these additional Kafka client properties files. Example: `/etc/kafka-configuration.properties` +``case-sensitive-name-matching`` Enable case-sensitive identifier support for schema, + table, and column names for the connector. When disabled, + names are matched case-insensitively using lowercase + normalization. Default is ``false``. =================================== ============================================================== ``kafka.table-names`` diff --git a/presto-docs/src/main/sphinx/connector/mongodb.rst b/presto-docs/src/main/sphinx/connector/mongodb.rst index f40e7b0c50554..b375491e44bcd 100644 --- a/presto-docs/src/main/sphinx/connector/mongodb.rst +++ b/presto-docs/src/main/sphinx/connector/mongodb.rst @@ -34,29 +34,38 @@ Configuration Properties The following configuration properties are available: -===================================== ============================================================== -Property Name Description -===================================== ============================================================== -``mongodb.seeds`` List of all mongod servers -``mongodb.schema-collection`` A collection which contains schema information +===================================== ============================================================== =========== +Property Name Description Default +===================================== ============================================================== =========== +``mongodb.seeds`` List of all MongoDB servers +``mongodb.schema-collection`` A collection which contains schema information ``_schema`` ``mongodb.credentials`` List of credentials -``mongodb.min-connections-per-host`` The minimum size of the connection pool per host -``mongodb.connections-per-host`` The maximum size of the connection pool per host -``mongodb.max-wait-time`` The maximum wait time -``mongodb.connection-timeout`` The socket connect timeout -``mongodb.socket-timeout`` The socket timeout -``mongodb.socket-keep-alive`` Whether keep-alive is enabled on each socket -``mongodb.ssl.enabled`` Use TLS/SSL for connections to mongod/mongos -``mongodb.read-preference`` The read preference -``mongodb.write-concern`` The write concern +``mongodb.min-connections-per-host`` The minimum size of the connection pool per host ``0`` +``mongodb.connections-per-host`` The maximum size of the connection pool per host ``100`` +``mongodb.max-wait-time`` The maximum wait time ``120000ms`` +``mongodb.connection-timeout`` The socket connect timeout ``10000ms`` +``mongodb.socket-timeout`` The socket timeout ``0ms`` +``mongodb.socket-keep-alive`` Whether keep-alive is enabled on each socket ``false`` +``mongodb.tls.enabled`` Use TLS/SSL for connections to MongoDB ``false`` +``mongodb.tls.keystore-path`` Path to the keystore file for client certificates +``mongodb.tls.keystore-password`` Password for the keystore file +``mongodb.tls.truststore-path`` Path to the truststore file for trusted Certificate + Authorities (CA) +``mongodb.tls.truststore-password`` Password for the truststore file +``mongodb.read-preference`` The read preference ``primary`` +``mongodb.write-concern`` The write concern ``acknowledged`` ``mongodb.required-replica-set`` The required replica set name ``mongodb.cursor-batch-size`` The number of elements to return in a batch -===================================== ============================================================== +``case-sensitive-name-matching`` Enable case-sensitive identifier support for schema, + table, and column names for the connector. When disabled, + names are matched case-insensitively using lowercase + normalization. Default is ``false`` +===================================== ============================================================== =========== ``mongodb.seeds`` ^^^^^^^^^^^^^^^^^ -Comma-separated list of ``hostname[:port]`` all mongod servers in the same replica set or a list of mongos servers in the same sharded cluster. If port is not specified, port 27017 will be used. +Comma-separated list of ``hostname[:port]`` all MongoDB servers in the same replica set or a list of MongoDB servers in the same sharded cluster. If port is not specified, port 27017 will be used. This property is required; there is no default and at least one seed must be defined. @@ -119,14 +128,45 @@ This flag controls the socket keep alive feature that keeps a connection alive t This property is optional; the default is ``false``. -``mongodb.ssl.enabled`` -^^^^^^^^^^^^^^^^^^^^^^^^ +``mongodb.tls.enabled`` +^^^^^^^^^^^^^^^^^^^^^^^ + +This flag enables TLS/SSL connections to MongoDB servers. + +This property is optional and defaults to ``false``. When enabled, you can optionally configure client certificate authentication and custom certificate authorities using the related TLS properties. + +.. note:: + + The ``mongodb.ssl.enabled`` property is deprecated and will be removed in a future version. + Use ``mongodb.tls.enabled`` instead. The old property name is supported for backward compatibility. + +``mongodb.tls.keystore-path`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Path to the Java KeyStore file containing the client certificate and private key for TLS authentication. The connector supports both Java KeyStore (JKS) format and Privacy-Enhanced Mail (PEM) file format. -This flag enables SSL connections to MongoDB servers. +This property is optional and only used when ``mongodb.tls.enabled`` is ``true``. Unlike the truststore, there is no default keystore - you must provide one if client certificate authentication is required. -This property is optional and defaults to ``false``. If you set it to ``true`` and host Presto yourself, it’s likely that you also use a TLS CA file. +``mongodb.tls.keystore-password`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -For setup instructions, see :ref:`tls-ca-definition-label`. +Password for the keystore file specified in ``mongodb.tls.keystore-path``. + +This property is optional and only used when a keystore path is specified. + +``mongodb.tls.truststore-path`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Path to the Java TrustStore file containing the trusted certificate authorities for TLS connections. The connector supports both Java KeyStore (JKS) format and Privacy-Enhanced Mail (PEM) file format. + +This property is optional and only used when ``mongodb.tls.enabled`` is ``true``. If not specified, the default system truststore will be used. + +``mongodb.tls.truststore-password`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Password for the truststore file specified in ``mongodb.tls.truststore-path``. + +This property is optional and only used when a truststore path is specified. ``mongodb.read-preference`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -170,6 +210,45 @@ This property is optional; the default is ``0``. .. _tls-ca-definition-label: +TLS/SSL Configuration +--------------------- + +The MongoDB connector supports comprehensive TLS/SSL configuration for secure connections to MongoDB clusters. + +Basic TLS Configuration +^^^^^^^^^^^^^^^^^^^^^^^ + +To enable basic TLS connections, set the following property: + +.. code-block:: none + + mongodb.tls.enabled=true + +This enables TLS using the system's default certificate authorities. + +Advanced TLS Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For advanced TLS configuration including client certificate authentication and custom certificate authorities, use the following properties: + +.. code-block:: none + + mongodb.tls.enabled=true + mongodb.tls.keystore-path=/path/to/client.jks + mongodb.tls.keystore-password=keystore_password + mongodb.tls.truststore-path=/path/to/truststore.jks + mongodb.tls.truststore-password=truststore_password + +Certificate Format Support +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The connector supports both Java KeyStore (JKS) and PEM file formats for certificates: + +- **Java KeyStore (JKS)**: Traditional Java keystore format +- **PEM Files**: Privacy-Enhanced Mail format, commonly used with OpenSSL + +The connector automatically detects the format and handles the certificates appropriately. + Configuring the MongoDB Connector to Use a TLS CA File ------------------------------------------------------ @@ -237,7 +316,11 @@ To configure a MongoDB catalog for this cluster, follow these steps: connector.name=mongodb mongodb.seeds=:27017 mongodb.credentials=:@ - mongodb.ssl.enabled=true + mongodb.tls.enabled=true + mongodb.tls.keystore-path=/path/to/client.jks + mongodb.tls.keystore-password=keystore_password + mongodb.tls.truststore-path=/path/to/truststore.jks + mongodb.tls.truststore-password=truststore_password mongodb.required-replica-set= Run Queries @@ -396,4 +479,4 @@ ALTER TABLE returns an error similar to the following: - ``Query 20240720_123348_00014_v7vrn failed: line 1:55: mismatched input 'int'. Expecting: 'FUNCTION', 'SCHEMA', 'TABLE'`` + ``Query 20240720_123348_00014_v7vrn failed: line 1:55: mismatched input 'int'. Expecting: 'FUNCTION', 'SCHEMA', 'TABLE'`` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/connector/mysql.rst b/presto-docs/src/main/sphinx/connector/mysql.rst index 85070da8fb6e6..758bfa110ceb9 100644 --- a/presto-docs/src/main/sphinx/connector/mysql.rst +++ b/presto-docs/src/main/sphinx/connector/mysql.rst @@ -68,6 +68,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema,mysql`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying MySQL diff --git a/presto-docs/src/main/sphinx/connector/oracle.rst b/presto-docs/src/main/sphinx/connector/oracle.rst index dfa84c873515d..a10732ed5c668 100644 --- a/presto-docs/src/main/sphinx/connector/oracle.rst +++ b/presto-docs/src/main/sphinx/connector/oracle.rst @@ -70,6 +70,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying Oracle @@ -98,6 +102,27 @@ Finally, you can access the ``clicks`` table in the ``web`` database:: If you used a different name for your catalog properties file, use that catalog name instead of ``oracle`` in the above examples. +Type mapping +------------ + +PrestoDB and Oracle each support types that the other does not. When reading from Oracle, Presto converts +the data types from Oracle to equivalent Presto data types. +Refer to the following section for type mapping in each direction. + +Oracle to PrestoDB type mapping +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The connector maps Oracle types to the corresponding PrestoDB types: + +.. list-table:: Oracle to PrestoDB type mapping + :widths: 50, 50 + :header-rows: 1 + + * - Oracle type + - PrestoDB type + * - ``BLOB`` + - ``VARBINARY`` + Oracle Connector Limitations ---------------------------- diff --git a/presto-docs/src/main/sphinx/connector/pinot.rst b/presto-docs/src/main/sphinx/connector/pinot.rst index ff3fa43c43334..57b4150bab224 100644 --- a/presto-docs/src/main/sphinx/connector/pinot.rst +++ b/presto-docs/src/main/sphinx/connector/pinot.rst @@ -84,6 +84,8 @@ Property Name Description ``pinot.broker-authentication-user`` Broker username for basic authentication method. ``pinot.broker-authentication-password`` Broker password for basic authentication method. ``pinot.query-options`` Pinot query-related case-sensitive options. E.g. skipUpsert:true,enableNullHandling:true +``case-sensitive-name-matching`` Enable case-sensitive identifier support for schema, table, and column names for the connector. When disabled, + names are matched case-insensitively using lowercase normalization. Default is ``false``. ========================================================== ============================================================================================================= If ``pinot.controller-authentication-type`` is set to ``PASSWORD`` then both ``pinot.controller-authentication-user`` and diff --git a/presto-docs/src/main/sphinx/connector/postgresql.rst b/presto-docs/src/main/sphinx/connector/postgresql.rst index 29839d6613f6c..065b54ec11baa 100644 --- a/presto-docs/src/main/sphinx/connector/postgresql.rst +++ b/presto-docs/src/main/sphinx/connector/postgresql.rst @@ -56,6 +56,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying PostgreSQL @@ -141,6 +145,10 @@ The connector maps PostgreSQL types to the corresponding PrestoDB types: - ``JSON`` * - ``JSONB`` - ``JSON`` + * - ``GEOMETRY`` + - ``VARCHAR`` + * - ``GEOGRAPHY`` + - ``VARCHAR`` No other types are supported. diff --git a/presto-docs/src/main/sphinx/connector/redis.rst b/presto-docs/src/main/sphinx/connector/redis.rst index 0f4b0ff287210..ab91303a31d9f 100644 --- a/presto-docs/src/main/sphinx/connector/redis.rst +++ b/presto-docs/src/main/sphinx/connector/redis.rst @@ -51,6 +51,10 @@ Property Name Description ``redis.hide-internal-columns`` Controls whether internal columns are part of the table schema or not ``redis.database-index`` Redis database index ``redis.password`` Redis server password +``redis.user`` Redis server username +``redis.tls.enabled`` Whether TLS security is enabled (defaults to ``false``) +``redis.tls.truststore-path`` Path to the TLS certificate file +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema, table, and column names for the connector. When disabled, names are matched case-insensitively using lowercase normalization. Defaults to ``false``. ================================= ============================================================== ``redis.table-names`` @@ -130,19 +134,39 @@ show up in ``DESCRIBE `` or ``SELECT *``. This property is optional; the default is ``true``. ``redis.database-index`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^ The Redis database to query. This property is optional; the default is ``0``. ``redis.password`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^ The password for password-protected Redis server. This property is optional; the default is ``null``. +``redis.user`` +^^^^^^^^^^^^^^ + +Redis server username. + +This property is required; there is no default. + +``redis.tls.enabled`` +^^^^^^^^^^^^^^^^^^^^^ + +Enable or disable TLS security. + +This property is optional; default is ``false``. + +``redis.tls.truststore-path`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Path to the TLS certificate file. + +This property is required if ``redis.tls.enabled`` is set to ``true``. Internal Columns ---------------- diff --git a/presto-docs/src/main/sphinx/connector/redshift.rst b/presto-docs/src/main/sphinx/connector/redshift.rst index b28960572f4e4..9ad6fa53dc53b 100644 --- a/presto-docs/src/main/sphinx/connector/redshift.rst +++ b/presto-docs/src/main/sphinx/connector/redshift.rst @@ -56,6 +56,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying Redshift diff --git a/presto-docs/src/main/sphinx/connector/singlestore.rst b/presto-docs/src/main/sphinx/connector/singlestore.rst index 3545bd3d99ee1..f55999410da90 100644 --- a/presto-docs/src/main/sphinx/connector/singlestore.rst +++ b/presto-docs/src/main/sphinx/connector/singlestore.rst @@ -75,6 +75,50 @@ For :doc:`/sql/create-table` statement, the default table type is ``columnstore` The table type can be configured by setting the ``default_table_type`` engine variable, see `Creating a Columnstore Table `_. +SingleStore to PrestoDB type mapping +------------------------------------ + +Map of SingleStore types to the relevant PrestoDB types: + +.. list-table:: SingleStore to PrestoDB type mapping + :widths: 50, 50 + :header-rows: 1 + + * - SingleStore type + - PrestoDB type + * - ``BOOLEAN`` + - ``BOOLEAN`` + * - ``INTEGER`` + - ``INTEGER`` + * - ``FLOAT`` + - ``REAL`` + * - ``DOUBLE`` + - ``DOUBLE`` + * - ``DECIMAL`` + - ``DECIMAL`` + * - ``LARGETEXT`` + - ``VARCHAR (unbounded)`` + * - ``VARCHAR(len)`` + - ``VARCHAR(len) len < 21845`` + * - ``CHAR(len)`` + - ``CHAR(len)`` + * - ``MEDIUMTEXT`` + - ``VARCHAR(len) 21845 <= len < 5592405`` + * - ``LARGETEXT`` + - ``VARCHAR(len) 5592405 <= len < 1431655765`` + * - ``MEDIUMBLOB`` + - ``VARBINARY`` + * - ``UUID`` + - ``UUID`` + * - ``DATE`` + - ``DATE`` + * - ``TIME`` + - ``TIME`` + * - ``DATETIME`` + - ``TIMESTAMP`` + +No other types are supported. + The following SQL statements are not supported: * :doc:`/sql/alter-schema` diff --git a/presto-docs/src/main/sphinx/connector/sqlserver.rst b/presto-docs/src/main/sphinx/connector/sqlserver.rst index dc28f3a0ae603..0df4e755273f5 100644 --- a/presto-docs/src/main/sphinx/connector/sqlserver.rst +++ b/presto-docs/src/main/sphinx/connector/sqlserver.rst @@ -134,6 +134,10 @@ Property Name Description cached. Set to ``0ms`` to disable the cache. ``1m`` ``list-schemas-ignored-schemas`` List of schemas to ignore when listing schemas. ``information_schema`` + +``case-sensitive-name-matching`` Enable case sensitive identifier support for schema and table ``false`` + names for the connector. When disabled, names are matched + case-insensitively using lowercase normalization. ================================================== ==================================================================== =========== Querying SQL Server diff --git a/presto-docs/src/main/sphinx/connector/system.rst b/presto-docs/src/main/sphinx/connector/system.rst index 9f0b8038e4479..0e388c38627fc 100644 --- a/presto-docs/src/main/sphinx/connector/system.rst +++ b/presto-docs/src/main/sphinx/connector/system.rst @@ -36,7 +36,40 @@ System Connector Tables ``metadata.catalogs`` ^^^^^^^^^^^^^^^^^^^^^ -The catalogs table contains the list of available catalogs. +The catalogs table contains the list of available catalogs. The columns in ``metadata.catalogs`` are: + +======================================= ====================================================================== +Column Name Description +======================================= ====================================================================== +``catalog_name`` The value of this column is derived from the names of + catalog.properties files present under ``etc/catalog`` path under + presto installation directory. Everything except the suffix + ``.properties`` is treated as the catalog name. For example, if there + is a file named ``my_catalog.properties``, then ``my_catalog`` will be + listed as the value for this column. + +``connector_id`` The values in this column are a duplicate of the values in the + ``catalog_name`` column. + +``connector_name`` This column represents the actual name of the underlying connector + that a particular catalog is using. This column contains the value of + ``connector.name`` property from the catalog.properties file. +======================================= ====================================================================== + +Example: + +Suppose a user configures a single catalog by creating a file named ``my_catalog.properties`` with the +below contents:: + + connector.name=hive-hadoop2 + hive.metastore.uri=thrift://localhost:9083 + +``metadata.catalogs`` table will show below output:: + + presto> select * from system.metadata.catalogs; + catalog_name | connector_id | connector_name + --------------+--------------+---------------- + my_catalog | my_catalog | hive-hadoop2 ``metadata.schema_properties`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-docs/src/main/sphinx/develop.rst b/presto-docs/src/main/sphinx/develop.rst index 3f65e72865df2..8f58863395ad8 100644 --- a/presto-docs/src/main/sphinx/develop.rst +++ b/presto-docs/src/main/sphinx/develop.rst @@ -13,6 +13,7 @@ This guide is intended for Presto contributors and plugin developers. develop/delete-and-update develop/types develop/functions + develop/procedures develop/system-access-control develop/password-authenticator develop/event-listener @@ -22,3 +23,4 @@ This guide is intended for Presto contributors and plugin developers. develop/presto-console develop/presto-authenticator develop/client-request-filter + develop/release-process diff --git a/presto-docs/src/main/sphinx/develop/client-protocol.rst b/presto-docs/src/main/sphinx/develop/client-protocol.rst index c5ebed6765ee4..a612b08231a95 100644 --- a/presto-docs/src/main/sphinx/develop/client-protocol.rst +++ b/presto-docs/src/main/sphinx/develop/client-protocol.rst @@ -122,6 +122,9 @@ Request Header Name Description ``X-Presto-Extra-Credential`` Provides extra credentials to the connector. The header is a name=value string that is saved in the session ``Identity`` object. The name and value are only meaningful to the connector. +``X-Presto-Retry-Query`` Boolean flag indicating that this query is a placeholder for potential retry. + When set to ``true``, marks the query on the backup cluster as a retry placeholder + and prevents retry chains in cross-cluster retry scenarios. ====================================== ========================================================================================= @@ -184,3 +187,69 @@ Data Member Type Notes ================= Class ``PrestoHeaders`` enumerates all the HTTP request and response headers allowed by the Presto client REST API. + + +Cross-Cluster Query Retry +========================= + +Presto supports automatic query retry on a backup cluster when a query fails on the primary cluster. This feature enables +high availability by transparently redirecting failed queries to a backup cluster. + +The cross-cluster retry mechanism works as follows: + +Query Parameters +---------------- + +When a router or load balancer handles a query that should support cross-cluster retry, it includes the following +query parameters when redirecting the client to the primary cluster: + +* ``retryUrl`` - The URL-encoded endpoint of the backup cluster where the query can be retried if it fails +* ``retryExpirationInSeconds`` - The number of seconds until the retry URL expires (must be at least 1). This value + should be set based on the ``Cache-Control`` headers returned by Presto query endpoints. Presto uses ``Cache-Control`` + headers to indicate how long a query will be retained in the server's memory. The retry expiration should not exceed + this cache duration to ensure the placeholder query is still available when the retry occurs. + +Both parameters must be provided together. If only one is provided, the request will be rejected with a 400 Bad Request error. + +Example request to primary cluster:: + + POST /v1/statement?retryUrl=https%3A%2F%2Fbackup.example.com%3A8080%2Fv1%2Fstatement&retryExpirationInSeconds=300 + +Retry Header +------------ + +The ``X-Presto-Retry-Query`` header is used to indicate that a query is being created as a placeholder for potential +retry. When set to ``true``, this header: + +* Indicates the query is a retry placeholder on the backup cluster +* Prevents retry chains - a query marked with this header will not trigger another retry if it fails + +Retry Flow +---------- + +1. Router/load balancer POSTs the query to the backup cluster with ``X-Presto-Retry-Query: true`` header to create + a placeholder query that can be used as a retry destination +2. Router redirects (HTTP 307) the client to the primary cluster with ``retryUrl`` and ``retryExpirationInSeconds`` + query parameters +3. Client follows the redirect and POSTs the query to the primary cluster +4. Primary cluster executes the query normally +5. If the query fails with a retriable error code (configured on the server), the Presto server modifies the + ``nextUri`` in the response to point to the retry URL of the backup cluster +6. Client follows the ``nextUri`` to the backup cluster where the placeholder query executes the actual query +7. If the retry query fails, it will not trigger another retry since it's marked with ``X-Presto-Retry-Query`` + +Limitations +----------- + +Cross-cluster retry has the following limitations: + +* **Query types**: Retry only works when no results have been sent back to the client. In practice, this feature + works well for: + + - ``CREATE TABLE AS SELECT`` statements + - DDL operations (``CREATE``, ``ALTER``, ``DROP``, etc.) + - ``INSERT`` statements + - ``SELECT`` queries that fail before any results are produced + + For ``SELECT`` queries that produce results, retry will only occur if the failure happens during planning or + before the first batch of results is generated. diff --git a/presto-docs/src/main/sphinx/develop/connectors.rst b/presto-docs/src/main/sphinx/develop/connectors.rst index c77877890e3d3..24713c68cc176 100644 --- a/presto-docs/src/main/sphinx/develop/connectors.rst +++ b/presto-docs/src/main/sphinx/develop/connectors.rst @@ -8,7 +8,8 @@ you adapt your data source to the API expected by Presto, you can write queries against this data. ConnectorSplit ----------------- +-------------- + Instances of your connector splits. The ``getNodeSelectionStrategy`` method indicates the node affinity @@ -81,3 +82,14 @@ Given a split and a list of columns, the record set provider is responsible for delivering data to the Presto execution engine. It creates a ``RecordSet``, which in turn creates a ``RecordCursor`` that is used by Presto to read the column values for each row. + +Node Selection Strategy +----------------------- + +The node selection strategy is specified by a connector on each split. The possible values are: + +* HARD_AFFINITY - The Presto runtime must schedule this split on the nodes specified on ``ConnectorSplit#getPreferredNodes``. +* SOFT_AFFINITY - The Presto runtime should prefer ``ConnectorSplit#getPreferredNodes`` nodes, but doesn't have to. Use this value primarily for caching. +* NO_PREFERENCE - No preference. + +Use the ``node_selection_strategy`` session property in Hive and Iceberg to override this. \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/develop/functions.rst b/presto-docs/src/main/sphinx/develop/functions.rst index 8afa657818aee..2de3e550d149f 100644 --- a/presto-docs/src/main/sphinx/develop/functions.rst +++ b/presto-docs/src/main/sphinx/develop/functions.rst @@ -2,12 +2,15 @@ Functions ========= +Functions in Presto can be implemented at Plugin and the Connector level. +The following two sections describe how to implement them. + Plugin Implementation --------------------- The function framework is used to implement SQL functions. Presto includes a number of built-in functions. In order to implement new functions, you can -write a plugin that returns one more functions from ``getFunctions()``: +write a plugin that returns one or more functions from ``getFunctions()``: .. code-block:: java @@ -31,10 +34,44 @@ Note that the ``ImmutableSet`` class is a utility class from Guava. The ``getFunctions()`` method contains all of the classes for the functions that we will implement below in this tutorial. +Functions registered using this method are available in the default +namespace ``presto.default``. + For a full example in the codebase, see either the ``presto-ml`` module for machine learning functions or the ``presto-teradata-functions`` module for Teradata-compatible functions, both in the root of the Presto source. +Connector Functions Implementation +---------------------------------- + +To implement new functions at the connector level, in your +connector implementation, override the ``getSystemFunctions()`` method that returns one +or more functions: + +.. code-block:: java + + public class ExampleFunctionsConnector + implements Connector + { + @Override + public Set> getSystemFunctions() + { + return ImmutableSet.>builder() + .add(ExampleNullFunction.class) + .add(IsNullFunction.class) + .add(IsEqualOrNullFunction.class) + .add(ExampleStringFunction.class) + .add(ExampleAverageFunction.class) + .build(); + } + } + +Functions registered using this interface are available in the namespace +``.system`` where ```` is the catalog name used +in the Presto deployment for this connector type. + +At present, connector level functions do not support Window functions and Scalar operators. + Scalar Function Implementation ------------------------------ @@ -76,7 +113,7 @@ a wrapper around ``byte[]``, rather than ``String`` for its native container typ ``@SqlNullable`` if it can return ``NULL`` when the arguments are non-null. Parametric Scalar Functions ---------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^ Scalar functions that have type parameters have some additional complexity. To make our previous example work with any type we need the following: @@ -150,7 +187,7 @@ To make our previous example work with any type we need the following: } Another Scalar Function Example -------------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The ``lowercaser`` function takes a single ``VARCHAR`` argument and returns a ``VARCHAR``, which is the argument converted to lower case: @@ -177,7 +214,7 @@ has no ``@SqlNullable`` annotations, meaning that if the argument is ``NULL``, the result will automatically be ``NULL`` (the function will not be called). Codegen Scalar Function Implementation --------------------------------------- +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Scalar functions can also be implemented in bytecode, allowing us to specialize and optimize functions according to the ``@TypeParameter`` diff --git a/presto-docs/src/main/sphinx/develop/presto-authenticator.rst b/presto-docs/src/main/sphinx/develop/presto-authenticator.rst index 44f2781bf8819..6f11be98c76a7 100644 --- a/presto-docs/src/main/sphinx/develop/presto-authenticator.rst +++ b/presto-docs/src/main/sphinx/develop/presto-authenticator.rst @@ -12,13 +12,36 @@ Implementation ``PrestoAuthenticator`` instance. It also defines the name of this authenticator which is used by the administrator in a Presto configuration. -``PrestoAuthenticator`` contains a single method, ``createAuthenticatedPrincipal()``, -that validates the request and returns a ``Principal``, which is then +``PrestoAuthenticator`` contains a single method, ``createAuthenticatedPrincipal(Map> headers)``, +that validates the request headers and returns a ``Principal``, which is then authorized by the :doc:`system-access-control`. The implementation of ``PrestoAuthenticatorFactory`` must be wrapped as a plugin and installed on the Presto cluster. +Error Handling +-------------- + +The ``createAuthenticatedPrincipal(Map> headers)`` method can throw two types of exceptions, +depending on the authentication outcome: + +* ``AuthenticatorNotApplicableException``: + + Thrown when the required authentication header is missing or invalid. This signals + to Presto that the current authentication method is not applicable, so it should + skip this authenticator and try the next configured one. The exception message is + not returned to the user, since authentication was never intended for this request. + +* ``AccessDeniedException``: + + Thrown when the required header is present but authentication fails. In this case, + Presto will still try the next configured authenticator but the error message is + passed back to the user, indicating that the authentication attempt was valid but + unsuccessful. + +This distinction ensures that Presto can properly chain multiple authenticators +while providing meaningful feedback to the user only when appropriate. + Configuration ------------- diff --git a/presto-docs/src/main/sphinx/develop/procedures.rst b/presto-docs/src/main/sphinx/develop/procedures.rst new file mode 100644 index 0000000000000..1f91d1dc2c576 --- /dev/null +++ b/presto-docs/src/main/sphinx/develop/procedures.rst @@ -0,0 +1,507 @@ +========== +Procedures +========== + +PrestoDB's procedures allow users to perform data manipulation and management tasks. Unlike traditional databases where procedural objects are +defined by users by using SQL, the procedures in PrestoDB are a set of system routines provided by developers through Connectors. The overall type hierarchy +is illustrated in the diagram below: + +.. code-block:: text + + «abstract» BaseProcedure + |-- Procedure // (Normal) Procedure + |-- «abstract» DistributedProcedure + |-- TableDataRewriteDistributedProcedure + |-- ...... // Other future subtypes + +PrestoDB supports two categories of procedures: + +* **Normal Procedure** + +These procedures are executed directly on the Coordinator node, and PrestoDB does not build distributed execution plans for them. +They are designed mainly for administrative tasks involving table or system metadata and cache management, such as ``sync_partition_metadata`` +and ``invalidate_directory_list_cache`` in Hive Connector, or ``expire_snapshots`` and ``invalidate_statistics_file_cache`` in Iceberg Connector. + +* **Distributed Procedure** + +Procedures of this type are executed with a distributed execution plan constructed by PrestoDB, which utilizes the entire cluster of Worker nodes +for distributed computation. They are suitable for operations involving table data—such as data optimization, re-partitioning, sorting, +and pre-processing—as well as for administrative tasks that need to be executed across the Worker nodes, for instance, clearing caches on specific workers. + +The type hierarchy for Distributed Procedures is designed to be extensible. Different distributed tasks can have different invocation parameters and +are planned into differently shaped execution plans; as such, they can be implemented as distinct subtypes of ``DistributedProcedure``. + +For example, for table data rewrite tasks, PrestoDB provides the ``TableDataRewriteDistributedProcedure`` subtype. +Connector developers can leverage this subtype to implement specific data-rewrite distributed procedures—such as table data optimization, compression, +repartitioning, or sorting—for their connectors. Within the PrestoDB engine, tasks of this subtype are uniformly planned into an execution plan tree with the +following shape: + +.. code-block:: text + + TableScanNode → FilterNode → CallDistributedProcedureNode → TableFinishNode → OutputNode + +In addition, developers can implement other kinds of distributed procedures by extending the type hierarchy—defining new subtypes that are mapped to +execution plans of varying shapes. + +For further design details, see `RFC-0021 for Presto `_. + +Normal Procedure +---------------- + +To make a procedure callable, a connector must first expose it to the PrestoDB engine. PrestoDB leverages the `Guice dependency injection framework `_ +to manage procedure registration and lifecycle. A specific procedure is implemented and bound as a ``Provider``, +thus creating an instance only when it is actually needed for execution, enabling on-demand instantiation. The following steps will guide you on how to +implement and provide a procedure in a connector. + +1. Procedure Provider Class +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An implementation class must implement the ``Provider`` interface. Its constructor can use ``@Inject`` to receive any required dependencies +that are managed by Guice. The class must then implement the ``get()`` method from the Provider interface, which is responsible for constructing and +returning a new Procedure instance. + +2. Creation of a Procedure Instance +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A Procedure object requires the following parameters upon creation: + +* ``String schema`` - The schema namespace to which this procedure belongs (typically ``system`` in PrestoDB). +* ``String name`` - The name of this procedure, for example, ``expire_snapshots``. +* ``List arguments`` - The parameter declarations list for this procedure. +* ``MethodHandle methodHandle`` - PrestoDB abstracts procedure execution through ``MethodHandle``. A procedure provider implements the core logic in + a dedicated method and exposes it as a ``MethodHandle`` that is injected into the procedure instance. + +.. note:: + + * The Java method corresponding to the ``MethodHandle`` have a correspondence with the procedure parameters. + + * Its first parameter must be a session object of type ``ConnectorSession.class``. + * The subsequent parameters must have a strict one-to-one correspondence in both type and order with the + procedure parameters declared in the arguments list. + + * The method implementation for the ``MethodHandle`` must account for classloader isolation. Since PrestoDB employs a plugin isolation mechanism where each + connector has its own ClassLoader, the engine must temporarily switch to the connector's specific ClassLoader when invoking its procedure logic. This context + switch is critical to prevent ``ClassNotFoundException`` or ``NoClassDefFoundError`` issues. + +As an example, the following is the ``expire_snapshots`` procedure implemented in the Iceberg connector: + +.. code-block:: java + + public class ExpireSnapshotsProcedure + implements Provider + { + private static final MethodHandle EXPIRE_SNAPSHOTS = methodHandle( + ExpireSnapshotsProcedure.class, + "expireSnapshots", + ConnectorSession.class, + String.class, + String.class, + SqlTimestamp.class, + Integer.class, + List.class); + private final IcebergMetadataFactory metadataFactory; + + @Inject + public ExpireSnapshotsProcedure(IcebergMetadataFactory metadataFactory) + { + this.metadataFactory = requireNonNull(metadataFactory, "metadataFactory is null"); + } + + @Override + public Procedure get() + { + return new Procedure( + "system", + "expire_snapshots", + ImmutableList.of( + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("older_than", TIMESTAMP, false, null), + new Argument("retain_last", INTEGER, false, null), + new Argument("snapshot_ids", "array(bigint)", false, null)), + EXPIRE_SNAPSHOTS.bindTo(this)); + } + + public void expireSnapshots(ConnectorSession clientSession, String schema, String tableName, SqlTimestamp olderThan, Integer retainLast, List snapshotIds) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + doExpireSnapshots(clientSession, schema, tableName, olderThan, retainLast, snapshotIds); + } + } + + private void doExpireSnapshots(ConnectorSession clientSession, String schema, String tableName, SqlTimestamp olderThan, Integer retainLast, List snapshotIds) + { + // Execute the snapshot expiration for the target table using the Iceberg interface + // ...... + } + } + +3. Exposing the Procedure Provider to PrestoDB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the Guice binding module of your target connector, add a binding for the procedure provider class defined above: + +.. code-block:: java + + ...... + Multibinder> procedures = newSetBinder(binder, new TypeLiteral>() {}); + procedures.addBinding().toProvider(ExpireSnapshotsProcedure.class).in(Scopes.SINGLETON); + ...... + +During startup, the PrestoDB engine collects the procedure providers exposed by all connectors and maintains them within their +respective namespaces (for example, ``hive.system`` or ``iceberg.system``). Once startup is complete, users can invoke these procedures by specifying +the corresponding connector namespace, for example: + +.. code-block:: sql + + call iceberg.system.expire_snapshots('default', 'test_table'); + call hive.system.invalidate_directory_list_cache(); + ...... + +Distributed Procedure +--------------------- + +PrestoDB supports building distributed execution plans for certain types of procedures, enabling them to leverage the calculation resources of +the entire cluster. Since different kinds of distributed procedures may correspond to distinct execution plan shapes, extending and implementing them +should be approached at two levels: + +* For a category of procedures that share the same execution plan shape, extend a subtype of ``DistributedProcedure``. The currently supported + ``TableDataRewriteDistributedProcedure`` subtype is designed for table data rewrite operations. +* Implement a concrete distributed procedure in a connector by building upon a specific ``DistributedProcedure`` subtype. For instance, + ``rewrite_data_files`` in the Iceberg connector is built upon the ``TableDataRewriteDistributedProcedure`` subtype. + +.. important:: + The ``DistributedProcedure`` class is abstract. Connector developers cannot implement it directly. + You must build your concrete distributed procedure upon a specific **subtype** (like ``TableDataRewriteDistributedProcedure``) + that the PrestoDB engine already knows how to analyze and plan. + +Extending a DistributedProcedure Subtype +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +1. Adding a DistributedProcedureType Enum Value +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Add a new value to the ``DistributedProcedure.DistributedProcedureType`` enum, for example: ``TABLE_DATA_REWRITE``. + +This enum value is important, as it is used during both the analysis and planning phases to distinguish between different ``DistributedProcedure`` subtypes +and execute the corresponding branch logic. + +2. Creating a subclass of DistributedProcedure +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Create a new subclass of ``DistributedProcedure``, such as: + +.. code-block:: java + + public class TableDataRewriteDistributedProcedure + extends DistributedProcedure + +* In the constructor, pass the corresponding ``DistributedProcedureType`` enum value such as ``TABLE_DATA_REWRITE`` to the ``super(...)`` method. +* In addition to the base parameters required by ``BaseProcedure`` (schema, name, and arguments, which are consistent with those in ``Procedure``), + a subtype can define and is responsible for processing and validating any additional parameters it requires. + +Additionally, the following three abstract methods defined by the base class ``DistributedProcedure`` should be implemented: + +.. code-block:: java + + /** + * Creates a connector-specific, or even a distributed procedure subtype-specific context object. + * For connectors that support distributed procedures, this method is invoked at the start of a distributed procedure's execution. + * The generated procedure context is then bound to the current ConnectorMetadata, maintaining all contextual information + * throughout the execution. This context would be accessed during calls to the procedure's {@link #begin} and {@link #finish} methods. + */ + public ConnectorProcedureContext createContext(Object... arguments); + + /** + * Performs the preparatory work required when starting the execution of this distributed procedure. + * */ + public abstract ConnectorDistributedProcedureHandle begin(ConnectorSession session, + ConnectorProcedureContext procedureContext, + ConnectorTableLayoutHandle tableLayoutHandle, + Object[] arguments); + + /** + * Performs the work required for the final centralized commit, after all distributed execution tasks have completed. + * */ + public abstract void finish(ConnectorSession session, + ConnectorProcedureContext procedureContext, + ConnectorDistributedProcedureHandle procedureHandle, + Collection fragments); + +.. note:: + + At this architectural level, distributed procedure subtypes are designed to be decoupled from specific connectors. When implementing + the three aforementioned abstract methods, it is recommended to focus solely on the common logic of the subtype. Connector-specific functionality + should be abstracted into method interfaces and delegated to the final, concrete distributed procedure implementations. + +As an illustration, the ``TableDataRewriteDistributedProcedure`` subtype, which handles table data rewrite operations, is defined as follows: + +.. code-block:: java + + public class TableDataRewriteDistributedProcedure + extends DistributedProcedure + { + private final BeginCallDistributedProcedure beginCallDistributedProcedure; + private final FinishCallDistributedProcedure finishCallDistributedProcedure; + private final Function contextProvider; + + public TableDataRewriteDistributedProcedure(String schema, String name, + List arguments, + BeginCallDistributedProcedure beginCallDistributedProcedure, + FinishCallDistributedProcedure finishCallDistributedProcedure, + Function contextProvider) + { + super(TABLE_DATA_REWRITE, schema, name, arguments); + this.beginCallDistributedProcedure = requireNonNull(beginCallDistributedProcedure, "beginCallDistributedProcedure is null"); + this.finishCallDistributedProcedure = requireNonNull(finishCallDistributedProcedure, "finishCallDistributedProcedure is null"); + this.contextProvider = requireNonNull(contextProvider, "contextProvider is null"); + + // Performs subtype-specific validation and processing logic on the parameters + ...... + } + + @Override + public ConnectorDistributedProcedureHandle begin(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorTableLayoutHandle tableLayoutHandle, Object[] arguments) + { + return this.beginCallDistributedProcedure.begin(session, procedureContext, tableLayoutHandle, arguments); + } + + @Override + public void finish(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments) + { + this.finishCallDistributedProcedure.finish(session, procedureContext, procedureHandle, fragments); + } + + @Override + public ConnectorProcedureContext createContext(Object... arguments) + { + return contextProvider.apply(arguments); + } + + @FunctionalInterface + public interface BeginCallDistributedProcedure + { + ConnectorDistributedProcedureHandle begin(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorTableLayoutHandle tableLayoutHandle, Object[] arguments); + } + + @FunctionalInterface + public interface FinishCallDistributedProcedure + { + void finish(ConnectorSession session, ConnectorProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments); + } + } + +3. Processing of Subtypes in the Analysis Phase +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the ``visitCall(...)`` method of ``StatementAnalyzer``, add a branch to handle the newly defined subtypes, such as ``TABLE_DATA_REWRITE``: + +.. code-block:: java + + @Override + protected Scope visitCall(Call call, Optional scope) + { + QualifiedObjectName procedureName = analysis.getProcedureName() + .orElse(createQualifiedObjectName(session, call, call.getName(), metadata)); + ConnectorId connectorId = metadata.getCatalogHandle(session, procedureName.getCatalogName()) + .orElseThrow(() -> new SemanticException(MISSING_CATALOG, call, "Catalog %s does not exist", procedureName.getCatalogName())); + + if (!metadata.getProcedureRegistry().isDistributedProcedure(connectorId, toSchemaTableName(procedureName))) { + throw new SemanticException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + procedureName); + } + DistributedProcedure procedure = metadata.getProcedureRegistry().resolveDistributed(connectorId, toSchemaTableName(procedureName)); + Object[] values = extractParameterValuesInOrder(call, procedure, metadata, session, analysis.getParameters()); + + analysis.setUpdateInfo(call.getUpdateInfo()); + analysis.setDistributedProcedureType(Optional.of(procedure.getType())); + analysis.setProcedureArguments(Optional.of(values)); + switch (procedure.getType()) { + case TABLE_DATA_REWRITE: + TableDataRewriteDistributedProcedure tableDataRewriteDistributedProcedure = (TableDataRewriteDistributedProcedure) procedure; + + // Performs analysis on the tableDataRewriteDistributedProcedure + ...... + + break; + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Unsupported distributed procedure type: " + procedure.getType()); + } + return createAndAssignScope(call, scope, Field.newUnqualified(Optional.empty(), "rows", BIGINT)); + } + +4. Processing of Subtypes in the Planning Phase +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the ``planStatementWithoutOutput(...)`` method of ``LogicalPlanner``, when the statement type is Call, add a branch to handle newly defined subtypes +such as ``TABLE_DATA_REWRITE``: + +.. code-block:: java + + private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement statement) + { + ...... + else if (statement instanceof Call) { + checkState(analysis.getDistributedProcedureType().isPresent(), "Call distributed procedure analysis is missing"); + switch (analysis.getDistributedProcedureType().get()) { + case TABLE_DATA_REWRITE: + return createCallDistributedProcedurePlanForTableDataRewrite(analysis, (Call) statement); + default: + throw new PrestoException(NOT_SUPPORTED, "Unsupported distributed procedure type: " + analysis.getDistributedProcedureType().get()); + } + } + else { + throw new PrestoException(NOT_SUPPORTED, "Unsupported statement type " + statement.getClass().getSimpleName()); + } + } + + private RelationPlan createCallDistributedProcedurePlanForTableDataRewrite(Analysis analysis, Call statement) + { + // Builds the logical plan for the table data rewrite procedure subtype from the analysis results: + // TableScanNode → FilterNode → CallDistributedProcedureNode → TableFinishNode → OutputNode + ...... + } + +.. note:: + + If a custom plan node is required, it must subsequently be handled in the plan visitors, the optimizers, and the local execution planner. If a custom + local execution operator ultimately needs to be generated, it must be implemented within PrestoDB as well. (This part is beyond the scope of this document + and will not be elaborated further) + +Implementing a Concrete Distributed Procedure in a Specific Connector +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to normal procedures, PrestoDB uses the Guice dependency injection framework to manage the registration and lifecycle of distributed procedures, +enabling connectors to dynamically provide these callable distributed procedures to the engine. A concrete distributed procedure is implemented and bound +as a ``Provider``, which ensures an instance is created on-demand when a procedure needs to be executed. The following steps will +guide you through implementing and supplying a distributed procedure in your connector. + +1. Procedure Provider Class +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +An implementation class must implement the ``Provider`` interface. In its constructor, it can use @Inject to receive any Guice-managed +dependencies. The class is required to implement the ``get()`` method from the Provider interface, which is responsible for constructing and returning a +specific subclass instance of ``DistributedProcedure``. + +2. Creation of a DistributedProcedure Instance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The parameters required to create a ``DistributedProcedure`` subclass differ, but the following common parameters are mandatory and consistent with those +described in the normal ``Procedure`` above. + +* ``String schema`` - The schema namespace to which this procedure belongs (typically ``system`` in PrestoDB) +* ``String name`` - The name of this procedure, for example, ``rewrite_data_files`` +* ``List arguments`` - The parameter declarations list for this procedure + +The following code demonstrates how to implement ``rewrite_data_files`` for the Iceberg connector, based on the ``TableDataRewriteDistributedProcedure`` class: + +.. code-block:: java + + public class RewriteDataFilesProcedure + implements Provider + { + TypeManager typeManager; + JsonCodec commitTaskCodec; + + @Inject + public RewriteDataFilesProcedure( + TypeManager typeManager, + JsonCodec commitTaskCodec) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + } + + @Override + public DistributedProcedure get() + { + return new TableDataRewriteDistributedProcedure( + "system", + "rewrite_data_files", + ImmutableList.of( + new Argument(SCHEMA, VARCHAR), + new Argument(TABLE_NAME, VARCHAR), + new Argument("filter", VARCHAR, false, "TRUE"), + new Argument("options", "map(varchar, varchar)", false, null)), + (session, procedureContext, tableLayoutHandle, arguments) -> beginCallDistributedProcedure(session, (IcebergProcedureContext) procedureContext, (IcebergTableLayoutHandle) tableLayoutHandle, arguments), + ((session, procedureContext, tableHandle, fragments) -> finishCallDistributedProcedure(session, (IcebergProcedureContext) procedureContext, tableHandle, fragments)), + arguments -> { + checkArgument(arguments.length == 2, format("invalid number of arguments: %s (should have %s)", arguments.length, 2)); + checkArgument(arguments[0] instanceof Table && arguments[1] instanceof Transaction, "Invalid arguments, required: [Table, Transaction]"); + return new IcebergProcedureContext((Table) arguments[0], (Transaction) arguments[1]); + }); + } + + private ConnectorDistributedProcedureHandle beginCallDistributedProcedure(ConnectorSession session, IcebergProcedureContext procedureContext, IcebergTableLayoutHandle layoutHandle, Object[] arguments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + Table icebergTable = procedureContext.getTable().orElseThrow(() -> new VerifyException("No partition data for partitioned table")); + IcebergTableHandle tableHandle = layoutHandle.getTable(); + + // Performs the preparatory work required when starting the execution of ``rewrite_data_files``, + // and encapsulates the necessary information and handling logic within the ``procedureContext`` + ...... + + return new IcebergDistributedProcedureHandle( + tableHandle.getSchemaName(), + tableHandle.getIcebergTableName(), + toPrestoSchema(icebergTable.schema(), typeManager), + toPrestoPartitionSpec(icebergTable.spec(), typeManager), + getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + icebergTable.location(), + getFileFormat(icebergTable), + getCompressionCodec(session), + icebergTable.properties()); + } + } + + private void finishCallDistributedProcedure(ConnectorSession session, IcebergProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments) + { + if (fragments.isEmpty() && + procedureContext.getScannedDataFiles().isEmpty() && + procedureContext.getFullyAppliedDeleteFiles().isEmpty()) { + return; + } + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + IcebergDistributedProcedureHandle handle = (IcebergDistributedProcedureHandle) procedureHandle; + Table icebergTable = procedureContext.getTransaction().table(); + + // Performs the final atomic commit by leveraging Iceberg's `RewriteFiles` API. + // This integrates the commit information from the distributed tasks (in `commitTasks`) + // with the file change tracking (for example, `scannedDataFiles`, `fullyAppliedDeleteFiles`, `newFiles`) + // maintained within the `procedureContext`. + List commitTasks = fragments.stream() + .map(slice -> commitTaskCodec.fromJson(slice.getBytes())) + .collect(toImmutableList()); + ...... + + RewriteFiles rewriteFiles = procedureContext.getTransaction().newRewrite(); + Set scannedDataFiles = procedureContext.getScannedDataFiles(); + Set fullyAppliedDeleteFiles = procedureContext.getFullyAppliedDeleteFiles(); + rewriteFiles.rewriteFiles(scannedDataFiles, fullyAppliedDeleteFiles, newFiles, ImmutableSet.of()); + + ...... + rewriteFiles.commit(); + } + } + } + +3. Exposing the DistributedProcedure Provider to PrestoDB +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In the Guice binding module of your target connector, add a binding for the distributed procedure provider class defined above: + +.. code-block:: java + + Multibinder> procedures = newSetBinder(binder, new TypeLiteral>() {}); + procedures.addBinding().toProvider(RewriteDataFilesProcedure.class).in(Scopes.SINGLETON); + ...... + +During startup, the PrestoDB engine collects the distributed procedure providers the same way as normal procedure providers exposed by +all connectors and maintains them within their respective namespaces (for example, ``hive.system`` or ``iceberg.system``). Once startup is complete, users +can invoke these distributed procedures by specifying the corresponding connector namespace, for example: + +.. code-block:: sql + + call iceberg.system.rewrite_data_files('default', 'test_table'); + call iceberg.system.rewrite_data_files(table_name => 'test_table', schema => 'default', filter => 'c1 > 3'); \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/develop/release-process.rst b/presto-docs/src/main/sphinx/develop/release-process.rst new file mode 100644 index 0000000000000..1dbbcf6fdce85 --- /dev/null +++ b/presto-docs/src/main/sphinx/develop/release-process.rst @@ -0,0 +1,280 @@ +=============== +Release Process +=============== + +Overview +======== + +Presto releases are managed by volunteer committers. Releases occur approximately every 2 months, with extended testing periods to ensure stability. + +Release Cadence +=============== + +Releases target a 2-month cycle. Actual timing depends on: + +* Release shepherd availability +* Testing feedback +* Critical issues requiring delay +* Contributing organization resources + +Schedules adjust based on volunteer availability. + +Release Quality Model +===================== + +**Note:** Trunk is not stable. Do not use master branch in production. + +**Extended Release Candidate Period** + * 2-4 week RC period for testing + * Issues fixed before final release + * Fixes verified in existing RC (no new RCs) + +**Community Testing** + * Organizations test RCs in their environments + * Weekly edge releases from master for early testing + * Join #releases in `Presto Slack `_ to participate + +**Automated Testing** + * Unit tests on all commits and RCs + * Product tests with full cluster deployment + * Shared connector tests with `AbstractTestQueries `_ + * CI for basic stability checks + +**Testing Contributions Needed** + * Additional `product test scenarios `_ + * Performance regression testing (shadow traffic) + * Test coverage for new features + +Version Numbering +================= + +* Format: ``0.XXX`` (for example, 0.293, 0.294) +* Major version fixed at 0 +* Minor version increments each release +* Patch releases (for example, 0.293.1) for critical fixes only + +Not semantic versioning. + +Release Types +============= + +**Regular Releases** + * Every ~2 months + * New features, improvements, bug fixes + * Extended RC testing period + +**Patch Releases** + * Critical issues only: + + - Security vulnerabilities (upon request - default: upgrade) + - Data correctness + - Performance regressions + - Stability issues + + * Based on previous stable release + * Minimal changes + +**Release Candidates** + * 2-4 week testing period + * One RC per release + * Fixes verified in existing RC + +**Edge Releases** + * Weekly from master + * Early access to features + * Testing only - not for production + +Release Shepherd Responsibilities +================================= + +Requirements: + +* Must be a committer +* Must understand codebase for judgment calls and rewrites + +Responsibilities: + +* Complete release process +* Ensure release notes follow guidelines +* Fix release note issues +* Cut and deploy release +* Coordinate with community +* Make go/no-go decisions + +See `Release Shepherding `_ for schedule and details. + +Contributing to Release Quality +=============================== + +**Testing** + * Run RCs in test environments + * Report issues with reproduction steps + * Verify cherry-picked fixes + +**Test Development** + * Add product tests for uncovered scenarios + * Unit tests for bug fixes + * Performance benchmarks (`pbench `_) + +**Documentation** + * Document behavior changes + * Write clear release notes + * Note compatibility issues + +**Code Review** + * Review PRs for correctness + * Identify compatibility issues + * Suggest test coverage + +Backward Compatibility Guidelines +================================= + +**Must Maintain Compatibility:** + +* **Client Libraries**: Evolve slowly. New server features must work with older clients. +* **SQL Syntax**: Keep stable. Deprecate with warnings before removal. +* **SPI**: Stable for connectors/plugins. Use ``@Deprecated`` for at least two releases before removal. + When adding new SPI methods, provide reasonable defaults to minimize connector updates. + Documented SPI interfaces must remain stable even without public implementations. + Exception: Undocumented AND unused SPI aspects. +* **Configuration**: Session and config properties need deprecation paths. Provide aliases for renames. + +**Can Change:** + +* Internal APIs (not SPI) +* Performance characteristics +* Query plans +* Default config values (document changes) + +**Developer Requirements:** + +* Document breaking changes in release notes +* Provide migration paths +* Revert inadvertent breaking changes to client protocol, SQL, SPI, or config + +Revert Guidelines +================= + +When to Revert +^^^^^^^^^^^^^^ + +Data Correctness Issues or Critical Bugs +---------------------------------------- + +Any change that introduces data correctness issues, major crashes, or severe stability problems +must be reverted immediately if a fix is not quick, especially near the RC finalization window. + +**Must revert**: + +- Wrong query results, data corruption, frequent crashes +- Memory leaks or resource exhaustion in common code paths + +**Should revert**: + +- Performance regressions of more than 50% in common queries + +Backwards Incompatible Client Changes +------------------------------------- + +Client libraries evolve slowly and many users cannot easily upgrade clients. Breaking changes to +the client protocol, SQL syntax, or session/config properties without proper migration paths must +be reverted if they cannot be fixed quickly, particularly near RC finalization: + +**Must revert**: + +- Breaking client-server protocol compatibility +- Removing SQL syntax without deprecation warnings + +**Should revert**: + +- Changing session/config property behavior without aliases + +Backwards Incompatible SPI Changes Without Migration Path +--------------------------------------------------------- + +If a backwards incompatible change to the SPI is discovered that lacks the required migration path +(for example, no deprecation period, no reasonable defaults for new methods), the change should be reverted +if a proper migration path cannot be added quickly, especially near RC finalization. Use these criteria: + +**Must revert**: + +- Breaking documented SPI interfaces or core connectors (Hive, Iceberg, Delta, Kafka) +- Breaking maintained connectors with active usage in the repository + +**Should revert**: + +- Breaking experimental or rarely-used connectors (weigh maintenance burden) + +Consider both documented interfaces and public implementations in the Presto repository. +Create a GitHub issue marked as "release blocker" to alert the release shepherd. + +When NOT to Revert +^^^^^^^^^^^^^^^^^^ + +If the fix is simpler than the revert and can be completed quickly (especially before RC finalization), prefer fixing forward. + +**Fix forward**: + +- Typos +- Logging issues +- Minor UI problems +- Test failures that don't affect production code +- Documentation errors or missing documentation + +Performance Trade-offs +---------------------- + +Performance changes with mixed impact require case-by-case evaluation based on community feedback: + +- **Evaluate carefully**: What's rare for one user may be critical for another +- **Consider configuration**: Can the optimization be made optional by using session properties? +- **Gather data**: Solicit feedback from multiple organizations during RC testing + +If multiple users report significant regressions, consider reverting or adding a feature flag. +Always document performance changes and workarounds in release notes. + +Proprietary or Hidden Infrastructure Dependencies +------------------------------------------------- + +Changes cannot be reverted based on impacts to proprietary infrastructure, private forks, or +non-public connectors or plugins. All revert decisions must be justifiable using only publicly +available code, documentation, and usage patterns visible in the open source project. + +Feature Additions With Minor Issues +----------------------------------- + +New features that don't affect existing functionality should be fixed. +Consider adding feature flags if stability is a concern. + +How to Revert +------------- + +- Create a GitHub issue that describes the problem and label it "release blocker" immediately +- Raise a PR to revert the problematic change and link to the issue + +Release Communication +===================== + +* `Presto Slack `_ #releases channel +* `GitHub Releases `_ +* `Mailing List `_ +* `Release notes in docs `_ + +Best Practices for Developers +============================= + +* Avoid risky merges near release cuts +* Create as many automated tests as possible, and for large changes, consider product tests and manual testing +* Always consider how new features are enabled, whether they're enabled by default, and if not opt-in through SPI or SQL, gate them with a session property +* Document breaking changes in release notes +* Monitor #releases channel during release cycles +* Fix release blockers promptly + +Getting Involved +================ + +* Join #releases in `Presto Slack `_ +* Test release candidates +* Volunteer as a `release shepherd `_ (committers only) +* Contribute tests +* Share production experiences \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/ext/download.py b/presto-docs/src/main/sphinx/ext/download.py index 3cb83a8ce083f..cea1fd2efb006 100644 --- a/presto-docs/src/main/sphinx/ext/download.py +++ b/presto-docs/src/main/sphinx/ext/download.py @@ -41,25 +41,33 @@ def maven_download(group, artifact, version, packaging, classifier): filename = maven_filename(artifact, version, packaging, classifier) return base + '/'.join((group_path, artifact, version, filename)) +def github_download(group, artifact, version, packaging, classifier): + base = 'https://github.com/prestodb/presto/releases/download/' + filename = maven_filename(artifact, version, packaging, classifier) + return base + '/'.join((version, filename)) def setup(app): # noinspection PyDefaultArgument,PyUnusedLocal - def download_link_role(role, rawtext, text, lineno, inliner, options={}, content=[]): - version = app.config.release + def create_download_role(download_func): + def download_link_role(role, rawtext, text, lineno, inliner, options={}, content=[]): + version = app.config.release + + if not text in ARTIFACTS: + inliner.reporter.error('Unsupported download type: ' + text) + return [], [] - if not text in ARTIFACTS: - inliner.reporter.error('Unsupported download type: ' + text) - return [], [] + artifact, packaging, classifier = ARTIFACTS[text] - artifact, packaging, classifier = ARTIFACTS[text] + title = maven_filename(artifact, version, packaging, classifier) + uri = download_func(GROUP_ID, artifact, version, packaging, classifier) - title = maven_filename(artifact, version, packaging, classifier) - uri = maven_download(GROUP_ID, artifact, version, packaging, classifier) + node = nodes.reference(title, title, internal=False, refuri=uri) - node = nodes.reference(title, title, internal=False, refuri=uri) + return [node], [] + return download_link_role - return [node], [] - app.add_role('maven_download', download_link_role) + app.add_role('maven_download', create_download_role(maven_download)) + app.add_role('github_download', create_download_role(github_download)) return { 'parallel_read_safe': True, diff --git a/presto-docs/src/main/sphinx/functions/array.rst b/presto-docs/src/main/sphinx/functions/array.rst index 7cd3c60589dde..a63e7fc0cd99e 100644 --- a/presto-docs/src/main/sphinx/functions/array.rst +++ b/presto-docs/src/main/sphinx/functions/array.rst @@ -198,6 +198,18 @@ Array Functions -1, IF(cardinality(x) = cardinality(y), 0, 1))); -- [[1, 2], [2, 3, 1], [4, 2, 1, 4]] +.. function:: array_sort(array(T), function(T,U)) -> array(T) + + Sorts and returns the ``array`` using a lambda function to extract sorting keys. The function is applied + to each element of the array to produce a key, and the array is sorted based on these keys in ascending order. + Null array elements and null keys are placed at the end. :: + + SELECT array_sort(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> length(x)); -- ['pear', 'kiwi', 'apple', 'banana'] + SELECT array_sort(ARRAY[5, 20, 3, 9, 100], x -> x); -- [3, 5, 9, 20, 100] + SELECT array_sort(ARRAY['apple', NULL, 'banana', NULL], x -> length(x)); -- ['apple', 'banana', NULL, NULL] + SELECT array_sort(ARRAY[CAST(0.0 AS DOUBLE), CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE)], x -> x); -- [-Infinity, 0.0, Infinity, NaN] + SELECT array_sort(ARRAY[ROW('a', 3), ROW('b', 1), ROW('c', 2)], x -> x[2]); -- [ROW('b', 1), ROW('c', 2), ROW('a', 3)] + .. function:: array_sort_desc(x) -> array Returns the ``array`` sorted in the descending order. Elements of the ``array`` must be orderable. @@ -207,6 +219,18 @@ Array Functions SELECT array_sort_desc(ARRAY [null, 100, null, 1, 10, 50]); -- [100, 50, 10, 1, null, null] SELECT array_sort_desc(ARRAY [ARRAY ["a", null], null, ARRAY ["a"]); -- [["a", null], ["a"], null] +.. function:: array_sort_desc(array(T), function(T,U)) -> array(T) + + Sorts and returns the ``array`` in descending order using a lambda function to extract sorting keys. + The function is applied to each element of the array to produce a key, and the array is sorted based + on these keys in descending order. Null array elements and null keys are placed at the end. :: + + SELECT array_sort_desc(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> length(x)); -- ['banana', 'apple', 'pear', 'kiwi'] + SELECT array_sort_desc(ARRAY[5, 20, 3, 9, 100], x -> x); -- [100, 20, 9, 5, 3] + SELECT array_sort_desc(ARRAY['apple', NULL, 'banana', NULL], x -> length(x)); -- ['banana', 'apple', NULL, NULL] + SELECT array_sort_desc(ARRAY[CAST(0.0 AS DOUBLE), CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE)], x -> x); -- [NaN, Infinity, 0.0, -Infinity] + SELECT array_sort_desc(ARRAY[ROW('a', 3), ROW('b', 1), ROW('c', 2)], x -> x[2]); -- [ROW('a', 3), ROW('c', 2), ROW('b', 1)] + .. function:: array_split_into_chunks(array(T), int) -> array(array(T)) Returns an ``array`` of arrays splitting the input ``array`` into chunks of given length. @@ -234,6 +258,28 @@ Array Functions SELECT array_top_n(ARRAY [1, 100], 5); -- [100, 1] SELECT array_top_n(ARRAY ['a', 'zzz', 'zz', 'b', 'g', 'f'], 3); -- ['zzz', 'zz', 'g'] +.. function:: array_top_n(array(T), int, function(T,T,int)) -> array(T) + + Returns an array of the top ``n`` elements from a given ``array`` using the specified comparator ``function``. + The comparator will take two nullable arguments representing two nullable elements of the ``array``. It returns -1, 0, or 1 + as the first nullable element is less than, equal to, or greater than the second nullable element. + If the comparator function returns other values (including ``NULL``), the query will fail and raise an error. + If ``n`` is larger than the size of the given ``array``, the returned list will be the same size as the input instead of ``n``. :: + + SELECT array_top_n(ARRAY [100, 1, 3, -10, 6, -5], 3, (x, y) -> IF(abs(x) < abs(y), -1, IF(abs(x) = abs(y), 0, 1))); -- [100, -10, 6] + SELECT array_top_n(ARRAY [CAST(ROW(1, 2) AS ROW(x INT, y INT)), CAST(ROW(0, 11) AS ROW(x INT, y INT)), CAST(ROW(5, 10) AS ROW(x INT, y INT))], 2, (a, b) -> IF(a.x*a.y < b.x*b.y, -1, IF(a.x*a.y = b.x*b.y, 0, 1))); -- [ROW(5, 10), ROW(1, 2)] + +.. function:: array_transpose(array(array(T))) -> array(array(T)) + + Returns a transpose of a 2D array (matrix), where rows become columns and columns become rows. + Converts ``a[x][y]`` to ``transpose(a)[y][x]``. All rows in the input array must have the same length, otherwise the function will fail with an error. + Returns an empty array if the input is empty or if all rows are empty. :: + + SELECT array_transpose(ARRAY [ARRAY [1, 2, 3], ARRAY [4, 5, 6]]) -- [[1, 4], [2, 5], [3, 6]] + SELECT array_transpose(ARRAY [ARRAY ['a', 'b'], ARRAY ['c', 'd'], ARRAY ['e', 'f']]) -- [['a', 'c', 'e'], ['b', 'd', 'f']] + SELECT array_transpose(ARRAY [ARRAY [1]]) -- [[1]] + SELECT array_transpose(ARRAY []) -- [] + .. function:: arrays_overlap(x, y) -> boolean Tests if arrays ``x`` and ``y`` have any non-null elements in common. diff --git a/presto-docs/src/main/sphinx/functions/binary.rst b/presto-docs/src/main/sphinx/functions/binary.rst index b878e31b822c0..e38ccc731d60f 100644 --- a/presto-docs/src/main/sphinx/functions/binary.rst +++ b/presto-docs/src/main/sphinx/functions/binary.rst @@ -146,6 +146,10 @@ Binary Functions Computes the xxhash64 hash of ``binary``. +.. function:: xxhash64(binary, bigint) -> varbinary + + Computes the xxhash64 hash of ``binary`` with seed ``bigint``. + .. function:: spooky_hash_v2_32(binary) -> varbinary Computes the 32-bit SpookyHashV2 hash of ``binary``. @@ -173,4 +177,4 @@ Binary Functions .. function:: reverse(binary) -> varbinary :noindex: - Returns ``binary`` with the bytes in reverse order. \ No newline at end of file + Returns ``binary`` with the bytes in reverse order. diff --git a/presto-docs/src/main/sphinx/functions/geospatial.rst b/presto-docs/src/main/sphinx/functions/geospatial.rst index da603f6a71f5e..6dc9f7a2ddeca 100644 --- a/presto-docs/src/main/sphinx/functions/geospatial.rst +++ b/presto-docs/src/main/sphinx/functions/geospatial.rst @@ -268,7 +268,7 @@ Accessors Returns the minimum convex geometry that encloses all input geometries. -.. function:: ST_CoordDim(Geometry) -> bigint +.. function:: ST_CoordDim(Geometry) -> tinyint Return the coordinate dimension of the geometry. diff --git a/presto-docs/src/main/sphinx/functions/khyperloglog.rst b/presto-docs/src/main/sphinx/functions/khyperloglog.rst index 94fa4c78e15ca..9a449f83ec0e0 100644 --- a/presto-docs/src/main/sphinx/functions/khyperloglog.rst +++ b/presto-docs/src/main/sphinx/functions/khyperloglog.rst @@ -51,13 +51,14 @@ Functions of the HyperLogLog that is mapped from the MinHash bucket that corresponds to ``x'``. This function returns a histogram that represents the uniqueness distribution, the X-axis being the ``uniqueness`` and the Y-axis being the relative - frequency of ``x`` values. + frequency of ``x`` values. The histogram size defaults to the current size of the + MinHash structure in the ``KHyperLogLog`` sketch. .. function:: uniqueness_distribution(khll, histogramSize) -> map - Returns the uniqueness histogram with the given amount of buckets. If omitted, - the value defaults to 256. All ``uniqueness`` values greater than ``histogramSize`` are - accumulated in the last bucket. + Returns the uniqueness histogram with the given number of buckets, ``histogramSize``. + All ``uniqueness`` values greater than ``histogramSize`` are accumulated + in the last bucket. .. function:: reidentification_potential(khll, threshold) -> double diff --git a/presto-docs/src/main/sphinx/functions/map.rst b/presto-docs/src/main/sphinx/functions/map.rst index bc9a62bc5cd2f..8205869e504db 100644 --- a/presto-docs/src/main/sphinx/functions/map.rst +++ b/presto-docs/src/main/sphinx/functions/map.rst @@ -132,7 +132,7 @@ Map Functions Returns top ``n`` keys in the map ``x`` by sorting its values in descending order. If two or more keys have equal values, the higher key takes precedence. ``n`` must be a non-negative integer.:: - SELECT map_top_n_keys_by_value(map(ARRAY['a', 'b', 'c'], ARRAY[2, 1, 3]), 2) --- ['c', 'a'] + SELECT map_keys_by_top_n_values(map(ARRAY['a', 'b', 'c'], ARRAY[2, 1, 3]), 2) --- ['c', 'a'] .. function:: map_top_n(x(K,V), n) -> map(K, V) @@ -214,3 +214,14 @@ Map Functions SELECT transform_values(MAP(ARRAY ['a', 'b'], ARRAY [1, 2]), (k, v) -> k || CAST(v as VARCHAR)); -- {a -> a1, b -> b2} SELECT transform_values(MAP(ARRAY [1, 2], ARRAY [1.0, 1.4]), -- {1 -> one_1.0, 2 -> two_1.4} (k, v) -> MAP(ARRAY[1, 2], ARRAY['one', 'two'])[k] || '_' || CAST(v AS VARCHAR)); + +.. function:: map_int_keys_to_array(map(int,V)) -> array(V) + Returns an ``array`` of values from the ``map`` with value at indexed by the original keys from ``map``:: + SELECT MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[3, 5, 6, 9], ARRAY['a', 'b', 'c', 'd'])) -> ARRAY[null, null, 'a', null, 'b', 'c', null, null, 'd'] + SELECT MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[3, 5, 6, 9], ARRAY['a', null, 'c', 'd'])) -> ARRAY[null, null, 'a', null, null, 'c', 'd'] + +.. function:: array_to_map_int_keys(array(v)) -> map(int, v) + Returns an ``map`` with indices of all non-null values from the ``array`` as keys and element at the specified index as the value:: + SELECT ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, 6, 9] AS ARRAY)) -> MAP(ARRAY[1, 2, 3,4], ARRAY[3, 5, 6, 9]) + SELECT ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, null, 6, 9] AS ARRAY)) -> MAP(ARRAY[1, 2, 4, 5], ARRAY[3, 5, 6, 9]) + SELECT ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, null, 6, 9, null, null, 1] AS ARRAY)) -> MAP(ARRAY[1, 2, 4, 5, 8], ARRAY[3, 5, 6, 9, 1]) \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/functions/math.rst b/presto-docs/src/main/sphinx/functions/math.rst index ebea116a3bf3c..b3623502e502e 100644 --- a/presto-docs/src/main/sphinx/functions/math.rst +++ b/presto-docs/src/main/sphinx/functions/math.rst @@ -40,6 +40,41 @@ Mathematical Functions SELECT cosine_similarity(MAP(ARRAY['a'], ARRAY[1.0]), MAP(ARRAY['a'], ARRAY[2.0])); -- 1.0 +.. function:: cosine_similarity(x, y) -> double + + Returns the cosine similarity between the arrays ``x`` and ``y``. + If the input arrays have different sizes or if the input arrays contain a null, the function throws user error:: + + SELECT cosine_similarity(ARRAY[1.2], ARRAY[2.0]); -- 1.0 + +.. function:: l2_squared(array(real), array(real)) -> real + + Returns the squared `Euclidean distance `_ between the vectors represented as array(real). + If the input arrays have different sizes or if the input arrays contain a null, the function throws user error:: + + SELECT l2_squared(ARRAY[1.0], ARRAY[2.0]); -- 1.0 + +.. function:: l2_squared(array(double), array(double)) -> double + + Returns the squared `Euclidean distance `_ between the vectors represented as array(double). + If the input arrays have different sizes or if the input arrays contain a null, the function throws user error:: + + SELECT l2_squared(ARRAY[1.0], ARRAY[2.0]); -- 1.0 + +.. function:: dot_product(array(real), array(real)) -> real + + Returns the dot product of two vectors represented as array(real). + If the input arrays have different sizes or if the input arrays contain a null, the function throws user error:: + + SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); -- 11.0 + +.. function:: dot_product(array(double), array(double)) -> double + + Returns the dot product of two vectors represented as array(double). + If the input arrays have different sizes or if the input arrays contain a null, the function throws user error:: + + SELECT dot_product(ARRAY[1.0, 2.0], ARRAY[3.0, 4.0]); -- 11.0 + .. function:: degrees(x) -> double Converts angle ``x`` in radians to degrees. @@ -231,9 +266,14 @@ Probability Functions: cdf Compute the Poisson cdf with given lambda (mean) parameter: P(N <= value; lambda). The lambda parameter must be a positive real number (of type DOUBLE) and value must be a non-negative integer. +.. function:: t_cdf(df, value) -> double + + Compute the Student's t cdf with given degrees of freedom: P(N < value; df). + The degrees of freedom must be a positive real number and value must be a real value. + .. function:: weibull_cdf(a, b, value) -> double - Compute the Weibull cdf with given parameters a, b: P(N <= value). The ``a`` + Compute the Weibull cdf with given parameters a, b: P(N <= value). The ``a`` and ``b`` parameters must be positive doubles and ``value`` must also be a double. @@ -272,7 +312,7 @@ Probability Functions: inverse_cdf .. function:: inverse_f_cdf(df1, df2, p) -> double - Compute the inverse of the F cdf with a given df1 (numerator degrees of freedom) and df2 (denominator degrees of freedom) parameters + Compute the inverse of the F cdf with a given df1 (numerator degrees of freedom) and df2 (denominator degrees of freedom) parameters for the cumulative probability (p): P(N < n). The numerator and denominator df parameters must be positive real numbers. The probability p must lie on the interval [0, 1]. @@ -297,6 +337,12 @@ Probability Functions: inverse_cdf The lambda parameter must be a positive real number (of type DOUBLE). The probability p must lie on the interval [0, 1). +.. function:: inverse_t_cdf(df, p) -> double + + Compute the inverse of the Student's t cdf with given degrees of freedom for the cumulative + probability (p): P(N < n). The degrees of freedom must be a positive real value. + The probability p must lie on the interval [0, 1]. + .. function:: inverse_weibull_cdf(a, b, p) -> double Compute the inverse of the Weibull cdf with given parameters ``a``, ``b`` for the probability ``p``. diff --git a/presto-docs/src/main/sphinx/functions/tdigest.rst b/presto-docs/src/main/sphinx/functions/tdigest.rst index 2694249e69896..2b13a34e609a9 100644 --- a/presto-docs/src/main/sphinx/functions/tdigest.rst +++ b/presto-docs/src/main/sphinx/functions/tdigest.rst @@ -58,6 +58,11 @@ Functions T-digest and array of values between 0 and 1 which represent the quantiles to return. +.. function:: quantiles_at_values(tdigest, values) -> array + + Returns an array of approximate quantile numbers between 0 and 1 from the T-digest + and array of ``values``. + .. function:: trimmed_mean(tdigest, lower_quantile, upper_quantile) -> double Returns an estimate of the mean, excluding portions of the distribution diff --git a/presto-docs/src/main/sphinx/installation.rst b/presto-docs/src/main/sphinx/installation.rst index c67f47b4324c4..daef03acb5288 100644 --- a/presto-docs/src/main/sphinx/installation.rst +++ b/presto-docs/src/main/sphinx/installation.rst @@ -6,6 +6,6 @@ Installation :maxdepth: 1 installation/deployment - installation/deploy-docker installation/deploy-brew + installation/deploy-docker installation/deploy-helm diff --git a/presto-docs/src/main/sphinx/installation/deploy-brew.rst b/presto-docs/src/main/sphinx/installation/deploy-brew.rst index 51cb3008da70e..3f7b768535753 100644 --- a/presto-docs/src/main/sphinx/installation/deploy-brew.rst +++ b/presto-docs/src/main/sphinx/installation/deploy-brew.rst @@ -1,37 +1,31 @@ -===================================== -Deploy Presto on a Mac using Homebrew -===================================== +=========================== +Deploy Presto with Homebrew +=========================== -- If you are deploying Presto on an Intel Mac, see `Deploy Presto on an Intel Mac using Homebrew`_. +This guide explains how to install and get started with Presto on macOS, Linux or WSL2 using the Homebrew package manager. -- If you are deploying Presto on an Apple Silicon Mac that has an M1 or M2 chip, see `Deploy Presto on an Apple Silicon Mac using Homebrew`_. +Prerequisites +------------- -Deploy Presto on an Intel Mac using Homebrew --------------------------------------------- -*Note*: These steps were developed and tested on Mac OS X on Intel. These steps will not work with Apple Silicon (M1 or M2) chips. - -Following these steps, you will: - -- install the Presto service and CLI on an Intel Mac using `Homebrew `_ -- start and stop the Presto service -- start the Presto CLI +`Homebrew `_ installed. Install Presto -^^^^^^^^^^^^^^ +-------------- -Follow these steps to install Presto on an Intel Mac using `Homebrew `_. +Run the following command to install the latest version of Presto using the `Homebrew Formulae `_: -1. If you do not have brew installed, run the following command: +.. code-block:: none - ``/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"`` + brew install prestodb -2. To install Presto, run the following command: +Homebrew installs packages in the ``Cellar`` directory, which can be found with this command: - ``brew install prestodb`` +.. code-block:: none - Presto is installed in the directory */usr/local/Cellar/prestodb/.* + brew --cellar -The following files are created in the *libexec/etc* directory in the Presto install directory: +The directory ``$(brew --cellar)/prestodb//libexec`` contains the Presto files used to run and configure the service. +For example, the ``etc`` directory within the Presto installation contains the following default configuration files: - node.properties - jvm.config @@ -39,35 +33,28 @@ The following files are created in the *libexec/etc* directory in the Presto ins - log.properties - catalog/jmx.properties -For example, the full path to the node.properties file is */usr/local/Cellar/prestodb//libexec/etc/node.properties*. - -The Presto CLI is installed in the *bin* directory of the Presto install directory: */usr/local/Cellar/prestodb//bin*. - -The executables are added to */usr/local/bin* path and should be available as part of $PATH. - Start and Stop Presto -^^^^^^^^^^^^^^^^^^^^^ +--------------------- -To start Presto, use the ``presto-server`` helper script. +Presto is installed with the ``presto-server`` helper script, which simplifies managing the cluster. +For example, run the following command to start the Presto service in the foreground: -To start the Presto service in the background, run the following command: - -``presto-server start`` +.. code-block:: none -To start the Presto service in the foreground, run the following command: + presto-server run -``presto-server run`` +To stop Presto from running in the foreground, press ``Ctrl + C`` until the terminal prompt appears, or close the terminal. -To stop the Presto service in the background, run the following command: +For more available commands and options, use help: -``presto-server stop`` +.. code-block:: none -To stop the Presto service in the foreground, close the terminal or select Ctrl + C until the terminal prompt is shown. + presto-server --help Open the Presto Console -^^^^^^^^^^^^^^^^^^^^^^^ +----------------------- -After starting Presto, you can access the web UI at the default port ``8080`` using the following link in a browser: +After starting the service, Presto Console can be accessible at the default port ``8080`` using the following link in a browser: .. code-block:: none @@ -79,117 +66,23 @@ After starting Presto, you can access the web UI at the default port ``8080`` us For more information about the Presto Console, see :doc:`/clients/presto-console`. Start the Presto CLI -^^^^^^^^^^^^^^^^^^^^ +-------------------- The Presto CLI is a terminal-based interactive shell for running queries, and is a `self-executing `_ JAR file that acts like a normal UNIX executable. -The Presto CLI is installed in the *bin* directory of the Presto install directory: */usr/local/Cellar/prestodb//bin*. - -To run the Presto CLI, run the following command: - -``presto`` - -The Presto CLI starts and displays the prompt ``presto>``. - -For more information, see :doc:`/clients/presto-cli`. - -Deploy Presto on an Apple Silicon Mac using Homebrew ----------------------------------------------------- -*Note*: These steps were developed and tested on Mac OS X on Apple Silicon. These steps will not work with Intel chips. - -Following these steps, you will: - -- install the Presto service and CLI on an Apple Silicon Mac using `Homebrew `_ -- start and stop the Presto service -- start the Presto CLI - -Install Presto -^^^^^^^^^^^^^^ - -Follow these steps to install Presto on an Apple Silicon Mac using `Homebrew `_. - -1. If you do not have brew installed, run the following command: - - ``arch -x86_64 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"`` - - This installs Homebrew into ``/usr/local/bin``. - - *Note*: The default installation of Homebrew on Apple Silicon is to ``/opt/homebrew``. - -2. To allow the shell to look for Homebrew in ``/usr/local/bin`` before it looks for Homebrew in ``/opt/homebrew``, run the following command: - - ``export PATH=/usr/local/bin:$PATH`` - -3. To install Presto, run the following command: - - ``arch -x86_64 brew install prestodb`` - - Presto is installed in the directory */usr/local/Cellar/prestodb/.* The executables ``presto`` - and ``presto-server`` are installed in ``/usr/local/bin/``. - -The following files are created in the *libexec/etc* directory in the Presto install directory: - -- node.properties -- jvm.config -- config.properties -- log.properties -- catalog/jmx.properties - -For example, the full path to the node.properties file is */usr/local/Cellar/prestodb//libexec/etc/node.properties*. - -The Presto CLI is installed in the *bin* directory of the Presto install directory: */usr/local/Cellar/prestodb//bin*. - -The executables are added to */usr/local/bin* path and should be available as part of $PATH. - -Start and Stop Presto -^^^^^^^^^^^^^^^^^^^^^ - -To start Presto, use the ``presto-server`` helper script. - -To start the Presto service in the background, run the following command: - -``arch -x86_64 presto-server start`` - -To start the Presto service in the foreground, run the following command: - -``arch -x86_64 presto-server run`` - -To stop the Presto service in the background, run the following command: - -``presto-server stop`` - -To stop the Presto service in the foreground, close the terminal or select Ctrl + C until the terminal prompt is shown. - -Open the Presto Console -^^^^^^^^^^^^^^^^^^^^^^^ - -After starting Presto, you can access the web UI at the default port ``8080`` using the following link in a browser: +The Presto CLI is installed in the directory ``$(brew --cellar)/prestodb//bin``. +To run the Presto CLI, use the following command: .. code-block:: none - http://localhost:8080 - -.. figure:: ../images/presto_console.png - :align: center - -For more information about the Presto Console, see :doc:`/clients/presto-console`. - -Start the Presto CLI -^^^^^^^^^^^^^^^^^^^^ - -The Presto CLI is a terminal-based interactive shell for running queries, and is a -`self-executing `_ -JAR file that acts like a normal UNIX executable. - -The Presto CLI is installed in the *bin* directory of the Presto install directory: */usr/local/Cellar/prestodb//bin*. -The executable ``presto`` is installed in ``/usr/local/bin/``. + presto -To run the Presto CLI, run the following command: +The Presto CLI starts and displays its prompt: -``presto`` +.. code-block:: none -The Presto CLI starts and displays the prompt ``presto>``. + presto> -For more information, see :doc:`/clients/presto-cli`. \ No newline at end of file +For more information, see :doc:`/clients/presto-cli`. diff --git a/presto-docs/src/main/sphinx/installation/deploy-docker.rst b/presto-docs/src/main/sphinx/installation/deploy-docker.rst index b6a824916c40e..d1a7882c7fced 100644 --- a/presto-docs/src/main/sphinx/installation/deploy-docker.rst +++ b/presto-docs/src/main/sphinx/installation/deploy-docker.rst @@ -1,60 +1,66 @@ -================================= -Deploy Presto From a Docker Image -================================= +========================= +Deploy Presto with Docker +========================= + +This guide explains how to install and get started with Presto using Docker. + +.. note:: + + These steps were developed and tested on Mac OS X, on both Intel and Apple Silicon chips. -These steps were developed and tested on Mac OS X, on both Intel and Apple Silicon chips. +Prepare the container environment +================================= -Follow these steps to: +If Docker is already installed, skip to step 4 to verify the setup. +Otherwise, follow the instructions below to install Docker and Colima using Homebrew or choose an alternative method. -- install the command line tools for brew, docker, and `Colima `_ -- verify your Docker setup -- pull the Docker image of the Presto server -- start your local Presto server +1. Install `Homebrew `_ if it is not already present on the system. -Installing brew, Docker, and Colima -=================================== +2. Install the Docker command line and `Colima `_ tools via the following command: -This task shows how to install brew, then to use brew to install Docker and Colima. + .. code-block:: shell -Note: If you have Docker installed you can skip steps 1-3, but you should -verify your Docker setup by running the command in step 4. + brew install docker colima -1. If you do not have brew installed, run the following command: +3. Run the following command to start Colima with defaults: - ``/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)"`` + .. code-block:: shell -2. To install the Docker command line and `Colima `_ tools, run the following command: + colima start - ``brew install docker colima`` + .. note:: -3. Run the following command: + The default VM created by Colima uses 2 CPUs, 2GiB memory and 100GiB storage. To customize the VM resources, + see the Colima README for `Customizing the VM `_. - ``colima start`` +4. Verify the local setup by running the following command: - *Note*: The default VM created by Colima uses 2 CPUs, 2GB memory and 60GB storage. To customize the VM resources, - see the Colima README for `Customizing the VM `_. + .. code-block:: shell -4. To verify your local setup, run the following command: + docker run hello-world - ``docker run hello-world`` + The following output confirms a successful installation. - If you see a response similar to the following, you are ready. + .. code-block:: shell + :class: no-copy - ``Hello from Docker!`` - ``This message shows that your installation appears to be working correctly.`` + Hello from Docker! + This message shows that your installation appears to be working correctly. Installing and Running the Presto Docker container ================================================== -1. Download the latest non-edge Presto container from `Presto on DockerHub `_. Run the following command: +1. Download the latest non-edge Presto container from `Presto on DockerHub `_: + + .. code-block:: shell - ``docker pull prestodb/presto:latest`` + docker pull prestodb/presto:latest Downloading the container may take a few minutes. When the download completes, go on to the next step. -2. On your local system, create a file named ``config.properties`` containing the following text: +2. On the local system, create a file named ``config.properties`` containing the following text: - .. code-block:: none + .. code-block:: properties coordinator=true node-scheduler.include-coordinator=true @@ -62,7 +68,7 @@ Installing and Running the Presto Docker container discovery-server.enabled=true discovery.uri=http://localhost:8080 -3. On your local system, create a file named ``jvm.config`` containing the following text: +3. On the local system, create a file named ``jvm.config`` containing the following text: .. code-block:: none @@ -78,20 +84,26 @@ Installing and Running the Presto Docker container 4. To start the Presto server in the Docker container, run the command: - ``docker run -p 8080:8080 -it -v ./config.properties:/opt/presto-server/etc/config.properties -v ./jvm.config:/opt/presto-server/etc/jvm.config --name presto prestodb/presto:latest`` + .. code-block:: shell + + docker run -p 8080:8080 -it -v ./config.properties:/opt/presto-server/etc/config.properties -v ./jvm.config:/opt/presto-server/etc/jvm.config --name presto prestodb/presto:latest This command assigns the name ``presto`` for the newly-created container that uses the downloaded image ``prestodb/presto:latest``. - The Presto server logs startup information in the terminal window. Once you see a response similar to the following, the Presto server is running in the Docker container. + The Presto server logs startup information in the terminal window. The following output confirms the Presto server is running in the Docker container. + + .. code-block:: shell + :class: no-copy - ``======== SERVER STARTED ========`` + ======== SERVER STARTED ======== Removing the Presto Docker container ==================================== -To remove the Presto Docker container, run the following two commands: +To stop and remove the Presto Docker container, run the following commands: -``docker stop presto`` +.. code-block:: shell -``docker rm presto`` + docker stop presto + docker rm presto These commands return the name of the container ``presto`` when they succeed. diff --git a/presto-docs/src/main/sphinx/installation/deploy-helm.rst b/presto-docs/src/main/sphinx/installation/deploy-helm.rst index 5fe7d0225e8c2..1fefb226fd142 100644 --- a/presto-docs/src/main/sphinx/installation/deploy-helm.rst +++ b/presto-docs/src/main/sphinx/installation/deploy-helm.rst @@ -1,5 +1,5 @@ -=============================== -Deploy Presto Using Helm Charts -=============================== +======================= +Deploy Presto with Helm +======================= -To deploy Presto using Helm charts, see the `Presto Helm Charts README `_. \ No newline at end of file +To deploy Presto using Helm, see the `Presto Helm Charts README `_. diff --git a/presto-docs/src/main/sphinx/installation/deployment.rst b/presto-docs/src/main/sphinx/installation/deployment.rst index a67cc5fb027c2..1478b8ca2d98e 100644 --- a/presto-docs/src/main/sphinx/installation/deployment.rst +++ b/presto-docs/src/main/sphinx/installation/deployment.rst @@ -301,8 +301,11 @@ See :doc:`/connector` for more information about configuring connectors. Running Presto -------------- -The installation directory contains the launcher script in ``bin/launcher``. -Presto can be started as a daemon by running the following: +Presto requires Java 17 to run both the coordinator and workers. Please ensure +that your system has a valid Java 17 installation before starting Presto. + +The installation directory contains the launcher script +in ``bin/launcher``. Presto can be started as a daemon by running the following: .. code-block:: none @@ -456,7 +459,7 @@ in the Hive connector catalog file are set to the following: .. code-block:: none - connector.name=hive + connector.name=hive-hadoop2 hive.metastore=file hive.metastore.catalog.dir=file:///data/hive_data/ @@ -496,7 +499,7 @@ and specify an entry point to run the server. .. code-block:: docker - FROM openjdk:8-jre + FROM openjdk:17-jre # Presto version will be passed in at build time ARG PRESTO_VERSION @@ -519,7 +522,7 @@ and specify an entry point to run the server. COPY etc /opt/presto/etc # Download the Presto CLI and put it in the image - RUN wget --quiet https://repo1.maven.org/maven2/com/facebook/presto/presto-cli/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar + RUN wget --quiet https://github.com/prestodb/presto/releases/download/${PRESTO_VERSION}/presto-cli-${PRESTO_VERSION}-executable.jar RUN mv presto-cli-${PRESTO_VERSION}-executable.jar /usr/local/bin/presto RUN chmod +x /usr/local/bin/presto diff --git a/presto-docs/src/main/sphinx/language/reserved.rst b/presto-docs/src/main/sphinx/language/reserved.rst index 14b769a996f6f..d3560c96c8537 100644 --- a/presto-docs/src/main/sphinx/language/reserved.rst +++ b/presto-docs/src/main/sphinx/language/reserved.rst @@ -6,6 +6,8 @@ The following table lists all of the keywords that are reserved in Presto, along with their status in the SQL standard. These reserved keywords must be quoted (using double quotes) in order to be used as an identifier. +The use of `$` as a prefix for column names is reserved. This is a convention only and it is not enforced. + ============================== ============= ============= Keyword SQL:2016 SQL-92 ============================== ============= ============= diff --git a/presto-docs/src/main/sphinx/optimizer/history-based-optimization.rst b/presto-docs/src/main/sphinx/optimizer/history-based-optimization.rst index eeb5884e55be7..a5a38a5599bd6 100644 --- a/presto-docs/src/main/sphinx/optimizer/history-based-optimization.rst +++ b/presto-docs/src/main/sphinx/optimizer/history-based-optimization.rst @@ -38,6 +38,7 @@ Configuration Property Name Description ``optimizer.history-based-optimizer-timeout`` Timeout for history based optimizer. ``10 seconds`` ``optimizer.enforce-timeout-for-hbo-query-registration`` Enforce timeout for query registration in HBO optimizer ``False`` ``optimizer.treat-low-confidence-zero-estimation-as-unknown`` Treat ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. ``False`` +``optimizer.query-types-enabled-for-hbo`` Query types which are enabled for history based optimization. ``SELECT,INSERT`` ``optimizer.confidence-based-broadcast`` Broadcast based on the confidence of the statistics that are being used, by broadcasting the side of a joinNode which ``False`` has the highest confidence statistics. If confidence is the same, then the original behavior will be followed. ``optimizer.retry-query-with-history-based-optimization`` Retry a failed query automatically if HBO can help change the existing query plan ``False`` @@ -67,12 +68,14 @@ Session property Name Description ``restrict_history_based_optimization_to_complex_query`` Enable history based optimization only for complex queries, i.e. queries with join and aggregation. ``True`` ``history_input_table_statistics_matching_threshold`` Overrides the behavior of the configuration property ``hbo.history-matching-threshold`` ``hbo.history-matching-threshold`` in the current session. -``treat-low-confidence-zero-estimation-as-unknown`` Overrides the behavior of the configuration property +``treat_low_confidence_zero_estimation_unknown_enabled`` Overrides the behavior of the configuration property ``optimizer.treat-low-confidence-zero-estimation-as-unknown`` in the current session. ``optimizer.treat-low-confidence-zero-estimation-as-unknown`` -``confidence-based-broadcast`` Overrides the behavior of the configuration property +``confidence_based_broadcast_enabled`` Overrides the behavior of the configuration property ``optimizer.confidence-based-broadcast`` in the current session. ``optimizer.confidence-based-broadcast`` -``retry-query-with-history-based-optimization`` Overrides the behavior of the configuration property +``retry_query_with_history_based_optimization`` Overrides the behavior of the configuration property ``optimizer.retry-query-with-history-based-optimization`` in the current session. ``optimizer.retry-query-with-history-based-optimization`` +``query_types_enabled_for_history_based_optimization`` Overrides the behavior of the configuration property + ``optimizer.query-types-enabled-for-hbo`` in the current session. ``optimizer.query-types-enabled-for-hbo`` =========================================================== ==================================================================================================== ============================================================== Example diff --git a/presto-docs/src/main/sphinx/optimizer/statistics.rst b/presto-docs/src/main/sphinx/optimizer/statistics.rst index eeb763575c82a..6be080b501f0f 100644 --- a/presto-docs/src/main/sphinx/optimizer/statistics.rst +++ b/presto-docs/src/main/sphinx/optimizer/statistics.rst @@ -46,8 +46,9 @@ The following statistics are available in Presto: The set of statistics available for a particular query depends on the connector being used and can also vary by table or even by table layout. For example, the -Hive connector does not currently provide statistics on data size. +Hive connector does not currently provide statistics on data size or histograms, +while the Iceberg connector provides both. -Table statistics can be can be fetched using the :doc:`/sql/show-stats` query. +Table statistics can be fetched using the :doc:`/sql/show-stats` query. For the Hive connector, refer to the :ref:`Hive connector ` documentation to learn how to update table statistics. diff --git a/presto-docs/src/main/sphinx/plugin.rst b/presto-docs/src/main/sphinx/plugin.rst index e5b321c09adc1..9d0c49f8d9c6c 100644 --- a/presto-docs/src/main/sphinx/plugin.rst +++ b/presto-docs/src/main/sphinx/plugin.rst @@ -8,4 +8,5 @@ This chapter outlines the plugins in Presto that are available for various use c :maxdepth: 1 plugin/redis-hbo-provider + plugin/native-sidecar-plugin diff --git a/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst b/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst new file mode 100644 index 0000000000000..f64a36d58485c --- /dev/null +++ b/presto-docs/src/main/sphinx/plugin/native-sidecar-plugin.rst @@ -0,0 +1,121 @@ +============== +Native Sidecar +============== + +Use the native sidecar plugin in a native cluster to extend the capabilities of the cluster by allowing the underlying execution engine to be the source of truth. +All supported functions, types and session properties are retrieved directly from the sidecar - a specialized worker node equipped with enhanced capabilities. +Only features supported by the native execution engine are exposed during function resolution, resolving session properties and the types supported. + +The native sidecar plugin also provides a native plan checker that validates compatibility between Presto and the native execution engine +by sending Presto plan fragments to a sidecar endpoint `v1/velox/plan`, where they are translated into Velox compatible plan fragments. +If any incompatibilities or errors are detected during the translation, they are surfaced immediately, allowing Presto to fail fast before allocating resources or scheduling execution. + +To use the sidecar functionalities, at least one sidecar worker must be present in the cluster. The system supports flexible configurations: a mixed setup with both sidecar +and regular C++ workers, or a cluster composed entirely of sidecar workers. A worker can be configured as a sidecar by adding the properties listed in :ref:`sidecar-worker-properties` section. + +Configuration +------------- + +Coordinator properties +^^^^^^^^^^^^^^^^^^^^^^ +To enable sidecar support on the coordinator, add the following properties to your coordinator configuration: + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``coordinator-sidecar-enabled`` Enables sidecar in the coordinator true +``native-execution-enabled`` Enables native execution true +``presto.default-namespace`` Sets the default function namespace `native.default` +``plugin.dir`` Specifies which directory under installation root `{root-directory}/native-plugin/` + to scan for plugins at startup. +============================================ ===================================================================== ============================== + +.. _sidecar-worker-properties: + +Sidecar worker properties +^^^^^^^^^^^^^^^^^^^^^^^^^ +Enable sidecar functionality with: + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``native-sidecar`` Enables a sidecar worker true +``presto.default-namespace`` Sets the default function namespace `native.default` +============================================ ===================================================================== ============================== + +Regular worker properties +^^^^^^^^^^^^^^^^^^^^^^^^^ +For regular workers (not acting as sidecars), configure: + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``presto.default-namespace`` Sets the default function namespace `native.default` +============================================ ===================================================================== ============================== + +The Native Sidecar plugin is designed to run with all its components enabled. Individual configuration properties must be specified in conjunction with one another to ensure the plugin operates as intended. +While the Native Sidecar plugin allows modular configuration, the recommended usage is to enable all the components for full functionality. + +Function registry +----------------- + +These properties must be configured in ``etc/function-namespace/native.properties`` to use the function namespace manager from the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``function-namespace-manager.name`` Identifier used to register the function namespace manager `native` +``function-implementation-type`` Indicates the language in which functions in this namespace CPP + are implemented. +``supported-function-languages`` Languages supported by the namespace manager. CPP +============================================ ===================================================================== ============================== + +Session properties +------------------ + +These properties must be configured in ``etc/session-property-providers/native-worker.properties`` to use the session property provider of the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``session-property-provider.name`` Identifier for the session property provider backed by the sidecar. `native-worker` + Enables discovery of supported session properties in native engine. +============================================ ===================================================================== ============================== + +Type Manager +----------------- + +These properties must be configured in ``etc/type-managers/native.properties`` to use the type manager of the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``type-manager.name`` Identifier for the type manager. Registers types `native` + supported by the native engine. +============================================ ===================================================================== ============================== + +Plan checker +----------------- + +These properties must be configured in ``etc/plan-checker-providers/native.properties`` to use the native plan checker of the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``plan-checker-provider.name`` Identifier for the plan checker. Enables validation of Presto `native` + query plans against native engine, ensuring execution compatibility. +============================================ ===================================================================== ============================== + +Expression optimizer +-------------------- + +These properties must be configured in ``etc/expression-manager/native.properties`` to use the native expression optimizer of the ``NativeSidecarPlugin``. + +============================================ ===================================================================== ============================== +Property Name Description Value +============================================ ===================================================================== ============================== +``expression-manager-factory.name`` Identifier for the expression optimizer. Enables optimization of `native` + expressions using the native expression optimizer. +============================================ ===================================================================== ============================== + +To enable the native expression optimizer for your session, set the expression_optimizer_name session property to native: ``SET SESSION expression_optimizer_name = 'native'`` diff --git a/presto-docs/src/main/sphinx/plugin/redis-hbo-provider.rst b/presto-docs/src/main/sphinx/plugin/redis-hbo-provider.rst index 85079d719736f..423e9e4db9cc1 100644 --- a/presto-docs/src/main/sphinx/plugin/redis-hbo-provider.rst +++ b/presto-docs/src/main/sphinx/plugin/redis-hbo-provider.rst @@ -9,27 +9,27 @@ Redis HBO Provider supports loading a custom configured Redis Client for storing Configuration ------------- -Create ``etc/catalog/redis-provider.properties`` to mount the Redis HBO Provider Plugin. +Create ``etc/redis-provider.properties`` to mount the Redis HBO Provider Plugin. Edit the configuration properties as appropriate: Configuration properties ------------------------ -The following configuration properties are available for use in ``etc/catalog/redis-provider.properties``: +The following configuration properties are available for use in ``etc/redis-provider.properties``: ============================================ ===================================================================== Property Name Description ============================================ ===================================================================== -``coordinator`` Boolean property whether Presto server is a coordinator +``coordinator`` Boolean property to decide whether Presto server is a coordinator ``hbo.redis-provider.server_uri`` Redis Server URI ``hbo.redis-provider.total-fetch-timeoutms`` Maximum timeout in ms for Redis fetch requests ``hbo.redis-provider.total-set-timeoutms`` Maximum timeout in ms for Redis set requests ``hbo.redis-provider.default-ttl-seconds`` TTL in seconds of the Redis data to be stored -``hbo.redis-provider.enabled`` Boolean property whether this plugin is enabled in production +``hbo.redis-provider.enabled`` Boolean property to enable this plugin ``credentials-path`` Path for Redis credentials -``hbo.redis-provider.cluster-mode-enabled`` Boolean property whether cluster mode is enabled +``hbo.redis-provider.cluster-mode-enabled`` Boolean property to enable cluster mode ============================================ ===================================================================== Coordinator Configuration for Historical Based Optimization @@ -80,29 +80,18 @@ You can place the plugin JARs in the production's ``plugins`` directory. Alternatively, follow this method to ensure that the plugin is loaded during the Presto build. -1. Add the following to register the plugin in ```` in ``presto-server/src/main/assembly/presto.xml``: +1. Add the following to register the plugin in ``presto-server/src/main/provisio/presto.xml``: .. code-block:: text - - - ${project.build.directory}/dependency/redis-hbo-provider-${project.version} - plugin/redis-hbo-provider - + + + + + + 2. In ``redis-hbo-provider/src/main/resources``, create the file ``META-INF.services`` with the Plugin entry class ``com.facebook.presto.statistic.RedisProviderPlugin``. -3. Add the dependency on the module in ``presto-server/pom.xml``: - - .. code-block:: text - - - com.facebook.presto - redis-hbo-provider - ${project.version} - zip - provided - - -4. (Optional) Add your custom Redis client connection login in ``com.facebook.presto.statistic.RedisClusterAsyncCommandsFactory``. +3. (Optional) Add your custom Redis client connection login in ``com.facebook.presto.statistic.RedisClusterAsyncCommandsFactory``. Note: The AsyncCommands must be provided properly. diff --git a/presto-docs/src/main/sphinx/presto-cpp.rst b/presto-docs/src/main/sphinx/presto-cpp.rst index 104ab1989e483..ee468a876c2b9 100644 --- a/presto-docs/src/main/sphinx/presto-cpp.rst +++ b/presto-docs/src/main/sphinx/presto-cpp.rst @@ -7,9 +7,11 @@ Note: Presto C++ is in active development. See :doc:`Limitations /v1/functions////`` + +For example, if the base URL is ``http://localhost:8080`` and you have a +function ``my_schema.my_function``, the endpoint would be: +``http://localhost:8080/v1/functions/my_schema/my_function/...`` + +``remote-function-server.serde`` +"""""""""""""""""""""""""""""""" + +* **Type:** ``string`` +* **Default value:** ``"presto_page"`` + +This property (shared with Thrift-based remote functions) determines the +serialization format for data sent to and received from the REST server. + +Supported values: + +* ``presto_page``: Uses Presto's native page serialization format +* ``spark_unsafe_row``: Uses Spark's unsafe row serialization format + +Setup and Usage +^^^^^^^^^^^^^^^ + +To use REST-based remote functions in your Presto C++ cluster: + +1. **Deploy a REST Function Server**: Implement a REST service that conforms to the + `REST Function Server API specification `_. + The server must implement endpoints for function discovery, management, and execution. + + Key requirements: + + * Implement ``GET /v1/functions`` to list available functions + * Implement ``POST /v1/functions/{schema}/{functionName}/{functionId}/{version}`` for function execution + * Accept serialized input data with appropriate Content-Type: + + * ``Content-Type: application/X-presto-pages`` for Presto page format + * ``Content-Type: application/X-spark-unsafe-row`` for Spark unsafe row format + + * Return serialized results with the same Content-Type as the request + +2. **Configure the Presto C++ Worker**: Add the following to your worker's + configuration file (for example, ``config.properties``): + + .. code-block:: properties + + remote-function-server.rest.url=http://your-function-server:8080 + remote-function-server.serde=presto_page + +3. **Register Functions**: Functions are registered when the coordinator sends + function metadata to the worker during query execution. The function + signatures and metadata are managed by the coordinator's function namespace + manager. + +4. **Use Functions in Queries**: Once configured, remote functions can be used + in SQL queries like any other function: + + .. code-block:: sql + + SELECT catalog.schema.remote_function(column1, column2) + FROM your_table; + +REST Function Server API Specification +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The REST function server must implement the API specification defined in +`rest_function_server.yaml `_. + +A sample implementation using Presto Java functions is available in +`FunctionServer.java `_. + +The key endpoints include: + +**Function Discovery:** + +* ``GET /v1/functions`` - List all available functions +* ``GET /v1/functions/{schema}`` - List functions in a specific schema +* ``GET /v1/functions/{schema}/{functionName}`` - Get specific function metadata + +**Function Management:** + +* ``POST /v1/functions/{schema}/{functionName}`` - Create a new function +* ``PUT /v1/functions/{schema}/{functionName}/{functionId}`` - Update an existing function +* ``DELETE /v1/functions/{schema}/{functionName}/{functionId}`` - Delete a function + +**Function Execution:** + +* ``POST /v1/functions/{schema}/{functionName}/{functionId}/{version}`` - Execute a function + + * **Request Headers**: + + * ``Content-Type: application/X-presto-pages`` (for Presto page format) + * ``Content-Type: application/X-spark-unsafe-row`` (for Spark unsafe row format) + + * **Request Body**: Serialized input vectors in the configured format (Presto page or Spark unsafe row) + * **Response Headers**: Same ``Content-Type`` as request + * **Response Body**: Serialized output vectors in the same format + * **Response Status**: ``200 OK`` on success, appropriate error codes on failure + +The function execution endpoint is responsible for: + +1. Deserializing the input data from the request body +2. Executing the function logic with the provided inputs +3. Serializing the results +4. Returning the serialized results in the response + +For complete API details, request/response schemas, and examples, refer to the +`OpenAPI specification `_. + JWT authentication support -------------------------- @@ -180,11 +319,11 @@ authentication failure (HTTP 401). LinuxMemoryChecker ------------------ -The LinuxMemoryChecker extends from PeriodicMemoryChecker and periodically checks -memory usage using memory calculation from inactive_anon + active_anon in the memory stat +The LinuxMemoryChecker extends from PeriodicMemoryChecker and periodically checks +memory usage using memory calculation from inactive_anon + active_anon in the memory stat file from Linux cgroups V1 or V2. The LinuxMemoryChecker is used for Linux systems only. -The LinuxMemoryChecker can be enabled by setting the CMake flag ``PRESTO_MEMORY_CHECKER_TYPE=LINUX_MEMORY_CHECKER``. +The LinuxMemoryChecker can be enabled by setting the CMake flag ``PRESTO_MEMORY_CHECKER_TYPE=LINUX_MEMORY_CHECKER``. .. _async_data_caching_and_prefetching: diff --git a/presto-docs/src/main/sphinx/presto_cpp/installation.rst b/presto-docs/src/main/sphinx/presto_cpp/installation.rst new file mode 100644 index 0000000000000..9182c16ae043c --- /dev/null +++ b/presto-docs/src/main/sphinx/presto_cpp/installation.rst @@ -0,0 +1,254 @@ +======================= +Presto C++ Installation +======================= + +.. contents:: + :local: + :backlinks: none + :depth: 1 + +This shows how to install and run a lightweight Presto cluster utilizing a PrestoDB Java Coordinator and Prestissimo (Presto C++) Workers using Docker. + +For more information about Presto C++, see the :ref:`presto-cpp:overview`. + +The setup uses Meta's high-performance Velox engine for worker-side query execution to configure a cluster and run a test query with the built-in TPC-H connector. + +Prerequisites +------------- + +To follow this tutorial, you need: + +* Docker installed. +* Basic familiarity with the terminal and shell commands. + +Create a Working Directory +-------------------------- +The recommended directory structure uses ``presto-lab`` as the root directory. + +Create a clean root directory to hold all necessary configuration files and the ``docker-compose.yml`` file. + +.. code-block:: bash + + mkdir -p ~/presto-lab + cd ~/presto-lab + +Configure the Presto Java Coordinator +------------------------------------- + +The Coordinator requires configuration to define its role, enable the discovery service, and set up a catalog for querying. + +1. Create Configuration Directory +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To create the necessary directories for the coordinator and its catalogs, run the following command: + +.. code-block:: bash + + mkdir -p coordinator/etc/catalog + + +2. Create the Coordinator Configuration File +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create the file ``coordinator/etc/config.properties`` with the following contents. This file enables the coordinator mode, the discovery server, and sets the HTTP port to ``8080``. + +.. code-block:: properties + + # coordinator/etc/config.properties + coordinator=true + node-scheduler.include-coordinator=true + http-server.http.port=8080 + discovery-server.enabled=true + discovery.uri=http://localhost:8080 + +* ``coordinator=true``: Enables the coordinator mode. +* ``discovery-server.enabled=true``: Designates the coordinator as the host for the worker discovery service. +* ``http-server.http.port=8080S``: Start the HTTP server on port 8080 for the coordinator (and workers, if enabled). + +3. Create the JVM Configuration File +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create the file ``coordinator/etc/jvm.config`` with the following content. These are standard Java 17 flags for Presto that ensures compatibility with Java 17's module system, provides stable garbage collection and memory behavior, and enforces safe failure handling. + +.. code-block:: properties + + # coordinator/etc/jvm.config + -server + -Xmx1G + -XX:+UseG1GC + -XX:G1HeapRegionSize=32M + -XX:+UseGCOverheadLimit + -XX:+ExplicitGCInvokesConcurrent + -XX:+HeapDumpOnOutOfMemoryError + -XX:+ExitOnOutOfMemoryError + -Djdk.attach.allowAttachSelf=true + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.lang.ref=ALL-UNNAMED + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED + --add-opens=java.base/java.net=ALL-UNNAMED + --add-opens=java.base/java.nio=ALL-UNNAMED + --add-opens=java.base/java.security=ALL-UNNAMED + --add-opens=java.base/javax.security.auth=ALL-UNNAMED + --add-opens=java.base/javax.security.auth.login=ALL-UNNAMED + --add-opens=java.base/java.text=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED + --add-opens=java.base/java.util.regex=ALL-UNNAMED + --add-opens=java.base/jdk.internal.loader=ALL-UNNAMED + --add-opens=java.base/sun.security.action=ALL-UNNAMED + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED + +4. Create the Node Properties File +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create the file ``coordinator/etc/node.properties`` with the following content to set the node environment and the data directory. + +.. code-block:: properties + + # coordinator/etc/node.properties + node.id=${ENV:HOSTNAME} + node.environment=test + node.data-dir=/var/lib/presto/data + +5. Create the TPC-H Catalog Configuration File +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create the file ``coordinator/etc/catalog/tpch.properties`` with the following content. The TPC-H catalog enables running test queries against an in-memory dataset. + +.. code-block:: properties + + # coordinator/etc/catalog/tpch.properties + connector.name=tpch + +Configure the Prestissimo (C++) Worker +-------------------------------------- + +Configure the Worker to locate the Coordinator or Discovery service and identify itself within the network. + +1. Create Worker Configuration Directory +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + mkdir -p worker-1/etc/catalog + +2. Create ``worker-1/etc/config.properties`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Configure the worker to point to the discovery service running on the coordinator. + +Note: You can repeat this step to add more workers, such as ``worker-2``. + +.. code-block:: properties + + # worker-1/etc/config.properties + discovery.uri=http://coordinator:8080 + presto.version=0.288-15f14bb + http-server.http.port=7777 + shutdown-onset-sec=1 + runtime-metrics-collection-enabled=true + +* ``discovery.uri=http://coordinator:8080``: This uses the coordinator service name as defined in the ``docker-compose.yml`` file for network communication within Docker. + +3. Configure ``worker-1/etc/node.properties`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Define the worker’s internal address to ensure reliable registration. + +.. code-block:: properties + + # worker-1/etc/node.properties + node.environment=test + node.internal-address=worker-1 + node.location=docker + node.id=worker-1 + +* ``node.internal-address=worker-1``: This setting matches the service name defined in :ref:`Docker Compose `. + +4. Add TPC-H Catalog Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Configure the worker with the same catalog definitions as the coordinator to execute query stages + +.. code-block:: properties + + # worker-1/etc/catalog/tpch.properties + connector.name=tpch + +.. _create-docker-compose-yml: + +Create ``docker-compose.yml`` +----------------------------- + +Create a ``docker-compose.yml`` file in the ``~/presto-lab`` directory to orchestrate both the Java Coordinator and the C++ Worker containers. + +.. code-block:: yaml + + # docker-compose.yml + services: + coordinator: + image: public.ecr.aws/oss-presto/presto:latest + platform: linux/amd64 + container_name: presto-coordinator + hostname: coordinator + ports: + - "8080:8080" + volumes: + - ./coordinator/etc:/opt/presto-server/etc:ro + restart: unless-stopped + + worker-1: + image: public.ecr.aws/oss-presto/presto-native:latest + platform: linux/amd64 + container_name: prestissimo-worker-1 + hostname: worker-1 + depends_on: + - coordinator + volumes: + - ./worker-1/etc:/opt/presto-server/etc:ro + restart: unless-stopped + + worker-2: + image: public.ecr.aws/oss-presto/presto-native:latest + platform: linux/amd64 + container_name: prestissimo-worker-2 + hostname: worker-2 + depends_on: + - coordinator + volumes: + - ./worker-2/etc:/opt/presto-server/etc:ro + restart: unless-stopped + +* The coordinator service uses the standard Java Presto image (presto:latest). +* The worker-1 and worker-2 services use the Prestissimo (C++ Native) image (presto-native:latest). +* The setting ``platform: linux/amd64`` is essential for users running on Apple Silicon Macs. +* The ``volumes`` section mounts your local configuration directories (``./coordinator/etc``, ``./worker-1/etc``) into the container's expected path (``/opt/presto-server/etc``). + +Start the Cluster and Verify +---------------------------- + +1. Start the Cluster +^^^^^^^^^^^^^^^^^^^^ + +Use Docker Compose to start the cluster in detached mode (``-d``). + +.. code-block:: bash + + docker compose up -d + +2. Verify +^^^^^^^^^ + +1. **Check the Web UI:** Open the Presto Web UI at http://localhost:8080. + + * You should see the UI displaying 3 Active Workers (1 Coordinator and 2 Workers). + +2. **Check Detailed Node Status** : Run the following SQL query to check the detailed status and metadata about every node (Coordinator and Workers). + + .. code-block:: sql + + select * from system.runtime.nodes; + + This confirms the cluster nodes are registered and active. \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/presto_cpp/plugin.rst b/presto-docs/src/main/sphinx/presto_cpp/plugin.rst new file mode 100644 index 0000000000000..33b1619bde66f --- /dev/null +++ b/presto-docs/src/main/sphinx/presto_cpp/plugin.rst @@ -0,0 +1,24 @@ +****************** +Presto C++ Plugins +****************** + +This page lists the plugins in Presto C++ that are available for various use cases such as to load User Defined Functions (UDFs) and describes the setup needed to use these plugins. + +.. toctree:: + :maxdepth: 1 + + plugin/function_plugin + + +Setup +----- + +1. Place the plugin shared libraries in the ``plugin`` directory. + +2. For each worker, edit the ``config.properties`` file to set the ``plugin.dir`` property to the path of the ``plugin`` directory. + + If ``plugin.dir`` is not specified, the path to this directory defaults to the ``plugin`` directory relative to the directory in which the process is being run. + +3. Start or restart the coordinator and workers to pick up any placed libraries. + +Note: To avoid issues with ABI compatibility, it is strongly recommended to recompile all shared library plugins during OS and Presto version upgrades. \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/presto_cpp/plugin/function_plugin.rst b/presto-docs/src/main/sphinx/presto_cpp/plugin/function_plugin.rst new file mode 100644 index 0000000000000..fdb390bf14c6b --- /dev/null +++ b/presto-docs/src/main/sphinx/presto_cpp/plugin/function_plugin.rst @@ -0,0 +1,50 @@ +=============== +Function Plugin +=============== + +Creating a Shared Library for UDFs +---------------------------------- +User defined functions (UDFs) allow users to create custom functions without the need to rebuild the executable. +There are many benefits to UDFs, such as: + +* Simplify SQL queries by creating UDFs for repetitive logic. +* Implement custom logic pertaining to the specific business use cases of the users. +* Once defined, easily reusable and called multiple times just like built in functions. +* Shorter compile times. + +1. To create the UDF, create a new C++ file in the same format as the below example file named ``ExampleFunction.cpp``: + + .. code-block:: c++ + + #include "presto_cpp/main/dynamic_registry/DynamicFunctionRegistrar.h" + + template + struct NameOfStruct { + VELOX_DEFINE_FUNCTION_TYPES(T); + FOLLY_ALWAYS_INLINE bool call(out_type& result, const arg_type in) { + ... + } + }; + + extern "C" { + void registerExtensions() { + facebook::presto::registerPrestoFunction< + NameOfStruct, + int64_t, facebook::velox::Varchar>("function_name", "test.namespace"); + } + } + + Note: The ``int64_t`` return type and the ``Varchar`` input type can be changed as needed. Additional or no arguments may be specified as well. For more examples, see the `examples `_. + The functions should follow the `Velox scalar function API `_. Return types and argument types include: + + * simple types: BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, DOUBLE, TIMESTAMP, DATE, VARCHAR, VARBINARY. + * complex types: ARRAY, MAP, ROW. They can be mixed with other simple and complex types. For example: MAP(INTEGER, ARRAY(BIGINT)). + +2. Create a shared library using ``CMakeLists.txt`` like the following: + + .. code-block:: text + + add_library(name_of_dynamic_fn SHARED ExampleFunction.cpp) + target_link_libraries(name_of_dynamic_fn PRIVATE presto_dynamic_function_registrar fmt::fmt gflags::gflags xsimd) + +3. Place your shared libraries in the ``plugin`` directory. The path to this directory defaults to the ``plugin`` directory relative to the directory in which the process is being run but it is configurable in the ``plugin.dir`` property set in :doc:`../plugin`. \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst b/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst index fea735e86f63b..cac0763034dba 100644 --- a/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst +++ b/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst @@ -97,18 +97,28 @@ If set to ``true``, disables the optimization in expression evaluation to delay This should only be used for debugging purposes. -``native_execution_type_rewrite_enabled`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``native_debug_memory_pool_name_regex`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* **Type:** ``boolean`` -* **Default value:** ``false`` +* **Type:** ``varchar`` +* **Default value:** ``""`` + +Native Execution only. Regular expression pattern to match memory pool names for allocation callsite tracking. +Matched pools will also perform leak checks at destruction. Empty string disables tracking. -When set to ``true``: - - Custom type names are peeled in the coordinator. Only the actual base type is preserved. - - ``CAST(col AS EnumType)`` is rewritten as ``CAST(col AS )``. - - ``ENUM_KEY(EnumType)`` is rewritten as ``ELEMENT_AT(MAP(, VARCHAR))``. +This should only be used for debugging purposes. -This property can only be enabled with native execution. +``native_debug_memory_pool_warn_threshold_bytes`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``bigint`` +* **Default value:** ``0`` + +Native Execution only. Warning threshold for memory pool allocations. Logs callsites when exceeded. +Requires allocation tracking to be enabled with ``native_debug_memory_pool_name_regex``. +Accepts B/KB/MB/GB units. Set to 0B to disable. + +This should only be used for debugging purposes. ``native_selective_nimble_reader_enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -345,7 +355,7 @@ underlying file system. * **Default value:** ``false`` Enable query tracing. After enabled, trace data will be generated with query execution, and -can be used by TraceReplayer. It needs to be used together with native_query_trace_node_ids, +can be used by TraceReplayer. It needs to be used together with native_query_trace_node_id, native_query_trace_max_bytes, native_query_trace_fragment_id, and native_query_trace_shard_id to match the task to be traced. @@ -358,14 +368,13 @@ to match the task to be traced. The location to store the trace files. -``native_query_trace_node_ids`` +``native_query_trace_node_id`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * **Type:** ``varchar`` * **Default value:** ``""`` -A comma-separated list of plan node ids whose input data will be traced. -Empty string if only want to trace the query metadata. +The plan node id whose input data will be traced. ``native_query_trace_max_bytes`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -470,3 +479,92 @@ to produce a batch of the size specified by this property. If set to ``0``, then * **Default value:** ``10`` Maximum wait time for exchange long poll requests in seconds. + +``native_query_memory_reclaimer_priority`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``2147483647`` + +Priority of the query in the memory pool reclaimer. Lower value means higher priority. +This is used in global arbitration victim selection. + +``native_max_num_splits_listened_to`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``0`` + +Maximum number of splits to listen to by the SplitListener per table scan node per +native worker. + +``native_max_split_preload_per_driver`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``0`` + +Maximum number of splits to preload per driver. Set to 0 to disable preloading. + +``native_index_lookup_join_max_prefetch_batches`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``0`` + +Specifies the max number of input batches to prefetch to do index lookup ahead. +If it is zero, then process one input batch at a time. + +``native_index_lookup_join_split_output`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +If this is true, then the index join operator might split output for each input +batch based on the output batch size control. Otherwise, it tries to produce a +single output for each input batch. + +``native_unnest_split_output`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +If this is true, then the unnest operator might split output for each input +batch based on the output batch size control. Otherwise, it produces a single +output for each input batch. + +``native_use_velox_geospatial_join`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + +If this is true, then the protocol::SpatialJoinNode is converted to a +velox::core::SpatialJoinNode. Otherwise, it is converted to a +velox::core::NestedLoopJoinNode. + +``native_aggregation_compaction_bytes_threshold`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``bigint`` +* **Default value:** ``0`` + +Native Execution only. Memory threshold in bytes for triggering string compaction +during global aggregation. When total string storage exceeds this limit and the +unused memory ratio is high, compaction is triggered to reclaim dead strings. +Disabled by default (0). Currently only applies to approx_most_frequent aggregate +with StringView type during global aggregation. + +``native_aggregation_compaction_unused_memory_ratio`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``double`` +* **Minimum value:** ``0`` +* **Maximum value:** ``1`` +* **Default value:** ``0.25`` + +Native Execution only. Ratio of unused (evicted) bytes to total bytes that triggers +compaction. The value is in the range of [0, 1). Currently only applies to +approx_most_frequent aggregate with StringView type during global aggregation. diff --git a/presto-docs/src/main/sphinx/presto_cpp/properties.rst b/presto-docs/src/main/sphinx/presto_cpp/properties.rst index 8179cc0ddc627..b9f01446670fa 100644 --- a/presto-docs/src/main/sphinx/presto_cpp/properties.rst +++ b/presto-docs/src/main/sphinx/presto_cpp/properties.rst @@ -38,6 +38,14 @@ Presto C++ workers. These Presto coordinator configuration properties are described here, in alphabetical order. +``driver.max-split-preload`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* **Type:** ``integer`` +* **Default value:** ``2`` + + Maximum number of splits to preload per driver. + Set to 0 to disable preloading. + ``driver.cancel-tasks-with-stuck-operators-threshold-ms`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * **Type:** ``string`` @@ -57,18 +65,6 @@ alphabetical order. This property is required when running Presto C++ workers because of underlying differences in behavior from Java workers. -``native-execution-type-rewrite-enabled`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -* **Type:** ``boolean`` -* **Default value:** ``false`` - - When set to ``true``: - - Custom type names are peeled in the coordinator. Only the actual base type is preserved. - - ``CAST(col AS EnumType)`` is rewritten as ``CAST(col AS )``. - - ``ENUM_KEY(EnumType)`` is rewritten as ``ELEMENT_AT(MAP(, VARCHAR))``. - This property can only be enabled with native execution. - ``optimizer.optimize-hash-generation`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -107,7 +103,19 @@ alphabetical order. * **Type:** ``string`` * **Default value:** ``presto.default`` - Specifies the namespace prefix for native C++ functions. + Specifies the namespace prefix for native C++ functions. This prefix is used when + registering Velox functions in Prestissimo to ensure proper function resolution in + multi-catalog environments. + + .. warning:: + + **Critical**: When registering Velox functions, you **must** follow the + ``catalog.schema.`` prefix pattern. Functions registered without this pattern + will cause worker node crashes. + + The configured value (for example, ``presto.default``) is automatically appended with a + trailing dot (``.``) to form the complete prefix (``presto.default.``). This results + in fully qualified function names like ``presto.default.substr`` or ``presto.default.sum``. Internal functions (prefixed with ``$internal$``) do not follow this pattern and are exempt from the three-part naming requirement. Worker Properties ----------------- @@ -164,7 +172,7 @@ The configuration properties of Presto C++ workers are described here, in alphab worker node. Memory for system usage such as disk spilling and cache prefetch are not counted in it. -``max_spill_bytes`` +``max-spill-bytes`` ^^^^^^^^^^^^^^^^^^^ * **Type:** ``integer`` @@ -218,6 +226,14 @@ avoid exceeding memory limits for the query. When ``spill_enabled`` is ``true``, this determines whether Presto will try spilling memory to disk for order by to avoid exceeding memory limits for the query. +``local-exchange.max-partition-buffer-size`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``integer`` +* **Default value:** ``65536`` (64KB) + + Specifies the maximum size in bytes to accumulate for a single partition of a local exchange before flushing. + ``shared-arbitrator.reserved-capacity`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -414,19 +430,27 @@ avoid exceeding memory limits for the query. only by aborting. This flag is only effective if ``shared-arbitrator.global-arbitration-enabled`` is ``true``. +``text-writer-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + + Enables writing data in ``TEXTFILE`` format. + +``text-reader-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``true`` + + Enables reading data in ``TEXTFILE`` format. + Cache Properties ---------------- The configuration properties of AsyncDataCache and SSD cache are described here. -``async-cache-persistence-interval`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* **Type:** ``string`` -* **Default value:** ``0s`` - - The interval for persisting in-memory cache to SSD. Set this - to a non-zero value to activate periodic cache persistence. - ``async-data-cache-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -508,6 +532,17 @@ The configuration properties of AsyncDataCache and SSD cache are described here. When enabled, a CRC-based checksum is calculated for each cache entry written to SSD. The checksum is stored in the next checkpoint file. +``ssd-cache-max-entries`` +^^^^^^^^^^^^^^^^^^^^^^^^^ +* **Type:** ``integer`` +* **Default value:** ``10000000`` + + Maximum number of entries allowed in the SSD cache. A value of 0 means no limit. + When the limit is reached, new entry writes will be skipped. + + The default of 10 million entries keeps metadata memory usage around 500MB, as each + cache entry uses approximately 50-60 bytes for the key, value, and hash overhead. + ``ssd-cache-read-verification-enabled`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * **Type:** ``bool`` @@ -552,6 +587,17 @@ Exchange Properties Maximum wait time for exchange request in seconds. +HTTP Client Properties +---------------------- + +``http-client.http2-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Specifies whether HTTP/2 should be enabled for HTTP client. + Memory Checker Properties ------------------------- @@ -617,6 +663,16 @@ memory use. Ignored if zero. CPU threshold in % above which the worker is considered overloaded in terms of CPU use. Ignored if zero. +``worker-overloaded-threshold-num-queued-drivers-hw-multiplier`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``double`` +* **Default value:** ``0.0`` + +Floating point number used in calculating how many drivers must be queued +for the worker to be considered overloaded. +Number of drivers is calculated as hw_concurrency x multiplier. Ignored if zero. + ``worker-overloaded-cooldown-period-sec`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst b/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst index 7f8e595f10ac7..ee55ed931277f 100644 --- a/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst +++ b/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst @@ -24,12 +24,28 @@ The following HTTP endpoints are implemented by the Presto C++ sidecar. Presto C++ worker. Each function's metadata is serialized to JSON in format ``JsonBasedUdfFunctionMetadata``. +.. function:: GET /v1/functions/{catalog} + + Returns a list of function metadata for all functions registered in the + Presto C++ worker that belong to the specified catalog. Each function's + metadata is serialized to JSON in format ``JsonBasedUdfFunctionMetadata``. + This endpoint allows filtering functions by catalog to support namespace + separation. + .. function:: POST /v1/velox/plan Converts a Presto plan fragment to its corresponding Velox plan and validates the Velox plan. Returns any errors encountered during plan conversion. +.. function:: POST /v1/expressions + + Optimizes a list of ``RowExpression``\s from the http request using + a combination of constant folding and logical rewrites by leveraging + the ``ExprOptimizer`` from Velox. Returns a list of ``RowExpressionOptimizationResult``, + that contains either the optimized ``RowExpression`` or the ``NativeSidecarFailureInfo`` + in case the expression optimization failed. + Configuration Properties ------------------------ diff --git a/presto-docs/src/main/sphinx/release.rst b/presto-docs/src/main/sphinx/release.rst index d6ee388bb32f7..a0b8d0f9ca67f 100644 --- a/presto-docs/src/main/sphinx/release.rst +++ b/presto-docs/src/main/sphinx/release.rst @@ -5,6 +5,10 @@ Release Notes .. toctree:: :maxdepth: 1 + Release-0.296 [2025-12-01] + Release-0.295 [2025-10-01] + Release-0.294 [2025-07-28] + Release-0.293 [2025-05-29] Release-0.292 [2025-03-28] Release-0.291 [2025-01-27] Release-0.290 [2024-11-01] diff --git a/presto-docs/src/main/sphinx/release/release-0.292.rst b/presto-docs/src/main/sphinx/release/release-0.292.rst index 87dd9d485d524..a6d49ca03bb32 100644 --- a/presto-docs/src/main/sphinx/release/release-0.292.rst +++ b/presto-docs/src/main/sphinx/release/release-0.292.rst @@ -49,7 +49,7 @@ _______________ * Add pagesink for DELETES to support future use. `#24528 `_ * Add serialization for new types. `#24528 `_ * Add support to build Presto with JDK 17. `#24677 `_ -* Add a new optimizer rule to add exchanges below a combination of partial aggregation+ GroupId . Enabled with the boolean session property ``enable_forced_exchange_below_group_id``. `#24047 `_ +* Add a new optimizer rule to add exchanges below a combination of partial aggregation+ GroupId . Enabled with the boolean session property ``add_exchange_below_partial_aggregation_over_group_id``. `#24047 `_ * Add module presto-native-tests to run end-to-end tests with Presto native workers. `#24234 `_ * Add map of node ID to plan node to QueryCompletedEvent in the event listener interface. `#24590 `_ * Add support for multiple query event listeners. `#24456 `_ diff --git a/presto-docs/src/main/sphinx/release/release-0.293.rst b/presto-docs/src/main/sphinx/release/release-0.293.rst new file mode 100644 index 0000000000000..7ad38cc0b90e2 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.293.rst @@ -0,0 +1,143 @@ +============= +Release 0.293 +============= + +**Highlights** +============== + +* Fix ROLLBACK statement to ensure it successfully aborts non-auto commit transactions corrupted by failed statements. `#23247 `_ +* Improve coordinator performance by introducing Thrift serialization. `#25079 `_ and `#25020 `_ +* Improve performance of ``ORDER BY`` queries on single node execution. `#25022 `_ +* Add authentication capabilities to Presto router. `#24407 `_ +* Add coordinator health checks to Presto router. `#24449 `_ +* Add support for custom scheduler plugin in the Presto Router. `#24439 `_ +* Add DDL SQL support for ``SHOW CREATE SCHEMA``. `#24356 `_ +* Add :func:`longest_common_prefix(string1, string2) -> varchar()` string function. `#24891 `_ +* Add support for row filtering and column masking in access control. `#24277 `_ +* Add security-related headers to the static resources served from the Presto Router UI, including: ``Content-Security-Policy``, ``X-Content-Type-Options``. See reference docs `Content-Security-Policy `_ and `X-Content-Type-Options `_. `#25165 `_ +* Add support for SSL/TLS encryption for HMS. `#24745 `_ +* Add support for the procedure ``.system.invalidate_manifest_file_cache()`` for ManifestFile cache invalidation in Iceberg. `#24831 `_ +* Add support for `JSON `_ type in MongoDB. `#25089 `_ +* Add support for `GEOMETRY `_ type in the MySQL connector. `#24996 `_ +* Add a display for number of queued and running queries for each Resource Group subgroup in the UI. `#24830 `_ +* Add `runtime metrics collection for S3 Filesystem `_. `#24554 `_ + +**Details** +=========== + +General Changes +_______________ +* Fix ROLLBACK statement to ensure it successfully aborts non-auto commit transactions corrupted by failed statements. `#23247 `_ +* Fix a bug in left join to semi join optimizer which leads to filter source variable not found error. `#25111 `_ +* Fix a bug where a mirrored :func:`arrays_overlap(x, y) -> boolean` function does not return the correct value. `#23845 `_ +* Fix returning incorrect results from the :func:`second(x) -> bigint()` UDF when a timestamp is in a time zone with an offset that is at the granularity of seconds. `#25090 `_ +* Fix issue with loading Redis HBO provider. `#24835 `_ +* Improve memory usage of readers of complex type columns. `#24912 `_ +* Improve the efficacy of ACL checks by delaying them until after SQL view processing. `#24955 `_ and `#24927 `_ +* Improve coordinator performance by introducing Thrift serialization. `#25079 `_ and `#25020 `_ +* Improve performance of operator stats reporting. `#24921 `_ +* Improve performance of ``ORDER BY`` queries on single node execution. `#25022 `_ +* Improve query plans by converting table scans without data to empty values nodes. `#25155 `_ +* Improve performance of ``LOJ + IS NULL`` queries by adding distinct on right side of semi-join for it. `#24884 `_ +* Add DDL SQL support for ``SHOW CREATE SCHEMA``. `#24356 `_ +* Add configuration property ``hive.metastore.catalog.name`` to pass catalog names to the metastore, enabling catalog-based schema management and filtering. `#24235 `_ +* Add :func:`cosine_similarity(x, y) -> double()` for array arguments. `#25056 `_ +* Add type rewrite support for native execution. This feature can be enabled by ``native-execution-type-rewrite-enabled`` configuration property and ``native_execution_type_rewrite_enabled`` session property. `#24916 `_ +* Add session property :ref:`admin/properties-session:\`\`query_client_timeout\`\`` to configure how long a query can run without contact from the client application, such as the CLI, before it is abandoned. `#25210 `_ +* Add :func:`longest_common_prefix(string1, string2) -> varchar()` string function. `#24891 `_ +* Replace ``exchange.compression-enabled``, ``fragment-result-cache.block-encoding-compression-enabled``, ``experimental.spill-compression-enabled`` with ``exchange.compression-codec``, ``fragment-result-cache.block-encoding-compression-codec`` to enable compression codec configurations. Supported codecs include GZIP, LZ4, LZO, SNAPPY, ZLIB and ZSTD. `#24670 `_ +* Replace dependency from PostgreSQL to redshift-jdbc42 to address `CVE-2024-1597 `_, `CVE-2022-31197 `_, and `CVE-2020-13692 `_. `#25106 `_ +* Upgrade netty version to 4.1.119.Final. `#24971 `_ + +Prestissimo (Native Execution) Changes +______________________________________ +* Improve batch shuffle performance by doing sorted serialization. `#24953 `_ +* Improve batch shuffle sorted serialization by using appropriate sorting key values for each buffer. `#25015 `_ +* Add type rewrite support for native execution. This feature can be enabled by ``native-execution-type-rewrite-enabled`` configuration property and ``native_execution_type_rewrite_enabled`` session property. `#24916 `_ +* Add `runtime metrics collection for S3 Filesystem `_. `#24554 `_ +* Add session property ``native_request_data_sizes_max_wait_sec`` for the maximum wait time for exchange long poll requests in seconds. `#24918 `_ +* Add session property ``native_streaming_aggregation_eager_flush`` to control if streaming aggregation should flush its output rows as quickly as it can. `#24947 `_ +* Add session property ``native_debug_memory_pool_name_regex`` to trace allocations of memory pools matching the regex. `#24833 `_ +* Replace using native functions with Java functions for creating failure functions when native execution is enabled. `#24792 `_ +* Remove worker configuration property ``register-test-functions``. `#24853 `_ + + +Router Changes +______________ + +* Add support for custom scheduler plugin in the Presto Router. `#24439 `_ +* Fix Round Round robin scheduler candidate cluster index, by adding group specific index. `#24580 `_ +* Add authentication capabilities to Presto router. `#24407 `_ +* Add coordinator health checks to Presto router. `#24449 `_ +* Add counter JMX metrics to Presto router. `#24449 `_ + +Security Changes +________________ +* Fix the issue of sensitive data such as passwords and access keys being exposed in logs by redacting sensitive field values. `#24886 `_ +* Add security-related headers to the static resources served from the Presto Router UI, including: ``Content-Security-Policy``, ``X-Content-Type-Options``. See reference docs `Content-Security-Policy `_ and `X-Content-Type-Options `_. `#25165 `_ +* Add support for access control row filters and column masks on views. `#25052 `_ +* Add support for row filtering and column masking in access control. `#24277 `_ +* Upgrade commons-beanutils to version 1.9.4 in response to `CVE-2014-0114 `_. `#24665 `_ +* Upgrade plexus-utils to version 3.6.0 in response to `CVE-2017-1000487 `_. `#24665 `_ +* Upgrade zookeeper to 3.9.3 to fix security vulnerability in presto-accumulo, presto-delta, presto-hive, presto-kafka, and presto-hudi in response to `CVE-2023-44981 `_. `#24403 `_ +* Upgrade MySQL to 9.2.0 to fix `CVE-2023-22102 `_. `#24754 `_ +* Upgrade kotlin-stdlib-jdk8 to 1.9.25. `#24971 `_ +* Upgrade snappy-java version at 1.1.10.4 across the codebase to address `CVE-2023-43642 `_. `#25106 `_ +* Upgrade commons-compress version to 1.26.2 across the codebase to address `CVE-2021-35517 `_, `CVE-2021-35516 `_, `CVE-2021-36090 `_, `CVE-2021-35515 `_, and `CVE-2024-25710 `_. `#25106 `_ + +Web UI Changes +______________ + +* Add a display for number of queued and running queries for each Resource Group subgroup in the UI. `#24830 `_ + +Delta Lake Connector Changes +____________________________ +* Fix a bug where after an incremental update with null values is made, reads start timing out. `#24920 `_ + +Elasticsearch Connector Changes +_______________________________ +* Upgrade elasticsearch to 7.17.27 in response to `CVE-2024-43709 `_. `#23894 `_ + +Hive Connector Changes +______________________ +* Add support for Web Identity authentication in S3 security mapping with the ``hive.s3.webidentity.enabled`` property. `#24645 `_ +* Add support for SSL/TLS encryption for HMS with configuration properties ``hive.metastore.thrift.client.tls.enabled``, ``hive.metastore.thrift.client.tls.keystore-path``, ``hive.metastore.thrift.client.tls.keystore-password``, and ``hive.metastore.thrift.client.tls.truststore-password``. `#24745 `_ +* Replace listObjects with listObjectsV2 in PrestoS3FileSystem listPrefix. `#24794 `_ + + +Iceberg Connector Changes +_________________________ +* Fix to pass full session to avoid ``Unknown connector`` errors using the Nessie catalog. `#24803 `_ +* Add support for the procedure ``.system.invalidate_manifest_file_cache()`` for ManifestFile cache invalidation in Iceberg. `#24831 `_ +* Add support for the procedure ``.system.invalidate_statistics_file_cache()`` for StatisticsFile cache invalidation in Iceberg. `#24831 `_ +* Add support for bucket transform for columns of type ``TimeType`` in Iceberg table. `#24829 `_ +* Replace RowDelta with AppendFiles for insert-only statements such as INSERT and CTAS. `#24989 `_ + +JDBC Connector Changes +______________________ +* Add ``list-schemas-ignored-schemas`` configuration property for JDBC connectors. `#24994 `_ + +Kafka Connector Changes +_______________________ +* Add support for optional Apache Kafka SASL. `#24798 `_ + +MongoDB Connector Changes +_________________________ +* Add support for `JSON `_ type in MongoDB. `#25089 `_ + +MySQL Connector Changes +_______________________ +* Add support for `GEOMETRY `_ type in the MySQL connector. `#24996 `_ + +SQL Server Connector Changes +____________________________ +* Upgrade SQL Server driver to version 12.8.1 to support NTLM authentication. See :ref:`connector/sqlserver:authentication`. This is a breaking change for existing connections, as the driver sets the encrypt property to ``true`` by default. To connect to a non-SSL SQL Server instance, you must set ``encrypt=false`` in your connection configuration to avoid connectivity issues. `#24686 `_ + +Documentation Changes +_____________________ +* Document :doc:`../presto_cpp/sidecar` and native sidecar plugin. `#24883 `_ + +**Credits** +=========== + +Akinori Musha, Amit Dutta, Anant Aneja, Andrew Xie, Andrii Rosa, Anurag Dwivedi, Arjun Gupta, Bryan Cutler, Chen Yang, Christian Zentgraf, Deepak Majeti, Deepak Mehra, Denodo Research Labs, Elbin Pallimalil, Emily (Xuetong) Sun, Ethan Zhang, Facebook Community Bot, Feilong Liu, Gary Helmling, Haritha Koloth, Hazmi, HeidiHan0000, Heng Xiao, Jacob Khaliqi, James Petty, Jay Narale, Jim Simon, Jimmy Lu, Joe Abraham, Ke Wang, Ke Wang, Kevin Tang, Kevin Wilfong, Krishna Pai, Li Zhou, Linsong Wang, Mariam Almesfer, Miguel Blanco Godón, Najib Adan, Natasha Sehgal, Nidhin Varghese, Nikhil Collooru, Nivin C S, Pradeep Vaka, Pramod Satya, Prashant Golash, Pratik Joseph Dabre, Rebecca Schlussel, Reetika Agrawal, Samuel Majoros, Sayari Mukherjee, Serge Druzkin, Sergey Pershin, Shahim Sharafudeen, Shang Ma, Shelton Cai, Shijin, Steve Burnett, Tim Meehan, Xiao Du, Xiaoxuan Meng, Xin Zhang, Yihong Wang, Ying, Yuanda (Yenda) Li, Zac Blanco, Zac Wen, aditi-pandit, auden-woolfson, ebonnal, jp-sivaprasad, lukmanulhakkeem, mecit-san, mima0000, mohsaka, namya28, tanjialiang, vhsu14, wangd, wraymo diff --git a/presto-docs/src/main/sphinx/release/release-0.294.rst b/presto-docs/src/main/sphinx/release/release-0.294.rst new file mode 100644 index 0000000000000..73b7d871ff7de --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.294.rst @@ -0,0 +1,126 @@ +============= +Release 0.294 +============= + +**Highlights** +============== +* Improve query resource usage by enabling subfield pushdown for :func:`map_filter` when selected keys are constants. `#25451 `_ +* Improve query resource usage by enabling subfield pushdown for :func:`map_subset` when the input array is a constant array. `#25394 `_ +* Improve the efficiency of queries that involve with serialization operator by processing data in large group instead of one by one. `#25569 `_ +* Improve efficiency of queries with distinct aggregation and semi joins. `#25238 `_ +* Add changes to populate data source metadata to support combined lineage tracking. `#25127 `_ +* Add mixed case support for schema and table names. `#24551 `_ +* Add case-sensitive support for column names. It can be enabled for JDBC based connector by setting ``case-sensitive-name-matching=true`` at the catalog level. `#24983 `_ +* Update ``presto-plan-checker-router-plugin router`` plugin to use ``EXPLAIN (TYPE VALIDATE)`` in place of ``EXPLAIN (TYPE DISTRIBUTED)``, enabling faster routing of queries to either native or Java clusters. `#25545 `_ +* From release 0.294, due to Maven Central publishing limitations, executable jar files including ``presto-cli``, ``presto-benchmark-driver``, and ``presto-test-server-launcher`` are no longer published in the Maven Central repository. These jars can now be found on the `Presto GitHub release page `_. + +**Details** +=========== + +General Changes +_______________ +* Fix filter pushdown to enable subfield pushdown for maps which are accessed with negative keys. `#25445 `_ +* Fix error classification for unsupported array comparison with null elements, converting it as a user error. `#25187 `_ +* Fix for :ref:`sql/update:UPDATE` statements involving multiple identical target column values. `#25599 `_ +* Fix inconsistent ordering with offset and limit. `#25216 `_ +* Fix precision loss in ``parse_duration`` function for large millisecond values. `#25538 `_ +* Fix randomize null join optimizer to keep HBO information for join input. `#25466 `_ +* Improve and optimize Docker image layers. `#25487 `_ +* Improve efficiency of inserts on ORC files. `#24913 `_ +* Improve query resource usage by enabling subfield pushdown for :func:`map_filter` when selected keys are constants. `#25451 `_ +* Improve query resource usage by enabling subfield pushdown for :func:`map_subset` when the input array is a constant array. `#25394 `_ +* Improve semi join performance for large filtering tables. `#25236 `_ +* Improve efficiency of queries with distinct aggregation and semi joins. `#25238 `_ +* Improve performance of min_by/max_by aggregations. `#25190 `_ +* Add :func:`dot_product(array(real), array(real)) -> real()` to calculate the sum of element wise product between two identically sized vectors represented as arrays. This function supports both array(real) and array(double) input types. For more information, refer to the `Dot Product definition `_. `#25508 `_ +* Add ``broadcast_semi_join_for_delete`` session property to disable the ReplicateSemiJoinInDelete optimizer. `#25256 `_ +* Add ``history_based_optimizer_estimate_size_using_variables`` session property to have HBO estimate plan node output size using individual variables. `#25400 `_ +* Add changes to populate data source metadata to support combined lineage tracking. `#25127 `_ +* Add mixed case support for schema and table names. `#24551 `_ +* Add session property ``native_query_memory_reclaimer_priority`` which controls which queries are killed first when a worker is running low on memory. Higher value means lower priority to be consistent with Velox memory reclaimer's convention. See :doc:`/presto_cpp/properties-session`. `#25325 `_ +* Add xxhash64 override with seed argument. `#25521 `_ +* Add the :func:`l2_squared(array(real), array(real)) -> real()` function to Java workers. `#25409 `_ +* Update QueryPlanner to only include the optional ``$row_id`` column in :ref:`sql/delete:DELETE` query output variables when it is actually used by the connector. `#25284 `_ +* Update the default value of ``check_access_control_on_utilized_columns_only`` session property to ``true``. The ``false`` value makes the access check apply to all columns. See :ref:`admin/properties-session:\`\`check_access_control_on_utilized_columns_only\`\``. `#25469 `_ + +Prestissimo (Native Execution) Changes +______________________________________ +* Fix Native Plan Checker for CTAS and Insert queries. `#25115 `_ +* Fix native session property manager reading plugin configs from file. `#25553 `_ +* Fix PrestoExchangeSource 400 Bad Request by adding the "Host" header. `#25272 `_ +* Improve memory usage in the ``PartitionAndSerialize`` operator and lower memory usage when serializing a sort key. `#25393 `_ +* Improve the efficiency of queries that involve with serialization operator by processing data in large groups instead of one by one. `#25569 `_ +* Add geometry type to the list of supported types in NativeTypeManager. `#25560 `_ +* Update stats API and Presto UI to report number of drivers and splits separately. `#24671 `_ + +Router Changes +______________ +* Add the `Presto Plan Checker Router Scheduler Plugin `_. `#25035 `_ +* Replace the parameters in router schedulers to use `RouterRequestInfo` to get the URL destination. `#25244 `_ +* Update ``presto-plan-checker-router-plugin router`` plugin to use ``EXPLAIN (TYPE VALIDATE)`` in place of ``EXPLAIN (TYPE DISTRIBUTED)``, enabling faster routing of queries to either native or Java clusters. `#25545 `_ +* Update router UI to eliminate vulnerabilities. `#25206 `_ + +Security Changes +________________ +* Add authorization support for ``SHOW CREATE TABLE``, ``SHOW CREATE VIEW``, ``SHOW COLUMNS``, and ``DESCRIBE`` queries. `#25364 `_ +* Upgrade ``commons-beanutils`` dependency to address `CVE-2025-48734 `_. `#25235 `_ +* Upgrade ``commons-lang3`` to 3.18.0 to address `CVE-2025-48924 `_. `#25549 `_ +* Upgrade ``kafka`` to 3.9.1 in response to `CVE-2025-27817 `_. `#25312 `_ + +JDBC Driver Changes +___________________ +* Fix issue introduced in `#25127 `_ by introducing `TableLocationProvider` interface to decouple table location logic from JDBC configuration. `#25582 `_ +* Improve type mapping API to add WriteMapping functionality. `#25437 `_ +* Add mixed case support related catalog property in JDBC connector ``case-sensitive-name-matching``. `#24551 `_ +* Add case-sensitive support for column names. It can be enabled for JDBC based connector by setting ``case-sensitive-name-matching=true`` at the catalog level. `#24983 `_ + +Arrow Flight Connector Changes +______________________________ +* Add support for mTLS authentication in Arrow Flight client. See :ref:`connector/base-arrow-flight:Configuration`. `#25179 `_ + +Delta Lake Connector Changes +____________________________ +* Improve mapping of ``TIMESTAMP`` column type by changing it from Presto ``TIMESTAMP`` type to ``TIMESTAMP_WITH_TIME_ZONE``. `#24418 `_ +* Add support for ``TIMESTAMP_NTZ`` column type as Presto ``TIMESTAMP`` type. ``legacy_timestamp`` should be set to ``false`` to match delta type specifications. When set to ``false``, ``TIMESTAMP`` will not adjust based on local timezone. `#24418 `_ + +Hive Connector Changes +______________________ +* Fix an issue while accessing symlink tables. `#25307 `_ +* Fix incorrectly ignoring computed table statistics in ``ANALYZE``. `#24973 `_ +* Improve split generation and read throughput for symlink tables. `#25277 `_ +* Add support for symlink files in :ref:`connector/hive:Quick Stats`. `#25250 `_ +* Update default value of ``hive.copy-on-first-write-configuration-enabled`` to ``false``. `#25420 `_ + +Iceberg Connector Changes +_________________________ +* Fix error querying ``$data_sequence_number`` metadata column for table with equality deletes. `#25293 `_ +* Fix the :ref:`connector/iceberg:Remove Orphan Files` procedure after deletion operations. `#25220 `_ +* Add ``iceberg.delete-as-join-rewrite-max-delete-columns`` configuration property and ``delete_as_join_rewrite_max_delete_columns`` session property to control when equality delete as join optimization is applied. The optimization is now only applied when the number of equality delete columns is less than or equal to this threshold (default: 400). Set to 0 to disable the optimization. See :doc:`/connector/iceberg`. `#25462 `_ +* Add support for ``$delete_file_path`` metadata column. `#25280 `_ +* Add support for ``$deleted`` metadata column. `#25280 `_ +* Add support of ``rename view`` for Iceberg connector when configured with ``REST`` and ``NESSIE``. `#25202 `_ +* Deprecate ``iceberg.delete-as-join-rewrite-enabled`` configuration property and ``delete_as_join_rewrite_enabled`` session property. Use ``iceberg.delete-as-join-rewrite-max-delete-columns`` instead. `#25462 `_ + +MySQL Connector Changes +_______________________ +* Add support for mixed-case in MySQL. It can be enabled by setting ``case-sensitive-name-matching=true`` configuration in the catalog configuration. `#24551 `_ + +Redshift Connector Changes +__________________________ +* Fix Redshift ``VARBYTE`` column handling for JDBC driver version 2.1.0.32+ by mapping ``jdbcType=1111`` and ``jdbcTypeName="binary varying"`` to Presto's ``VARBINARY`` type. `#25488 `_ +* Fix Redshift connector runtime failure due to a missing dependency on ``com.amazonaws.util.StringUtils``. Add ``aws-java-sdk-core`` as a runtime dependency to support Redshift JDBC driver (v2.1.0.32) which relies on this class for metadata operations. `#25265 `_ + +SPI Changes +___________ +* Add a function to SPI ``Constraint`` class to return the input arguments for the predicate. `#25248 `_ +* Add support for ``UnnestNode`` in connector optimizers. `#25317 `_ + +Documentation Changes +_____________________ +* Add :ref:`connector/hive:Avro Configuration Properties` to Hive Connector documentation. `#25311 `_ +* Add documentation for ``hive.copy-on-first-write-configuration-enabled`` configuration property to :ref:`connector/hive:Hive Configuration Properties`. `#25443 `_ + +**Credits** +=========== + +Amit Dutta, Anant Aneja, Andrew Xie, Andrii Rosa, Auden Woolfson, Beinan, Chandra Vankayalapati, Chandrashekhar Kumar Singh, Chen Yang, Christian Zentgraf, Deepak Majeti, Denodo Research Labs, Elbin Pallimalil, Emily (Xuetong) Sun, Facebook Community Bot, Feilong Liu, Gary Helmling, Hazmi, HeidiHan0000, Henry Edwin Dikeman, Jalpreet Singh Nanda (:imjalpreet), Joe Abraham, Ke Wang, Ke Wang, Kevin Tang, Li Zhou, Mahadevuni Naveen Kumar, Natasha Sehgal, Nidhin Varghese, Nikhil Collooru, Nishitha-Bhaskaran, Ping Liu, Pradeep Vaka, Pramod Satya, Pratik Joseph Dabre, Raaghav Ravishankar, Rebecca Schlussel, Reetika Agrawal, Sebastiano Peluso, Sergey Pershin, Sergii Druzkin, Shahim Sharafudeen, Shakyan Kushwaha, Shang Ma, Shelton Cai, Shrinidhi Joshi, Soumya Duriseti, Sreeni Viswanadha, Steve Burnett, Thanzeel Hassan, Tim Meehan, Vincent Crabtree, Wei He, XiaoDu, Xiaoxuan, Yihong Wang, Ying, Zac Blanco, Zac Wen, Zhichen Xu, Zhiying Liang, Zoltan Arnold Nagy, aditi-pandit, ajay kharat, duhow, github username, jay.narale, lingbin, martinsander00, mohsaka, namya28, pratyakshsharma, vhsu14, wangd diff --git a/presto-docs/src/main/sphinx/release/release-0.295.rst b/presto-docs/src/main/sphinx/release/release-0.295.rst new file mode 100644 index 0000000000000..962fcabb941c8 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.295.rst @@ -0,0 +1,172 @@ +============= +Release 0.295 +============= + +**Breaking Changes** +==================== +* Add all inline SQL invoked functions into a new plugin ``presto-sql-invoked-functions-plugin``. The following functions were moved: ``replace_first``, ``trail``, ``key_sampling_percent``, ``no_values_match``, ``no_keys_match``, ``any_values_match``, ``any_keys_match``, ``all_keys_match``, ``map_remove_null_values``, ``map_top_n_values``, ``map_top_n_keys``, ``map_top_n``, ``map_key_exists``, ``map_keys_by_top_n_values``, ``map_normalize``, ``array_top_n``, ``remove_nulls``, ``array_sort_desc``, ``array_min_by``, ``array_max_by``, ``array_least_frequent``, ``array_has_duplicates``, ``array_duplicates``, ``array_frequency``, ``array_split_into_chunks``, ``array_average``, ``array_intersect``. See `#26025 `_ and `presto-sql-helpers/README.md `_. `#25818 `_ +* Upgrade Presto to require Java 17. The Presto client and Presto-on-Spark remain Java 8-compatible. Presto now requires a Java 17 VM to run both coordinator and workers. `#24866 `_ + +**Highlights** +============== +* Add OAuth2 support for WebUI and JDBC Presto Client. `#24443 `_ +* Add a new configuration property ``query.max-queued-time`` to specify maximum queued time for a query before killing it. This can be overridden by the ``query_max_queued_time`` session property. `#25589 `_ +* Add spatial join support for native execution. `#25823 `_ +* Add support for `mutual TLS (mTLS) authentication `_ in the Arrow Flight connector. `#25388 `_ +* Add support for `GEOMETRY `_ type in the PostgreSQL connector. `#25240 `_ +* Add documentation about the Presto :doc:`/develop/release-process` and :doc:`/admin/version-support`. `#25742 `_ +* Add support for configuring http2 server on worker for communication between coordinator and workers. To enable, set the configuration property ``http-server.http2.enabled`` to ``true``. `#25708 `_ +* Add support for cross-cluster query retry. Failed queries can be automatically retried on a backup cluster by providing the retry URL and expiration time as query parameters. `#25625 `_ + +**Details** +=========== + +General Changes +_______________ +* Fix `localtime` and `current_time` issues in legacy timestamp semantics. `#25985 `_ +* Fix a bug where ``map(varchar, json)`` does not canonicalize values. See :doc:`/functions/map`. `#24232 `_ +* Fix add exchange and add local exchange optimizers to simplify query plans with unique columns. `#25882 `_ +* Fix failure when preparing statements or creating views that contain a quoted reserved word as a table name. `#25528 `_ +* Fix weak cipher mode usage during spilling by switching to a stronger algorithm. `#25603 `_ +* Improve ``DELETE`` on columns with special characters in their names. `#25737 `_ +* Improve the protocol efficiency of the C++ worker by supporting thrift codec for connector-specific data. `#25595 `_ +* Improve the protocol efficiency of coordinator by supporting thrift codec for connector-specific data. `#25242 `_ +* Add Scale and Precision columns to :doc:`/sql/show-columns` to get the respective scale of the decimal value and precision of numerical values. A Length column is introduced to get the length of ``CHAR`` and ``VARCHAR`` fields. `#25351 `_ +* Add ``Cache-Control`` header with max-age to statement API responses. `#25433 `_ +* Add ``X-Presto-Retry-Query`` header to identify queries that are being retried on a backup cluster. `#25625 `_ +* Add ``presto-sql-helpers`` directory for inlined SQL invoked function plugins with plugin loading rules. `#26025 `_ +* Add a new plugin ``presto-native-sql-invoked-functions-plugin`` that contains all inline SQL functions, except those with overridden native implementations. `#25870 `_ +* Add ``max_serializable_object_size`` session property to change the maximum serializable object size at the coordinator. `#25616 `_ +* Add all inline SQL invoked functions into a new plugin ``presto-sql-invoked-functions-plugin``. The following functions were moved: ``replace_first``, ``trail``, ``key_sampling_percent``, ``no_values_match``, ``no_keys_match``, ``any_values_match``, ``any_keys_match``, ``all_keys_match``, ``map_remove_null_values``, ``map_top_n_values``, ``map_top_n_keys``, ``map_top_n``, ``map_key_exists``, ``map_keys_by_top_n_values``, ``map_normalize``, ``array_top_n``, ``remove_nulls``, ``array_sort_desc``, ``array_min_by``, ``array_max_by``, ``array_least_frequent``, ``array_has_duplicates``, ``array_duplicates``, ``array_frequency``, ``array_split_into_chunks``, ``array_average``, ``array_intersect``. See `#26025 `_ and `presto-sql-helpers/README.md `_. `#25818 `_ +* Add ``array_sort(array, function)`` support for key-based sorting. See :doc:`/functions/array`. `#25851 `_ +* Add ``array_sort_desc(array, function)`` support for key-based sorting. See :doc:`/functions/array`. `#25851 `_ +* Add OAuth2 support for WebUI and JDBC Presto Client. `#24443 `_ +* Add a new configuration property ``query.max-queued-time`` to specify maximum queued time for a query before killing it. This can be overridden by the ``query_max_queued_time`` session property. `#25589 `_ +* Add support for BuiltInFunctionKind enum parameter in BuiltInFunctionHandle's JSON constructor creator. `#25821 `_ +* Add support for configuring http2 server on worker for communication between coordinator and workers. To enable, set the configuration property ``http-server.http2.enabled`` to ``true``. `#25708 `_ +* Add support for cross-cluster query retry. Failed queries can be automatically retried on a backup cluster by providing the retry URL and expiration time as query parameters. `#25625 `_ +* Add support for using a Netty client to do HTTP communication between coordinator and worker. To enable, set the configuration property ``reactor.netty-http-client-enabled`` to ``true`` on the coordinator. `#25573 `_ +* Add test methods ``assertStartTransaction`` and ``assertEndTransaction`` to better support non-autocommit transaction testing scenarios. `#25053 `_ +* Add a database-based session property manager. See :doc:`/admin/session-property-managers`. `#24995 `_ +* Add support to use the MariaDB Java client with a MySQL based function server. `#25698 `_ +* Add support and plumbing for ``DELETE`` queries to identify modified partitions as outputs in the generated QueryIOMetadata. `#26134 `_ +* Add reporting lineage details for columns which are created or inserted to the event listener. `#25913 `_ +* Upgrade Presto to require Java 17. The Presto client and Presto-on-Spark remain Java 8-compatible. Presto now requires a Java 17 VM to run both coordinator and workers. `#24866 `_ +* Update Provisio packaging to split plugin packaging into ``plugins`` and ``native-plugin`` directory. `#25984 `_ +* Update Provisio plugin to package the memory connector plugin under the ``native-plugin`` directory. `#26044 `_ +* Update to preserve table name quoting in the output of :doc:`/sql/show-create-view`. `#25528 `_ + +Prestissimo (Native Execution) Changes +______________________________________ +* Fix an issue when processing multiple splits for the same plan node from multiple sources. `#26031 `_ +* Fix constant folding to handle deeply nested call statements. `#26088 `_ +* Fix constant folding in sidecar enabled clusters. `#26125 `_ +* Improve native execution of sidecar query analysis by enabling Presto built-in functions. `#25135 `_ +* Add the parameterized ``VARCHAR`` type in the list of supported types in NativeTypeManager. `#26003 `_ +* Add session property :ref:`presto_cpp/properties-session:\`\`native_index_lookup_join_max_prefetch_batches\`\`` which controls the max number of input batches to prefetch to do index lookup ahead. If it is set to ``0``, then process one input batch at a time. `#25886 `_ +* Add session property :ref:`presto_cpp/properties-session:\`\`native_index_lookup_join_split_output\`\``. If set to ``true``, then the index join operator might split output for each input batch based on the output batch size control. Otherwise, it tries to produce a single output for each input batch. `#25886 `_ +* Add session property :ref:`presto_cpp/properties-session:\`\`native_unnest_split_output\`\``. If this is set to ``true``, then the unnest operator might split output for each input batch based on the output batch size control. Otherwise, it produces a single output for each input batch. `#25886 `_ +* Add session properties :ref:`presto_cpp/properties-session:\`\`native_debug_memory_pool_name_regex\`\`` and :ref:`presto_cpp/properties-session:\`\`native_debug_memory_pool_warn_threshold_bytes\`\`` to help debug memory pool usage patterns. `25750 `_ +* Add limited use of the ``CHAR(N)`` type with PrestoC++. When ``CHAR(N)`` is used in a query it is mapped to the Velox ``VARCHAR`` type. As a result ``CHAR(N)`` semantics are not preserved in the exectution engine. `#25843 `_ +* Add spatial join support for native execution. `#25823 `_ +* Rename ``native_query_trace_node_ids`` to ``native_query_trace_node_id`` to provide a single plan node id for tracing. `#25684 `_ +* Update coordinator behavior to validate sidecar function signatures against plugin loaded function signatures at startup. `#25919 `_ + +Security Changes +________________ +* Fix the Content Security Policy (CSP) by adding ``form-action 'self'`` and setting ``img-src 'self'`` in response to `CWE-693 `_. `#25910 `_ +* Upgrade Netty to version 4.1.126.Final to address `CVE-2025-58056 `_ and `CVE-2025-58057 `_. `#26006 `_ +* Upgrade commons-lang3 to 3.18.0 to address `CVE-2025-48924 `_. `#25751 `_ +* Upgrade jaxb-runtime to v4.0.5 in response to `CVE-2020-15250 `_. `#26024 `_ +* Upgrade netty dependency to address `CVE-2025-55163 `_. `#25806 `_ +* Upgrade reactor-netty-http dependency to address `CVE-2025-22227 `_. `#25739 `_ + +JDBC Driver Changes +___________________ +* Add ``DECIMAL`` type support to query builder. `#25699 `_ + +Web UI Changes +______________ +* Fix the query id tooltip being displayed at an incorrect position. `#25809 `_ + +Arrow Flight Connector Changes +______________________________ +* Add support for `mutual TLS (mTLS) authentication `_. `#25388 `_ + +BigQuery Connector Changes +__________________________ +* Fix query failures on ``SELECT`` operations by aligning BigQuery v1beta1 with protobuf-java 3.25.8, preventing runtime incompatibility with protobuf 4.x. `#25805 `_ +* Add support for case-sensitive identifiers in BigQuery. To enable, set the configuration property ``case-sensitive-name-matching=true`` in the catalog file. `#25764 `_ + +Cassandra Connector Changes +___________________________ +* Add support to read ``TUPLE`` type as a Presto ``VARCHAR``. `#25516 `_ + +ClickHouse Connector Changes +____________________________ +* Add support for case-sensitive identifiers in Clickhouse. To enable, set the configuration property ``case-sensitive-name-matching=true`` in the catalog file. `#25863 `_ + +Delta Lake Connector Changes +____________________________ +* Upgrade to Hadoop 3.4.1. `#24799 `_ + +Hive Connector Changes +______________________ +* Fix Hive connector to ignore unsupported table formats when querying ``system.jdbc.columns`` to prevent errors. `#25779 `_ +* Add session property ``hive.orc_use_column_names`` to toggle the accessing of columns based on the names recorded in the ORC file rather than their ordinal position in the file. `#25285 `_ +* Upgrade to Hadoop 3.4.1. `#24799 `_ + +Hudi Connector Changes +______________________ +* Upgrade to Hadoop 3.4.1. `#24799 `_ + +Iceberg Connector Changes +_________________________ +* Fix null pointer exception (NPE) error in getViews API call when a schema is not provided. `#25695 `_ +* Fix implementation of commit to do one operation as opposed to two. `#25615 `_ +* Fix Iceberg connector rename column failed if the column is used as source column of non-identity transform. `#25697 `_ +* Improve Iceberg's ``apply_changelog`` function by migrating it from the global namespace to the connector-specific namespace. The function is now available as ``iceberg.system.apply_changelog()`` instead of ``apply_changelog()``. `#25871 `_ +* Improve the property mechanism to enable a property to accept and process property values of multiple types. `#25862 `_ +* Add Iceberg bucket scalar function. `#25951 `_ +* Add ``iceberg.engine.hive.lock-enabled`` configuration to disable Hive locks. `#25615 `_ +* Add support for specifying multiple transforms when adding a column. `#25862 `_ +* Upgrade Iceberg version to 1.8.1. `#25999 `_ +* Upgrade Nessie to version 0.95.0. `#25593 `_ +* Upgrade to Hadoop 3.4.1. `#24799 `_ +* Update to implement ConnectorMetadata::finishDeleteWithOutput(). `#26134 `_ + +Kudu Connector Changes +______________________ +* Update to implement ConnectorMetadata::finishDeleteWithOutput(). `#26134 `_ + +MongoDB Connector Changes +_________________________ +* Add support for case-sensitive identifiers in MongoDB. To enable, set the configuration property ``case-sensitive-name-matching=true`` in the catalog file. `#25853 `_ +* Upgrade MongoDB java driver to 3.12.14. `#25436 `_ + +PostgreSQL Connector Changes +____________________________ +* Add support for `GEOMETRY `_ type in the PostgreSQL connector. `#25240 `_ + +Redis Connector Changes +_______________________ +* Add changes to enable TLS support. `#25373 `_ + +SPI Changes +___________ +* Add a new ``getSqlInvokedFunctions`` SPI in Presto, which only supports SQL invoked functions. `#25597 `_ +* Add a new ``ConnectorMetadata::finishDeleteWithOutput()`` method, returning ``Optional``. This allows connectors implementing ``DELETE`` to identify partitions modified in queries, which can be important for tracing lineage. `#26134 `_ +* Add AuthenticatorNotApplicableException to prevent irrelevant authenticator errors from being returned to clients. `#25606 `_ +* Deprecate the existing ``ConnectorMetadata::finishDelete()`` method. By default, the new ``finishDeleteWithOutput()`` method delegates to the existing ``finishDelete()`` method, and returns ``Optional.empty()``. This allows existing connectors to continue working without changes. `#26134 `_ + +Documentation Changes +_____________________ +* Improve :doc:`/installation/deploy-brew`. `#25924 `_ +* Add documentation about the Presto :doc:`/develop/release-process` and :doc:`/admin/version-support`. `#25742 `_ + + + +**Credits** +=========== + +Abhash Jain, Adrian Carpente (Denodo), Amit Dutta, Amritanshu Darbari, Anant Aneja, Andrew Xie, Arjun Gupta, Artem Selishchev, Bryan Cutler, Christian Zentgraf, Dilli-Babu-Godari, Elbin Pallimalil, Facebook Community Bot, Feilong Liu, Gary Helmling, Ge Gao, Hazmi, HeidiHan0000, Jalpreet Singh Nanda (:imjalpreet), James Gill, Jay Narale, Jialiang Tan, Joe Abraham, Joe O'Hallaron, Karthikeyan Natarajan, Ke Wang, Ke Wang, Kevin Tang, Kewen Wang, Krishna Pai, Mahadevuni Naveen Kumar, Maria Basmanova, Mariam Almesfer, Matt Karrmann, Miguel Blanco Godón, Natasha Sehgal, Naveen Nitturu, Nidhin Varghese, Nikhil Collooru, Nishitha-Bhaskaran, PRASHANT GOLASH, Ping Liu, Pradeep Vaka, Pramod Satya, Prashant Sharma, Pratik Joseph Dabre, Raaghav Ravishankar, Rebecca Schlussel, Rebecca Whitworth, Reetika Agrawal, Richard Barnes, Sayari Mukherjee, Sergey Pershin, Shahim Sharafudeen, Shang Ma, Shijin, Shrinidhi Joshi, Steve Burnett, Sumi Mathew, Timothy Meehan, Valery Mironov, Vamsi Karnika, Vivian Hsu, Wei He, Xiaoxuan Meng, Xin Zhang, Yihong Wang, Ying, Zac Blanco, Zac Wen, abhinavmuk04, aditi-pandit, adkharat, aspegren_david, auden-woolfson, beinan, dnskr, ericyuliu, haneel-kumar, j-sund, juwentus1234, lingbin, mehradpk, mohsaka, pratik.pugalia@gmail.com, pratyakshsharma, singcha, unidevel, wangd, yangbin09 diff --git a/presto-docs/src/main/sphinx/release/release-0.296.rst b/presto-docs/src/main/sphinx/release/release-0.296.rst new file mode 100644 index 0000000000000..768b4dbef8995 --- /dev/null +++ b/presto-docs/src/main/sphinx/release/release-0.296.rst @@ -0,0 +1,153 @@ +============= +Release 0.296 +============= + +**Breaking Changes** +==================== +* Replace default Iceberg compression codec from ``GZIP`` to ``ZSTD``. Existing tables are unaffected, but new tables will use ZSTD compression by default if ``iceberg.compression-codec`` is not set. `#26399 `_ +* Replace the ``String serializedCommitOutput`` argument with ``Optional commitOutput`` in the ``com.facebook.presto.spi.eventlistener.QueryInputMetadata`` and ``com.facebook.presto.spi.eventlistener.QueryOutputMetadata`` constructors. `#26331 `_ + +**Highlights** +============== +* Add support for :doc:`Materialized Views `. `#26492 `_ +* Add support for the :doc:`/sql/merge` command in the Presto engine. `#26278 `_ +* Add support for distributed execution of procedures. `#26373 `_ +* Add HTTP/2 support for internal cluster communication with data compression. `#26439 `_ `#26381 `_ +* Add support for basic insertion to Iceberg tables on C++ worker clusters. `#26338 `_ + +**Details** +=========== + +General Changes +_______________ +* Improve sort-merge join performance when one side of the join input is already sorted. `#26361 `_ +* Improve query performance for semi joins (used in ``IN`` and ``EXISTS`` subqueries) when join keys contain many null values. `#26251 `_ +* Improve connector optimizer to support queries involving multiple connectors. `#26246 `_ +* Add :func:`array_transpose` to return a transpose of an array. `#26470 `_ +* Add :ref:`admin/properties:\`\`cluster-tag\`\`` configuration property to assign a custom identifier to the cluster, which is displayed in the Web UI. `#26485 `_ +* Add a session property ``query_types_enabled_for_history_based_optimization`` to specify query types which will use HBO. See :doc:`/optimizer/history-based-optimization`. `#26183 `_ +* Add data compression support for HTTP/2 protocol. `#26381 `_ `#26382 `_ +* Add :ref:`admin/properties:\`\`max-prefixes-count\`\`` configuration property to limit the number of catalog/schema/table scope prefixes generated when querying ``information_schema``, which can improve metadata query performance. `#25550 `_ +* Add detailed latency and failure count metrics for the system access control plugin. `#26116 `_ +* Add experimental support for sorted exchanges to improve sort-merge join performance. When enabled with the ``sorted_exchange_enabled`` session property or the ``optimizer.experimental.sorted-exchange-enabled`` configuration property, this optimization eliminates redundant sorting steps and reduces memory usage for distributed queries with sort-merge joins. This feature is disabled by default. `#26403 `_ +* Add HTTP/2 support for internal cluster communication. `#26439 `_ +* Add ``native_use_velox_geospatial_join`` session property to enable an optimized implementation for geospatial joins in native execution. This feature is enabled by default. `#26057 `_ +* Add support for the :doc:`/sql/merge` command in the Presto engine. `#26278 `_ +* Add ``enable-java-cluster-query-retry`` configuration property in ``router-scheduler.properties`` to retry queries on ``router-java-url`` when they fail on ``router-native-url``. `#25720 `_ +* Add :func:`array_to_map_int_keys` function. `#26681 `_ +* Add :func:`map_int_keys_to_array` function. `#26681 `_ +* Add :func:`t_cdf` and :func:`inverse_t_cdf` functions for Student's t-distribution calculations. `#26363 `_ +* Add support for distributed execution of procedures. `#26373 `_ +* Add support for :doc:`Materialized Views `. `#26492 `_ +* Update encoding of refresh token secret key from HMAC to AES. `#26487 `_ + +Prestissimo (Native Execution) Changes +______________________________________ +* Fix query errors when using mixed case column names with the Iceberg connector. `#26163 `_ +* Add ``native_max_partial_aggregation_memory`` session property to control memory limits for partial aggregation. `#26389 `_ +* Add :ref:`presto_cpp/properties-session:\`\`native_max_split_preload_per_driver\`\`` session property to configure the maximum number of splits to preload per driver. `#26591 `_ +* Add support for basic insertion to Iceberg tables. `#26338 `_ +* Add support for custom schemas in native sidecar function registry. `#26236 `_ +* Add support for the TPC-DS connector. `#24751 `_ +* Add support for REST API for remote functions. `#23568 `_ + +Security Changes +________________ +* Upgrade dagre-d3-es to 7.0.13 in response to `CVE-2025-57347 `_. `#26422 `_ +* Upgrade Netty to 4.1.128.Final to address `CVE-2025-59419 `_. `#26349 `_ +* Upgrade at.favre.lib:bcrypt version to 0.10.2 in response to `CVE-2020-15250 `_. `#26463 `_ +* Upgrade calcite-core to 1.41.0 in response to `CVE-2025-48924 `_. `#26248 `_ +* Upgrade io.grpc:grpc-netty-shaded from 1.70.0 to 1.75.0 to address `CVE-2025-55163 `_. `#26273 `_ +* Upgrade mssql-jdbc to 12.10.2.jre8 to address `CVE-2025-59250 `_. `#26534 `_ +* Upgrade org.apache.calcite to 1.38.0 in response to `CVE-2022-36944 `_. `#26400 `_ +* Upgrade zookeeper to 3.9.4 to address `CVE-2025-58457 `_. `#26180 `_. + +Arrow Flight Connector Changes +______________________________ +* Add support for case-sensitive identifiers in Arrow. To enable, set ``case-sensitive-name-matching=true``. `#26176 `_ + +Cassandra Connector Changes +___________________________ +* Add support for case-sensitive identifiers in Cassandra. To enable, set ``case-sensitive-name-matching=true`` configuration in the catalog configuration. `#25690 `_ + +Delta Connector Changes +_______________________ +* Fix problem reading Delta Lake tables with spaces in location or partition values. `#26397 `_ + +Druid Connector Changes +_______________________ +* Fix Druid connector to use strict application/json content type. `#26200 `_ +* Add TLS support. `#26027 `_ +* Add support for case-sensitive identifiers in Druid. To enable, set ``case-sensitive-name-matching=true`` configuration in the catalog configuration. `#26038 `_ + +Elasticsearch Connector Changes +_______________________________ +* Add support for case-sensitive identifiers in Elasticsearch. To enable, set ``case-sensitive-name-matching=true`` in the catalog configuration. `#26352 `_ + +Hive Connector Changes +______________________ +* Add support for ``LZ4`` compression codec in ORC format. `#26346 `_ +* Add support for ``ZSTD`` compression codec in Parquet format. `#26346 `_ + +Iceberg Connector Changes +_________________________ +* Fix Bearer authentication with Nessie catalog. `#26512 `_ +* Fix ``SHOW STATS`` for Timestamp with Timezone columns. `#26305 `_ +* Fix reading decimal partition values when using native execution. `#26240 `_ +* Fix handling of ``TIME`` columns in Iceberg tables. `#26523 `_ +* Add support for ``LZ4`` compression codec in ORC format. `#26346 `_ +* Add support for ``ZSTD`` compression codec in Parquet format. `#26346 `_ +* Add support for ``engine.hive.lock-enabled`` property when creating or altering Iceberg tables. `#26234 `_ +* Add support to access Nessie with S3 using Iceberg REST catalog. `#26610 `_ +* Add support for :ref:`Materialized Views `. `#26603 `_ +* Replace default Iceberg compression codec from ``GZIP`` to ``ZSTD``. Existing tables are unaffected, but new tables will use ZSTD compression by default if ``iceberg.compression-codec`` is not set. `#26399 `_ + +Kafka Connector Changes +_______________________ +* Add support for case-sensitive identifiers in Kafka. To enable, set ``case-sensitive-name-matching=true`` in the catalog configuration. `#26023 `_ + +Memory Connector Changes +________________________ +* Add support for :doc:`Materialized Views `. `#26405 `_ + +MongoDB Connector Changes +_________________________ +* Add TLS/SSL support with automatic JKS and PEM certificate format detection. Configure using ``mongodb.tls.enabled``, ``mongodb.tls.keystore-path``, ``mongodb.tls.keystore-password``, ``mongodb.tls.truststore-path``, and ``mongodb.tls.truststore-password`` properties. `#25374 `_ +* Upgrade MongoDB Java Driver to 3.12.14. `#25374 `_ + +MySQL Connector Changes +_______________________ +* Fix timestamp handling when ``legacy_timestamp`` is disabled. Timestamp values are now correctly stored and retrieved as wall-clock times without timezone conversion. Previously, values were incorrectly converted using the JVM timezone, causing data corruption. `#26449 `_ + +Oracle Connector Changes +________________________ +* Add support for fetching table statistics from Oracle source tables. `#26120 `_ +* Add support for reading Oracle ``BLOB`` columns as ``VARBINARY``. `#25354 `_ + +Pinot Connector Changes +_______________________ +* Add support for case-sensitive identifiers in Pinot. To enable, set ``case-sensitive-name-matching=true`` configuration in the catalog configuration. `#26239 `_ +* Upgrade Pinot version to 1.3.0. `#25785 `_ + +PostgreSQL Connector Changes +____________________________ +* Fix timestamp handling when ``legacy_timestamp`` is disabled. Timestamp values are now correctly stored and retrieved as wall-clock times without timezone conversion. Previously, values were incorrectly converted using the JVM timezone, causing data corruption. `#26449 `_ + +Redis Connector Changes +_______________________ +* Add support for case-sensitive identifiers in Redis. To enable, set ``case-sensitive-name-matching=true`` configuration in the catalog configuration. `#26078 `_ + +SingleStore Connector Changes +_____________________________ +* Fix string type mapping to support ``VARCHAR(len)`` where len <= 21844. `#25476 `_ + +SPI Changes +___________ +* Add ``getCommitOutputForRead()`` and ``getCommitOutputForWrite()`` methods to ``ConnectorCommitHandle``, and deprecates the existing ``getSerializedCommitOutputForRead()`` and ``getSerializedCommitOutputForWrite()`` methods. `#26331 `_ +* Add new metric ``getTotalScheduledTime()`` to QueryStatistics SPI. This value is the sum of wall time across all threads of all tasks/stages of a query that were actually scheduled for execution. `#26279 `_ +* Replace the ``String serializedCommitOutput`` argument with ``Optional commitOutput`` in the ``com.facebook.presto.spi.eventlistener.QueryInputMetadata`` and ``com.facebook.presto.spi.eventlistener.QueryOutputMetadata`` constructors. `#26331 `_ + +**Credits** +=========== + +Aditi Pandit, Adrian Carpente (Denodo), Alex Austin Chettiar, Amit Dutta, Anant Aneja, Andrew X, Andrii Rosa, Artem Selishchev, Auden Woolfson, Bryan Cutler, Chris Matzenbach, Christian Zentgraf, Deepak Majeti, Denodo Research Labs, Dilli-Babu-Godari, Dong Wang, Elbin Pallimalil, Gary Helmling, Ge Gao, Han Yan, HeidiHan0000, Jalpreet Singh Nanda, James Gill, Jay Feldblum, Jiaqi Zhang, Joe Abraham, Joe O'Hallaron, Karthikeyan, Ke, Kevin Tang, Li Zhou, LingBin, Maria Basmanova, Mariam AlMesfer, Namya Sehgal, Natasha Sehgal, Nidhin Varghese, Nikhil Collooru, Nivin C S, PRASHANT GOLASH, Pedro Pedreira, Ping Liu, Pramod Satya, Prashant Sharma, Pratyaksh Sharma, Rebecca Schlussel, Reetika Agrawal, RindsSchei225e, Sayari Mukherjee, Sergey Pershin, Shahad Shamsan, Shahim Sharafudeen, Shang Ma, Shrinidhi Joshi, Sreeni Viswanadha, Steve Burnett, Tal Galili, Timothy Meehan, Weitao Wan, XiaoDu, Xiaoxuan, Xin Zhang, Yihong Wang, Yolande Yan, Zac, Zoltán Arnold Nagy, abhinavmuk04, adheer-araokar, bibith4, dependabot[bot], ericyuliu, feilong-liu, inf, jkhaliqi, maniloya, mohsaka, nishithakbhaskaran, rkurniawati, shanhao-203, singcha, sumi-mathew, tanjialiang, vhsu14 diff --git a/presto-docs/src/main/sphinx/security.rst b/presto-docs/src/main/sphinx/security.rst index f8c5432555e97..6d793c4c785e6 100644 --- a/presto-docs/src/main/sphinx/security.rst +++ b/presto-docs/src/main/sphinx/security.rst @@ -13,3 +13,4 @@ Security security/built-in-system-access-control security/internal-communication security/authorization + security/oauth2 diff --git a/presto-docs/src/main/sphinx/security/authorization.rst b/presto-docs/src/main/sphinx/security/authorization.rst index 6092a45ca6235..82649077ba3e3 100644 --- a/presto-docs/src/main/sphinx/security/authorization.rst +++ b/presto-docs/src/main/sphinx/security/authorization.rst @@ -46,7 +46,7 @@ so make sure you have authentication enabled. http-server.authentication.type=CERTIFICATE - It is also possible to specify other authentication types such as - ``KERBEROS``, ``PASSWORD`` and ``JWT``. Additional configuration may be + ``KERBEROS``, ``PASSWORD``, ``JWT``, and ``OAUTH2``. Additional configuration may be needed. .. code-block:: none diff --git a/presto-docs/src/main/sphinx/security/cli.rst b/presto-docs/src/main/sphinx/security/cli.rst index aead5094a55f2..ccbeb86added0 100644 --- a/presto-docs/src/main/sphinx/security/cli.rst +++ b/presto-docs/src/main/sphinx/security/cli.rst @@ -33,8 +33,6 @@ principal. .. include:: ktadd-note.fragment -.. include:: jce-policy.fragment - Java Keystore File for TLS ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-docs/src/main/sphinx/security/ldap.rst b/presto-docs/src/main/sphinx/security/ldap.rst index 123bdabce2f5a..f487403f66697 100644 --- a/presto-docs/src/main/sphinx/security/ldap.rst +++ b/presto-docs/src/main/sphinx/security/ldap.rst @@ -85,6 +85,7 @@ Property Description Should be set to ``true``. Default value is ``false``. ``http-server.https.port`` HTTPS server port. +``http-server.http2.enabled`` Enables HTTP2 server on the worker. ``http-server.https.keystore.path`` The location of the Java Keystore file that will be used to secure TLS. ``http-server.https.keystore.key`` The password for the keystore. This must match the @@ -317,3 +318,37 @@ with the appropriate :abbr:`SAN (Subject Alternative Name)` added. Adding a SAN to this certificate is required in cases where ``https://`` uses IP address in the URL rather than the domain contained in the coordinator's certificate, and the certificate does not contain the :abbr:`SAN (Subject Alternative Name)` parameter with the matching IP address as an alternative attribute. + +No console from which to read password +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you use the command line options ``--user`` and ``--password`` along with +supplying the query statements to execute using ``--file`` or ``--execute`` command line options as below - + +.. code-block:: none + + ./presto \ + --server https://presto-coordinator.example.com:8443 \ + --execute "select * from tpch.tiny.nation" + --user \ + --password + + OR + + ./presto \ + --server https://presto-coordinator.example.com:8443 \ + --file input_queries.txt + --user \ + --password > output.txt + +The following error might be displayed: + +.. code-block:: none + + java.lang.RuntimeException: No console from which to read password + +To avoid this error, export the password using the ``PRESTO_PASSWORD`` environment variable: + +.. code-block:: none + + export PRESTO_PASSWORD= \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/security/oauth2.rst b/presto-docs/src/main/sphinx/security/oauth2.rst new file mode 100644 index 0000000000000..65b676c42d394 --- /dev/null +++ b/presto-docs/src/main/sphinx/security/oauth2.rst @@ -0,0 +1,68 @@ +======================== +Oauth 2.0 Authentication +======================== + +Presto can be configured to enable frontend OAuth2 authentication over HTTPS for clients such as the CLI, JDBC, and ODBC drivers. OAuth2 provides a secure and flexible way to authenticate users by using an external identity provider (IdP), such as Okta, Auth0, Azure AD, or Google. + +OAuth2 authentication in Presto uses the Authorization Code Flow with PKCE and OpenID Connect (OIDC). The Presto coordinator initiates an OAuth2 challenge, and the client completes the flow by obtaining an access token from the identity provider. + +Presto Server Configuration +--------------------------- + +To enable OAuth2 authentication, configuration changes are made **only on the Presto coordinator**. No changes are required on the workers. + +Secure Communication +-------------------- + +Access to the Presto coordinator must be secured with HTTPS. You must configure a valid TLS certificate and keystore on the coordinator. See the `TLS setup guide `_ for details. + +OAuth2 Configuration +-------------------- + +Below are the key configuration properties for enabling OAuth2 authentication in ``config.properties``: + +.. code-block:: properties + + http-server.authentication.type=OAUTH2 + + http-server.authentication.oauth2.issuer=https://your-idp.com/oauth2/default + http-server.authentication.oauth2.client-id=your-client-id + http-server.authentication.oauth2.client-secret=your-client-secret + http-server.authentication.oauth2.scopes=openid,email,profile + http-server.authentication.oauth2.principal-field=sub + http-server.authentication.oauth2.groups-field=groups + http-server.authentication.oauth2.challenge-timeout=15m + http-server.authentication.oauth2.max-clock-skew=1m + http-server.authentication.oauth2.refresh-tokens=true + http-server.authentication.oauth2.oidc.discovery=true + http-server.authentication.oauth2.state-key=your-hmac-secret + http-server.authentication.oauth2.additional-audiences=your-client-id,another-audience + http-server.authentication.oauth2.user-mapping.pattern=(.*) + +It is worth noting that ``configuration-based-authorizer.role-regex-map.file-path`` must be configured if +authentication type is set to ``OAUTH2``. + +TLS Truststore for IdP +---------------------- + +If your IdP uses a custom or self-signed certificate, import it into the Java truststore on the Presto coordinator: + +.. code-block:: bash + + keytool -import \ + -keystore $JAVA_HOME/lib/security/cacerts \ + -trustcacerts \ + -alias idp_cert \ + -file idp_cert.crt + +Notes +----- + +- **Issuer**: The base URL of your IdP’s OIDC discovery endpoint. +- **Client ID/Secret**: Registered credentials for Presto in your IdP. +- **Scopes**: Must include ``openid``; others like ``email``, ``profile``, or ``groups`` are optional. +- **Principal Field**: The claim in the ID token used as the Presto username. +- **Groups Field**: Optional claim used for role-based access control. +- **State Key**: A secret used to sign the OAuth2 state parameter (HMAC). +- **Refresh Tokens**: Enable if your IdP supports issuing refresh tokens. +- **Callback**: When configuring your IdP the callback URI must be set to ``[presto]/oauth2/callback`` diff --git a/presto-docs/src/main/sphinx/security/server.rst b/presto-docs/src/main/sphinx/security/server.rst index a0070995d347d..bc1b81f1246b1 100644 --- a/presto-docs/src/main/sphinx/security/server.rst +++ b/presto-docs/src/main/sphinx/security/server.rst @@ -47,8 +47,6 @@ In addition, the Presto coordinator needs a `keytab file .. include:: ktadd-note.fragment -.. include:: jce-policy.fragment - Java Keystore File for TLS ^^^^^^^^^^^^^^^^^^^^^^^^^^ When using Kerberos authentication, access to the Presto coordinator should be @@ -115,6 +113,7 @@ Property Description ``http-server.https.enabled`` Enables HTTPS access for the Presto coordinator. Should be set to ``true``. ``http-server.https.port`` HTTPS server port. +``http-server.http2.enabled`` Enables HTTP2 server on the worker. ``http-server.https.keystore.path`` The location of the Java Keystore file that will be used to secure TLS. ``http-server.https.keystore.key`` The password for the keystore. This must match the diff --git a/presto-docs/src/main/sphinx/sql.rst b/presto-docs/src/main/sphinx/sql.rst index c18015a1444c5..042bcfea7ad35 100644 --- a/presto-docs/src/main/sphinx/sql.rst +++ b/presto-docs/src/main/sphinx/sql.rst @@ -20,6 +20,7 @@ This chapter describes the SQL syntax used in Presto. sql/create-table sql/create-table-as sql/create-view + sql/create-materialized-view sql/deallocate-prepare sql/delete sql/describe @@ -30,13 +31,16 @@ This chapter describes the SQL syntax used in Presto. sql/drop-schema sql/drop-table sql/drop-view + sql/drop-materialized-view sql/execute sql/explain sql/explain-analyze sql/grant sql/grant-roles sql/insert + sql/merge sql/prepare + sql/refresh-materialized-view sql/reset-session sql/revoke sql/revoke-roles @@ -50,6 +54,7 @@ This chapter describes the SQL syntax used in Presto. sql/show-create-schema sql/show-create-table sql/show-create-view + sql/show-create-materialized-view sql/show-functions sql/show-grants sql/show-role-grants diff --git a/presto-docs/src/main/sphinx/sql/alter-table.rst b/presto-docs/src/main/sphinx/sql/alter-table.rst index 451b4dd2794b6..93778714a24bb 100644 --- a/presto-docs/src/main/sphinx/sql/alter-table.rst +++ b/presto-docs/src/main/sphinx/sql/alter-table.rst @@ -15,6 +15,8 @@ Synopsis ALTER TABLE [ IF EXISTS ] name DROP CONSTRAINT [ IF EXISTS ] constraint_name ALTER TABLE [ IF EXISTS ] name ALTER [ COLUMN ] column_name { SET | DROP } NOT NULL ALTER TABLE [ IF EXISTS ] name SET PROPERTIES (property_name=value, [, ...]) + ALTER TABLE [ IF EXISTS ] name DROP BRANCH [ IF EXISTS ] branch_name + ALTER TABLE [ IF EXISTS ] name DROP TAG [ IF EXISTS ] tag_name Description ----------- @@ -94,6 +96,14 @@ Set table property (``x=y``) to table ``users``:: ALTER TABLE users SET PROPERTIES (x='y'); +Drop branch ``branch1`` from the ``users`` table:: + + ALTER TABLE users DROP BRANCH 'branch1'; + +Drop tag ``tag1`` from the ``users`` table:: + + ALTER TABLE users DROP TAG 'tag1'; + See Also -------- diff --git a/presto-docs/src/main/sphinx/sql/create-materialized-view.rst b/presto-docs/src/main/sphinx/sql/create-materialized-view.rst new file mode 100644 index 0000000000000..0fd5d69de2aa6 --- /dev/null +++ b/presto-docs/src/main/sphinx/sql/create-materialized-view.rst @@ -0,0 +1,98 @@ +======================== +CREATE MATERIALIZED VIEW +======================== + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + + To enable, set :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\`` = ``false`` + in configuration properties. + +Synopsis +-------- + +.. code-block:: none + + CREATE MATERIALIZED VIEW [ IF NOT EXISTS ] view_name + [ COMMENT 'string' ] + [ SECURITY { DEFINER | INVOKER } ] + [ WITH ( property_name = expression [, ...] ) ] + AS query + +Description +----------- + +Create a new materialized view of a :doc:`select` query. The materialized view physically stores +the query results, unlike regular views which are virtual. Queries can read pre-computed results +instead of re-executing the underlying query. + +The optional ``IF NOT EXISTS`` clause causes the materialized view to be created only if it does +not already exist. + +The optional ``COMMENT`` clause stores a description of the materialized view in the metastore. + +The optional ``SECURITY`` clause specifies the security mode for the materialized view. When +``legacy_materialized_views=false``: + +* ``SECURITY DEFINER``: The view executes with the permissions of the user who created it. This is the default mode if ``SECURITY`` is not specified and matches the behavior of most SQL systems. The view owner must have ``CREATE_VIEW_WITH_SELECT_COLUMNS`` permission on base tables for non-owners to query the view. +* ``SECURITY INVOKER``: The view executes with the permissions of the user querying it. Each user must have appropriate permissions on the underlying base tables. + +When ``legacy_materialized_views=true``, the ``SECURITY`` clause is not supported and will +cause an error if used. + +The optional ``WITH`` clause specifies connector-specific properties. Connector properties vary by +connector implementation. Consult connector documentation for supported properties. + +Examples +-------- + +Create a materialized view with daily aggregations:: + + CREATE MATERIALIZED VIEW daily_sales AS + SELECT date_trunc('day', order_date) AS day, + region, + SUM(amount) AS total_sales, + COUNT(*) AS order_count + FROM orders + GROUP BY date_trunc('day', order_date), region + +Create a materialized view with DEFINER security mode:: + + CREATE MATERIALIZED VIEW daily_sales + SECURITY DEFINER + AS + SELECT date_trunc('day', order_date) AS day, + region, + SUM(amount) AS total_sales + FROM orders + GROUP BY date_trunc('day', order_date), region + +Create a materialized view with INVOKER security mode:: + + CREATE MATERIALIZED VIEW user_specific_sales + SECURITY INVOKER + AS + SELECT date_trunc('day', order_date) AS day, + SUM(amount) AS total_sales + FROM orders + GROUP BY date_trunc('day', order_date) + +Create a materialized view with connector properties:: + + CREATE MATERIALIZED VIEW partitioned_sales + WITH ( + partitioned_by = ARRAY['year', 'month'] + ) + AS + SELECT year(order_date) AS year, + month(order_date) AS month, + SUM(amount) AS total_sales + FROM orders + GROUP BY year(order_date), month(order_date) + +See Also +-------- + +:doc:`drop-materialized-view`, :doc:`refresh-materialized-view`, +:doc:`show-create-materialized-view`, :doc:`/admin/materialized-views` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/sql/drop-materialized-view.rst b/presto-docs/src/main/sphinx/sql/drop-materialized-view.rst new file mode 100644 index 0000000000000..ddfa19dabb1c9 --- /dev/null +++ b/presto-docs/src/main/sphinx/sql/drop-materialized-view.rst @@ -0,0 +1,42 @@ +====================== +DROP MATERIALIZED VIEW +====================== + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + + To enable, set :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\`` = ``false`` + in configuration properties. + +Synopsis +-------- + +.. code-block:: none + + DROP MATERIALIZED VIEW [ IF EXISTS ] view_name + +Description +----------- + +Drop an existing materialized view and delete its stored data. + +The optional ``IF EXISTS`` clause causes the statement to succeed even if the materialized view +does not exist. + +Examples +-------- + +Drop the materialized view ``daily_sales``:: + + DROP MATERIALIZED VIEW daily_sales + +Drop the materialized view if it exists:: + + DROP MATERIALIZED VIEW IF EXISTS daily_sales + +See Also +-------- + +:doc:`create-materialized-view`, :doc:`refresh-materialized-view`, +:doc:`show-create-materialized-view`, :doc:`/admin/materialized-views` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/sql/merge.rst b/presto-docs/src/main/sphinx/sql/merge.rst new file mode 100644 index 0000000000000..b4b738dda86ab --- /dev/null +++ b/presto-docs/src/main/sphinx/sql/merge.rst @@ -0,0 +1,71 @@ +===== +MERGE +===== + +Synopsis +-------- + +.. code-block:: text + + MERGE INTO target_table [ [ AS ] target_alias ] + USING { source_table | query } [ [ AS ] source_alias ] + ON search_condition + WHEN MATCHED THEN + UPDATE SET ( column = expression [, ...] ) + WHEN NOT MATCHED THEN + INSERT [ column_list ] + VALUES (expression, ...) + +Description +----------- + +The ``MERGE`` statement inserts or updates rows in a ``target_table`` based on the contents of the ``source_table``. +The ``search_condition`` defines a relation between the source and target tables. +When the condition is met, the target row is updated. When the condition is not met, a new row is inserted into the target table. +In the ``MATCHED`` case, the ``UPDATE`` column value expressions can depend on any field of the target or the source. +In the ``NOT MATCHED`` case, the ``INSERT`` expressions can depend on any field of the source. + +The ``MERGE`` command requires each target row to match at most one source row. An exception is raised when a single target table row matches more than one source row. +If a source row is not matched by the ``WHEN`` clause and there is no ``WHEN NOT MATCHED`` clause, the source row is ignored. + +The ``MERGE`` statement is commonly used to integrate data from two tables with different contents but similar structures. +For example, the source table could be part of a production transactional system, while the target table might be located in a data warehouse for analytics. +Regularly, MERGE operations are performed to update the analytics warehouse with the latest production data. +You can also use MERGE with tables that have different structures, as long as you can define a condition to match the rows between them. + +MERGE Command Privileges +------------------------ + +The ``MERGE`` statement does not have a dedicated privilege. Instead, executing a ``MERGE`` statement requires the privileges associated with the individual actions it performs: + +* ``UPDATE`` actions: require the ``UPDATE`` privilege on the target table columns referenced in the ``SET`` clause. +* ``INSERT`` actions: require the ``INSERT`` privilege on the target table. + +Each privilege must be granted to the user executing the ``MERGE`` command, based on the specific operations included in the statement. + +Example +------- + +Update the sales information for existing products and insert the sales information for the new products in the market. + +.. code-block:: text + + MERGE INTO product_sales AS s + USING monthly_sales AS ms + ON s.product_id = ms.product_id + WHEN MATCHED THEN + UPDATE SET + sales = sales + ms.sales + , last_sale = ms.sale_date + , current_price = ms.price + WHEN NOT MATCHED THEN + INSERT (product_id, sales, last_sale, current_price) + VALUES (ms.product_id, ms.sales, ms.sale_date, ms.price) + +Limitations +----------- + +Any connector can be used as a source table for a ``MERGE`` statement. +Only connectors which support the ``MERGE`` statement can be the target of a merge operation. +See the :doc:`connector documentation ` for more information. +The ``MERGE`` statement is currently supported only by the Iceberg connector. diff --git a/presto-docs/src/main/sphinx/sql/refresh-materialized-view.rst b/presto-docs/src/main/sphinx/sql/refresh-materialized-view.rst new file mode 100644 index 0000000000000..a8af256ab7b69 --- /dev/null +++ b/presto-docs/src/main/sphinx/sql/refresh-materialized-view.rst @@ -0,0 +1,36 @@ +========================= +REFRESH MATERIALIZED VIEW +========================= + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + + To enable, set :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\`` = ``false`` + in configuration properties. + +Synopsis +-------- + +.. code-block:: none + + REFRESH MATERIALIZED VIEW view_name + +Description +----------- + +Refresh the data stored in a materialized view by re-executing the view query against the base +tables. + +Examples +-------- + +Refresh a materialized view:: + + REFRESH MATERIALIZED VIEW daily_sales + +See Also +-------- + +:doc:`create-materialized-view`, :doc:`drop-materialized-view`, +:doc:`show-create-materialized-view`, :doc:`/admin/materialized-views` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/sql/show-columns.rst b/presto-docs/src/main/sphinx/sql/show-columns.rst index 898b233d796bb..294c3688f4bdb 100644 --- a/presto-docs/src/main/sphinx/sql/show-columns.rst +++ b/presto-docs/src/main/sphinx/sql/show-columns.rst @@ -12,4 +12,4 @@ Synopsis Description ----------- -List the columns in ``table`` along with their data type and other attributes. +List the columns in ``table`` along with their data type and other attributes such as Extra, Comment, Precision, Scale, and Length. diff --git a/presto-docs/src/main/sphinx/sql/show-create-materialized-view.rst b/presto-docs/src/main/sphinx/sql/show-create-materialized-view.rst new file mode 100644 index 0000000000000..6f56415c765aa --- /dev/null +++ b/presto-docs/src/main/sphinx/sql/show-create-materialized-view.rst @@ -0,0 +1,35 @@ +============================= +SHOW CREATE MATERIALIZED VIEW +============================= + +.. warning:: + + Materialized views are experimental. The SPI and behavior may change in future releases. + + To enable, set :ref:`admin/properties:\`\`experimental.legacy-materialized-views\`\`` = ``false`` + in configuration properties. + +Synopsis +-------- + +.. code-block:: none + + SHOW CREATE MATERIALIZED VIEW view_name + +Description +----------- + +Show the SQL statement that creates the specified materialized view. + +Examples +-------- + +Show the SQL for the ``daily_sales`` materialized view:: + + SHOW CREATE MATERIALIZED VIEW daily_sales + +See Also +-------- + +:doc:`create-materialized-view`, :doc:`drop-materialized-view`, +:doc:`refresh-materialized-view`, :doc:`/admin/materialized-views` \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/troubleshoot/query.rst b/presto-docs/src/main/sphinx/troubleshoot/query.rst index 3232454d8697d..af51240d6b040 100644 --- a/presto-docs/src/main/sphinx/troubleshoot/query.rst +++ b/presto-docs/src/main/sphinx/troubleshoot/query.rst @@ -44,4 +44,7 @@ Edit ``config.properties`` for the Presto coordinator, and set the value of the http-server.max-request-header-size=5MB -See :ref:`admin/properties:\`\`http-server.max-request-header-size\`\``. \ No newline at end of file +See :ref:`admin/properties:\`\`http-server.max-request-header-size\`\``. + +Alternatively, avoid using prepared statements for large queries. Prepared statements +place the SQL template in request headers, which can exceed the header size limit. \ No newline at end of file diff --git a/presto-druid/pom.xml b/presto-druid/pom.xml index f940d1cfd4acc..12d089e8536d5 100644 --- a/presto-druid/pom.xml +++ b/presto-druid/pom.xml @@ -4,16 +4,19 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-druid + presto-druid Presto - Druid Connector presto-plugin ${project.parent.basedir} - 2.24.3 + 2.25.3 + 17 + true @@ -26,6 +29,21 @@ + + + + org.hibernate + hibernate-validator + 8.0.3.Final + + + org.glassfish + jakarta.el + 4.0.1 + + + + org.apache.druid @@ -99,7 +117,7 @@ commons-lang commons-lang - + org.apache.logging.log4j log4j-api @@ -168,14 +186,19 @@ guava + + com.google.code.findbugs + jsr305 + + com.google.inject guice - com.google.code.findbugs - jsr305 + jakarta.annotation + jakarta.annotation-api true @@ -212,8 +235,13 @@ - javax.validation - validation-api + jakarta.inject + jakarta.inject-api + + + + jakarta.validation + jakarta.validation-api @@ -223,7 +251,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -250,7 +278,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -267,7 +295,7 @@ - io.airlift + com.facebook.airlift units provided @@ -340,4 +368,19 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.fasterxml.jackson.core:jackson-databind + javax.inject:javax.inject + + + + + diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidBrokerPageSource.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidBrokerPageSource.java index 96010fe22c17e..97058911341fb 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidBrokerPageSource.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidBrokerPageSource.java @@ -130,7 +130,7 @@ public Page getNextPage() Type type = columnTypes.get(i); BlockBuilder blockBuilder = pageBuilder.getBlockBuilder(i); JsonNode value = rootNode.get(((DruidColumnHandle) columnHandles.get(i)).getColumnName()); - if (value == null) { + if (value == null || value.isNull()) { blockBuilder.appendNull(); continue; } diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidClient.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidClient.java index ee35e7159ae7e..74e3c77c4702f 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidClient.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidClient.java @@ -24,16 +24,9 @@ import com.facebook.presto.druid.metadata.DruidSegmentInfo; import com.facebook.presto.druid.metadata.DruidTableInfo; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.base.CharMatcher; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; -import org.checkerframework.checker.nullness.qual.Nullable; import javax.inject.Inject; @@ -42,12 +35,7 @@ import java.io.InputStream; import java.io.InputStreamReader; import java.net.URI; -import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; import java.util.stream.Collectors; import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; @@ -57,20 +45,14 @@ import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; -import static com.facebook.presto.druid.DruidErrorCode.DRUID_AMBIGUOUS_OBJECT_NAME; import static com.facebook.presto.druid.DruidErrorCode.DRUID_BROKER_RESULT_ERROR; import static com.facebook.presto.druid.DruidResultFormat.OBJECT; import static com.facebook.presto.druid.DruidResultFormat.OBJECT_LINES; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; -import static com.google.common.net.MediaType.JSON_UTF_8; import static java.lang.String.format; import static java.net.HttpURLConnection.HTTP_OK; -import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; public class DruidClient { @@ -91,8 +73,6 @@ public class DruidClient private final URI druidCoordinator; private final URI druidBroker; private final String druidSchema; - protected final boolean caseInsensitiveNameMatching; - private final Cache> remoteTables; @Inject public DruidClient(DruidConfig config, @ForDruidClient HttpClient httpClient) @@ -102,44 +82,6 @@ public DruidClient(DruidConfig config, @ForDruidClient HttpClient httpClient) this.druidCoordinator = URI.create(config.getDruidCoordinatorUrl()); this.druidBroker = URI.create(config.getDruidBrokerUrl()); this.druidSchema = config.getDruidSchema(); - this.caseInsensitiveNameMatching = config.isCaseInsensitiveNameMatching(); - - Duration caseInsensitiveNameMatchingCacheTtl = requireNonNull(config.getCaseInsensitiveNameMatchingCacheTtl(), "caseInsensitiveNameMatchingCacheTtl is null"); - CacheBuilder remoteTableNamesCacheBuilder = CacheBuilder.newBuilder() - .expireAfterWrite(caseInsensitiveNameMatchingCacheTtl.toMillis(), MILLISECONDS); - this.remoteTables = remoteTableNamesCacheBuilder.build(); - } - - Optional toRemoteTable(SchemaTableName schemaTableName) - { - requireNonNull(schemaTableName, "schemaTableName is null"); - verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(schemaTableName.getTableName()), "Expected table name from internal metadata to be lowercase: %s", schemaTableName); - if (!caseInsensitiveNameMatching) { - return Optional.of(RemoteTableObject.of(schemaTableName.getTableName())); - } - - @Nullable Optional remoteTable = remoteTables.getIfPresent(schemaTableName); - if (remoteTable != null) { - return remoteTable; - } - - // Cache miss, reload the cache - Map> mapping = new HashMap<>(); - for (String table : getTables()) { - SchemaTableName cacheKey = new SchemaTableName(getSchema(), table); - mapping.merge( - cacheKey, - Optional.of(RemoteTableObject.of(table)), - (currentValue, collision) -> currentValue.map(current -> current.registerCollision(collision.get().getOnlyRemoteTableName()))); - remoteTables.put(cacheKey, mapping.get(cacheKey)); - } - - // explicitly cache if the requested table doesn't exist - if (!mapping.containsKey(schemaTableName)) { - remoteTables.put(schemaTableName, Optional.empty()); - } - - return mapping.containsKey(schemaTableName) ? mapping.get(schemaTableName) : Optional.empty(); } public URI getDruidBroker() @@ -202,7 +144,7 @@ public InputStream ingestData(DruidIngestTask ingestTask) private static Request.Builder setContentTypeHeaders(Request.Builder requestBuilder) { return requestBuilder - .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()); + .setHeader(CONTENT_TYPE, "application/json"); } private static byte[] createRequestBody(String query, DruidResultFormat resultFormat, boolean queryHeader) @@ -305,46 +247,4 @@ public String toJson() return JsonCodec.jsonCodec(DruidRequestBody.class).toJson(this); } } - - static final class RemoteTableObject - { - private final Set remoteTableNames; - - private RemoteTableObject(Set remoteTableNames) - { - this.remoteTableNames = ImmutableSet.copyOf(remoteTableNames); - } - - public static RemoteTableObject of(String remoteName) - { - return new RemoteTableObject(ImmutableSet.of(remoteName)); - } - - public RemoteTableObject registerCollision(String ambiguousName) - { - return new RemoteTableObject(ImmutableSet.builderWithExpectedSize(remoteTableNames.size() + 1) - .addAll(remoteTableNames) - .add(ambiguousName) - .build()); - } - - public String getAnyRemoteTableName() - { - return Collections.min(remoteTableNames); - } - - public String getOnlyRemoteTableName() - { - if (!isAmbiguous()) { - return getOnlyElement(remoteTableNames); - } - - throw new PrestoException(DRUID_AMBIGUOUS_OBJECT_NAME, "Found ambiguous names in Druid when looking up '" + getAnyRemoteTableName().toLowerCase(ENGLISH) + "': " + remoteTableNames); - } - - public boolean isAmbiguous() - { - return remoteTableNames.size() > 1; - } - } } diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidConfig.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidConfig.java index a5b6a49d5785e..886c182c7e3df 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidConfig.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidConfig.java @@ -15,22 +15,20 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.configuration.LegacyConfig; import com.google.common.base.Splitter; import com.google.common.base.StandardSystemProperty; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; +import jakarta.validation.constraints.NotNull; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import javax.annotation.Nullable; -import javax.validation.constraints.NotNull; import java.util.List; import java.util.Map; -import static java.util.concurrent.TimeUnit.MINUTES; - public class DruidConfig { private String coordinatorUrl; @@ -42,8 +40,11 @@ public class DruidConfig private String basicAuthenticationUsername; private String basicAuthenticationPassword; private String ingestionStoragePath = StandardSystemProperty.JAVA_IO_TMPDIR.value(); - private boolean caseInsensitiveNameMatching; - private Duration caseInsensitiveNameMatchingCacheTtl = new Duration(1, MINUTES); + private boolean caseSensitiveNameMatchingEnabled; + + private boolean tlsEnabled; + private String trustStorePath; + private String trustStorePassword; public enum DruidAuthenticationType { @@ -201,29 +202,55 @@ public DruidConfig setIngestionStoragePath(String ingestionStoragePath) return this; } - public boolean isCaseInsensitiveNameMatching() + public boolean isCaseSensitiveNameMatchingEnabled() { - return caseInsensitiveNameMatching; + return caseSensitiveNameMatchingEnabled; } - @Config("druid.case-insensitive-name-matching") - public DruidConfig setCaseInsensitiveNameMatching(boolean caseInsensitiveNameMatching) + @LegacyConfig("case-insensitive-name-matching") + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching of schema, table and column names across the connector. " + + "When disabled, names are matched case-insensitively using lowercase normalization.") + public DruidConfig setCaseSensitiveNameMatchingEnabled(boolean caseSensitiveNameMatchingEnabled) { - this.caseInsensitiveNameMatching = caseInsensitiveNameMatching; + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; return this; } - @NotNull - @MinDuration("0ms") - public Duration getCaseInsensitiveNameMatchingCacheTtl() + public boolean isTlsEnabled() + { + return tlsEnabled; + } + + @Config("druid.tls.enabled") + public DruidConfig setTlsEnabled(boolean tlsEnabled) + { + this.tlsEnabled = tlsEnabled; + return this; + } + + public String getTrustStorePath() + { + return trustStorePath; + } + + @Config("druid.tls.truststore-path") + public DruidConfig setTrustStorePath(String path) + { + this.trustStorePath = path; + return this; + } + + public String getTrustStorePassword() { - return caseInsensitiveNameMatchingCacheTtl; + return trustStorePassword; } - @Config("druid.case-insensitive-name-matching.cache-ttl") - public DruidConfig setCaseInsensitiveNameMatchingCacheTtl(Duration caseInsensitiveNameMatchingCacheTtl) + @Config("druid.tls.truststore-password") + @ConfigSecuritySensitive + public DruidConfig setTrustStorePassword(String password) { - this.caseInsensitiveNameMatchingCacheTtl = caseInsensitiveNameMatchingCacheTtl; + this.trustStorePassword = password; return this; } } diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidConnector.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidConnector.java index 109bd5e826e5d..2a638746894a3 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidConnector.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidConnector.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidErrorCode.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidErrorCode.java index 4fae4a399ef52..07770f91be311 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidErrorCode.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidErrorCode.java @@ -28,8 +28,7 @@ public enum DruidErrorCode DRUID_UNSUPPORTED_TYPE_ERROR(3, EXTERNAL), DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION(4, EXTERNAL), DRUID_QUERY_GENERATOR_FAILURE(5, EXTERNAL), - DRUID_BROKER_RESULT_ERROR(6, EXTERNAL), - DRUID_AMBIGUOUS_OBJECT_NAME(7, EXTERNAL); + DRUID_BROKER_RESULT_ERROR(6, EXTERNAL); private final ErrorCode errorCode; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidMetadata.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidMetadata.java index 824fbf1c4e80f..dfeca1ddde175 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidMetadata.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidMetadata.java @@ -13,8 +13,6 @@ */ package com.facebook.presto.druid; -import com.facebook.airlift.log.Logger; -import com.facebook.presto.druid.DruidClient.RemoteTableObject; import com.facebook.presto.druid.ingestion.DruidIngestionTableHandle; import com.facebook.presto.druid.metadata.DruidColumnInfo; import com.facebook.presto.druid.metadata.DruidColumnType; @@ -30,6 +28,7 @@ import com.facebook.presto.spi.ConnectorTableLayoutResult; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.connector.ConnectorMetadata; @@ -38,8 +37,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -48,21 +46,25 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.facebook.presto.druid.DruidTableHandle.fromSchemaTableName; +import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.String.format; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; public class DruidMetadata implements ConnectorMetadata { - private static final Logger log = Logger.get(DruidMetadata.class); - private final DruidClient druidClient; + private final DruidConfig druidConfig; @Inject - public DruidMetadata(DruidClient druidClient) + public DruidMetadata(DruidClient druidClient, DruidConfig druidConfig) { this.druidClient = requireNonNull(druidClient, "druidClient is null"); + this.druidConfig = requireNonNull(druidConfig, "druidConfig is null"); } @Override @@ -72,25 +74,29 @@ public List listSchemaNames(ConnectorSession session) } @Override - public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName schemaTableName) + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) { - String remoteTableName = druidClient.toRemoteTable(schemaTableName) - .map(RemoteTableObject::getOnlyRemoteTableName) - .orElse(schemaTableName.getTableName()); - + if (!normalizeIdentifier(session, druidClient.getSchema()).equals + (normalizeIdentifier(session, tableName.getSchemaName()))) { + throw new PrestoException(NOT_FOUND, format("Schema %s does not exist", tableName.getSchemaName())); + } return druidClient.getTables().stream() - .filter(name -> name.equals(remoteTableName)) - .map(name -> new DruidTableHandle(druidClient.getSchema(), remoteTableName, Optional.empty())) + .filter(name -> name.equals(tableName.getTableName())) + .map(name -> fromSchemaTableName(tableName)) .findFirst() .orElse(null); } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { DruidTableHandle handle = (DruidTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new DruidTableLayoutHandle(handle, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -104,7 +110,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect { DruidTableHandle druidTable = (DruidTableHandle) tableHandle; List columns = druidClient.getColumnDataType(druidTable.getTableName()).stream() - .map(column -> toColumnMetadata(column)) + .map(column -> toColumnMetadata(session, column)) .collect(toImmutableList()); return new ConnectorTableMetadata(druidTable.toSchemaTableName(), columns); @@ -113,22 +119,9 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect @Override public List listTables(ConnectorSession session, Optional schemaName) { - ImmutableList.Builder tableNames = ImmutableList.builder(); - for (String table : druidClient.getTables()) { - // Ignore ambiguous tables - boolean isAmbiguous = druidClient.toRemoteTable(new SchemaTableName(druidClient.getSchema(), table)) - .filter(RemoteTableObject::isAmbiguous) - .isPresent(); - - if (!isAmbiguous) { - tableNames.add(new SchemaTableName(druidClient.getSchema(), table)); - } - else { - log.debug("Filtered out [%s.%s] from list of tables due to ambiguous name", druidClient.getSchema(), table); - } - } - - return tableNames.build(); + return druidClient.getTables().stream() + .map(tableName -> new SchemaTableName(druidClient.getSchema(), tableName)) + .collect(toImmutableList()); } @Override @@ -145,7 +138,7 @@ public Map> listTableColumns(ConnectorSess requireNonNull(prefix, "prefix is null"); ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { - ConnectorTableMetadata tableMetadata = getTableMetadata(session, getTableHandle(session, tableName)); + ConnectorTableMetadata tableMetadata = getTableMetadata(session, fromSchemaTableName(tableName)); if (tableMetadata != null) { columns.put(tableName, tableMetadata.getColumns()); } @@ -190,18 +183,24 @@ public Optional finishCreateTable(ConnectorSession sess return Optional.empty(); } + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return druidConfig.isCaseSensitiveNameMatchingEnabled() ? identifier : identifier.toLowerCase(ROOT); + } + private List listTables(ConnectorSession session, SchemaTablePrefix prefix) { if (prefix.getTableName() == null) { - return listTables(session, Optional.of(prefix.getSchemaName())); + return listTables(session, Optional.ofNullable(prefix.getSchemaName())); } return ImmutableList.of(prefix.toSchemaTableName()); } - private static ColumnMetadata toColumnMetadata(DruidColumnInfo column) + private ColumnMetadata toColumnMetadata(ConnectorSession session, DruidColumnInfo column) { return ColumnMetadata.builder() - .setName(column.getColumnName()) + .setName(normalizeIdentifier(session, column.getColumnName())) .setType(column.getDataType().getPrestoType()) .build(); } diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPageSourceProvider.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPageSourceProvider.java index ffc710e7fb1dc..58f01cded0b1c 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPageSourceProvider.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPageSourceProvider.java @@ -30,13 +30,12 @@ import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java index 0259bbc8c2234..9b94c16d40a80 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidPlanOptimizer.java @@ -34,8 +34,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.IdentityHashMap; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java index 714bd5ddef09d..16101c5eeac63 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidQueryGenerator.java @@ -35,8 +35,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashSet; import java.util.LinkedHashMap; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidSessionProperties.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidSessionProperties.java index b717ff69a8005..e2213f93590e7 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidSessionProperties.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidSessionProperties.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/DruidSplitManager.java b/presto-druid/src/main/java/com/facebook/presto/druid/DruidSplitManager.java index be8b0a468e27e..b2bba614a628d 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/DruidSplitManager.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/DruidSplitManager.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/ForDruidClient.java b/presto-druid/src/main/java/com/facebook/presto/druid/ForDruidClient.java index 7e2161ab92565..6c22b27d53afe 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/ForDruidClient.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/ForDruidClient.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.druid; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidAuthenticationModule.java b/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidAuthenticationModule.java index 7bb42ac8e7a01..747d22befb97d 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidAuthenticationModule.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidAuthenticationModule.java @@ -14,6 +14,7 @@ package com.facebook.presto.druid.authentication; import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.airlift.http.client.HttpClientConfig; import com.facebook.presto.druid.DruidConfig; import com.facebook.presto.druid.ForDruidClient; import com.google.inject.Binder; @@ -26,6 +27,7 @@ import static com.facebook.presto.druid.DruidConfig.DruidAuthenticationType.BASIC; import static com.facebook.presto.druid.DruidConfig.DruidAuthenticationType.KERBEROS; import static com.facebook.presto.druid.DruidConfig.DruidAuthenticationType.NONE; +import static java.util.Objects.requireNonNull; public class DruidAuthenticationModule extends AbstractConfigurationAwareModule @@ -33,17 +35,19 @@ public class DruidAuthenticationModule @Override protected void setup(Binder binder) { + DruidConfig druidConfig = buildConfigObject(DruidConfig.class); + bindAuthenticationModule( config -> config.getDruidAuthenticationType() == NONE, - noneAuthenticationModule()); + noneAuthenticationModule(druidConfig)); bindAuthenticationModule( config -> config.getDruidAuthenticationType() == BASIC, - basicAuthenticationModule()); + basicAuthenticationModule(druidConfig)); bindAuthenticationModule( config -> config.getDruidAuthenticationType() == KERBEROS, - kerberosbAuthenticationModule()); + kerberosbAuthenticationModule(druidConfig)); } private void bindAuthenticationModule(Predicate predicate, Module module) @@ -51,24 +55,40 @@ private void bindAuthenticationModule(Predicate predicate, Module m install(installModuleIf(DruidConfig.class, predicate, module)); } - private static Module noneAuthenticationModule() + private static Module noneAuthenticationModule(DruidConfig druidConfig) { - return binder -> httpClientBinder(binder).bindHttpClient("druid-client", ForDruidClient.class); + return binder -> httpClientBinder(binder).bindHttpClient("druid-client", ForDruidClient.class) + .withConfigDefaults(config -> applyTlsConfig(config, druidConfig)); } - private static Module basicAuthenticationModule() + private static Module basicAuthenticationModule(DruidConfig druidConfig) { return binder -> httpClientBinder(binder).bindHttpClient("druid-client", ForDruidClient.class) .withConfigDefaults( - config -> config.setAuthenticationEnabled(false) //disable Kerberos auth + config -> { + config.setAuthenticationEnabled(false); //disable Kerberos auth + applyTlsConfig(config, druidConfig); + } ).withFilter( DruidBasicAuthHttpRequestFilter.class); } - private static Module kerberosbAuthenticationModule() + private static Module kerberosbAuthenticationModule(DruidConfig druidConfig) { return binder -> httpClientBinder(binder).bindHttpClient("druid-client", ForDruidClient.class) .withConfigDefaults( - config -> config.setAuthenticationEnabled(true)); + config -> { + config.setAuthenticationEnabled(true); + applyTlsConfig(config, druidConfig); + }); + } + + private static void applyTlsConfig(HttpClientConfig config, DruidConfig druidConfig) + { + if (druidConfig.isTlsEnabled()) { + requireNonNull(druidConfig.getTrustStorePath(), "druid.tls.truststore-path is null"); + config.setTrustStorePath(druidConfig.getTrustStorePath()); + config.setTrustStorePassword(druidConfig.getTrustStorePassword()); // Not adding null check for truststore password as it can be null for self-signed certs + } } } diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidBasicAuthHttpRequestFilter.java b/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidBasicAuthHttpRequestFilter.java index 34ffbcc78d534..64be14f787e67 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidBasicAuthHttpRequestFilter.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/authentication/DruidBasicAuthHttpRequestFilter.java @@ -17,8 +17,7 @@ import com.facebook.airlift.http.client.HttpRequestFilter; import com.facebook.airlift.http.client.Request; import com.facebook.presto.druid.DruidConfig; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageSinkProvider.java b/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageSinkProvider.java index 34f32088e993c..6e0d430fd39e5 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageSinkProvider.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageSinkProvider.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageWriter.java b/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageWriter.java index b12edc35cd203..8c0e308ded35a 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageWriter.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/ingestion/DruidPageWriter.java @@ -21,13 +21,12 @@ import com.facebook.presto.spi.PrestoException; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.io.IOException; import java.util.UUID; import java.util.zip.GZIPOutputStream; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/metadata/DruidSegmentInfo.java b/presto-druid/src/main/java/com/facebook/presto/druid/metadata/DruidSegmentInfo.java index ca38860637b1c..f0126630fd1a3 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/metadata/DruidSegmentInfo.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/metadata/DruidSegmentInfo.java @@ -16,8 +16,7 @@ import com.facebook.presto.spi.PrestoException; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.UnsupportedEncodingException; import java.net.URI; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileData.java b/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileData.java index d7b1de996ba58..2ee4ceb5ffbcb 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileData.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileData.java @@ -28,8 +28,7 @@ import com.facebook.presto.druid.DataInputSource; import com.facebook.presto.spi.PrestoException; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.IOException; import java.nio.charset.Charset; diff --git a/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileEntry.java b/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileEntry.java index 4242d14d6619a..0ccebe0b562d7 100644 --- a/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileEntry.java +++ b/presto-druid/src/main/java/com/facebook/presto/druid/zip/ZipFileEntry.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.druid.zip; -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.EnumSet; diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/DruidQueryRunner.java b/presto-druid/src/test/java/com/facebook/presto/druid/DruidQueryRunner.java index 35a05f00c5a6a..1793da7cbbd4d 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/DruidQueryRunner.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/DruidQueryRunner.java @@ -18,7 +18,6 @@ import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableMap; -import java.util.HashMap; import java.util.Map; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -36,17 +35,17 @@ private DruidQueryRunner() {} private static String broker = "http://localhost:8082"; private static String coordinator = "http://localhost:8081"; - public static DistributedQueryRunner createDruidQueryRunner(Map connectorProperties) + public static DistributedQueryRunner createDruidQueryRunner() throws Exception { DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(createSession()).build(); try { queryRunner.installPlugin(new DruidPlugin()); - connectorProperties = new HashMap<>(ImmutableMap.copyOf(connectorProperties)); - connectorProperties.putIfAbsent("druid.coordinator-url", coordinator); - connectorProperties.putIfAbsent("druid.broker-url", broker); - - queryRunner.createCatalog(DEFAULT_CATALOG, "druid", connectorProperties); + Map properties = ImmutableMap.builder() + .put("druid.coordinator-url", coordinator) + .put("druid.broker-url", broker) + .build(); + queryRunner.createCatalog(DEFAULT_CATALOG, "druid", properties); return queryRunner; } catch (Exception e) { @@ -67,7 +66,7 @@ public static Session createSession() public static void main(String[] args) throws Exception { - DistributedQueryRunner queryRunner = createDruidQueryRunner(ImmutableMap.of()); + DistributedQueryRunner queryRunner = createDruidQueryRunner(); log.info(format("Presto server started: %s", queryRunner.getCoordinator().getBaseUrl())); } } diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidConfig.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidConfig.java index 5b5b90043ab32..aedbfcaeb1819 100644 --- a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidConfig.java +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidConfig.java @@ -17,7 +17,6 @@ import com.google.common.base.StandardSystemProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -26,8 +25,6 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; import static com.facebook.presto.druid.DruidConfig.DruidAuthenticationType.BASIC; import static com.facebook.presto.druid.DruidConfig.DruidAuthenticationType.NONE; -import static java.util.concurrent.TimeUnit.MINUTES; -import static java.util.concurrent.TimeUnit.SECONDS; public class TestDruidConfig { @@ -44,16 +41,18 @@ public void testDefaults() .setBasicAuthenticationUsername(null) .setBasicAuthenticationPassword(null) .setIngestionStoragePath(StandardSystemProperty.JAVA_IO_TMPDIR.value()) - .setCaseInsensitiveNameMatching(false) - .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, MINUTES))); + .setCaseSensitiveNameMatchingEnabled(false) + .setTlsEnabled(false) + .setTrustStorePath(null) + .setTrustStorePassword(null)); } @Test public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() - .put("druid.broker-url", "http://druid.broker:1234") - .put("druid.coordinator-url", "http://druid.coordinator:4321") + .put("druid.broker-url", "https://druid.broker:1234") + .put("druid.coordinator-url", "https://druid.coordinator:4321") .put("druid.schema-name", "test") .put("druid.compute-pushdown-enabled", "true") .put("druid.hadoop.config.resources", "/etc/core-site.xml,/etc/hdfs-site.xml") @@ -61,13 +60,15 @@ public void testExplicitPropertyMappings() .put("druid.basic.authentication.username", "http_basic_username") .put("druid.basic.authentication.password", "http_basic_password") .put("druid.ingestion.storage.path", "hdfs://foo/bar/") - .put("druid.case-insensitive-name-matching", "true") - .put("druid.case-insensitive-name-matching.cache-ttl", "1s") + .put("case-sensitive-name-matching", "true") + .put("druid.tls.enabled", "true") + .put("druid.tls.truststore-path", "/tmp/truststore") + .put("druid.tls.truststore-password", "truststore-password") .build(); DruidConfig expected = new DruidConfig() - .setDruidBrokerUrl("http://druid.broker:1234") - .setDruidCoordinatorUrl("http://druid.coordinator:4321") + .setDruidBrokerUrl("https://druid.broker:1234") + .setDruidCoordinatorUrl("https://druid.coordinator:4321") .setDruidSchema("test") .setComputePushdownEnabled(true) .setHadoopConfiguration(ImmutableList.of("/etc/core-site.xml", "/etc/hdfs-site.xml")) @@ -75,8 +76,10 @@ public void testExplicitPropertyMappings() .setBasicAuthenticationUsername("http_basic_username") .setBasicAuthenticationPassword("http_basic_password") .setIngestionStoragePath("hdfs://foo/bar/") - .setCaseInsensitiveNameMatching(true) - .setCaseInsensitiveNameMatchingCacheTtl(new Duration(1, SECONDS)); + .setCaseSensitiveNameMatchingEnabled(true) + .setTlsEnabled(true) + .setTrustStorePath(("/tmp/truststore")) + .setTrustStorePassword("truststore-password"); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPageSourceNullHandling.java b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPageSourceNullHandling.java new file mode 100644 index 0000000000000..8be69063be38a --- /dev/null +++ b/presto-druid/src/test/java/com/facebook/presto/druid/TestDruidPageSourceNullHandling.java @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.druid; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpStatus; +import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.http.client.testing.TestingResponse; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.spi.ColumnHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static org.testng.Assert.assertTrue; + +public class TestDruidPageSourceNullHandling +{ + @Test + public void testNullAndMissingColumns() + { + String jsonRows = + "{\"region.Id\":1,\"city\":\"Boston\",\"fare\":10.0}\n" + + "{\"region.Id\":2,\"city\":null,\"fare\":20.0}\n" + // city column is having null value + "{\"region.Id\":3,\"fare\":30.0}\n" + // missing city column + "\n"; + + ListMultimap headers = ImmutableListMultimap.of( + "Content-Type", "application/json"); + TestingResponse response = new TestingResponse( + HttpStatus.OK, + headers, + jsonRows.getBytes(StandardCharsets.UTF_8)); + HttpClient httpClient = new TestingHttpClient(request -> response); + + DruidConfig druidConfig = new DruidConfig() + .setDruidSchema("default") + .setDruidCoordinatorUrl("http://localhost:8081") + .setDruidBrokerUrl("http://localhost:8082"); + + ImmutableList columnHandles = ImmutableList.of(new DruidColumnHandle("region.Id", BIGINT), + new DruidColumnHandle("city", VARCHAR), + new DruidColumnHandle("fare", DOUBLE)); + + DruidBrokerPageSource pageSource = new DruidBrokerPageSource( + new DruidQueryGenerator.GeneratedDql("testTable", "SELECT region.Id, city, fare FROM test", true), + columnHandles, + new DruidClient(druidConfig, httpClient)); + + Page page; + boolean foundNull = false; + boolean foundMissing = false; + + while ((page = pageSource.getNextPage()) != null) { + Block cityBlock = page.getBlock(1); + for (int i = 0; i < cityBlock.getPositionCount(); i++) { + if (cityBlock.isNull(i)) { + if (i == 1) { + foundNull = true; // row with "city":null + } + if (i == 2) { + foundMissing = true; // row missing "city" + } + } + } + } + + assertTrue(foundNull, "Expected null value in column 'city'"); + assertTrue(foundMissing, "Expected missing column to be treated as null"); + } +} diff --git a/presto-elasticsearch/pom.xml b/presto-elasticsearch/pom.xml index 23087f7310adc..31c776a5f9092 100644 --- a/presto-elasticsearch/pom.xml +++ b/presto-elasticsearch/pom.xml @@ -4,16 +4,19 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-elasticsearch + presto-elasticsearch Presto - Elasticsearch Connector presto-plugin ${project.parent.basedir} 7.17.27 - 2.24.3 + 2.25.3 + true + 8.11.3 @@ -69,18 +72,18 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -191,7 +194,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -203,7 +206,7 @@ - io.airlift + com.facebook.airlift units provided @@ -223,7 +226,6 @@ org.jetbrains annotations - 19.0.0 test @@ -328,8 +330,8 @@ - javax.servlet - javax.servlet-api + jakarta.servlet + jakarta.servlet-api test @@ -383,6 +385,9 @@ org.apache.maven.plugins maven-dependency-plugin + + org.elasticsearch:elasticsearch-x-content + org.yaml:snakeyaml:jar diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/AwsSecurityConfig.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/AwsSecurityConfig.java index ebe350f8c5347..ad4dc9de3599f 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/AwsSecurityConfig.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/AwsSecurityConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.elasticsearch; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.util.Optional; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConfig.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConfig.java index 43cfd6ba33d8a..6521a3644abc8 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConfig.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConfig.java @@ -16,11 +16,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.Optional; @@ -57,6 +56,7 @@ public enum Security private boolean ignorePublishAddress; private boolean verifyHostnames = true; private Security security; + private boolean caseSensitiveNameMatching; @NotNull public String getHost() @@ -324,4 +324,17 @@ public ElasticsearchConfig setSecurity(Security security) this.security = security; return this; } + + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatching; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching. When disabled, names are matched case-insensitively using lowercase normalization.") + public ElasticsearchConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatching) + { + this.caseSensitiveNameMatching = caseSensitiveNameMatching; + return this; + } } diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnector.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnector.java index 00666829bddf7..6c4ba99c8ccad 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnector.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnector.java @@ -23,8 +23,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Set; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnectorModule.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnectorModule.java index 4cd973f38793e..8358adffa92d7 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnectorModule.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchConnectorModule.java @@ -22,8 +22,7 @@ import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; import com.google.inject.Binder; import com.google.inject.Scopes; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.airlift.configuration.ConditionalModule.installModuleIf; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchMetadata.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchMetadata.java index 6217ce9ce1e48..88e1e4eeba006 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchMetadata.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchMetadata.java @@ -45,8 +45,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.BaseEncoding; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -87,6 +86,7 @@ public class ElasticsearchMetadata private final ElasticsearchClient client; private final String schemaName; private final Type ipAddressType; + private final boolean caseSensitiveNameMatching; @Inject public ElasticsearchMetadata(TypeManager typeManager, ElasticsearchClient client, ElasticsearchConfig config) @@ -96,6 +96,7 @@ public ElasticsearchMetadata(TypeManager typeManager, ElasticsearchClient client this.client = requireNonNull(client, "client is null"); requireNonNull(config, "config is null"); this.schemaName = config.getDefaultSchema(); + this.caseSensitiveNameMatching = config.isCaseSensitiveNameMatching(); Type jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON)); queryResultColumnMetadata = ColumnMetadata.builder() @@ -159,11 +160,15 @@ public ElasticsearchTableHandle getTableHandle(ConnectorSession session, SchemaT } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { ElasticsearchTableHandle handle = (ElasticsearchTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new ElasticsearchTableLayoutHandle(handle, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -182,39 +187,39 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect new SchemaTableName(handle.getSchema(), handle.getIndex()), ImmutableList.of(queryResultColumnMetadata)); } - return getTableMetadata(handle.getSchema(), handle.getIndex()); + return getTableMetadata(session, handle.getSchema(), handle.getIndex()); } - private ConnectorTableMetadata getTableMetadata(String schemaName, String tableName) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, String schemaName, String tableName) { - InternalTableMetadata internalTableMetadata = makeInternalTableMetadata(schemaName, tableName); + InternalTableMetadata internalTableMetadata = makeInternalTableMetadata(session, schemaName, tableName); return new ConnectorTableMetadata(new SchemaTableName(schemaName, tableName), internalTableMetadata.getColumnMetadata()); } - private InternalTableMetadata makeInternalTableMetadata(ConnectorTableHandle table) + private InternalTableMetadata makeInternalTableMetadata(ConnectorSession session, ConnectorTableHandle table) { ElasticsearchTableHandle handle = (ElasticsearchTableHandle) table; - return makeInternalTableMetadata(handle.getSchema(), handle.getIndex()); + return makeInternalTableMetadata(session, handle.getSchema(), handle.getIndex()); } - private InternalTableMetadata makeInternalTableMetadata(String schema, String tableName) + private InternalTableMetadata makeInternalTableMetadata(ConnectorSession session, String schema, String tableName) { IndexMetadata metadata = client.getIndexMetadata(tableName); - List fields = getColumnFields(metadata); - return new InternalTableMetadata(new SchemaTableName(schema, tableName), makeColumnMetadata(fields), makeColumnHandles(fields)); + List fields = getColumnFields(session, metadata); + return new InternalTableMetadata(new SchemaTableName(schema, tableName), makeColumnMetadata(session, fields), makeColumnHandles(fields)); } - private List getColumnFields(IndexMetadata metadata) + private List getColumnFields(ConnectorSession session, IndexMetadata metadata) { ImmutableList.Builder result = ImmutableList.builder(); Map counts = metadata.getSchema() .getFields().stream() - .collect(Collectors.groupingBy(f -> f.getName().toLowerCase(ENGLISH), Collectors.counting())); + .collect(Collectors.groupingBy(f -> normalizeIdentifier(session, f.getName()), Collectors.counting())); for (IndexMetadata.Field field : metadata.getSchema().getFields()) { Type type = toPrestoType(field); - if (type == null || counts.get(field.getName().toLowerCase(ENGLISH)) > 1) { + if (type == null || counts.get(normalizeIdentifier(session, field.getName())) > 1) { continue; } result.add(field); @@ -222,7 +227,7 @@ private List getColumnFields(IndexMetadata metadata) return result.build(); } - private List makeColumnMetadata(List fields) + private List makeColumnMetadata(ConnectorSession session, List fields) { ImmutableList.Builder result = ImmutableList.builder(); @@ -231,7 +236,7 @@ private List makeColumnMetadata(List fields } for (IndexMetadata.Field field : fields) { - result.add(ColumnMetadata.builder().setName(field.getName()).setType(toPrestoType(field)).build()); + result.add(ColumnMetadata.builder().setName(normalizeIdentifier(session, field.getName())).setType(toPrestoType(field)).build()); } return result.build(); } @@ -374,7 +379,7 @@ public Map getColumnHandles(ConnectorSession session, Conn return queryTableColumns; } - InternalTableMetadata tableMetadata = makeInternalTableMetadata(tableHandle); + InternalTableMetadata tableMetadata = makeInternalTableMetadata(session, tableHandle); return tableMetadata.getColumnHandles(); } @@ -411,15 +416,21 @@ public Map> listTableColumns(ConnectorSess } if (prefix.getSchemaName() != null && prefix.getTableName() != null) { - ConnectorTableMetadata metadata = getTableMetadata(prefix.getSchemaName(), prefix.getTableName()); + ConnectorTableMetadata metadata = getTableMetadata(session, prefix.getSchemaName(), prefix.getTableName()); return ImmutableMap.of(metadata.getTable(), metadata.getColumns()); } return listTables(session, prefix.getSchemaName()).stream() - .map(name -> getTableMetadata(name.getSchemaName(), name.getTableName())) + .map(name -> getTableMetadata(session, name.getSchemaName(), name.getTableName())) .collect(toImmutableMap(ConnectorTableMetadata::getTable, ConnectorTableMetadata::getColumns)); } + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatching ? identifier : identifier.toLowerCase(ENGLISH); + } + private static class InternalTableMetadata { private final SchemaTableName tableName; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchPageSourceProvider.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchPageSourceProvider.java index c1b67b4178712..e59a11feb73d7 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchPageSourceProvider.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchPageSourceProvider.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchSplitManager.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchSplitManager.java index be0c5d48ddb1b..4c4e1a0caa0a6 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchSplitManager.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/ElasticsearchSplitManager.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/NodesSystemTable.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/NodesSystemTable.java index 42265df65512c..69f94bda2b35f 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/NodesSystemTable.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/NodesSystemTable.java @@ -29,8 +29,7 @@ import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Set; diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/PasswordConfig.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/PasswordConfig.java index d2def0d8a3e40..159b4f304b4f6 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/PasswordConfig.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/PasswordConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigSecuritySensitive; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class PasswordConfig { diff --git a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/client/ElasticsearchClient.java b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/client/ElasticsearchClient.java index 42a0e4a69680b..fef600e26e725 100644 --- a/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/client/ElasticsearchClient.java +++ b/presto-elasticsearch/src/main/java/com/facebook/presto/elasticsearch/client/ElasticsearchClient.java @@ -22,6 +22,7 @@ import com.facebook.airlift.json.JsonObjectMapperProvider; import com.facebook.airlift.log.Logger; import com.facebook.airlift.security.pem.PemReader; +import com.facebook.airlift.units.Duration; import com.facebook.presto.elasticsearch.AwsSecurityConfig; import com.facebook.presto.elasticsearch.ElasticsearchConfig; import com.facebook.presto.elasticsearch.PasswordConfig; @@ -33,7 +34,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.apache.http.HttpEntity; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; @@ -62,9 +65,6 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; diff --git a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchLoader.java b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchLoader.java index 5a0e736c70c44..3c7d62d7a7567 100644 --- a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchLoader.java +++ b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchLoader.java @@ -119,7 +119,7 @@ public void addResults(QueryStatusInfo statusInfo, QueryData data) } @Override - public Void build(Map setSessionProperties, Set resetSessionProperties) + public Void build(Map setSessionProperties, Set resetSessionProperties, String startTransactionId, boolean clearTransactionId) { return null; } diff --git a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchQueryRunner.java b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchQueryRunner.java index 4619ba865073d..f59a061f08c6a 100644 --- a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchQueryRunner.java +++ b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/ElasticsearchQueryRunner.java @@ -31,9 +31,9 @@ import java.util.Map; import static com.facebook.airlift.testing.Closeables.closeAllSuppress; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.concurrent.TimeUnit.SECONDS; @@ -95,21 +95,23 @@ private static void installElasticsearchPlugin( Map extraConnectorProperties) { queryRunner.installPlugin(new ElasticsearchPlugin(factory)); - Map config = ImmutableMap.builder() + ImmutableMap.Builder config = ImmutableMap.builder() .put("elasticsearch.host", address.getHost()) .put("elasticsearch.port", Integer.toString(address.getPort())) // Node discovery relies on the publish_address exposed via the Elasticseach API // This doesn't work well within a docker environment that maps ES's port to a random public port .put("elasticsearch.ignore-publish-address", "true") - .put("elasticsearch.default-schema-name", TPCH_SCHEMA) .put("elasticsearch.scroll-size", "1000") .put("elasticsearch.scroll-timeout", "1m") .put("elasticsearch.max-hits", "1000000") .put("elasticsearch.request-timeout", "2m") - .putAll(extraConnectorProperties) - .build(); + .putAll(extraConnectorProperties); + if (!extraConnectorProperties.containsKey("elasticsearch.default-schema-name")) { + config.put("elasticsearch.default-schema-name", TPCH_SCHEMA); + } + Map newconfig = config.build(); - queryRunner.createCatalog("elasticsearch", "elasticsearch", config); + queryRunner.createCatalog("elasticsearch", "elasticsearch", newconfig); } private static void loadTpchTopic(RestHighLevelClient client, TestingPrestoClient prestoClient, TpchTable table) diff --git a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchConfig.java b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchConfig.java index 0b4acfc9edf43..8d50ea21b94bc 100644 --- a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchConfig.java +++ b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.elasticsearch; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.File; @@ -52,7 +52,8 @@ public void testDefaults() .setTruststorePassword(null) .setVerifyHostnames(true) .setIgnorePublishAddress(false) - .setSecurity(null)); + .setSecurity(null) + .setCaseSensitiveNameMatching(false)); } @Test @@ -79,6 +80,7 @@ public void testExplicitPropertyMappings() .put("elasticsearch.tls.verify-hostnames", "false") .put("elasticsearch.ignore-publish-address", "true") .put("elasticsearch.security", "AWS") + .put("case-sensitive-name-matching", "true") .build(); ElasticsearchConfig expected = new ElasticsearchConfig() @@ -101,7 +103,8 @@ public void testExplicitPropertyMappings() .setTruststorePassword("truststore-password") .setVerifyHostnames(false) .setIgnorePublishAddress(true) - .setSecurity(AWS); + .setSecurity(AWS) + .setCaseSensitiveNameMatching(true); assertFullMapping(properties, expected); } diff --git a/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchMixedCaseTest.java b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchMixedCaseTest.java new file mode 100644 index 0000000000000..b877af191bc36 --- /dev/null +++ b/presto-elasticsearch/src/test/java/com/facebook/presto/elasticsearch/TestElasticsearchMixedCaseTest.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.elasticsearch; + +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.HostAndPort; +import io.airlift.tpch.TpchTable; +import org.apache.http.HttpHost; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.RestClient; +import org.elasticsearch.client.RestHighLevelClient; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.util.Map; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.elasticsearch.ElasticsearchQueryRunner.createElasticsearchQueryRunner; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.tests.QueryAssertions.assertContains; +import static org.elasticsearch.client.RequestOptions.DEFAULT; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +@Test +public class TestElasticsearchMixedCaseTest + extends AbstractTestQueryFramework +{ + private final String elasticsearchServer = "docker.elastic.co/elasticsearch/elasticsearch:7.17.27"; + private ElasticsearchServer elasticsearch; + private RestHighLevelClient client; + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + elasticsearch = new ElasticsearchServer(elasticsearchServer, ImmutableMap.of()); + HostAndPort address = elasticsearch.getAddress(); + client = new RestHighLevelClient(RestClient.builder(new HttpHost(address.getHost(), address.getPort()))); + + return createElasticsearchQueryRunner(elasticsearch.getAddress(), + TpchTable.getTables(), + ImmutableMap.of(), + ImmutableMap.of("case-sensitive-name-matching", "true", "elasticsearch.default-schema-name", "MySchema")); + } + + @AfterClass(alwaysRun = true) + public final void destroy() + throws IOException + { + elasticsearch.stop(); + client.close(); + } + private void index(String index, Map document) + throws IOException + { + client.index(new IndexRequest(index, "_doc") + .source(document) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), DEFAULT); + } + + @Test + public void testShowColumns() + throws IOException + { + String indexName = "mixed_case"; + index(indexName, ImmutableMap.builder() + .put("NAME", "JOHN") + .put("Profession", "Developer") + .put("id", 2) + .put("name", "john") + .build()); + + MaterializedResult actual = computeActual("SHOW columns FROM MySchema.mixed_case"); + assertEquals(actual.getMaterializedRows().get(0).getField(0), "NAME"); + assertEquals(actual.getMaterializedRows().get(1).getField(0), "Profession"); + assertEquals(actual.getMaterializedRows().get(2).getField(0), "id"); + assertEquals(actual.getMaterializedRows().get(3).getField(0), "name"); + } + + @Test + public void testSelect() + throws IOException + { + String indexName = "mixed_case_select"; + index(indexName, ImmutableMap.builder() + .put("NAME", "JOHN") + .put("Profession", "Developer") + .put("name", "john") + .build()); + + MaterializedResult actualRow = computeActual("SELECT * from MySchema.mixed_case_select"); + MaterializedResult expectedRow = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR) + .row("JOHN", "Developer", "john") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + @Test + public void testSchema() + { + MaterializedResult actualRow = computeActual("SHOW schemas from elasticsearch"); + MaterializedResult expectedRow = resultBuilder(getSession(), VARCHAR) + .row("MySchema") + .build(); + assertContains(actualRow, expectedRow); + } +} diff --git a/presto-example-http/pom.xml b/presto-example-http/pom.xml index a2c847a6abf87..dbd812da0e411 100644 --- a/presto-example-http/pom.xml +++ b/presto-example-http/pom.xml @@ -4,15 +4,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-example-http + presto-example-http Presto - Example HTTP Connector presto-plugin ${project.parent.basedir} + true @@ -47,8 +49,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api + + + + jakarta.inject + jakarta.inject-api @@ -75,7 +82,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -87,7 +94,7 @@ - io.airlift + com.facebook.airlift units provided @@ -136,9 +143,23 @@ - javax.servlet - javax.servlet-api + jakarta.servlet + jakarta.servlet-api test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + javax.inject:javax.inject + + + + + diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleClient.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleClient.java index 07c7b1bc1367d..f4d9bbcd15c6d 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleClient.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleClient.java @@ -21,8 +21,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Resources; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConfig.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConfig.java index 7ef582d3195a0..49730aa053385 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConfig.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.example; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.net.URI; diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConnector.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConnector.java index 64bf524813eb4..3e539575cb997 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConnector.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleConnector.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.example.ExampleTransactionHandle.INSTANCE; import static java.util.Objects.requireNonNull; diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleMetadata.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleMetadata.java index b67bbcdde8afa..9b6f4f2c504b9 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleMetadata.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleMetadata.java @@ -29,8 +29,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -38,6 +37,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class ExampleMetadata @@ -81,11 +81,15 @@ public ExampleTableHandle getTableHandle(ConnectorSession session, SchemaTableNa } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { ExampleTableHandle tableHandle = (ExampleTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new ExampleTableLayoutHandle(tableHandle)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -101,7 +105,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect checkArgument(exampleTableHandle.getConnectorId().equals(connectorId), "tableHandle is not for this connector"); SchemaTableName tableName = new SchemaTableName(exampleTableHandle.getSchemaName(), exampleTableHandle.getTableName()); - return getTableMetadata(tableName); + return getTableMetadata(session, tableName); } @Override @@ -137,20 +141,25 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder columnHandles = ImmutableMap.builder(); int index = 0; - for (ColumnMetadata column : table.getColumnsMetadata()) { + List columns = table.getColumnsMetadata().stream() + .map(column -> column.toBuilder() + .setName(normalizeIdentifier(session, column.getName())) + .build()) + .collect(toImmutableList()); + + for (ColumnMetadata column : columns) { columnHandles.put(column.getName(), new ExampleColumnHandle(connectorId, column.getName(), column.getType(), index)); index++; } return columnHandles.build(); } - @Override public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) { requireNonNull(prefix, "prefix is null"); ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { - ConnectorTableMetadata tableMetadata = getTableMetadata(tableName); + ConnectorTableMetadata tableMetadata = getTableMetadata(session, tableName); // table can disappear during listing operation if (tableMetadata != null) { columns.put(tableName, tableMetadata.getColumns()); @@ -159,7 +168,7 @@ public Map> listTableColumns(ConnectorSess return columns.build(); } - private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName tableName) { if (!listSchemaNames().contains(tableName.getSchemaName())) { return null; @@ -169,8 +178,13 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) if (table == null) { return null; } + List columns = table.getColumnsMetadata().stream() + .map(column -> column.toBuilder() + .setName(normalizeIdentifier(session, column.getName())) + .build()) + .collect(toImmutableList()); - return new ConnectorTableMetadata(tableName, table.getColumnsMetadata()); + return new ConnectorTableMetadata(tableName, columns); } private List listTables(ConnectorSession session, SchemaTablePrefix prefix) diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleModule.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleModule.java index 0b8409342203c..dae71956efec1 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleModule.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleModule.java @@ -20,8 +20,7 @@ import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Scopes; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.airlift.json.JsonBinder.jsonBinder; diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleRecordSetProvider.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleRecordSetProvider.java index 7c4b3e8ab54da..f2b2f5c3c99b5 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleRecordSetProvider.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleRecordSetProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleSplitManager.java b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleSplitManager.java index d0b4142c34259..ec7a51f2d80d1 100644 --- a/presto-example-http/src/main/java/com/facebook/presto/example/ExampleSplitManager.java +++ b/presto-example-http/src/main/java/com/facebook/presto/example/ExampleSplitManager.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.net.URI; import java.util.ArrayList; diff --git a/presto-example-http/src/test/java/com/facebook/presto/example/ExampleHttpServer.java b/presto-example-http/src/test/java/com/facebook/presto/example/ExampleHttpServer.java index 62c47f670d74e..679464ceec072 100644 --- a/presto-example-http/src/test/java/com/facebook/presto/example/ExampleHttpServer.java +++ b/presto-example-http/src/test/java/com/facebook/presto/example/ExampleHttpServer.java @@ -25,11 +25,10 @@ import com.google.inject.Injector; import com.google.inject.Module; import com.google.inject.TypeLiteral; - -import javax.servlet.Servlet; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.net.URI; diff --git a/presto-expressions/pom.xml b/presto-expressions/pom.xml index 22c0b8203ecd5..ada4f99dc69c2 100644 --- a/presto-expressions/pom.xml +++ b/presto-expressions/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-expressions @@ -13,6 +13,8 @@ ${project.parent.basedir} + 8 + true diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/DynamicFilters.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/DynamicFilters.java index fce21bb6f52e4..fec5ea16fd44f 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/DynamicFilters.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/DynamicFilters.java @@ -36,10 +36,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import static com.facebook.presto.common.function.OperatorType.EQUAL; -import static com.facebook.presto.common.function.OperatorType.GREATER_THAN; -import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; -import static com.facebook.presto.common.function.OperatorType.LESS_THAN; -import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL; import static com.facebook.presto.common.type.StandardTypes.BOOLEAN; import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; diff --git a/presto-file-session-property-manager/pom.xml b/presto-file-session-property-manager/pom.xml new file mode 100644 index 0000000000000..43de710ff62e4 --- /dev/null +++ b/presto-file-session-property-manager/pom.xml @@ -0,0 +1,139 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.297-edge10.1-SNAPSHOT + + + presto-file-session-property-manager + presto-file-session-property-manager + Presto - File Session Property Manager + presto-plugin + + + ${project.parent.basedir} + + + + + com.facebook.presto + presto-session-property-managers-common + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + json + + + + com.facebook.airlift + configuration + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + javax.inject + javax.inject + + + + jakarta.inject + jakarta.inject-api + + + + jakarta.validation + jakarta.validation-api + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-core + + + + + com.facebook.presto + presto-spi + provided + + + + com.facebook.presto + presto-common + provided + + + + com.facebook.airlift + units + provided + + + + io.airlift + slice + provided + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + org.openjdk.jol + jol-core + provided + + + + + com.facebook.presto + presto-testng-services + test + + + + org.testng + testng + test + + + + com.facebook.airlift + testing + test + + + + com.facebook.presto + presto-session-property-managers-common + test-jar + test + + + diff --git a/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyConfigurationManagerPlugin.java b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyConfigurationManagerPlugin.java new file mode 100644 index 0000000000000..573c62295572c --- /dev/null +++ b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyConfigurationManagerPlugin.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.session.file; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; +import com.google.common.collect.ImmutableList; + +public class FileSessionPropertyConfigurationManagerPlugin + implements Plugin +{ + @Override + public Iterable getSessionPropertyConfigurationManagerFactories() + { + return ImmutableList.of( + new FileSessionPropertyManagerFactory()); + } +} diff --git a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManager.java b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManager.java similarity index 59% rename from presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManager.java rename to presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManager.java index d6402324050bf..6f559a00f2b91 100644 --- a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManager.java +++ b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManager.java @@ -11,41 +11,36 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.file; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonCodecFactory; import com.facebook.airlift.json.JsonObjectMapperProvider; -import com.facebook.presto.spi.session.SessionConfigurationContext; -import com.facebook.presto.spi.session.SessionPropertyConfigurationManager; +import com.facebook.presto.session.AbstractSessionPropertyManager; +import com.facebook.presto.session.SessionMatchSpec; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException; -import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; -import java.util.Map; -import java.util.Set; import static com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class FileSessionPropertyManager - implements SessionPropertyConfigurationManager + extends AbstractSessionPropertyManager { public static final JsonCodec> CODEC = new JsonCodecFactory( () -> new JsonObjectMapperProvider().get().enable(FAIL_ON_UNKNOWN_PROPERTIES)) .listJsonCodec(SessionMatchSpec.class); - private final List sessionMatchSpecs; + private final ImmutableList sessionMatchSpecs; @Inject public FileSessionPropertyManager(FileSessionPropertyManagerConfig config) @@ -54,7 +49,7 @@ public FileSessionPropertyManager(FileSessionPropertyManagerConfig config) Path configurationFile = config.getConfigFile().toPath(); try { - sessionMatchSpecs = CODEC.fromJson(Files.readAllBytes(configurationFile)); + sessionMatchSpecs = ImmutableList.copyOf(CODEC.fromJson(Files.readAllBytes(configurationFile))); } catch (IOException e) { throw new UncheckedIOException(e); @@ -81,31 +76,8 @@ public FileSessionPropertyManager(FileSessionPropertyManagerConfig config) } @Override - public SystemSessionPropertyConfiguration getSystemSessionProperties(SessionConfigurationContext context) - { - // later properties override earlier properties - Map defaultProperties = new HashMap<>(); - Set overridePropertyNames = new HashSet<>(); - for (SessionMatchSpec sessionMatchSpec : sessionMatchSpecs) { - Map newProperties = sessionMatchSpec.match(context); - defaultProperties.putAll(newProperties); - if (sessionMatchSpec.getOverrideSessionProperties().orElse(false)) { - overridePropertyNames.addAll(newProperties.keySet()); - } - } - - // Once a property has been overridden it stays that way and the value is updated by any rule - Map overrideProperties = new HashMap<>(); - for (String propertyName : overridePropertyNames) { - overrideProperties.put(propertyName, defaultProperties.get(propertyName)); - } - - return new SystemSessionPropertyConfiguration(ImmutableMap.copyOf(defaultProperties), ImmutableMap.copyOf(overrideProperties)); - } - - @Override - public Map> getCatalogSessionProperties(SessionConfigurationContext context) + protected List getSessionMatchSpecs() { - return ImmutableMap.of(); + return sessionMatchSpecs; } } diff --git a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerConfig.java b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerConfig.java similarity index 91% rename from presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerConfig.java rename to presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerConfig.java index 67b1c656ee81b..54164b968b105 100644 --- a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerConfig.java +++ b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerConfig.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.file; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; diff --git a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerFactory.java b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerFactory.java similarity index 97% rename from presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerFactory.java rename to presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerFactory.java index aec39239cfbb5..08a49b74f54f5 100644 --- a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerFactory.java +++ b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerFactory.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.file; import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonModule; diff --git a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerModule.java b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerModule.java similarity index 96% rename from presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerModule.java rename to presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerModule.java index 03fee8d5207dc..9d50e6144df4c 100644 --- a/presto-session-property-managers/src/main/java/com/facebook/presto/session/FileSessionPropertyManagerModule.java +++ b/presto-file-session-property-manager/src/main/java/com/facebook/presto/session/file/FileSessionPropertyManagerModule.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.file; import com.google.inject.Binder; import com.google.inject.Module; diff --git a/presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManager.java b/presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManager.java new file mode 100644 index 0000000000000..f071f2547b353 --- /dev/null +++ b/presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManager.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.session.file; + +import com.facebook.airlift.testing.TempFile; +import com.facebook.presto.session.AbstractTestSessionPropertyManager; +import com.facebook.presto.session.SessionMatchSpec; +import com.facebook.presto.spi.session.SessionPropertyConfigurationManager; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.session.file.FileSessionPropertyManager.CODEC; +import static org.testng.Assert.assertEquals; + +public class TestFileSessionPropertyManager + extends AbstractTestSessionPropertyManager +{ + @Override + protected void assertProperties(Map defaultProperties, SessionMatchSpec... specs) + throws IOException + { + assertProperties(defaultProperties, ImmutableMap.of(), ImmutableMap.of(), specs); + } + + @Override + protected void assertProperties(Map defaultProperties, Map overrideProperties, SessionMatchSpec... specs) + throws IOException + { + assertProperties(defaultProperties, overrideProperties, ImmutableMap.of(), specs); + } + + protected void assertProperties(Map defaultProperties, Map overrideProperties, Map> catalogProperties, SessionMatchSpec... specs) + throws IOException + { + try (TempFile tempFile = new TempFile()) { + Path configurationFile = tempFile.path(); + Files.write(configurationFile, CODEC.toJsonBytes(Arrays.asList(specs))); + SessionPropertyConfigurationManager manager = new FileSessionPropertyManager(new FileSessionPropertyManagerConfig().setConfigFile(configurationFile.toFile())); + SessionPropertyConfigurationManager.SystemSessionPropertyConfiguration propertyConfiguration = manager.getSystemSessionProperties(CONTEXT); + assertEquals(propertyConfiguration.systemPropertyDefaults, defaultProperties); + assertEquals(propertyConfiguration.systemPropertyOverrides, overrideProperties); + assertEquals(manager.getCatalogSessionProperties(CONTEXT), catalogProperties); + } + } + + @Test + public void testNullSessionProperties() + throws IOException + { + ImmutableMap> catalogProperties = ImmutableMap.of("CATALOG", ImmutableMap.of("PROPERTY", "VALUE")); + SessionMatchSpec spec = new SessionMatchSpec( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + null, + catalogProperties); + + assertProperties(ImmutableMap.of(), ImmutableMap.of(), catalogProperties, spec); + } + + @Test + public void testNullCatalogSessionProperties() + throws IOException + { + Map properties = ImmutableMap.of("PROPERTY1", "VALUE1", "PROPERTY2", "VALUE2"); + SessionMatchSpec spec = new SessionMatchSpec( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + properties, + null); + + assertProperties(properties, spec); + } + + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Either sessionProperties or catalogSessionProperties must be provided") + public void testNullBothSessionProperties() + throws IOException + { + SessionMatchSpec spec = new SessionMatchSpec( + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + null, + null); + + assertProperties(ImmutableMap.of(), spec); + } +} diff --git a/presto-session-property-managers/src/test/java/com/facebook/presto/session/TestFileSessionPropertyManagerConfig.java b/presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManagerConfig.java similarity index 92% rename from presto-session-property-managers/src/test/java/com/facebook/presto/session/TestFileSessionPropertyManagerConfig.java rename to presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManagerConfig.java index b555593fa09a0..43403ef267bce 100644 --- a/presto-session-property-managers/src/test/java/com/facebook/presto/session/TestFileSessionPropertyManagerConfig.java +++ b/presto-file-session-property-manager/src/test/java/com/facebook/presto/session/file/TestFileSessionPropertyManagerConfig.java @@ -11,12 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.session; +package com.facebook.presto.session.file; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; -import java.io.File; +import java.nio.file.Paths; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; @@ -40,7 +40,7 @@ public void testExplicitPropertyMappings() .build(); FileSessionPropertyManagerConfig expected = new FileSessionPropertyManagerConfig() - .setConfigFile(new File("/test.json")); + .setConfigFile(Paths.get("/test.json").toFile()); assertFullMapping(properties, expected); } diff --git a/presto-function-namespace-managers-common/pom.xml b/presto-function-namespace-managers-common/pom.xml index c9d58efd0f56c..084b979288ff6 100644 --- a/presto-function-namespace-managers-common/pom.xml +++ b/presto-function-namespace-managers-common/pom.xml @@ -5,14 +5,17 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT ${project.parent.basedir} + 8 + true presto-function-namespace-managers-common + presto-function-namespace-managers-common com.facebook.presto @@ -25,8 +28,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -41,14 +44,21 @@ - io.airlift + com.facebook.airlift units provided - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations + true + + + + jakarta.annotation + jakarta.annotation-api + true diff --git a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java index b5d17f3c59844..4f691d066d998 100644 --- a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java @@ -44,9 +44,8 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.util.concurrent.UncheckedExecutionException; - -import javax.annotation.ParametersAreNonnullByDefault; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nonnull; import java.util.Collection; import java.util.List; @@ -90,8 +89,7 @@ public AbstractSqlInvokedFunctionNamespaceManager(String catalogName, SqlFunctio .build(new CacheLoader>() { @Override - @ParametersAreNonnullByDefault - public Collection load(QualifiedObjectName functionName) + public Collection load(@Nonnull QualifiedObjectName functionName) { Collection functions = fetchFunctionsDirect(functionName); for (SqlInvokedFunction function : functions) { @@ -109,8 +107,7 @@ public Collection load(QualifiedObjectName functionName) .build(new CacheLoader() { @Override - @ParametersAreNonnullByDefault - public FunctionMetadata load(SqlFunctionHandle functionHandle) + public FunctionMetadata load(@Nonnull SqlFunctionHandle functionHandle) { return fetchFunctionMetadataDirect(functionHandle); } @@ -255,7 +252,8 @@ private static PrestoException convertToPrestoException(UncheckedExecutionExcept return new PrestoException(GENERIC_INTERNAL_ERROR, failureMessage, cause); } - protected String getCatalogName() + @Override + public String getCatalogName() { return catalogName; } diff --git a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java index 5c855a87968f2..d193e513e8cf5 100644 --- a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java +++ b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/JsonBasedUdfFunctionMetadata.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; +import java.net.URI; import java.util.List; import java.util.Optional; import java.util.stream.IntStream; @@ -77,6 +78,10 @@ public class JsonBasedUdfFunctionMetadata private final Optional> longVariableConstraints; private final Optional functionId; private final Optional version; + /** + * Optional execution endpoint for routing function execution to a different server + */ + private final Optional executionEndpoint; @JsonCreator public JsonBasedUdfFunctionMetadata( @@ -91,7 +96,8 @@ public JsonBasedUdfFunctionMetadata( @JsonProperty("functionId") Optional functionId, @JsonProperty("version") Optional version, @JsonProperty("typeVariableConstraints") Optional> typeVariableConstraints, - @JsonProperty("longVariableConstraints") Optional> longVariableConstraints) + @JsonProperty("longVariableConstraints") Optional> longVariableConstraints, + @JsonProperty("executionEndpoint") Optional executionEndpoint) { this.docString = requireNonNull(docString, "docString is null"); this.functionKind = requireNonNull(functionKind, "functionKind is null"); @@ -108,6 +114,13 @@ public JsonBasedUdfFunctionMetadata( this.version = requireNonNull(version, "version is null"); this.typeVariableConstraints = requireNonNull(typeVariableConstraints, "typeVariableConstraints is null"); this.longVariableConstraints = requireNonNull(longVariableConstraints, "longVariableConstraints is null"); + this.executionEndpoint = requireNonNull(executionEndpoint, "executionEndpoint is null"); + executionEndpoint.ifPresent(uri -> { + String scheme = uri.getScheme(); + if (scheme == null || (!scheme.equalsIgnoreCase("http") && !scheme.equalsIgnoreCase("https"))) { + throw new IllegalArgumentException("Execution endpoint must use HTTP or HTTPS protocol: " + uri); + } + }); } @JsonProperty @@ -187,4 +200,10 @@ public Optional> getLongVariableConstraints() { return longVariableConstraints; } + + @JsonProperty + public Optional getExecutionEndpoint() + { + return executionEndpoint; + } } diff --git a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/ServingCatalog.java b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/ServingCatalog.java index 6178f60cf45ea..63b034c82a129 100644 --- a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/ServingCatalog.java +++ b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/ServingCatalog.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.functionNamespace; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/SqlInvokedFunctionNamespaceManagerConfig.java b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/SqlInvokedFunctionNamespaceManagerConfig.java index cf9552fd10f08..b3278b2efaf50 100644 --- a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/SqlInvokedFunctionNamespaceManagerConfig.java +++ b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/SqlInvokedFunctionNamespaceManagerConfig.java @@ -14,10 +14,10 @@ package com.facebook.presto.functionNamespace; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; import java.util.Set; diff --git a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/testing/InMemoryFunctionNamespaceManager.java b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/testing/InMemoryFunctionNamespaceManager.java index 19c5279b570b2..827dd206867ab 100644 --- a/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/testing/InMemoryFunctionNamespaceManager.java +++ b/presto-function-namespace-managers-common/src/main/java/com/facebook/presto/functionNamespace/testing/InMemoryFunctionNamespaceManager.java @@ -28,8 +28,7 @@ import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.Collection; import java.util.List; @@ -109,7 +108,7 @@ public void addUserDefinedType(UserDefinedType type) checkArgument( !userDefinedTypes.containsKey(name), "Parametric type %s already registered", - name); + name); userDefinedTypes.put(name, type); } diff --git a/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManager.java b/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManager.java index fc40f90bedff0..120f41607b573 100644 --- a/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.functionNamespace; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.common.type.UserDefinedType; @@ -31,7 +32,6 @@ import com.facebook.presto.spi.function.SqlInvokedFunction; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Collection; @@ -47,7 +47,7 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.SQL; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static java.lang.String.format; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.testng.Assert.assertEquals; @@ -111,19 +111,19 @@ public void testTransactionalGetFunction() FunctionNamespaceTransactionHandle transaction2 = functionNamespaceManager.beginTransaction(); Collection functions2 = functionNamespaceManager.getFunctions(Optional.of(transaction2), POWER_TOWER); assertEquals(functions2.size(), 1); - assertEquals(getOnlyElement(functions2), FUNCTION_POWER_TOWER_DOUBLE.withVersion("1")); + assertEquals(functions2.stream().collect(onlyElement()), FUNCTION_POWER_TOWER_DOUBLE.withVersion("1")); // update the function, second transaction still sees the old functions functionNamespaceManager.createFunction(FUNCTION_POWER_TOWER_DOUBLE_UPDATED, true); functions2 = functionNamespaceManager.getFunctions(Optional.of(transaction2), POWER_TOWER); assertEquals(functions2.size(), 1); - assertEquals(getOnlyElement(functions2), FUNCTION_POWER_TOWER_DOUBLE.withVersion("1")); + assertEquals(functions2.stream().collect(onlyElement()), FUNCTION_POWER_TOWER_DOUBLE.withVersion("1")); // third transaction sees the updated function FunctionNamespaceTransactionHandle transaction3 = functionNamespaceManager.beginTransaction(); Collection functions3 = functionNamespaceManager.getFunctions(Optional.of(transaction3), POWER_TOWER); assertEquals(functions3.size(), 1); - assertEquals(getOnlyElement(functions3), FUNCTION_POWER_TOWER_DOUBLE_UPDATED.withVersion("2")); + assertEquals(functions3.stream().collect(onlyElement()), FUNCTION_POWER_TOWER_DOUBLE_UPDATED.withVersion("2")); functionNamespaceManager.commit(transaction1); functionNamespaceManager.commit(transaction2); @@ -137,8 +137,8 @@ public void testCaching() functionNamespaceManager.createFunction(FUNCTION_POWER_TOWER_DOUBLE, false); // fetchFunctionsDirect does not produce the same function reference - SqlInvokedFunction function1 = getOnlyElement(functionNamespaceManager.fetchFunctionsDirect(POWER_TOWER)); - SqlInvokedFunction function2 = getOnlyElement(functionNamespaceManager.fetchFunctionsDirect(POWER_TOWER)); + SqlInvokedFunction function1 = functionNamespaceManager.fetchFunctionsDirect(POWER_TOWER).stream().collect(onlyElement()); + SqlInvokedFunction function2 = functionNamespaceManager.fetchFunctionsDirect(POWER_TOWER).stream().collect(onlyElement()); assertEquals(function1, function2); assertNotSame(function1, function2); diff --git a/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManagerConfig.java b/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManagerConfig.java index 8982d2c7097af..dbb8a1d6add7d 100644 --- a/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManagerConfig.java +++ b/presto-function-namespace-managers-common/src/test/java/com/facebook/presto/functionNamespace/TestSqlInvokedFunctionNamespaceManagerConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.functionNamespace; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-function-namespace-managers/pom.xml b/presto-function-namespace-managers/pom.xml index ad7a4887e20fe..f0d982f3b5626 100644 --- a/presto-function-namespace-managers/pom.xml +++ b/presto-function-namespace-managers/pom.xml @@ -3,7 +3,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT 4.0.0 @@ -13,8 +13,20 @@ ${project.parent.basedir} + 17 + true - + + + + + org.slf4j + slf4j-api + 2.0.16 + + + + com.facebook.airlift @@ -42,18 +54,18 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided - com.facebook.drift + com.facebook.airlift.drift drift-client - com.facebook.drift + com.facebook.airlift.drift drift-transport-spi @@ -102,24 +114,29 @@ - io.airlift + com.facebook.airlift units provided - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api + + + + javax.inject + javax.inject @@ -128,6 +145,12 @@ runtime + + org.mariadb.jdbc + mariadb-java-client + runtime + + org.jdbi jdbi3-core @@ -152,15 +175,17 @@ com.fasterxml.jackson.core jackson-databind + runtime com.fasterxml.jackson.datatype jackson-datatype-jdk8 + runtime - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty @@ -208,8 +233,8 @@ - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api test diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/AbstractAnnotatedProvider.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/AbstractAnnotatedProvider.java new file mode 100644 index 0000000000000..44ed673668311 --- /dev/null +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/AbstractAnnotatedProvider.java @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2012 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functionNamespace.execution.thrift; + +import com.google.inject.Injector; +import jakarta.inject.Inject; + +import javax.inject.Provider; + +import java.lang.annotation.Annotation; +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public abstract class AbstractAnnotatedProvider + implements Provider +{ + private final Annotation annotation; + private Injector injector; + + protected AbstractAnnotatedProvider(Annotation annotation) + { + this.annotation = requireNonNull(annotation, "annotation is null"); + } + + @Inject + public final void setInjector(Injector injector) + { + this.injector = injector; + } + + @Override + public final T get() + { + checkState(injector != null, "injector was not set"); + return get(injector, annotation); + } + + protected abstract T get(Injector injector, Annotation annotation); + + @Override + public final boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + AbstractAnnotatedProvider other = (AbstractAnnotatedProvider) obj; + return Objects.equals(this.annotation, other.annotation); + } + + @Override + public final int hashCode() + { + return Objects.hash(annotation); + } +} diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualAddressSelector.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualAddressSelector.java index eb2a84750f416..9ea74681b7094 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualAddressSelector.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualAddressSelector.java @@ -44,6 +44,6 @@ public Optional selectAddress(Optional context) public Optional selectAddress(Optional context, Set attempted) { checkArgument(context.isPresent(), "context is empty"); - return delegates.get(context.get()).selectAddress(Optional.empty(), attempted); + return delegates.get(context.orElseThrow()).selectAddress(Optional.empty(), attempted); } } diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualSimpleAddressSelectorBinder.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualSimpleAddressSelectorBinder.java index cd324fa713d71..d81d3323785de 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualSimpleAddressSelectorBinder.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ContextualSimpleAddressSelectorBinder.java @@ -17,7 +17,6 @@ import com.facebook.drift.client.address.SimpleAddressSelector; import com.facebook.drift.client.address.SimpleAddressSelector.SimpleAddress; import com.facebook.drift.client.address.SimpleAddressSelectorConfig; -import com.facebook.drift.client.guice.AbstractAnnotatedProvider; import com.facebook.drift.client.guice.AddressSelectorBinder; import com.google.common.collect.ImmutableMap; import com.google.inject.Binder; @@ -56,6 +55,7 @@ private static class SimpleAddressSelectorProvider extends AbstractAnnotatedProvider> { private final Map addresses; + protected SimpleAddressSelectorProvider(Annotation annotation, Map addresses) { super(annotation); diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ThriftSqlFunctionExecutor.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ThriftSqlFunctionExecutor.java index 3d05d7bb5e420..981cf8130b2f8 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ThriftSqlFunctionExecutor.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/execution/thrift/ThriftSqlFunctionExecutor.java @@ -56,7 +56,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -184,7 +184,7 @@ private SqlFunctionResult toSqlFunctionResult(ThriftUdfResult result, Type retur ThriftUdfPage page = result.getResult(); switch (page.getPageFormat()) { case PRESTO_THRIFT: - return new SqlFunctionResult(getOnlyElement(page.getThriftPage().getThriftBlocks()).toBlock(returnType), result.getUdfStats().getTotalCpuTimeMs()); + return new SqlFunctionResult(page.getThriftPage().getThriftBlocks().stream().collect(onlyElement()).toBlock(returnType), result.getUdfStats().getTotalCpuTimeMs()); case PRESTO_SERIALIZED: checkState(blockEncodingSerde != null, "blockEncodingSerde not set"); PagesSerde pagesSerde = new PagesSerde(blockEncodingSerde, Optional.empty(), Optional.empty(), Optional.empty()); diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionDefinitionProvider.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionDefinitionProvider.java index 749f7cc0d2824..a25fed005d786 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionDefinitionProvider.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionDefinitionProvider.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.Paths; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -61,7 +60,7 @@ public UdfFunctionSignatureMap getUdfDefinition(String filePath) private List getFilesInPath(String filePath, int maxDirectoryDepth) throws IOException { try (Stream stream = Files.find( - Paths.get(filePath), + Path.of(filePath), maxDirectoryDepth, (p, basicFileAttributes) -> p.getFileName().toString().toLowerCase(ENGLISH).endsWith(JSON_FILE_EXTENSION))) { return stream.collect(Collectors.toList()); diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManager.java index b830b1c4e1863..39c8a75ff97fd 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManager.java @@ -36,8 +36,7 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -48,9 +47,7 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; import static java.lang.Long.parseLong; @@ -139,7 +136,6 @@ private void populateNameSpaceManager(UdfFunctionSignatureMap udfFunctionSignatu private SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonBasedUdfFunctionMetadata jsonBasedUdfFunctionMetaData) { - checkState(jsonBasedUdfFunctionMetaData.getRoutineCharacteristics().getLanguage().equals(CPP), "JsonFileBasedFunctionNamespaceManager only supports CPP UDF"); QualifiedObjectName qualifiedFunctionName = QualifiedObjectName.valueOf(new CatalogSchemaName(getCatalogName(), jsonBasedUdfFunctionMetaData.getSchema()), functionName); List parameterNameList = jsonBasedUdfFunctionMetaData.getParamNames(); List parameterTypeList = jsonBasedUdfFunctionMetaData.getParamTypes(); @@ -170,6 +166,25 @@ protected Collection fetchFunctionsDirect(QualifiedObjectNam .collect(toImmutableList()); } + @Override + protected FunctionMetadata sqlInvokedFunctionToMetadata(SqlInvokedFunction function) + { + return new FunctionMetadata( + function.getSignature().getName(), + function.getSignature().getArgumentTypes(), + function.getParameters().stream() + .map(Parameter::getName) + .collect(toImmutableList()), + function.getSignature().getReturnType(), + function.getSignature().getKind(), + function.getRoutineCharacteristics().getLanguage(), + getFunctionImplementationType(function), + function.isDeterministic(), + function.isCalledOnNullInput(), + function.getVersion(), + function.getDescription()); + } + @Override protected UserDefinedType fetchUserDefinedTypeDirect(QualifiedObjectName typeName) { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerConfig.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerConfig.java index 46bd6f17a84e7..b9d8a8ec2f472 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerConfig.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class JsonFileBasedFunctionNamespaceManagerConfig { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerFactory.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerFactory.java index 31eff88c580f7..a4f7f7ee359a0 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerFactory.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonFileBasedFunctionNamespaceManagerFactory.java @@ -36,7 +36,7 @@ public class JsonFileBasedFunctionNamespaceManagerFactory { public static final String NAME = "json_file"; - private static final SqlFunctionHandle.Resolver HANDLE_RESOLVER = new SqlFunctionHandle.Resolver(); + private static final SqlFunctionHandle.Resolver HANDLE_RESOLVER = SqlFunctionHandle.Resolver.getInstance(); @Override public String getName() diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionConfig.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionConfig.java index 1952a395b4a69..3f61396d95c8e 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionConfig.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionConfig.java @@ -14,11 +14,12 @@ package com.facebook.presto.functionNamespace.mysql; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class MySqlConnectionConfig { + private String jdbcDriverName = "com.mysql.jdbc.Driver"; + private String databaseUrl; @NotNull @@ -33,4 +34,16 @@ public MySqlConnectionConfig setDatabaseUrl(String databaseUrl) this.databaseUrl = databaseUrl; return this; } + + public String getJdbcDriverName() + { + return jdbcDriverName; + } + + @Config("database-driver-name") + public MySqlConnectionConfig setJdbcDriverName(String jdbcDriverName) + { + this.jdbcDriverName = jdbcDriverName; + return this; + } } diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionModule.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionModule.java index 1fbe6101b0e29..4623bbbb48fd2 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionModule.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlConnectionModule.java @@ -17,11 +17,11 @@ import com.google.inject.Binder; import com.google.inject.Injector; import com.google.inject.TypeLiteral; +import jakarta.inject.Inject; import org.jdbi.v3.core.ConnectionFactory; import org.jdbi.v3.core.Jdbi; import org.jdbi.v3.sqlobject.SqlObjectPlugin; -import javax.inject.Inject; import javax.inject.Provider; import java.sql.DriverManager; @@ -36,9 +36,11 @@ protected void setup(Binder binder) { configBinder(binder).bindConfig(MySqlConnectionConfig.class); - String databaseUrl = buildConfigObject(MySqlConnectionConfig.class).getDatabaseUrl(); + MySqlConnectionConfig mySqlConnectionConfig = buildConfigObject(MySqlConnectionConfig.class); + String databaseUrl = mySqlConnectionConfig.getDatabaseUrl(); + String jdbcDriverName = mySqlConnectionConfig.getJdbcDriverName(); try { - Class.forName("com.mysql.jdbc.Driver"); + Class.forName(jdbcDriverName); } catch (ClassNotFoundException e) { throw new RuntimeException(e); diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManager.java index 62e7318828683..15f2783437969 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManager.java @@ -32,11 +32,10 @@ import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; import org.jdbi.v3.core.Jdbi; -import javax.annotation.PostConstruct; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -50,7 +49,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static com.google.common.hash.Hashing.sha256; import static java.lang.Long.parseLong; import static java.lang.String.format; @@ -118,10 +117,7 @@ public void addUserDefinedType(UserDefinedType type) public UserDefinedType fetchUserDefinedTypeDirect(QualifiedObjectName typeName) { Optional type = functionNamespaceDao.getUserDefinedType(typeName.getCatalogName(), typeName.getSchemaName(), typeName.getObjectName()); - if (!type.isPresent()) { - throw new PrestoException(NOT_FOUND, format("Type %s not found", typeName)); - } - return type.get(); + return type.orElseThrow(() -> new PrestoException(NOT_FOUND, format("Type %s not found", typeName))); } @Override @@ -139,10 +135,7 @@ protected FunctionMetadata fetchFunctionMetadataDirect(SqlFunctionHandle functio { checkCatalog(functionHandle); Optional function = functionNamespaceDao.getFunction(hash(functionHandle.getFunctionId()), functionHandle.getFunctionId(), getLongVersion(functionHandle)); - if (!function.isPresent()) { - throw new InvalidFunctionHandleException(functionHandle); - } - return sqlInvokedFunctionToMetadata(function.get()); + return sqlInvokedFunctionToMetadata(function.orElseThrow(() -> new InvalidFunctionHandleException(functionHandle))); } @Override @@ -150,10 +143,7 @@ protected ScalarFunctionImplementation fetchFunctionImplementationDirect(SqlFunc { checkCatalog(functionHandle); Optional function = functionNamespaceDao.getFunction(hash(functionHandle.getFunctionId()), functionHandle.getFunctionId(), getLongVersion(functionHandle)); - if (!function.isPresent()) { - throw new InvalidFunctionHandleException(functionHandle); - } - return sqlInvokedFunctionToImplementation(function.get()); + return sqlInvokedFunctionToImplementation(function.orElseThrow(() -> new InvalidFunctionHandleException(functionHandle))); } @Override @@ -189,16 +179,16 @@ public void createFunction(SqlInvokedFunction function, boolean replace) jdbi.useTransaction(handle -> { FunctionNamespaceDao transactionDao = handle.attach(functionNamespaceDaoClass); Optional latestVersion = transactionDao.getLatestRecordForUpdate(hash(function.getFunctionId()), function.getFunctionId()); - if (!replace && latestVersion.isPresent() && !latestVersion.get().isDeleted()) { + if (!replace && latestVersion.isPresent() && !latestVersion.orElseThrow().isDeleted()) { throw new PrestoException(ALREADY_EXISTS, "Function already exists: " + function.getFunctionId()); } - if (!latestVersion.isPresent() || !latestVersion.get().getFunction().hasSameDefinitionAs(function)) { + if (latestVersion.isEmpty() || !latestVersion.orElseThrow().getFunction().hasSameDefinitionAs(function)) { long newVersion = latestVersion.map(SqlInvokedFunctionRecord::getFunction).map(MySqlFunctionNamespaceManager::getLongVersion).orElse(0L) + 1; insertSqlInvokedFunction(transactionDao, function, newVersion); } - else if (latestVersion.get().isDeleted()) { - SqlInvokedFunction latest = latestVersion.get().getFunction(); + else if (latestVersion.orElseThrow().isDeleted()) { + SqlInvokedFunction latest = latestVersion.orElseThrow().getFunction(); checkState(latest.hasVersion(), "Function version missing: %s", latest.getFunctionId()); transactionDao.setDeletionStatus(hash(latest.getFunctionId()), latest.getFunctionId(), getLongVersion(latest), false); } @@ -250,7 +240,7 @@ public void dropFunction(QualifiedObjectName functionName, Optional getSqlFunctions(FunctionNamespaceDao functionNa { List records = new ArrayList<>(); if (parameterTypes.isPresent()) { - SqlFunctionId functionId = new SqlFunctionId(functionName, parameterTypes.get()); + SqlFunctionId functionId = new SqlFunctionId(functionName, parameterTypes.orElseThrow()); functionNamespaceDao.getLatestRecordForUpdate(hash(functionId), functionId).ifPresent(records::add); } else { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerConfig.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerConfig.java index 9dd187c605d8f..2f9d7970db2a4 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerConfig.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.functionNamespace.mysql; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class MySqlFunctionNamespaceManagerConfig { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerFactory.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerFactory.java index bb3b9489e0214..666d47ce13f7c 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerFactory.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/mysql/MySqlFunctionNamespaceManagerFactory.java @@ -32,7 +32,7 @@ public class MySqlFunctionNamespaceManagerFactory { public static final String NAME = "mysql"; - private static final SqlFunctionHandle.Resolver HANDLE_RESOLVER = new SqlFunctionHandle.Resolver(); + private static final SqlFunctionHandle.Resolver HANDLE_RESOLVER = SqlFunctionHandle.Resolver.getInstance(); @Override public String getName() diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManager.java index c82a0cdab8be8..4694911f2e3d0 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManager.java @@ -40,8 +40,7 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; @@ -156,7 +155,8 @@ private SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonBas qualifiedFunctionName, jsonBasedUdfFunctionMetaData.getFunctionKind(), jsonBasedUdfFunctionMetaData.getOutputType(), - jsonBasedUdfFunctionMetaData.getParamTypes())))); + jsonBasedUdfFunctionMetaData.getParamTypes()), + jsonBasedUdfFunctionMetaData.getExecutionEndpoint()))); } @Override @@ -193,11 +193,7 @@ protected FunctionMetadata fetchFunctionMetadataDirect(SqlFunctionHandle functio checkCatalog(functionHandle); Optional function = getSqlInvokedFunction(functionHandle); - if (!function.isPresent()) { - throw new InvalidFunctionHandleException(functionHandle); - } - - return sqlInvokedFunctionToMetadata(function.get()); + return sqlInvokedFunctionToMetadata(function.orElseThrow(() -> new InvalidFunctionHandleException(functionHandle))); } @Override @@ -206,11 +202,7 @@ protected ScalarFunctionImplementation fetchFunctionImplementationDirect(SqlFunc checkCatalog(functionHandle); Optional function = getSqlInvokedFunction(functionHandle); - if (!function.isPresent()) { - throw new InvalidFunctionHandleException(functionHandle); - } - - return sqlInvokedFunctionToImplementation(function.get()); + return sqlInvokedFunctionToImplementation(function.orElseThrow(() -> new InvalidFunctionHandleException(functionHandle))); } @Override diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManagerConfig.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManagerConfig.java index c4bcea72210ba..303a575bb1a6f 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManagerConfig.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestBasedFunctionNamespaceManagerConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class RestBasedFunctionNamespaceManagerConfig { diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestSqlFunctionExecutor.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestSqlFunctionExecutor.java index d6aca0eb38961..a2d530c425030 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestSqlFunctionExecutor.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/rest/RestSqlFunctionExecutor.java @@ -14,7 +14,6 @@ package com.facebook.presto.functionNamespace.rest; import com.facebook.airlift.http.client.HttpClient; -import com.facebook.airlift.http.client.HttpUriBuilder; import com.facebook.airlift.http.client.Request; import com.facebook.airlift.http.client.Response; import com.facebook.airlift.http.client.ResponseHandler; @@ -26,6 +25,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.FunctionImplementationType; import com.facebook.presto.spi.function.RemoteScalarFunctionImplementation; +import com.facebook.presto.spi.function.RestFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionExecutor; import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; @@ -36,21 +36,17 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.InputStreamSliceInput; import io.airlift.slice.SliceInput; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; -import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; import static com.facebook.airlift.concurrent.MoreFutures.failedFuture; import static com.facebook.airlift.concurrent.MoreFutures.toCompletableFuture; -import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static com.facebook.airlift.http.client.Request.Builder.preparePost; import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static com.facebook.presto.functionNamespace.rest.RestErrorCode.REST_SERVER_BAD_RESPONSE; @@ -71,6 +67,7 @@ import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_SERVER_ERROR; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public class RestSqlFunctionExecutor @@ -113,13 +110,15 @@ public CompletableFuture executeFunction( Type returnType) { SqlFunctionHandle functionHandle = functionImplementation.getFunctionHandle(); + checkArgument(functionHandle instanceof RestFunctionHandle, "Expected RestFunctionHandle but got %s", functionHandle.getClass().getName()); + RestFunctionHandle restFunctionHandle = (RestFunctionHandle) functionHandle; SqlFunctionId functionId = functionHandle.getFunctionId(); String functionVersion = functionHandle.getVersion(); DynamicSliceOutput sliceOutput = new DynamicSliceOutput((int) input.getRetainedSizeInBytes()); writeSerializedPage(sliceOutput, pageSerde.serialize(input)); try { Request request = preparePost() - .setUri(getExecutionEndpoint(functionId, functionVersion)) + .setUri(getExecutionEndpoint(restFunctionHandle, functionId, functionVersion)) .setBodyGenerator(createStaticBodyGenerator(sliceOutput.slice().byteArray())) .setHeader(CONTENT_TYPE, PRESTO_PAGES) .setHeader(ACCEPT, PRESTO_PAGES) @@ -133,24 +132,21 @@ public CompletableFuture executeFunction( } } - private URI getExecutionEndpoint(SqlFunctionId functionId, String functionVersion) + private URI getExecutionEndpoint(RestFunctionHandle restFunctionHandle, SqlFunctionId functionId, String functionVersion) { String encodedFunctionId; - try { - encodedFunctionId = URLEncoder.encode(functionId.toJsonString(), StandardCharsets.UTF_8.toString()); - } - catch (UnsupportedEncodingException e) { - // Should never happen - throw new IllegalStateException("UTF-8 encoding is not supported", e); - } - - HttpUriBuilder uri = uriBuilderFrom(URI.create(restBasedFunctionNamespaceManagerConfig.getRestUrl())) - .appendPath(format("/v1/functions/%s/%s/%s/%s", - functionId.getFunctionName().getSchemaName(), - functionId.getFunctionName().getObjectName(), - encodedFunctionId, - functionVersion)); - return uri.build(); + encodedFunctionId = URLEncoder.encode(functionId.toJsonString(), UTF_8); + + // Use execution endpoint from handle if present, otherwise use default + URI baseUri = restFunctionHandle.getExecutionEndpoint() + .orElse(URI.create(restBasedFunctionNamespaceManagerConfig.getRestUrl())); + String path = format("/v1/functions/%s/%s/%s/%s", + functionId.getFunctionName().getSchemaName(), + functionId.getFunctionName().getObjectName(), + encodedFunctionId, + functionVersion); + + return URI.create(String.format("%s%s", baseUri, path)); } public static class SqlFunctionResultResponseHandler diff --git a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/TestRestBasedFunctionNamespaceManager.java b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/TestRestBasedFunctionNamespaceManager.java index 6dbe4b355e5d7..ecce6415ea0f3 100644 --- a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/TestRestBasedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/TestRestBasedFunctionNamespaceManager.java @@ -35,18 +35,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.inject.Injector; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HEAD; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import javax.ws.rs.GET; -import javax.ws.rs.HEAD; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; - import java.net.URI; import java.util.ArrayList; import java.util.Arrays; @@ -147,7 +146,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.square"), ImmutableList.of(parseTypeSignature("integer")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); squareFunctions.add(new JsonBasedUdfFunctionMetadata( "square a double", FunctionKind.SCALAR, @@ -160,7 +160,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.square"), ImmutableList.of(parseTypeSignature("double")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); udfSignatureMap.put("square", squareFunctions); // array_function_1 @@ -177,7 +178,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY>"), parseTypeSignature("ARRAY>")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); arrayFunction1.add(new JsonBasedUdfFunctionMetadata( "combines two float arrays into one", FunctionKind.SCALAR, @@ -190,7 +192,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY>"), parseTypeSignature("ARRAY>")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); arrayFunction1.add(new JsonBasedUdfFunctionMetadata( "combines two double arrays into one", FunctionKind.SCALAR, @@ -203,7 +206,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_1"), ImmutableList.of(parseTypeSignature("ARRAY"), parseTypeSignature("ARRAY")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); udfSignatureMap.put("array_function_1", arrayFunction1); // array_function_2 @@ -220,7 +224,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.array_function_2"), ImmutableList.of(parseTypeSignature("ARRAY>"), parseTypeSignature("ARRAY")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); arrayFunction2.add(new JsonBasedUdfFunctionMetadata( "transforms inputs into the output", FunctionKind.SCALAR, @@ -233,7 +238,8 @@ public static Map> createUdfSignature Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.array_function_2"), ImmutableList.of(parseTypeSignature("ARRAY>"), parseTypeSignature("ARRAY>"), parseTypeSignature("ARRAY")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); udfSignatureMap.put("array_function_2", arrayFunction2); return udfSignatureMap; @@ -257,7 +263,8 @@ public static Map> createUpdatedUdfSi Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.default.square"), ImmutableList.of(parseTypeSignature("integer")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); squareFunctions.add(new JsonBasedUdfFunctionMetadata( "square a double", FunctionKind.SCALAR, @@ -270,14 +277,16 @@ public static Map> createUpdatedUdfSi Optional.of(new SqlFunctionId(QualifiedObjectName.valueOf("unittest.test_schema.square"), ImmutableList.of(parseTypeSignature("double")))), Optional.of("1"), Optional.of(emptyList()), - Optional.of(emptyList()))); + Optional.of(emptyList()), + Optional.empty())); udfSignatureMap.put("square", squareFunctions); return udfSignatureMap; } @BeforeMethod - public void setUp() throws Exception + public void setUp() + throws Exception { resource = new TestingFunctionResource(createUdfSignatureMap()); ObjectMapper mapper = new ObjectMapper(); diff --git a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlConnectionConfig.java b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlConnectionConfig.java index 4f67e71ff31f2..2161a2a5ecb74 100644 --- a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlConnectionConfig.java +++ b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlConnectionConfig.java @@ -28,7 +28,8 @@ public class TestMySqlConnectionConfig public void testDefault() { assertRecordedDefaults(recordDefaults(MySqlConnectionConfig.class) - .setDatabaseUrl(null)); + .setDatabaseUrl(null) + .setJdbcDriverName("com.mysql.jdbc.Driver")); } @Test @@ -36,9 +37,11 @@ public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() .put("database-url", "localhost:1080") + .put("database-driver-name", "org.mariadb.jdbc.Driver") .build(); MySqlConnectionConfig expected = new MySqlConnectionConfig() - .setDatabaseUrl("localhost:1080"); + .setDatabaseUrl("localhost:1080") + .setJdbcDriverName("org.mariadb.jdbc.Driver"); assertFullMapping(properties, expected); } diff --git a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManager.java b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManager.java index 72279f9ba7c19..d2194de7b44f1 100644 --- a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManager.java @@ -79,9 +79,9 @@ @Test(singleThreaded = true) public class TestMySqlFunctionNamespaceManager { - private static final String DB = "presto"; + protected static final String DB = "presto"; - private TestingMySqlServer mySqlServer; + protected TestingMySqlServer mySqlServer; private Jdbi jdbi; private Injector injector; private MySqlFunctionNamespaceManager functionNamespaceManager; @@ -102,16 +102,10 @@ public void setup() new DriftNettyClientModule(), new MySqlConnectionModule()); - Map config = ImmutableMap.builder() - .put("function-cache-expiration", "0s") - .put("function-instance-cache-expiration", "0s") - .put("database-url", mySqlServer.getJdbcUrl(DB)) - .build(); - try { this.injector = app .doNotInitializeLogging() - .setRequiredConfigurationProperties(config) + .setRequiredConfigurationProperties(getConfig()) .initialize(); this.functionNamespaceManager = injector.getInstance(MySqlFunctionNamespaceManager.class); this.jdbi = injector.getInstance(Jdbi.class); @@ -122,6 +116,15 @@ public void setup() } } + protected Map getConfig() + { + return ImmutableMap.builder() + .put("function-cache-expiration", "0s") + .put("function-instance-cache-expiration", "0s") + .put("database-url", mySqlServer.getJdbcUrl(DB)) + .build(); + } + @BeforeMethod public void setupFunctionNamespace() { @@ -443,7 +446,7 @@ private FunctionHandle getLatestFunctionHandle(SqlFunctionId functionId) .max(comparing(SqlInvokedFunction::getRequiredVersion)); assertTrue(function.isPresent()); functionNamespaceManager.commit(transactionHandle); - return function.get().getRequiredFunctionHandle(); + return function.orElseThrow().getRequiredFunctionHandle(); } private void assertListFunctions(Optional likePattern, Optional escape, SqlInvokedFunction... functions) diff --git a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManagerWithMariaDbDriver.java b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManagerWithMariaDbDriver.java new file mode 100644 index 0000000000000..1e4aa1f4fdba0 --- /dev/null +++ b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/mysql/TestMySqlFunctionNamespaceManagerWithMariaDbDriver.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functionNamespace.mysql; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +@Test(singleThreaded = true) +public class TestMySqlFunctionNamespaceManagerWithMariaDbDriver + extends TestMySqlFunctionNamespaceManager +{ + @Override + protected Map getConfig() + { + String jdbcUrl = mySqlServer.getJdbcUrl(DB).replaceFirst("jdbc:mysql:", "jdbc:mariadb:"); + + return ImmutableMap.builder() + .put("function-cache-expiration", "0s") + .put("function-instance-cache-expiration", "0s") + .put("database-url", jdbcUrl) + .put("database-driver-name", "org.mariadb.jdbc.Driver") + .build(); + } +} diff --git a/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/rest/TestRestSqlFunctionExecutorRouting.java b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/rest/TestRestSqlFunctionExecutorRouting.java new file mode 100644 index 0000000000000..fb9a66fd35818 --- /dev/null +++ b/presto-function-namespace-managers/src/test/java/com/facebook/presto/functionNamespace/rest/TestRestSqlFunctionExecutorRouting.java @@ -0,0 +1,237 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functionNamespace.rest; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.function.SqlFunctionResult; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.RemoteScalarFunctionImplementation; +import com.facebook.presto.spi.function.RestFunctionHandle; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunctionId; +import com.facebook.presto.spi.page.PagesSerde; +import com.google.common.collect.ImmutableList; +import com.google.common.net.MediaType; +import io.airlift.slice.DynamicSliceOutput; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.airlift.http.client.HttpStatus.OK; +import static com.facebook.airlift.http.client.testing.TestingResponse.mockResponse; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.spi.function.FunctionImplementationType.REST; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.page.PagesSerdeUtil.writeSerializedPage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestRestSqlFunctionExecutorRouting +{ + private static final MediaType PRESTO_PAGES = MediaType.create("application", "X-presto-pages"); + + private RestSqlFunctionExecutor executor; + private AtomicReference capturedRequest; + private RestBasedFunctionNamespaceManagerConfig config; + + @BeforeMethod + public void setup() + { + capturedRequest = new AtomicReference<>(); + config = new RestBasedFunctionNamespaceManagerConfig() + .setRestUrl("http://default-server.example.com"); + + HttpClient httpClient = new TestingHttpClient(request -> { + capturedRequest.set(request); + // Return a valid response + PagesSerde pagesSerde = new PagesSerde(new BlockEncodingManager(), Optional.empty(), Optional.empty(), Optional.empty()); + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + pageBuilder.declarePosition(); + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 42); + Page resultPage = pageBuilder.build(); + + DynamicSliceOutput sliceOutput = new DynamicSliceOutput(resultPage.getPositionCount()); + writeSerializedPage(sliceOutput, pagesSerde.serialize(resultPage)); + return mockResponse(OK, PRESTO_PAGES, sliceOutput.slice().toStringUtf8()); + }); + + executor = new RestSqlFunctionExecutor(config, httpClient); + executor.setBlockEncodingSerde(new BlockEncodingManager()); + } + + @Test + public void testRoutesToDefaultServerWhenNoExecutionEndpoint() + throws Exception + { + SqlFunctionId functionId = new SqlFunctionId( + QualifiedObjectName.valueOf("test.schema.function"), + ImmutableList.of(parseTypeSignature("bigint"))); + + Signature signature = new Signature( + QualifiedObjectName.valueOf("test.schema.function"), + FunctionKind.SCALAR, + parseTypeSignature("bigint"), + ImmutableList.of(parseTypeSignature("bigint"))); + + RestFunctionHandle handle = new RestFunctionHandle( + functionId, + "1.0", + signature); // No execution endpoint + + RemoteScalarFunctionImplementation implementation = new RemoteScalarFunctionImplementation( + handle, + CPP, + REST); + + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + pageBuilder.declarePosition(); + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10); + Page input = pageBuilder.build(); + + SqlFunctionResult result = executor.executeFunction( + "test-source", + implementation, + input, + ImmutableList.of(0), + ImmutableList.of(BIGINT), + BIGINT).get(); + + assertTrue(result.getResult().getPositionCount() == 1); + assertEquals(BIGINT.getLong(result.getResult(), 0), 42L); + assertNotNull(capturedRequest.get()); + URI uri = capturedRequest.get().getUri(); + assertEquals(uri.getHost(), "default-server.example.com"); + assertTrue(uri.getPath().contains("/v1/functions/schema/function/")); + } + + @Test + public void testRoutesToCustomExecutionEndpoint() + throws Exception + { + SqlFunctionId functionId = new SqlFunctionId( + QualifiedObjectName.valueOf("test.schema.function"), + ImmutableList.of(parseTypeSignature("bigint"))); + + Signature signature = new Signature( + QualifiedObjectName.valueOf("test.schema.function"), + FunctionKind.SCALAR, + parseTypeSignature("bigint"), + ImmutableList.of(parseTypeSignature("bigint"))); + + RestFunctionHandle handle = new RestFunctionHandle( + functionId, + "1.0", + signature, + Optional.of(URI.create("https://compute-cluster-1.example.com"))); + + RemoteScalarFunctionImplementation implementation = new RemoteScalarFunctionImplementation( + handle, + CPP, + REST); + + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + pageBuilder.declarePosition(); + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10); + Page input = pageBuilder.build(); + + SqlFunctionResult result = executor.executeFunction( + "test-source", + implementation, + input, + ImmutableList.of(0), + ImmutableList.of(BIGINT), + BIGINT).get(); + + assertTrue(result.getResult().getPositionCount() == 1); + assertEquals(BIGINT.getLong(result.getResult(), 0), 42L); + assertNotNull(capturedRequest.get()); + URI uri = capturedRequest.get().getUri(); + assertEquals(uri.getScheme(), "https"); + assertEquals(uri.getHost(), "compute-cluster-1.example.com"); + assertTrue(uri.getPath().contains("/v1/functions/schema/function/")); + } + + @Test + public void testMultipleFunctionsRouteToDifferentServers() + throws Exception + { + // Test that different functions can route to different execution servers + SqlFunctionId functionId1 = new SqlFunctionId( + QualifiedObjectName.valueOf("test.schema.function1"), + ImmutableList.of(parseTypeSignature("bigint"))); + + SqlFunctionId functionId2 = new SqlFunctionId( + QualifiedObjectName.valueOf("test.schema.function2"), + ImmutableList.of(parseTypeSignature("bigint"))); + + Signature signature1 = new Signature( + QualifiedObjectName.valueOf("test.schema.function1"), + FunctionKind.SCALAR, + parseTypeSignature("bigint"), + ImmutableList.of(parseTypeSignature("bigint"))); + + Signature signature2 = new Signature( + QualifiedObjectName.valueOf("test.schema.function2"), + FunctionKind.SCALAR, + parseTypeSignature("bigint"), + ImmutableList.of(parseTypeSignature("bigint"))); + + RestFunctionHandle handle1 = new RestFunctionHandle( + functionId1, + "1.0", + signature1, + Optional.of(URI.create("https://server1.example.com"))); + + RestFunctionHandle handle2 = new RestFunctionHandle( + functionId2, + "1.0", + signature2, + Optional.of(URI.create("https://server2.example.com"))); + + RemoteScalarFunctionImplementation implementation1 = new RemoteScalarFunctionImplementation( + handle1, + CPP, + REST); + + RemoteScalarFunctionImplementation implementation2 = new RemoteScalarFunctionImplementation( + handle2, + CPP, + REST); + + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + pageBuilder.declarePosition(); + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), 10); + Page input = pageBuilder.build(); + + // Execute function 1 + executor.executeFunction("test-source", implementation1, input, ImmutableList.of(0), ImmutableList.of(BIGINT), BIGINT).get(); + assertEquals(capturedRequest.get().getUri().getHost(), "server1.example.com"); + + // Execute function 2 + executor.executeFunction("test-source", implementation2, input, ImmutableList.of(0), ImmutableList.of(BIGINT), BIGINT).get(); + assertEquals(capturedRequest.get().getUri().getHost(), "server2.example.com"); + } +} diff --git a/presto-function-server/pom.xml b/presto-function-server/pom.xml index 55ca029d7923a..bf1e4f8da5042 100644 --- a/presto-function-server/pom.xml +++ b/presto-function-server/pom.xml @@ -5,13 +5,18 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-function-server - presto-function-namespace-managers + presto-function-server jar + + 17 + true + + @@ -34,11 +39,6 @@ presto-common - - com.facebook.presto - presto-function-namespace-managers - - com.facebook.airlift node @@ -90,8 +90,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -110,18 +110,13 @@ - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api - com.google.code.findbugs - jsr305 - - - - com.facebook.drift - drift-api + com.google.errorprone + error_prone_annotations @@ -140,6 +135,13 @@ ${project.version} + + + com.facebook.presto + presto-function-namespace-managers + test + + com.facebook.presto presto-tests @@ -164,26 +166,6 @@ - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - presto-function-server-executable - - - com.facebook.presto.server.FunctionServer - - - - - - org.gaul modernizer-maven-plugin @@ -195,4 +177,59 @@ + + + + executable-jar + + + !skipExecutableJar + + + + + + org.apache.maven.plugins + maven-shade-plugin + + + package + + shade + + + true + executable + false + + + + com.facebook.presto.server.FunctionServer + + + + + + + + + org.skife.maven + really-executable-jar-maven-plugin + + -Xms128m + executable + + + + package + + really-executable-jar + + + + + + + + diff --git a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionPluginManager.java b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionPluginManager.java index ce337f9d96da7..4631ffa801814 100644 --- a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionPluginManager.java +++ b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionPluginManager.java @@ -20,13 +20,12 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.resolver.ArtifactResolver; import io.airlift.resolver.DefaultArtifact; +import jakarta.inject.Inject; import org.sonatype.aether.artifact.Artifact; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.io.File; import java.io.IOException; import java.net.URL; @@ -61,7 +60,7 @@ public class FunctionPluginManager .add("com.fasterxml.jackson.annotation.") .add("com.fasterxml.jackson.module.afterburner.") .add("io.airlift.slice.") - .add("io.airlift.units.") + .add("com.facebook.airlift.units.") .add("org.openjdk.jol.") .add("com.facebook.presto.common") .add("com.facebook.drift.annotations.") diff --git a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionResource.java b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionResource.java index 0d560ae5665f3..6c1e9744fa1be 100644 --- a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionResource.java +++ b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionResource.java @@ -39,17 +39,16 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.Slices; - -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.HEAD; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HEAD; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.io.UnsupportedEncodingException; import java.lang.invoke.MethodHandle; @@ -141,7 +140,8 @@ private static JsonBasedUdfFunctionMetadata sqlFunctionToMetadata(SqlFunction fu function.getSignature().getArgumentTypes())), Optional.of("1"), Optional.of(function.getSignature().getTypeVariableConstraints()), - Optional.ofNullable(function.getSignature().getLongVariableConstraints())); + Optional.ofNullable(function.getSignature().getLongVariableConstraints()), + Optional.empty()); } @GET @@ -223,26 +223,41 @@ public byte[] execute( SerializedPage serializedPage = readSerializedPage(new BasicSliceInput(slice)); Page inputPage = pagesSerde.deserialize(serializedPage); - // Use functionId to retrieve argument types List argumentTypeSignatures = extractArgumentTypeSignatures(functionId); + Type[] types = new Type[argumentTypeSignatures.size()]; + Block[] blocks = new Block[argumentTypeSignatures.size()]; + for (int i = 0; i < argumentTypeSignatures.size(); i++) { + types[i] = manager.getType(argumentTypeSignatures.get(i).getTypeSignature()); + blocks[i] = inputPage.getBlock(i); + } FunctionHandle functionHandle = manager.lookupFunction(functionName, argumentTypeSignatures); BuiltInScalarFunctionImplementation functionImplementation = (BuiltInScalarFunctionImplementation) manager.getJavaScalarFunctionImplementation(functionHandle); - Object[] inputValues = new Object[inputPage.getChannelCount()]; - for (int i = 0; i < inputPage.getChannelCount(); i++) { - TypeSignatureProvider typeSignatureProvider = argumentTypeSignatures.get(i); - Type type = manager.getType(typeSignatureProvider.getTypeSignature()); + int positionCount = inputPage.getPositionCount(); + int channelCount = inputPage.getChannelCount(); + Type returnType = manager.getType(manager.getFunctionMetadata(functionHandle).getReturnType()); + PageBuilder pageBuilder = new PageBuilder(Collections.singletonList(returnType)); - inputValues[i] = deserializeBlock(type, inputPage.getBlock(i)); + for (int position = 0; position < positionCount; position++) { + Object[] inputValues = new Object[blocks.length]; + for (int i = 0; i < blocks.length; i++) { + if (blocks[i].isNull(position)) { + inputValues[i] = null; + } + else { + inputValues[i] = deserializeBlock(types[i], blocks[i].getRegion(position, 1)); + } + } + Object result = executeFunction(functionImplementation, inputValues); + pageBuilder.declarePosition(); + BlockBuilder output = pageBuilder.getBlockBuilder(0); + createResultBlock(output, returnType, result); } - Object result = executeFunction(functionImplementation, inputValues); - Type returnType = manager.getType(manager.getFunctionMetadata(functionHandle).getReturnType()); - Page outputPage = createResultPage(returnType, result); + Page outputPage = pageBuilder.build(); DynamicSliceOutput sliceOutput = new DynamicSliceOutput((int) outputPage.getRetainedSizeInBytes()); writeSerializedPage(sliceOutput, pagesSerde.serialize(outputPage)); - return sliceOutput.slice().byteArray(); } @@ -310,11 +325,8 @@ private static PagesSerde createPagesSerde() Optional.empty()); } - private Page createResultPage(Type type, Object result) + private void createResultBlock(BlockBuilder output, Type type, Object result) { - PageBuilder pageBuilder = new PageBuilder(Collections.singletonList(type)); - pageBuilder.declarePosition(); - BlockBuilder output = pageBuilder.getBlockBuilder(0); switch (type.getTypeSignature().getBase()) { case "integer": case "bigint": @@ -358,7 +370,6 @@ else if (result instanceof byte[]) { type.writeObject(output, result); break; } - return pageBuilder.build(); } @HEAD diff --git a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionServerModule.java b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionServerModule.java index ceb03b55a6650..fef77d2489302 100644 --- a/presto-function-server/src/main/java/com/facebook/presto/server/FunctionServerModule.java +++ b/presto-function-server/src/main/java/com/facebook/presto/server/FunctionServerModule.java @@ -26,6 +26,7 @@ import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.transaction.NoOpTransactionManager; @@ -53,6 +54,7 @@ protected void setup(Binder binder) { jaxrsBinder(binder).bind(FunctionResource.class); binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(TransactionManager.class).to(NoOpTransactionManager.class).in(Scopes.SINGLETON); binder.bind(HandleResolver.class).in(Scopes.SINGLETON); install(new InternalCommunicationModule()); diff --git a/presto-function-server/src/test/java/com/facebook/presto/tests/TestRestRemoteFunctions.java b/presto-function-server/src/test/java/com/facebook/presto/tests/TestRestRemoteFunctions.java index a538d59a1c099..ad0165da96778 100644 --- a/presto-function-server/src/test/java/com/facebook/presto/tests/TestRestRemoteFunctions.java +++ b/presto-function-server/src/test/java/com/facebook/presto/tests/TestRestRemoteFunctions.java @@ -41,13 +41,22 @@ public class TestRestRemoteFunctions extends AbstractTestQueryFramework { - private TestingFunctionServer functionServer; private static final Session session = testSessionBuilder() .setSource("test") .setCatalog("tpch") .setSchema("tiny") .setSystemProperty("remote_functions_enabled", "true") .build(); + private TestingFunctionServer functionServer; + + private static int findRandomPort() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } + @Override protected QueryRunner createQueryRunner() throws Exception @@ -59,6 +68,12 @@ protected QueryRunner createQueryRunner() ImmutableMap.of("list-built-in-functions-only", "false")); } + @Override + protected Session getSession() + { + return session; + } + @Test public void testShowFunction() { @@ -106,6 +121,19 @@ public void testFunctionPlugins() "false"); } + @Test + public void testRemoteFunctionAppliedToColumn() + { + assertQueryWithSameQueryRunner( + "SELECT rest.default.floor(totalprice) FROM orders", + "SELECT floor(totalprice) FROM orders"); + assertQueryWithSameQueryRunner( + "SELECT rest.default.abs(discount) FROM tpch.sf1.lineitem", + "SELECT abs(discount) FROM tpch.sf1.lineitem"); + assertEquals(computeActual("SELECT rest.default.length(CAST(comment AS VARBINARY)) FROM tpch.sf1.orders") + .getMaterializedRows().size(), 1500000); + } + private static final class DummyPlugin implements Plugin { @@ -126,12 +154,4 @@ public static boolean isPositive(@SqlType(BIGINT) long input) return input > 0; } } - - private static int findRandomPort() - throws IOException - { - try (ServerSocket socket = new ServerSocket(0)) { - return socket.getLocalPort(); - } - } } diff --git a/presto-geospatial-toolkit/pom.xml b/presto-geospatial-toolkit/pom.xml index 791da798a1510..083b5a00a16ae 100644 --- a/presto-geospatial-toolkit/pom.xml +++ b/presto-geospatial-toolkit/pom.xml @@ -4,14 +4,16 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-geospatial-toolkit + presto-geospatial-toolkit Presto - Geospatial utilities ${project.parent.basedir} + true @@ -56,8 +58,8 @@ - com.google.code.findbugs - jsr305 + jakarta.annotation + jakarta.annotation-api true diff --git a/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/EsriGeometrySerde.java b/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/EsriGeometrySerde.java index 8df5b565790ce..e591fc5fd18a0 100644 --- a/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/EsriGeometrySerde.java +++ b/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/EsriGeometrySerde.java @@ -38,8 +38,7 @@ import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceInput; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.nio.ByteBuffer; import java.util.ArrayList; diff --git a/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/JtsGeometrySerde.java b/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/JtsGeometrySerde.java index d8d7c43366d60..8f052128ec972 100644 --- a/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/JtsGeometrySerde.java +++ b/presto-geospatial-toolkit/src/main/java/com/facebook/presto/geospatial/serde/JtsGeometrySerde.java @@ -38,7 +38,7 @@ import static com.facebook.presto.geospatial.GeometryUtils.isEsriNaN; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.collect.MoreCollectors.onlyElement; import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static java.lang.Double.NaN; import static java.lang.Double.isNaN; @@ -200,7 +200,7 @@ private static Geometry readPolygon(SliceInput input, boolean multitype) if (multitype) { return GEOMETRY_FACTORY.createMultiPolygon(polygons.toArray(new Polygon[0])); } - return getOnlyElement(polygons); + return polygons.stream().collect(onlyElement()); } private static Geometry readGeometryCollection(BasicSliceInput input) diff --git a/presto-geospatial-toolkit/src/test/java/com/facebook/presto/geospatial/TestKdbTree.java b/presto-geospatial-toolkit/src/test/java/com/facebook/presto/geospatial/TestKdbTree.java index 75b5c77724cbb..491232d2e527b 100644 --- a/presto-geospatial-toolkit/src/test/java/com/facebook/presto/geospatial/TestKdbTree.java +++ b/presto-geospatial-toolkit/src/test/java/com/facebook/presto/geospatial/TestKdbTree.java @@ -15,13 +15,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import org.testng.annotations.Test; import java.util.Map; import java.util.Set; import static com.facebook.presto.geospatial.KdbTree.buildKdbTree; +import static com.google.common.collect.MoreCollectors.onlyElement; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.POSITIVE_INFINITY; import static org.testng.Assert.assertEquals; @@ -69,7 +69,7 @@ private void testSinglePartition(double width, double height) assertEquals(tree.getLeaves().size(), 1); - Map.Entry entry = Iterables.getOnlyElement(tree.getLeaves().entrySet()); + Map.Entry entry = tree.getLeaves().entrySet().stream().collect(onlyElement()); assertEquals(entry.getKey().intValue(), 0); assertEquals(entry.getValue(), Rectangle.getUniverseRectangle()); } diff --git a/presto-google-sheets/pom.xml b/presto-google-sheets/pom.xml index d9c489c5220bb..9479d80ddf25b 100644 --- a/presto-google-sheets/pom.xml +++ b/presto-google-sheets/pom.xml @@ -5,15 +5,18 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-google-sheets + presto-google-sheets Presto - Google Sheets Connector presto-plugin ${project.parent.basedir} + true + 2.0.0 @@ -32,7 +35,7 @@ com.google.apis google-api-services-sheets - v4-rev516-1.23.0 + v4-rev20250616-2.0.0 com.google.guava @@ -79,7 +82,7 @@ - io.airlift + com.facebook.airlift units provided @@ -87,7 +90,7 @@ com.google.oauth-client google-oauth-client - 1.33.3 + 1.39.0 com.google.http-client @@ -102,19 +105,19 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api com.google.http-client google-http-client - 1.27.0 + ${dep.http-client.version} commons-logging @@ -126,7 +129,7 @@ com.google.http-client google-http-client-jackson2 - 1.27.0 + ${dep.http-client.version} @@ -143,7 +146,7 @@ com.google.api-client google-api-client - 1.27.0 + 2.8.0 com.google.guava diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsClient.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsClient.java index 441237054faa0..be31ca870882a 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsClient.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsClient.java @@ -30,8 +30,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.UncheckedExecutionException; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.FileInputStream; import java.io.IOException; diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConfig.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConfig.java index 3377d0a4e24ac..428af875fa362 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConfig.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConfig.java @@ -15,11 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConnector.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConnector.java index f5855fc43d915..358faefcc66dc 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConnector.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsConnector.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsMetadata.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsMetadata.java index 69b73d4c6915e..c077c77c0e323 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsMetadata.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsMetadata.java @@ -29,8 +29,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -82,13 +81,15 @@ public SheetsTableHandle getTableHandle(ConnectorSession session, SchemaTableNam } @Override - public List getTableLayouts( - ConnectorSession session, ConnectorTableHandle table, - Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { SheetsTableHandle tableHandle = (SheetsTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new SheetsTableLayoutHandle(tableHandle)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -100,21 +101,27 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) { - Optional connectorTableMetadata = getTableMetadata(((SheetsTableHandle) table).toSchemaTableName()); + Optional connectorTableMetadata = getTableMetadata(session, ((SheetsTableHandle) table).toSchemaTableName()); if (!connectorTableMetadata.isPresent()) { throw new PrestoException(SHEETS_UNKNOWN_TABLE_ERROR, "Metadata not found for table " + ((SheetsTableHandle) table).getTableName()); } return connectorTableMetadata.get(); } - private Optional getTableMetadata(SchemaTableName tableName) + private Optional getTableMetadata(ConnectorSession session, SchemaTableName tableName) { if (!listSchemaNames().contains(tableName.getSchemaName())) { return Optional.empty(); } Optional table = sheetsClient.getTable(tableName.getTableName()); if (table.isPresent()) { - return Optional.of(new ConnectorTableMetadata(tableName, table.get().getColumnsMetadata())); + List columns = table.get().getColumnsMetadata().stream() + .map(column -> column.toBuilder() + .setName(normalizeIdentifier(session, column.getName())) + .build()) + .collect(toImmutableList()); + + return Optional.of(new ConnectorTableMetadata(tableName, columns)); } return Optional.empty(); } @@ -130,7 +137,13 @@ public Map getColumnHandles(ConnectorSession session, Conn ImmutableMap.Builder columnHandles = ImmutableMap.builder(); int index = 0; - for (ColumnMetadata column : table.get().getColumnsMetadata()) { + List columns = table.get().getColumnsMetadata().stream() + .map(column -> column.toBuilder() + .setName(normalizeIdentifier(session, column.getName())) + .build()) + .collect(toImmutableList()); + + for (ColumnMetadata column : columns) { columnHandles.put(column.getName(), new SheetsColumnHandle(column.getName(), column.getType(), index)); index++; } @@ -148,8 +161,8 @@ public Map> listTableColumns(ConnectorSess { requireNonNull(prefix, "prefix is null"); ImmutableMap.Builder> columns = ImmutableMap.builder(); - for (SchemaTableName tableName : listTables(session, Optional.of(prefix.getSchemaName()))) { - Optional tableMetadata = getTableMetadata(tableName); + for (SchemaTableName tableName : listTables(session, Optional.ofNullable(prefix.getSchemaName()))) { + Optional tableMetadata = getTableMetadata(session, tableName); // table can disappear during listing operation if (tableMetadata.isPresent()) { columns.put(tableName, tableMetadata.get().getColumns()); diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsModule.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsModule.java index 9ddfb5036963c..f59e3c52c4d6d 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsModule.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsModule.java @@ -20,8 +20,7 @@ import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Scopes; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.airlift.json.JsonBinder.jsonBinder; diff --git a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsSplitManager.java b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsSplitManager.java index e458a2c60bfef..58da2d9249434 100644 --- a/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsSplitManager.java +++ b/presto-google-sheets/src/main/java/com/facebook/presto/google/sheets/SheetsSplitManager.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collections; diff --git a/presto-google-sheets/src/test/java/com/facebook/presto/google/sheets/TestSheetsConfig.java b/presto-google-sheets/src/test/java/com/facebook/presto/google/sheets/TestSheetsConfig.java index 64cb0233d22be..22fce99ee0c09 100644 --- a/presto-google-sheets/src/test/java/com/facebook/presto/google/sheets/TestSheetsConfig.java +++ b/presto-google-sheets/src/test/java/com/facebook/presto/google/sheets/TestSheetsConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.google.sheets; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.IOException; diff --git a/presto-grpc-api/pom.xml b/presto-grpc-api/pom.xml index 63a50b6b14546..986ea49ef47d5 100644 --- a/presto-grpc-api/pom.xml +++ b/presto-grpc-api/pom.xml @@ -3,7 +3,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT 4.0.0 @@ -14,6 +14,7 @@ ${project.parent.basedir} + true @@ -79,8 +80,8 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api @@ -106,6 +107,9 @@ com.google.protobuf:protoc:3.25.1:exe:${os.detected.classifier} grpc-java io.grpc:protoc-gen-grpc-java:1.64.0:exe:${os.detected.classifier} + + @generated=omit + @@ -122,7 +126,7 @@ - javax.annotation:javax.annotation-api + jakarta.annotation:jakarta.annotation-api diff --git a/presto-grpc-testing-udf-server/pom.xml b/presto-grpc-testing-udf-server/pom.xml index 88f00f0faf648..5c235bc2a3de1 100644 --- a/presto-grpc-testing-udf-server/pom.xml +++ b/presto-grpc-testing-udf-server/pom.xml @@ -3,7 +3,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT 4.0.0 @@ -14,6 +14,7 @@ ${project.parent.basedir} com.facebook.presto.udf.thrift.TestingThriftUdfServer + true diff --git a/presto-hana/pom.xml b/presto-hana/pom.xml index c9f2e718886ff..31b889efe0165 100644 --- a/presto-hana/pom.xml +++ b/presto-hana/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hana + presto-hana Presto - HANA Connector presto-plugin ${project.parent.basedir} + true @@ -48,8 +50,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -66,7 +68,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -78,7 +80,7 @@ - io.airlift + com.facebook.airlift units provided @@ -159,7 +161,6 @@ org.jetbrains annotations - 19.0.0 test diff --git a/presto-hana/src/main/java/com/facebook/presto/plugin/hana/HanaClient.java b/presto-hana/src/main/java/com/facebook/presto/plugin/hana/HanaClient.java index 3873eb729f897..16ca9804403e3 100644 --- a/presto-hana/src/main/java/com/facebook/presto/plugin/hana/HanaClient.java +++ b/presto-hana/src/main/java/com/facebook/presto/plugin/hana/HanaClient.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.google.common.base.Joiner; import com.sap.db.jdbc.Driver; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -145,4 +144,10 @@ private static String singleQuote(String literal) // HANA only accepts upper case return "\"" + literal.toUpperCase() + "\""; } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ENGLISH); + } } diff --git a/presto-hdfs-core/pom.xml b/presto-hdfs-core/pom.xml index 248e48d42aa83..b2038bc71d1af 100644 --- a/presto-hdfs-core/pom.xml +++ b/presto-hdfs-core/pom.xml @@ -5,14 +5,16 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT ${project.parent.basedir} + true presto-hdfs-core + presto-hdfs-core com.facebook.presto @@ -20,9 +22,8 @@ - com.google.code.findbugs - jsr305 - true + jakarta.annotation + jakarta.annotation-api @@ -37,7 +38,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -46,7 +47,7 @@ - io.airlift + com.facebook.airlift units @@ -61,7 +62,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api @@ -79,13 +80,13 @@ - com.facebook.drift + com.facebook.airlift.drift drift-codec test - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty test @@ -96,4 +97,18 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + io.airlift:slice + + + + + diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/BlockLocation.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/BlockLocation.java index c3089d8b1fae9..7f96bd00a6fb8 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/BlockLocation.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/BlockLocation.java @@ -17,10 +17,9 @@ import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.google.common.collect.ImmutableList; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.io.IOException; import java.util.List; import java.util.Objects; diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuota.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuota.java index f7d26cdba1823..57ea433f11192 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuota.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuota.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; import java.util.Objects; import java.util.Optional; diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java index fcc2404926d91..ea73cc01e9e38 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java @@ -13,9 +13,12 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.DataSize; import java.util.Objects; import java.util.Optional; @@ -24,6 +27,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; +@ThriftStruct public class CacheQuotaRequirement { public static final CacheQuotaRequirement NO_CACHE_REQUIREMENT = new CacheQuotaRequirement(GLOBAL, Optional.empty()); @@ -32,6 +36,7 @@ public class CacheQuotaRequirement private final Optional quota; @JsonCreator + @ThriftConstructor public CacheQuotaRequirement( @JsonProperty("cacheQuotaScope") CacheQuotaScope cacheQuotaScope, @JsonProperty("quota") Optional quota) @@ -41,12 +46,14 @@ public CacheQuotaRequirement( } @JsonProperty + @ThriftField(1) public CacheQuotaScope getCacheQuotaScope() { return cacheQuotaScope; } @JsonProperty + @ThriftField(2) public Optional getQuota() { return quota; diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java index 2c8dddd3b6ba7..c10f05acebd4b 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java @@ -13,7 +13,28 @@ */ package com.facebook.presto.hive; +import com.facebook.drift.annotations.ThriftEnum; +import com.facebook.drift.annotations.ThriftEnumValue; + +@ThriftEnum public enum CacheQuotaScope { - GLOBAL, SCHEMA, TABLE, PARTITION + GLOBAL(0), + SCHEMA(1), + TABLE(2), + PARTITION(3), + /**/; + + private final int value; + + CacheQuotaScope(int value) + { + this.value = value; + } + + @ThriftEnumValue + public int getValue() + { + return value; + } } diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HdfsContext.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HdfsContext.java index fe30292219ea3..0ba86159e0047 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HdfsContext.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HdfsContext.java @@ -38,8 +38,8 @@ public class HdfsContext private final Optional session; /** - * Table information is expected to be provided when accessing a storage. - * Do not use this constructor. + * Table information is expected to be provided when accessing a storage. + * Do not use this constructor. */ @Deprecated public HdfsContext(ConnectorIdentity identity) @@ -58,8 +58,8 @@ public HdfsContext(ConnectorIdentity identity) } /** - * Table information is expected to be provided when accessing a storage. - * Do not use this constructor. + * Table information is expected to be provided when accessing a storage. + * Do not use this constructor. */ @Deprecated public HdfsContext(ConnectorSession session) @@ -68,10 +68,10 @@ public HdfsContext(ConnectorSession session) } /** - * Table information is expected to be provided when accessing a storage. - * Currently the only legit use case for this constructor is the schema - * level operations (e.g.: create/drop schema) or drop a view - * Do not use this constructor for any other use cases. + * Table information is expected to be provided when accessing a storage. + * Currently the only legit use case for this constructor is the schema + * level operations (e.g.: create/drop schema) or drop a view + * Do not use this constructor for any other use cases. */ @Deprecated public HdfsContext(ConnectorSession session, String schemaName) @@ -80,8 +80,8 @@ public HdfsContext(ConnectorSession session, String schemaName) } /** - * Table information is expected to be provided when accessing a storage. - * Do not use this constructor. + * Table information is expected to be provided when accessing a storage. + * Do not use this constructor. */ @Deprecated public HdfsContext(ConnectorSession session, String schemaName, String tableName) @@ -110,6 +110,7 @@ public HdfsContext( Optional.of(isNewTable), Optional.of(isPathValidationNeeded)); } + public HdfsContext( ConnectorSession session, String schemaName, @@ -141,6 +142,7 @@ private HdfsContext( isNewTable, Optional.empty()); } + private HdfsContext( ConnectorSession session, Optional schemaName, @@ -149,14 +151,41 @@ private HdfsContext( Optional isNewTable, Optional isPathValidationNeeded) { - this.session = Optional.of(requireNonNull(session, "session is null")); - this.identity = requireNonNull(session.getIdentity(), "session.getIdentity() is null"); - this.source = requireNonNull(session.getSource(), "session.getSource() is null"); - this.queryId = Optional.of(session.getQueryId()); + this( + Optional.of(requireNonNull(session, "session is null")), + session.getIdentity(), + session.getSource(), + Optional.of(session.getQueryId()), + schemaName, + tableName, + session.getClientInfo(), + Optional.of(session.getClientTags()), + tablePath, + isNewTable, + isPathValidationNeeded); + } + + public HdfsContext( + Optional session, + ConnectorIdentity identity, + Optional source, + Optional queryId, + Optional schemaName, + Optional tableName, + Optional clientInfo, + Optional> clientTags, + Optional tablePath, + Optional isNewTable, + Optional isPathValidationNeeded) + { + this.session = session; + this.identity = requireNonNull(identity, "identity is null"); + this.source = requireNonNull(source, "source is null"); + this.queryId = requireNonNull(queryId, "queryId is null"); this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.tableName = requireNonNull(tableName, "tableName is null"); - this.clientInfo = session.getClientInfo(); - this.clientTags = Optional.of(session.getClientTags()); + this.clientInfo = requireNonNull(clientInfo, "clientInfo is null"); + this.clientTags = requireNonNull(clientTags, "clientTags is null"); this.tablePath = requireNonNull(tablePath, "tablePath is null"); this.isNewTable = requireNonNull(isNewTable, "isNewTable is null"); this.isPathValidationNeeded = requireNonNull(isPathValidationNeeded, "isPathValidationNeeded is null"); diff --git a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HiveFileContext.java b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HiveFileContext.java index e41c91690030f..23565da4e24df 100644 --- a/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HiveFileContext.java +++ b/presto-hdfs-core/src/main/java/com/facebook/presto/hive/HiveFileContext.java @@ -15,9 +15,12 @@ import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.RuntimeUnit; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.security.ConnectorIdentity; import java.util.Optional; import java.util.OptionalLong; +import java.util.Set; import static com.facebook.presto.hive.CacheQuota.NO_CACHE_CONSTRAINTS; import static java.util.Objects.requireNonNull; @@ -47,6 +50,12 @@ public class HiveFileContext private final OptionalLong length; private final long modificationTime; private final boolean verboseRuntimeStatsEnabled; + private final Optional source; + private final Optional identity; + private final Optional queryId; + private final Optional schema; + private final Optional clientInfo; + private final Optional> clientTags; private final RuntimeStats stats; @@ -73,6 +82,71 @@ public HiveFileContext( long modificationTime, boolean verboseRuntimeStatsEnabled, RuntimeStats runtimeStats) + { + this( + cacheable, + cacheQuota, + extraFileInfo, + fileSize, + startOffset, + length, + modificationTime, + verboseRuntimeStatsEnabled, + runtimeStats, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + } + + public HiveFileContext( + boolean cacheable, + CacheQuota cacheQuota, + Optional> extraFileInfo, + OptionalLong fileSize, + OptionalLong startOffset, + OptionalLong length, + long modificationTime, + boolean verboseRuntimeStatsEnabled, + RuntimeStats runtimeStats, + ConnectorSession session) + { + this( + cacheable, + cacheQuota, + extraFileInfo, + fileSize, + startOffset, + length, + modificationTime, + verboseRuntimeStatsEnabled, + runtimeStats, + session.getSource(), + Optional.of(session.getIdentity()), + Optional.of(session.getQueryId()), + session.getSchema(), + session.getClientInfo(), + Optional.of(session.getClientTags())); + } + + public HiveFileContext( + boolean cacheable, + CacheQuota cacheQuota, + Optional> extraFileInfo, + OptionalLong fileSize, + OptionalLong startOffset, + OptionalLong length, + long modificationTime, + boolean verboseRuntimeStatsEnabled, + RuntimeStats runtimeStats, + Optional source, + Optional identity, + Optional queryId, + Optional schema, + Optional clientInfo, + Optional> clientTags) { this.cacheable = cacheable; this.cacheQuota = requireNonNull(cacheQuota, "cacheQuota is null"); @@ -83,6 +157,12 @@ public HiveFileContext( this.modificationTime = modificationTime; this.verboseRuntimeStatsEnabled = verboseRuntimeStatsEnabled; this.stats = requireNonNull(runtimeStats, "runtimeStats is null"); + this.source = requireNonNull(source, "source is null"); + this.identity = requireNonNull(identity, "identity is null"); + this.queryId = requireNonNull(queryId, "queryId is null"); + this.schema = requireNonNull(schema, "schema is null"); + this.clientInfo = requireNonNull(clientInfo, "clientInfo is null"); + this.clientTags = requireNonNull(clientTags, "clientTags is null"); } /** @@ -142,4 +222,34 @@ public RuntimeStats getStats() { return stats; } + + public Optional getSource() + { + return source; + } + + public Optional getIdentity() + { + return identity; + } + + public Optional getQueryId() + { + return queryId; + } + + public Optional getSchema() + { + return schema; + } + + public Optional getClientInfo() + { + return clientInfo; + } + + public Optional> getClientTags() + { + return clientTags; + } } diff --git a/presto-hdfs-core/src/main/java/org/apache/hadoop/fs/HadoopExtendedFileSystemCache.java b/presto-hdfs-core/src/main/java/org/apache/hadoop/fs/HadoopExtendedFileSystemCache.java index 09dd855aff7e2..2c768e41903ac 100644 --- a/presto-hdfs-core/src/main/java/org/apache/hadoop/fs/HadoopExtendedFileSystemCache.java +++ b/presto-hdfs-core/src/main/java/org/apache/hadoop/fs/HadoopExtendedFileSystemCache.java @@ -13,10 +13,6 @@ */ package org.apache.hadoop.fs; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Modifier; - public class HadoopExtendedFileSystemCache { private static PrestoExtendedFileSystemCache cache; @@ -26,49 +22,9 @@ private HadoopExtendedFileSystemCache() {} public static synchronized void initialize() { if (cache == null) { - cache = setFinalStatic(FileSystem.class, "CACHE", new PrestoExtendedFileSystemCache()); - } - } - - private static T setFinalStatic(Class clazz, String name, T value) - { - try { - Field field = clazz.getDeclaredField(name); - field.setAccessible(true); - - Field modifiersField = getModifiersField(); - modifiersField.setAccessible(true); - modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL); - - field.set(null, value); - - return value; - } - catch (ReflectiveOperationException e) { - throw new AssertionError(e); - } - } - - private static Field getModifiersField() throws NoSuchFieldException - { - try { - return Field.class.getDeclaredField("modifiers"); - } - catch (NoSuchFieldException e) { - try { - Method getDeclaredFields0 = Class.class.getDeclaredMethod("getDeclaredFields0", boolean.class); - getDeclaredFields0.setAccessible(true); - Field[] fields = (Field[]) getDeclaredFields0.invoke(Field.class, false); - for (Field field : fields) { - if ("modifiers".equals(field.getName())) { - return field; - } - } - } - catch (ReflectiveOperationException ex) { - e.addSuppressed(ex); - } - throw e; + PrestoExtendedFileSystemCache newCache = new PrestoExtendedFileSystemCache(); + FileSystem.setCache(newCache); + cache = newCache; } } } diff --git a/presto-hdfs-core/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java b/presto-hdfs-core/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java index 751c622636d87..0cb4643b16927 100644 --- a/presto-hdfs-core/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java +++ b/presto-hdfs-core/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java @@ -14,12 +14,12 @@ package com.facebook.presto.hive; import com.facebook.airlift.http.client.thrift.ThriftProtocolUtils; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.codec.ThriftCodecManager; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.BlockLocation; import org.apache.hadoop.fs.LocatedFileStatus; import org.apache.hadoop.fs.Path; @@ -28,9 +28,9 @@ import java.io.IOException; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.drift.transport.netty.codec.Protocol.FB_COMPACT; import static com.facebook.presto.hive.HiveFileInfo.createHiveFileInfo; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertEquals; diff --git a/presto-hive-common/pom.xml b/presto-hive-common/pom.xml index e0c4aa6111838..ce37e2f1f1e11 100644 --- a/presto-hive-common/pom.xml +++ b/presto-hive-common/pom.xml @@ -5,14 +5,16 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT ${project.parent.basedir} + true presto-hive-common + presto-hive-common com.facebook.presto @@ -35,11 +37,21 @@ - com.google.code.findbugs - jsr305 + com.facebook.airlift + stats + + + + com.google.errorprone + error_prone_annotations true + + jakarta.annotation + jakarta.annotation-api + + io.airlift slice @@ -61,13 +73,8 @@ - com.facebook.presto.hadoop - hadoop-apache2 - - - - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -81,7 +88,7 @@ - io.airlift + com.facebook.airlift units @@ -91,13 +98,29 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api + + + + software.amazon.awssdk + metrics-spi + + + + software.amazon.awssdk + sdk-core + + + + org.weakref + jmxutils com.facebook.presto presto-hdfs-core + test @@ -108,13 +131,13 @@ - com.facebook.drift + com.facebook.airlift.drift drift-codec test - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty test @@ -124,5 +147,11 @@ http-client test + + + com.facebook.presto.hadoop + hadoop-apache + test + diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/BaseHiveTableLayoutHandle.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/BaseHiveTableLayoutHandle.java index 264df0281689a..f855cf408e0b6 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/BaseHiveTableLayoutHandle.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/BaseHiveTableLayoutHandle.java @@ -37,7 +37,7 @@ public class BaseHiveTableLayoutHandle private final TupleDomain partitionColumnPredicate; // coordinator-only properties - private final Optional> partitions; + private final Optional partitions; public BaseHiveTableLayoutHandle( List partitionColumns, @@ -45,7 +45,7 @@ public BaseHiveTableLayoutHandle( RowExpression remainingPredicate, boolean pushdownFilterEnabled, TupleDomain partitionColumnPredicate, - Optional> partitions) + Optional partitions) { this.partitionColumns = ImmutableList.copyOf(requireNonNull(partitionColumns, "partitionColumns is null")); this.domainPredicate = requireNonNull(domainPredicate, "domainPredicate is null"); @@ -91,7 +91,7 @@ public TupleDomain getPartitionColumnPredicate() * @return list of partitions if available, {@code Optional.empty()} if dropped */ @JsonIgnore - public Optional> getPartitions() + public Optional getPartitions() { return partitions; } diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonClientConfig.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonClientConfig.java index 9a2ec46dd00f9..1e54401247e2f 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonClientConfig.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonClientConfig.java @@ -15,16 +15,15 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; -import io.airlift.units.DataSize; - -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.NotNull; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class HiveCommonClientConfig { @@ -38,6 +37,7 @@ public class HiveCommonClientConfig private DataSize orcStreamBufferSize = new DataSize(8, MEGABYTE); private OrcWriteValidationMode orcWriterValidationMode = OrcWriteValidationMode.BOTH; private double orcWriterValidationPercentage; + private boolean useOrcColumnNames; private DataSize orcTinyStripeThreshold = new DataSize(8, MEGABYTE); private boolean parquetBatchReadOptimizationEnabled; private boolean parquetEnableBatchReaderVerification; @@ -184,6 +184,19 @@ public HiveCommonClientConfig setOrcWriterValidationPercentage(double orcWriterV return this; } + public boolean isUseOrcColumnNames() + { + return useOrcColumnNames; + } + + @Config("hive.orc.use-column-names") + @ConfigDescription("Access ORC columns using names from the file first, and fallback to Hive schema column names if not found to ensure backward compatibility with old data") + public HiveCommonClientConfig setUseOrcColumnNames(boolean useOrcColumnNames) + { + this.useOrcColumnNames = useOrcColumnNames; + return this; + } + @NotNull public DataSize getOrcTinyStripeThreshold() { diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonSessionProperties.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonSessionProperties.java index b67955d5b7024..b505cd37d8188 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonSessionProperties.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveCommonSessionProperties.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.orc.OrcWriteValidation.OrcWriteValidationMode; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; @@ -20,9 +21,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.ThreadLocalRandom; @@ -43,6 +42,8 @@ public class HiveCommonSessionProperties public static final String RANGE_FILTERS_ON_SUBSCRIPTS_ENABLED = "range_filters_on_subscripts_enabled"; @VisibleForTesting public static final String PARQUET_BATCH_READ_OPTIMIZATION_ENABLED = "parquet_batch_read_optimization_enabled"; + @VisibleForTesting + public static final String ORC_USE_COLUMN_NAMES = "orc_use_column_names"; public static final String NODE_SELECTION_STRATEGY = "node_selection_strategy"; private static final String ORC_BLOOM_FILTERS_ENABLED = "orc_bloom_filters_enabled"; @@ -154,6 +155,11 @@ public HiveCommonSessionProperties(HiveCommonClientConfig hiveCommonClientConfig "use JNI based zstd decompression for reading ORC files", hiveCommonClientConfig.isZstdJniDecompressionEnabled(), true), + booleanProperty( + ORC_USE_COLUMN_NAMES, + "Access ORC columns using names from the file first, and fallback to Hive schema column names if not found to ensure backward compatibility with old data", + hiveCommonClientConfig.isUseOrcColumnNames(), + false), booleanProperty( PARQUET_BATCH_READ_OPTIMIZATION_ENABLED, "Is Parquet batch read optimization enabled", @@ -263,6 +269,11 @@ public static boolean isOrcZstdJniDecompressionEnabled(ConnectorSession session) return session.getProperty(ORC_ZSTD_JNI_DECOMPRESSION_ENABLED, Boolean.class); } + public static boolean isUseOrcColumnNames(ConnectorSession session) + { + return session.getProperty(ORC_USE_COLUMN_NAMES, Boolean.class); + } + public static boolean isParquetBatchReadsEnabled(ConnectorSession session) { return session.getProperty(PARQUET_BATCH_READ_OPTIMIZATION_ENABLED, Boolean.class); diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java index bee080b4e64ba..aa122e1c28f1f 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java @@ -76,6 +76,7 @@ public enum HiveErrorCode HIVE_RANGER_SERVER_ERROR(48, EXTERNAL), HIVE_FUNCTION_INITIALIZATION_ERROR(49, EXTERNAL), HIVE_METASTORE_INITIALIZE_SSL_ERROR(50, EXTERNAL), + UNKNOWN_TABLE_TYPE(51, EXTERNAL), /**/; private final ErrorCode errorCode; diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputInfo.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputInfo.java new file mode 100644 index 0000000000000..d2d2085857574 --- /dev/null +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputInfo.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.hive; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class HiveOutputInfo +{ + private final List partitionNames; + private final String tableLocation; + + @JsonCreator + public HiveOutputInfo( + @JsonProperty("partitionNames") List partitionNames, + @JsonProperty("tableLocation") String tableLocation) + { + this.partitionNames = ImmutableList.copyOf(requireNonNull(partitionNames, "partitionNames is null")); + this.tableLocation = requireNonNull(tableLocation, "tableLocation is null"); + } + + @JsonProperty + public List getPartitionNames() + { + return partitionNames; + } + + @JsonProperty + public String getTableLocation() + { + return tableLocation; + } + + @Override + public boolean equals(Object o) + { + if (o == null || getClass() != o.getClass()) { + return false; + } + HiveOutputInfo that = (HiveOutputInfo) o; + return Objects.equals(partitionNames, that.partitionNames) && Objects.equals(tableLocation, that.tableLocation); + } + + @Override + public int hashCode() + { + return Objects.hash(partitionNames, tableLocation); + } + + @Override + public String toString() + { + return "HiveOutputInfo{" + + "partitionNames=" + partitionNames + + ", tableLocation='" + tableLocation + '\'' + + '}'; + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWrittenPartitions.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputMetadata.java similarity index 68% rename from presto-hive/src/main/java/com/facebook/presto/hive/HiveWrittenPartitions.java rename to presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputMetadata.java index a01676c204066..f78f4f174f125 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWrittenPartitions.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveOutputMetadata.java @@ -16,26 +16,24 @@ import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; - -import java.util.List; import static java.util.Objects.requireNonNull; -public class HiveWrittenPartitions +public class HiveOutputMetadata implements ConnectorOutputMetadata { - private final List partitionNames; + private final HiveOutputInfo hiveOutputInfo; @JsonCreator - public HiveWrittenPartitions(@JsonProperty("partitionNames") List partitionNames) + public HiveOutputMetadata(@JsonProperty("hiveOutputInfo") HiveOutputInfo hiveOutputInfo) { - this.partitionNames = ImmutableList.copyOf(requireNonNull(partitionNames, "partitionNames is null")); + this.hiveOutputInfo = requireNonNull(hiveOutputInfo, "hiveOutputInfo is null"); } @JsonProperty - public List getInfo() + @Override + public HiveOutputInfo getInfo() { - return partitionNames; + return hiveOutputInfo; } } diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/MetadataUtils.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/MetadataUtils.java index 67e4c4201968c..6da270538c586 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/MetadataUtils.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/MetadataUtils.java @@ -29,8 +29,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.HashSet; import java.util.List; @@ -55,16 +54,24 @@ public final class MetadataUtils private static final String CATALOG_DB_THRIFT_NAME_MARKER = "@"; private static final String DB_EMPTY_MARKER = "!"; private static final String DEFAULT_DATABASE = "default"; + private MetadataUtils() {} - public static Optional getDiscretePredicates(List partitionColumns, List partitions) + public static Optional getDiscretePredicates(List partitionColumns, Iterable partitions) { Optional discretePredicates = Optional.empty(); - if (!partitionColumns.isEmpty() && !(partitions.size() == 1 && partitions.get(0).getPartitionId().equals(UNPARTITIONED_ID))) { + if (!partitionColumns.isEmpty()) { // Do not create tuple domains for every partition at the same time! // There can be a huge number of partitions so use an iterable so // all domains do not need to be in memory at the same time. - Iterable> partitionDomains = Iterables.transform(partitions, (hivePartition) -> TupleDomain.fromFixedValues(hivePartition.getKeys())); + Iterable> partitionDomains = Iterables.transform(partitions, (hivePartition) -> { + if (hivePartition.getPartitionId().equals(UNPARTITIONED_ID)) { + return TupleDomain.all(); + } + else { + return TupleDomain.fromFixedValues(hivePartition.getKeys()); + } + }); discretePredicates = Optional.of(new DiscretePredicates(partitionColumns, partitionDomains)); } return discretePredicates; diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionNameWithVersion.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionNameWithVersion.java index e86afe9064adc..6597d5f29e60a 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionNameWithVersion.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionNameWithVersion.java @@ -16,8 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ComparisonChain; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.Optional; diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionSet.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionSet.java new file mode 100644 index 0000000000000..ae6148c612202 --- /dev/null +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/PartitionSet.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.ThreadSafe; + +import java.util.Iterator; +import java.util.List; + +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class PartitionSet + implements Iterable +{ + private volatile boolean fullyLoaded; + private PartitionLoader partitionLoader; + private List partitions; + + public PartitionSet(List partitions) + { + this.partitions = requireNonNull(partitions, "partitions is null"); + this.fullyLoaded = true; + } + + public PartitionSet(PartitionLoader partitionLoader) + { + this.partitionLoader = requireNonNull(partitionLoader, "partitionLoader is null"); + } + + @Override + public Iterator iterator() + { + return new LazyIterator(this); + } + + public List getFullyLoadedPartitions() + { + tryFullyLoad(); + return partitions; + } + + public boolean isEmpty() + { + if (fullyLoaded) { + return partitions.isEmpty(); + } + else { + synchronized (this) { + if (fullyLoaded) { + return partitions.isEmpty(); + } + return partitionLoader.isEmpty(); + } + } + } + + private void tryFullyLoad() + { + if (!fullyLoaded) { + synchronized (this) { + if (!fullyLoaded) { + partitions = ImmutableList.copyOf(partitionLoader.loadPartitions()); + fullyLoaded = true; + partitionLoader = null; + } + } + } + } + + public interface PartitionLoader + { + List loadPartitions(); + + boolean isEmpty(); + } + + private static class LazyIterator + extends AbstractIterator + { + private final PartitionSet lazyPartitions; + private List partitions; + private int position = -1; + + private LazyIterator(PartitionSet lazyPartitions) + { + this.lazyPartitions = lazyPartitions; + } + + @Override + protected HivePartition computeNext() + { + if (partitions == null) { + partitions = lazyPartitions.getFullyLoadedPartitions(); + } + + position++; + if (position >= partitions.size()) { + return endOfData(); + } + return partitions.get(position); + } + } +} diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/SortingFileWriterConfig.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/SortingFileWriterConfig.java index 29e835c7fd777..cd327cf022afb 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/SortingFileWriterConfig.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/SortingFileWriterConfig.java @@ -16,14 +16,13 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; -import io.airlift.units.MaxDataSize; -import io.airlift.units.MinDataSize; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.MaxDataSize; +import com.facebook.airlift.units.MinDataSize; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; -import javax.validation.constraints.Max; -import javax.validation.constraints.Min; - -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class SortingFileWriterConfig { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java similarity index 68% rename from presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java rename to presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java index b7d6323ec6ff2..cb740e975b535 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java @@ -11,18 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.iceberg; +package com.facebook.presto.hive; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.SchemaTableName; -import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_UNKNOWN_TABLE_TYPE; +import static com.facebook.presto.hive.HiveErrorCode.UNKNOWN_TABLE_TYPE; public class UnknownTableTypeException extends PrestoException { - public UnknownTableTypeException(SchemaTableName tableName) + public UnknownTableTypeException(String message) { - super(ICEBERG_UNKNOWN_TABLE_TYPE, "Not an Iceberg table: " + tableName); + super(UNKNOWN_TABLE_TYPE, message); } } diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/metrics/AwsSdkClientStats.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/metrics/AwsSdkClientStats.java new file mode 100644 index 0000000000000..5b7fd14ee9261 --- /dev/null +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/metrics/AwsSdkClientStats.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.aws.metrics; + +import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.stats.TimeStat; +import com.google.errorprone.annotations.ThreadSafe; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; +import software.amazon.awssdk.metrics.MetricCollection; +import software.amazon.awssdk.metrics.MetricPublisher; + +import java.time.Duration; + +import static java.time.Duration.ZERO; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static software.amazon.awssdk.core.internal.metrics.SdkErrorType.THROTTLING; +import static software.amazon.awssdk.core.metrics.CoreMetric.API_CALL_DURATION; +import static software.amazon.awssdk.core.metrics.CoreMetric.BACKOFF_DELAY_DURATION; +import static software.amazon.awssdk.core.metrics.CoreMetric.ERROR_TYPE; +import static software.amazon.awssdk.core.metrics.CoreMetric.RETRY_COUNT; +import static software.amazon.awssdk.core.metrics.CoreMetric.SERVICE_CALL_DURATION; + +/** + * For reference on AWS SDK v2 Metrics: https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/metrics-list.html + * Metrics Publisher: https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/metrics.html + */ +@ThreadSafe +public final class AwsSdkClientStats +{ + private final CounterStat awsRequestCount = new CounterStat(); + private final CounterStat awsRetryCount = new CounterStat(); + private final CounterStat awsThrottleExceptions = new CounterStat(); + private final TimeStat awsServiceCallDuration = new TimeStat(MILLISECONDS); + private final TimeStat awsApiCallDuration = new TimeStat(MILLISECONDS); + private final TimeStat awsBackoffDelayDuration = new TimeStat(MILLISECONDS); + + @Managed + @Nested + public CounterStat getAwsRequestCount() + { + return awsRequestCount; + } + + @Managed + @Nested + public CounterStat getAwsRetryCount() + { + return awsRetryCount; + } + + @Managed + @Nested + public CounterStat getAwsThrottleExceptions() + { + return awsThrottleExceptions; + } + + @Managed + @Nested + public TimeStat getAwsServiceCallDuration() + { + return awsServiceCallDuration; + } + + @Managed + @Nested + public TimeStat getAwsApiCallDuration() + { + return awsApiCallDuration; + } + + @Managed + @Nested + public TimeStat getAwsBackoffDelayDuration() + { + return awsBackoffDelayDuration; + } + + public AwsSdkClientRequestMetricsPublisher newRequestMetricsPublisher() + { + return new AwsSdkClientRequestMetricsPublisher(this); + } + + public static class AwsSdkClientRequestMetricsPublisher + implements MetricPublisher + { + private final AwsSdkClientStats stats; + + protected AwsSdkClientRequestMetricsPublisher(AwsSdkClientStats stats) + { + this.stats = requireNonNull(stats, "stats is null"); + } + + @Override + public void publish(MetricCollection metricCollection) + { + long requestCount = metricCollection.metricValues(RETRY_COUNT) + .stream() + .map(i -> i + 1) + .reduce(Integer::sum).orElse(0); + + stats.awsRequestCount.update(requestCount); + + long retryCount = metricCollection.metricValues(RETRY_COUNT) + .stream() + .reduce(Integer::sum).orElse(0); + + stats.awsRetryCount.update(retryCount); + + long throttleExceptions = metricCollection + .childrenWithName("ApiCallAttempt") + .flatMap(mc -> mc.metricValues(ERROR_TYPE).stream()) + .filter(s -> s.equals(THROTTLING.toString())) + .count(); + + stats.awsThrottleExceptions.update(throttleExceptions); + + Duration serviceCallDuration = metricCollection + .childrenWithName("ApiCallAttempt") + .flatMap(mc -> mc.metricValues(SERVICE_CALL_DURATION).stream()) + .reduce(Duration::plus).orElse(ZERO); + + stats.awsServiceCallDuration.add(serviceCallDuration.toMillis(), MILLISECONDS); + + Duration apiCallDuration = metricCollection + .metricValues(API_CALL_DURATION) + .stream().reduce(Duration::plus).orElse(ZERO); + + stats.awsApiCallDuration.add(apiCallDuration.toMillis(), MILLISECONDS); + + Duration backoffDelayDuration = metricCollection + .childrenWithName("ApiCallAttempt") + .flatMap(mc -> mc.metricValues(BACKOFF_DELAY_DURATION).stream()) + .reduce(Duration::plus).orElse(ZERO); + + stats.awsBackoffDelayDuration.add(backoffDelayDuration.toMillis(), MILLISECONDS); + } + + @Override + public void close() + { + } + } +} diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingConfig.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingConfig.java index e66bc59b757b2..799bae622b06f 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingConfig.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingConfig.java @@ -15,11 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.annotation.Nullable; -import javax.validation.constraints.AssertTrue; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.AssertTrue; import java.io.File; import java.util.Optional; diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingsSupplier.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingsSupplier.java index 73d6052d01586..cfecf3aeacd11 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingsSupplier.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/aws/security/AWSSecurityMappingsSupplier.java @@ -14,8 +14,8 @@ package com.facebook.presto.hive.aws.security; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.google.common.base.Suppliers; -import io.airlift.units.Duration; import java.io.File; import java.util.Optional; diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java index 1894a49afe965..657cb732737fa 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/rule/BaseSubfieldExtractionRewriter.java @@ -345,7 +345,8 @@ private Constraint extractDeterministicConjuncts( RowExpression deterministicPredicate = logicalRowExpressions.filterDeterministicConjuncts(decomposedFilter.getRemainingExpression()); if (!TRUE_CONSTANT.equals(deterministicPredicate)) { ConstraintEvaluator evaluator = new ConstraintEvaluator(rowExpressionService, session, columnHandles, deterministicPredicate); - constraint = new Constraint<>(entireColumnDomain, evaluator::isCandidate); + List predicateInputs = ImmutableList.builder().addAll(evaluator.getArguments()).build(); + constraint = new Constraint<>(entireColumnDomain, Optional.of(evaluator::isCandidate), Optional.of(predicateInputs)); } } return constraint; @@ -417,6 +418,11 @@ public ConstraintEvaluator( .collect(toImmutableSet()); } + public Set getArguments() + { + return arguments; + } + private boolean isCandidate(Map bindings) { if (intersection(bindings.keySet(), arguments).isEmpty()) { diff --git a/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveCommonClientConfig.java b/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveCommonClientConfig.java index b543c456cd308..42579c3048b92 100644 --- a/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveCommonClientConfig.java +++ b/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveCommonClientConfig.java @@ -14,16 +14,16 @@ package com.facebook.presto.hive; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.orc.OrcWriteValidation; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.HARD_AFFINITY; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class TestHiveCommonClientConfig { @@ -45,6 +45,7 @@ public void testDefaults() .setOrcOptimizedWriterEnabled(true) .setOrcWriterValidationPercentage(0.0) .setOrcWriterValidationMode(OrcWriteValidation.OrcWriteValidationMode.BOTH) + .setUseOrcColumnNames(false) .setZstdJniDecompressionEnabled(false) .setParquetBatchReaderVerificationEnabled(false) .setParquetBatchReadOptimizationEnabled(false) @@ -71,6 +72,7 @@ public void testExplicitPropertyMappings() .put("hive.orc.optimized-writer.enabled", "false") .put("hive.orc.writer.validation-percentage", "0.16") .put("hive.orc.writer.validation-mode", "DETAILED") + .put("hive.orc.use-column-names", "true") .put("hive.zstd-jni-decompression-enabled", "true") .put("hive.enable-parquet-batch-reader-verification", "true") .put("hive.parquet-batch-read-optimization-enabled", "true") @@ -94,6 +96,7 @@ public void testExplicitPropertyMappings() .setOrcOptimizedWriterEnabled(false) .setOrcWriterValidationPercentage(0.16) .setOrcWriterValidationMode(OrcWriteValidation.OrcWriteValidationMode.DETAILED) + .setUseOrcColumnNames(true) .setZstdJniDecompressionEnabled(true) .setParquetBatchReaderVerificationEnabled(true) .setParquetBatchReadOptimizationEnabled(true) diff --git a/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java b/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java index 751c622636d87..0cb4643b16927 100644 --- a/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java +++ b/presto-hive-common/src/test/java/com/facebook/presto/hive/TestHiveFileInfo.java @@ -14,12 +14,12 @@ package com.facebook.presto.hive; import com.facebook.airlift.http.client.thrift.ThriftProtocolUtils; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.codec.ThriftCodecManager; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.BlockLocation; import org.apache.hadoop.fs.LocatedFileStatus; import org.apache.hadoop.fs.Path; @@ -28,9 +28,9 @@ import java.io.IOException; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.drift.transport.netty.codec.Protocol.FB_COMPACT; import static com.facebook.presto.hive.HiveFileInfo.createHiveFileInfo; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertEquals; diff --git a/presto-hive-common/src/test/java/com/facebook/presto/hive/aws/security/TestAWSSecurityMappingConfig.java b/presto-hive-common/src/test/java/com/facebook/presto/hive/aws/security/TestAWSSecurityMappingConfig.java index c1d730d980ef4..90f289e689115 100644 --- a/presto-hive-common/src/test/java/com/facebook/presto/hive/aws/security/TestAWSSecurityMappingConfig.java +++ b/presto-hive-common/src/test/java/com/facebook/presto/hive/aws/security/TestAWSSecurityMappingConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.hive.aws.security; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.IOException; diff --git a/presto-hive-function-namespace/pom.xml b/presto-hive-function-namespace/pom.xml index 404048ee5a636..0ebef79cc7dc4 100644 --- a/presto-hive-function-namespace/pom.xml +++ b/presto-hive-function-namespace/pom.xml @@ -4,17 +4,29 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hive-function-namespace + presto-hive-function-namespace Hive functions for Presto presto-plugin ${project.parent.basedir} + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + com.google.guava @@ -22,13 +34,13 @@ - com.google.code.findbugs - jsr305 + jakarta.annotation + jakarta.annotation-api com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -45,8 +57,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -67,7 +79,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -79,7 +91,7 @@ - io.airlift + com.facebook.airlift units provided @@ -109,11 +121,6 @@ guice - - org.jetbrains - annotations - - org.openjdk.jol jol-core @@ -127,6 +134,12 @@ test + + org.jetbrains + annotations + test + + com.facebook.airlift testing diff --git a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/ForHiveFunction.java b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/ForHiveFunction.java index b3db7a981d638..3b8ffc28b6f69 100644 --- a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/ForHiveFunction.java +++ b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/ForHiveFunction.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive.functions; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/HiveFunctionNamespaceManager.java b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/HiveFunctionNamespaceManager.java index 3ec1ec255856b..d68e613218795 100644 --- a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/HiveFunctionNamespaceManager.java +++ b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/HiveFunctionNamespaceManager.java @@ -38,14 +38,13 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.hadoop.hive.ql.exec.UDAF; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver2; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import javax.inject.Inject; - import java.util.Collection; import java.util.List; import java.util.Objects; @@ -234,6 +233,12 @@ public AggregationFunctionImplementation getAggregateFunctionImplementation(Func return ((HiveAggregationFunction) function).getImplementation(); } + @Override + public String getCatalogName() + { + return catalogName; + } + private static class DummyHiveFunction extends HiveFunction { diff --git a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/StaticHiveFunctionRegistry.java b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/StaticHiveFunctionRegistry.java index 40018734b7b2d..0eae64f3c579c 100644 --- a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/StaticHiveFunctionRegistry.java +++ b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/StaticHiveFunctionRegistry.java @@ -16,10 +16,9 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import jakarta.inject.Inject; import org.apache.hadoop.hive.ql.parse.SemanticException; -import javax.inject.Inject; - import static com.facebook.presto.hive.functions.FunctionRegistry.getFunctionInfo; import static java.util.Objects.requireNonNull; diff --git a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/gen/CompilerOperations.java b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/gen/CompilerOperations.java index c4faeb7213442..4ec502a65e922 100644 --- a/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/gen/CompilerOperations.java +++ b/presto-hive-function-namespace/src/main/java/com/facebook/presto/hive/functions/gen/CompilerOperations.java @@ -14,8 +14,7 @@ package com.facebook.presto.hive.functions.gen; import com.facebook.presto.common.block.Block; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-hive-hadoop2/bin/common.sh b/presto-hive-hadoop2/bin/common.sh index 1b13e40732a7f..dc338858b99e7 100755 --- a/presto-hive-hadoop2/bin/common.sh +++ b/presto-hive-hadoop2/bin/common.sh @@ -33,8 +33,8 @@ function check_hadoop() { HADOOP_MASTER_CONTAINER=$(hadoop_master_container) docker exec ${HADOOP_MASTER_CONTAINER} supervisorctl status hive-server2 | grep -iq running && \ docker exec ${HADOOP_MASTER_CONTAINER} supervisorctl status hive-metastore | grep -iq running && \ - docker exec ${HADOOP_MASTER_CONTAINER} netstat -lpn | grep -iq 0.0.0.0:10000 && - docker exec ${HADOOP_MASTER_CONTAINER} netstat -lpn | grep -iq 0.0.0.0:9083 + docker exec ${HADOOP_MASTER_CONTAINER} netstat -lpn | grep -iq :10000 && + docker exec ${HADOOP_MASTER_CONTAINER} netstat -lpn | grep -iq :9083 } function exec_in_hadoop_master_container() { @@ -67,8 +67,8 @@ function termination_handler(){ exit 130 } -export HADOOP_BASE_IMAGE="${HADOOP_BASE_IMAGE:-prestodb/hdp2.6-hive}" -export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-5} +export HADOOP_BASE_IMAGE="${HADOOP_BASE_IMAGE:-prestodb/hdp3.1-hive}" +export DOCKER_IMAGES_VERSION=${DOCKER_IMAGES_VERSION:-11} SCRIPT_DIR="${BASH_SOURCE%/*}" INTEGRATION_TESTS_ROOT="${SCRIPT_DIR}/.." diff --git a/presto-hive-hadoop2/conf/docker-compose.yml b/presto-hive-hadoop2/conf/docker-compose.yml index c5cc888cce663..4984dce4a6c1f 100644 --- a/presto-hive-hadoop2/conf/docker-compose.yml +++ b/presto-hive-hadoop2/conf/docker-compose.yml @@ -20,4 +20,5 @@ services: - ./files/words:/usr/share/dict/words:ro - ./files/core-site.xml.s3-template:/etc/hadoop/conf/core-site.xml.s3-template:ro - ./files/hive-site.xml.s3-template:/etc/hive/conf/hive-site.xml.s3-template:ro + - ./files/tez-site.xml:/etc/tez/conf/tez-site.xml:ro - ./files:/tmp/files:ro \ No newline at end of file diff --git a/presto-hive-hadoop2/conf/files/tez-site.xml b/presto-hive-hadoop2/conf/files/tez-site.xml new file mode 100644 index 0000000000000..69e06e472e5c7 --- /dev/null +++ b/presto-hive-hadoop2/conf/files/tez-site.xml @@ -0,0 +1,100 @@ + + + + + + + tez.lib.uris.ignore + false + + + tez.lib.uris + file:///usr/hdp/current/tez-client/lib/tez.tar.gz + + + tez.am.mode.session + false + + + tez.am.acl.enabled + false + + + tez.am.log.level + WARN + + + tez.task.log.level + WARN + + + tez.runtime.io.sort.mb + 8 + + + tez.am.max.app.attempts + 1 + + + tez.am.task.max.failed.attempts + 1 + + + tez.shuffle-vertex-manager.min-src-fraction + 0.10 + + + tez.shuffle-vertex-manager.max-src-fraction + 1.00 + + + tez.am.launch.cmd-opts + -server -Djava.net.preferIPv4Stack=true -XX:+UseParallelGC -Dhadoop.metrics.log.level=WARN + + + tez.am.resource.memory.mb + 512 + + + tez.task.launch.cmd-opts + -server -Djava.net.preferIPv4Stack=true -XX:+UseParallelGC -Dhadoop.metrics.log.level=WARN + + + tez.task.resource.memory.mb + 512 + + + tez.task.resource.cpu.vcores + 1 + + + tez.runtime.sort.threads + 1 + + + tez.runtime.io.sort.factor + 100 + + + tez.runtime.shuffle.memory-to-memory.enable + false + + + tez.runtime.optimize.local.fetch + true + + + hive.tez.container.size + 2048 + + \ No newline at end of file diff --git a/presto-hive-hadoop2/pom.xml b/presto-hive-hadoop2/pom.xml index 3fbc5bf164705..8d3cb2288845f 100644 --- a/presto-hive-hadoop2/pom.xml +++ b/presto-hive-hadoop2/pom.xml @@ -5,17 +5,30 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hive-hadoop2 + presto-hive-hadoop2 Presto - Hive Connector - Apache Hadoop 2.x presto-plugin ${project.parent.basedir} + 17 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + com.facebook.presto @@ -39,10 +52,20 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache runtime + + com.facebook.airlift + json + + + + com.facebook.presto + presto-hdfs-core + + com.facebook.presto @@ -57,7 +80,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -69,7 +92,7 @@ - io.airlift + com.facebook.airlift units provided @@ -157,6 +180,27 @@ + + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.google.guava:guava + com.facebook.airlift:concurrent + com.facebook.airlift:json + com.facebook.airlift:stats + com.facebook.presto:presto-cache + com.facebook.presto:presto-hive-common + com.facebook.presto:presto-hive-metastore + com.facebook.presto:presto-hdfs-core + + + + + + default diff --git a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/S3SelectTestHelper.java b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/S3SelectTestHelper.java index e28d632f64354..2ef51d76014d3 100644 --- a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/S3SelectTestHelper.java +++ b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/S3SelectTestHelper.java @@ -30,7 +30,6 @@ import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HiveCommonClientConfig; import com.facebook.presto.hive.HiveEncryptionInformationProvider; -import com.facebook.presto.hive.HiveFileRenamer; import com.facebook.presto.hive.HiveHdfsConfiguration; import com.facebook.presto.hive.HiveLocationService; import com.facebook.presto.hive.HiveMetadataFactory; @@ -178,7 +177,6 @@ public S3SelectTestHelper(String host, new HivePartitionObjectBuilder(), new HiveEncryptionInformationProvider(ImmutableSet.of()), new HivePartitionStats(), - new HiveFileRenamer(), columnConverterProvider, new QuickStatsProvider(metastoreClient, hdfsEnvironment, DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(config)); diff --git a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectCsvPushdownWithSplits.java b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectCsvPushdownWithSplits.java index d5b5e5b3e7e85..9449fe2682815 100644 --- a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectCsvPushdownWithSplits.java +++ b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectCsvPushdownWithSplits.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.hive.s3select; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.testing.MaterializedResult; -import io.airlift.units.DataSize; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Parameters; @@ -26,12 +26,12 @@ import java.util.Optional; import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveFileSystemTestUtils.newSession; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.s3select.S3SelectTestHelper.expectedResult; import static com.facebook.presto.hive.s3select.S3SelectTestHelper.isSplitCountInOpenInterval; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static org.testng.Assert.assertTrue; public class TestHiveFileSystemS3SelectCsvPushdownWithSplits diff --git a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectJsonPushdownWithSplits.java b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectJsonPushdownWithSplits.java index 7be723cdac203..5c3adc1db663d 100644 --- a/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectJsonPushdownWithSplits.java +++ b/presto-hive-hadoop2/src/test/java/com/facebook/presto/hive/s3select/TestHiveFileSystemS3SelectJsonPushdownWithSplits.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.hive.s3select; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.testing.MaterializedResult; -import io.airlift.units.DataSize; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Parameters; @@ -26,12 +26,12 @@ import java.util.Optional; import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveFileSystemTestUtils.newSession; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.s3select.S3SelectTestHelper.expectedResult; import static com.facebook.presto.hive.s3select.S3SelectTestHelper.isSplitCountInOpenInterval; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static org.testng.Assert.assertTrue; public class TestHiveFileSystemS3SelectJsonPushdownWithSplits diff --git a/presto-hive-metastore/pom.xml b/presto-hive-metastore/pom.xml index 85ca1695fe5e9..6f859b4a2b7d2 100644 --- a/presto-hive-metastore/pom.xml +++ b/presto-hive-metastore/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hive-metastore @@ -13,8 +13,19 @@ ${project.parent.basedir} + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + com.facebook.presto @@ -62,7 +73,7 @@ - io.airlift + com.facebook.airlift units @@ -76,6 +87,17 @@ security + + com.google.errorprone + error_prone_annotations + true + + + + jakarta.annotation + jakarta.annotation-api + + com.fasterxml.jackson.core jackson-core @@ -102,13 +124,63 @@ - com.amazonaws - aws-java-sdk-glue + software.amazon.awssdk + auth - com.amazonaws - aws-java-sdk-sts + software.amazon.awssdk + aws-core + + + + software.amazon.awssdk + glue + + + + software.amazon.awssdk + http-client-spi + + + + software.amazon.awssdk + metrics-spi + + + + software.amazon.awssdk + netty-nio-client + + + + software.amazon.awssdk + regions + + + + software.amazon.awssdk + sdk-core + + + + software.amazon.awssdk + sts + + + + software.amazon.awssdk + utils + + + + software.amazon.awssdk + retries-spi + + + + software.amazon.awssdk + retries @@ -127,8 +199,13 @@ - javax.validation - validation-api + jakarta.inject + jakarta.inject-api + + + + jakarta.validation + jakarta.validation-api @@ -158,7 +235,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache provided diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForCachingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForCachingHiveMetastore.java index b4cdb281658a8..bf677f6e33c81 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForCachingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForCachingHiveMetastore.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForMetastoreHdfsEnvironment.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForMetastoreHdfsEnvironment.java index 54be368395ee5..cbac7bddbff09 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForMetastoreHdfsEnvironment.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForMetastoreHdfsEnvironment.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForRecordingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForRecordingHiveMetastore.java index 3777fa7969377..a1f9b9d25bc74 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForRecordingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/ForRecordingHiveMetastore.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HdfsEnvironment.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HdfsEnvironment.java index 58301f0d1aa04..a1de65ea082fa 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HdfsEnvironment.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HdfsEnvironment.java @@ -17,13 +17,12 @@ import com.facebook.presto.hive.authentication.GenericExceptionAction; import com.facebook.presto.hive.authentication.HdfsAuthentication; import com.facebook.presto.hive.filesystem.ExtendedFileSystem; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.HadoopExtendedFileSystemCache; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.io.IOException; import static com.google.common.base.Preconditions.checkState; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveBasicStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveBasicStatistics.java index 4d8112fe73173..98f33b32283ad 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveBasicStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveBasicStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.OptionalLong; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveColumnConverterProvider.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveColumnConverterProvider.java index c9cca843ca11b..8395c812770ed 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveColumnConverterProvider.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveColumnConverterProvider.java @@ -14,8 +14,7 @@ package com.facebook.presto.hive; import com.facebook.presto.hive.metastore.HiveColumnConverter; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class HiveColumnConverterProvider implements ColumnConverterProvider diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStatisticsUtil.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStatisticsUtil.java index b138a899c94cf..574e369464ec9 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStatisticsUtil.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStatisticsUtil.java @@ -21,10 +21,8 @@ import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ComputedStatistics; import com.google.common.base.VerifyException; -import com.google.common.collect.ImmutableMap; import org.joda.time.DateTimeZone; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalLong; @@ -89,13 +87,6 @@ public static PartitionStatistics createPartitionStatistics( return createPartitionStatistics(session, columnTypes, computedStatistics, computedStatistics.getColumnStatistics().keySet(), timeZone); } - public static Map getColumnStatistics(Map, ComputedStatistics> statistics, List partitionValues) - { - return Optional.ofNullable(statistics.get(partitionValues)) - .map(ComputedStatistics::getColumnStatistics) - .orElse(ImmutableMap.of()); - } - // TODO: Collect file count, on-disk size and in-memory size during ANALYZE /** diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStorageFormat.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStorageFormat.java index 2ed329b09d403..9e8577cc52e57 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStorageFormat.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/HiveStorageFormat.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; import com.facebook.presto.hive.metastore.StorageFormat; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; import org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat; import org.apache.hadoop.hive.ql.io.RCFileInputFormat; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/MetastoreClientConfig.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/MetastoreClientConfig.java index 950e65202f3ea..a58a1edfa5d09 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/MetastoreClientConfig.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/MetastoreClientConfig.java @@ -15,29 +15,49 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.configuration.LegacyConfig; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; +import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; - +import com.google.inject.ConfigurationException; +import com.google.inject.spi.Message; +import jakarta.annotation.PostConstruct; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; + +import java.util.Arrays; +import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.Iterables.transform; +import static java.util.Locale.ENGLISH; import static java.util.concurrent.TimeUnit.MINUTES; public class MetastoreClientConfig { + private static final Splitter SPLITTER = Splitter.on(',').trimResults().omitEmptyStrings(); + private HostAndPort metastoreSocksProxy; private Duration metastoreTimeout = new Duration(10, TimeUnit.SECONDS); private boolean verifyChecksum = true; private boolean requireHadoopNative = true; - private Duration metastoreCacheTtl = new Duration(0, TimeUnit.SECONDS); - private Duration metastoreRefreshInterval = new Duration(0, TimeUnit.SECONDS); + private Set enabledCaches = ImmutableSet.of(); + private Set disabledCaches = ImmutableSet.of(); + private Duration defaultMetastoreCacheTtl = new Duration(0, TimeUnit.SECONDS); + private Map metastoreCacheTtlByType = ImmutableMap.of(); + private Duration defaultMetastoreCacheRefreshInterval = new Duration(0, TimeUnit.SECONDS); + private Map metastoreCacheRefreshIntervalByType = ImmutableMap.of(); private long metastoreCacheMaximumSize = 10000; private long perTransactionMetastoreCacheMaximumSize = 1000; private int maxMetastoreRefreshThreads = 100; @@ -91,31 +111,135 @@ public MetastoreClientConfig setVerifyChecksum(boolean verifyChecksum) return this; } + public Set getEnabledCaches() + { + return enabledCaches; + } + + @Config("hive.metastore.cache.enabled-caches") + @ConfigDescription("Comma-separated list of metastore cache types to enable") + public MetastoreClientConfig setEnabledCaches(String caches) + { + if (caches == null) { + this.enabledCaches = ImmutableSet.of(); + return this; + } + + this.enabledCaches = ImmutableSet.copyOf(transform( + SPLITTER.split(caches), + cache -> MetastoreCacheType.valueOf(cache.toUpperCase(ENGLISH)))); + return this; + } + + public Set getDisabledCaches() + { + return disabledCaches; + } + + @Config("hive.metastore.cache.disabled-caches") + @ConfigDescription("Comma-separated list of metastore cache types to disable") + public MetastoreClientConfig setDisabledCaches(String caches) + { + if (caches == null) { + this.disabledCaches = ImmutableSet.of(); + return this; + } + + this.disabledCaches = ImmutableSet.copyOf(transform( + SPLITTER.split(caches), + cache -> MetastoreCacheType.valueOf(cache.toUpperCase(ENGLISH)))); + return this; + } + + @PostConstruct + public void isBothEnabledAndDisabledConfigured() + { + if (!getEnabledCaches().isEmpty() && !getDisabledCaches().isEmpty()) { + throw new ConfigurationException(ImmutableList.of(new Message("Only one of 'hive.metastore.cache.enabled-caches' or 'hive.metastore.cache.disabled-caches' can be set. " + + "These configs are mutually exclusive."))); + } + } + @NotNull - public Duration getMetastoreCacheTtl() + public Duration getDefaultMetastoreCacheTtl() { - return metastoreCacheTtl; + return defaultMetastoreCacheTtl; } @MinDuration("0ms") - @Config("hive.metastore-cache-ttl") - public MetastoreClientConfig setMetastoreCacheTtl(Duration metastoreCacheTtl) + @Config("hive.metastore.cache.ttl.default") + @ConfigDescription("Default time-to-live for Hive metastore cache entries. " + + "It is used when no per-cache TTL override is configured. " + + "TTL of 0ms would mean cache is disabled.") + @LegacyConfig("hive.metastore-cache-ttl") + public MetastoreClientConfig setDefaultMetastoreCacheTtl(Duration defaultMetastoreCacheTtl) + { + this.defaultMetastoreCacheTtl = defaultMetastoreCacheTtl; + return this; + } + + public Map getMetastoreCacheTtlByType() { - this.metastoreCacheTtl = metastoreCacheTtl; + return metastoreCacheTtlByType; + } + + @Config("hive.metastore.cache.ttl-by-type") + @ConfigDescription("Per-cache time-to-live (TTL) overrides for Hive metastore caches.\n" + + "The value is a comma-separated list of : pairs.") + public MetastoreClientConfig setMetastoreCacheTtlByType(String metastoreCacheTtlByTypeValues) + { + if (metastoreCacheTtlByTypeValues == null || metastoreCacheTtlByTypeValues.isEmpty()) { + return this; + } + + this.metastoreCacheTtlByType = Arrays.stream(metastoreCacheTtlByTypeValues.split(",")) + .map(entry -> entry.split(":")) + .filter(parts -> parts.length == 2) + .collect(toImmutableMap( + parts -> MetastoreCacheType.valueOf(parts[0].trim().toUpperCase(ENGLISH)), + parts -> Duration.valueOf(parts[1].trim()))); + return this; } @NotNull - public Duration getMetastoreRefreshInterval() + public Duration getDefaultMetastoreCacheRefreshInterval() { - return metastoreRefreshInterval; + return defaultMetastoreCacheRefreshInterval; } @MinDuration("1ms") - @Config("hive.metastore-refresh-interval") - public MetastoreClientConfig setMetastoreRefreshInterval(Duration metastoreRefreshInterval) + @Config("hive.metastore.cache.refresh-interval.default") + @ConfigDescription("Default refresh interval for Hive metastore cache entries.\n" + + "Controls how often cached values are asynchronously refreshed.") + @LegacyConfig("hive.metastore-refresh-interval") + public MetastoreClientConfig setDefaultMetastoreCacheRefreshInterval(Duration defaultMetastoreCacheRefreshInterval) { - this.metastoreRefreshInterval = metastoreRefreshInterval; + this.defaultMetastoreCacheRefreshInterval = defaultMetastoreCacheRefreshInterval; + return this; + } + + public Map getMetastoreCacheRefreshIntervalByType() + { + return metastoreCacheRefreshIntervalByType; + } + + @Config("hive.metastore.cache.refresh-interval-by-type") + @ConfigDescription("Per-cache refresh interval overrides for Hive metastore caches.\n" + + "The value is a comma-separated list of : pairs.") + public MetastoreClientConfig setMetastoreCacheRefreshIntervalByType(String metastoreCacheRefreshIntervalByTypeValues) + { + if (metastoreCacheRefreshIntervalByTypeValues == null || metastoreCacheRefreshIntervalByTypeValues.isEmpty()) { + return this; + } + + this.metastoreCacheRefreshIntervalByType = Arrays.stream(metastoreCacheRefreshIntervalByTypeValues.split(",")) + .map(entry -> entry.split(":")) + .filter(parts -> parts.length == 2) + .collect(toImmutableMap( + parts -> MetastoreCacheType.valueOf(parts[0].trim().toUpperCase(ENGLISH)), + parts -> Duration.valueOf(parts[1].trim()))); + return this; } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/RetryDriver.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/RetryDriver.java index 20859c24028dd..340c5103dea34 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/RetryDriver.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/RetryDriver.java @@ -14,8 +14,8 @@ package com.facebook.presto.hive; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.ArrayList; import java.util.Arrays; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/aws/AbstractSdkMetricsCollector.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/aws/AbstractSdkMetricsCollector.java index e1dc61fde634f..ccb388659a34c 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/aws/AbstractSdkMetricsCollector.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/aws/AbstractSdkMetricsCollector.java @@ -18,7 +18,7 @@ import com.amazonaws.metrics.RequestMetricCollector; import com.amazonaws.util.AWSRequestMetrics; import com.amazonaws.util.TimingInfo; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.List; import java.util.function.Consumer; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/AbstractCachingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/AbstractCachingHiveMetastore.java index 1c0df753f20d0..1f1bb61723889 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/AbstractCachingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/AbstractCachingHiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HiveType; @@ -20,9 +21,7 @@ import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.statistics.ColumnStatisticType; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; import java.util.Map; @@ -40,6 +39,25 @@ public enum MetastoreCacheScope ALL, PARTITION } + public enum MetastoreCacheType + { + ALL, + DATABASE, + DATABASE_NAMES, + TABLE, + TABLE_NAMES, + TABLE_STATISTICS, + TABLE_CONSTRAINTS, + PARTITION, + PARTITION_STATISTICS, + PARTITION_FILTER, + PARTITION_NAMES, + VIEW_NAMES, + TABLE_PRIVILEGES, + ROLES, + ROLE_GRANTS + } + public abstract ExtendedHiveMetastore getDelegate(); @Override diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/BooleanStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/BooleanStatistics.java index bbd624698cdaa..4f1c5eb415a56 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/BooleanStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/BooleanStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.OptionalLong; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Column.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Column.java index 5aa481580d087..bb1e7dc0c512b 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Column.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Column.java @@ -16,10 +16,9 @@ import com.facebook.presto.hive.HiveType; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.errorprone.annotations.Immutable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.concurrent.Immutable; - import java.util.Objects; import java.util.Optional; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Database.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Database.java index 2d751e714c694..ebb8e4764132a 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Database.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Database.java @@ -17,8 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.LinkedHashMap; import java.util.Map; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DateStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DateStatistics.java index 8b422ad4f0f85..d36cf226c2cb1 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DateStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DateStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.time.LocalDate; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DecimalStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DecimalStatistics.java index ada80c2c38944..b4d4b789146cb 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DecimalStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DecimalStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.math.BigDecimal; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DoubleStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DoubleStatistics.java index c70ff272a93b8..a0a642a5ede3f 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DoubleStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/DoubleStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.OptionalDouble; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java index 0f108f67dbf87..4cfdc0d799a86 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/ExtendedHiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.NotSupportedException; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; @@ -24,13 +25,13 @@ import com.facebook.presto.spi.security.RoleGrant; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.Supplier; public interface ExtendedHiveMetastore { @@ -89,6 +90,8 @@ default void dropTableFromMetastore(MetastoreContext metastoreContext, String da */ MetastoreOperationResult replaceTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges); + MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters); + MetastoreOperationResult renameTable(MetastoreContext metastoreContext, String databaseName, String tableName, String newDatabaseName, String newTableName); MetastoreOperationResult addColumn(MetastoreContext metastoreContext, String databaseName, String tableName, String columnName, HiveType columnType, String columnComment); diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java index df9fcf7ee7782..df477e9b07b49 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveColumnStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.math.BigDecimal; import java.time.LocalDate; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveCommitHandle.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveCommitHandle.java index 151b39113e4c5..f0b4258b829ac 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveCommitHandle.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveCommitHandle.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorCommitHandle; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -24,6 +25,7 @@ import static java.util.Objects.requireNonNull; +@JsonIgnoreProperties({ "lastDataCommitTimesForRead", "lastDataCommitTimesForWrite" }) public class HiveCommitHandle implements ConnectorCommitHandle { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionMutator.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionMutator.java index e2efcd85113da..a9b38c6bf21ed 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionMutator.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionMutator.java @@ -15,10 +15,9 @@ import com.facebook.presto.hive.PartitionMutator; import com.facebook.presto.hive.metastore.Partition.Builder; +import jakarta.inject.Inject; import org.apache.hadoop.hive.metastore.api.Partition; -import javax.inject.Inject; - public class HivePartitionMutator implements PartitionMutator { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionName.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionName.java index ebb348c736498..a7568a02cc352 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionName.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePartitionName.java @@ -17,8 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePrivilegeInfo.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePrivilegeInfo.java index ca5b5a95b7966..dac23a80d7fac 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePrivilegeInfo.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HivePrivilegeInfo.java @@ -19,8 +19,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableSet; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.Set; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveTableName.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveTableName.java index f80a945763f52..790b1dc1e71f4 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveTableName.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/HiveTableName.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InMemoryCachingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InMemoryCachingHiveMetastore.java index a1327c0bc910d..f7b47b3c08734 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InMemoryCachingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InMemoryCachingHiveMetastore.java @@ -30,12 +30,10 @@ import com.google.common.collect.Iterables; import com.google.common.collect.SetMultimap; import com.google.common.util.concurrent.UncheckedExecutionException; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -50,10 +48,24 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; import static com.facebook.presto.hive.HiveErrorCode.HIVE_CORRUPTED_PARTITION_CACHE; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_DROPPED_DURING_QUERY; -import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope.ALL; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.DATABASE; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.DATABASE_NAMES; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION_FILTER; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION_NAMES; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION_STATISTICS; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ROLES; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ROLE_GRANTS; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE_CONSTRAINTS; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE_NAMES; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE_PRIVILEGES; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE_STATISTICS; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.VIEW_NAMES; import static com.facebook.presto.hive.metastore.HivePartitionName.hivePartitionName; import static com.facebook.presto.hive.metastore.HiveTableName.hiveTableName; import static com.facebook.presto.hive.metastore.NoopMetastoreCacheStats.NOOP_METASTORE_CACHE_STATS; @@ -95,6 +107,7 @@ public class InMemoryCachingHiveMetastore private final LoadingCache, Set> rolesCache; private final LoadingCache, Set> roleGrantsCache; private final MetastoreCacheStats metastoreCacheStats; + private final MetastoreCacheSpecProvider metastoreCacheSpecProvider; private final boolean metastoreImpersonationEnabled; private final boolean partitionVersioningEnabled; @@ -106,47 +119,43 @@ public InMemoryCachingHiveMetastore( @ForCachingHiveMetastore ExtendedHiveMetastore delegate, @ForCachingHiveMetastore ExecutorService executor, MetastoreCacheStats metastoreCacheStats, - MetastoreClientConfig metastoreClientConfig) + MetastoreClientConfig metastoreClientConfig, + MetastoreCacheSpecProvider metastoreCacheSpecProvider) { this( delegate, executor, metastoreClientConfig.isMetastoreImpersonationEnabled(), - metastoreClientConfig.getMetastoreCacheTtl(), - metastoreClientConfig.getMetastoreRefreshInterval(), metastoreClientConfig.getMetastoreCacheMaximumSize(), metastoreClientConfig.isPartitionVersioningEnabled(), - metastoreClientConfig.getMetastoreCacheScope(), metastoreClientConfig.getPartitionCacheValidationPercentage(), metastoreClientConfig.getPartitionCacheColumnCountLimit(), - metastoreCacheStats); + metastoreCacheStats, + metastoreCacheSpecProvider); } public InMemoryCachingHiveMetastore( ExtendedHiveMetastore delegate, ExecutorService executor, boolean metastoreImpersonationEnabled, - Duration cacheTtl, - Duration refreshInterval, long maximumSize, boolean partitionVersioningEnabled, - MetastoreCacheScope metastoreCacheScope, double partitionCacheValidationPercentage, int partitionCacheColumnCountLimit, - MetastoreCacheStats metastoreCacheStats) + MetastoreCacheStats metastoreCacheStats, + MetastoreCacheSpecProvider metastoreCacheSpecProvider) { this( delegate, executor, metastoreImpersonationEnabled, - OptionalLong.of(cacheTtl.toMillis()), - refreshInterval.toMillis() >= cacheTtl.toMillis() ? OptionalLong.empty() : OptionalLong.of(refreshInterval.toMillis()), maximumSize, partitionVersioningEnabled, - metastoreCacheScope, partitionCacheValidationPercentage, partitionCacheColumnCountLimit, - metastoreCacheStats); + metastoreCacheStats, + Optional.of(metastoreCacheSpecProvider), + false); } public static InMemoryCachingHiveMetastore memoizeMetastore(ExtendedHiveMetastore delegate, boolean isMetastoreImpersonationEnabled, long maximumSize, int partitionCacheMaxColumnCount) @@ -155,28 +164,26 @@ public static InMemoryCachingHiveMetastore memoizeMetastore(ExtendedHiveMetastor delegate, newDirectExecutorService(), isMetastoreImpersonationEnabled, - OptionalLong.empty(), - OptionalLong.empty(), maximumSize, false, - ALL, 0.0, partitionCacheMaxColumnCount, - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + Optional.empty(), + true); } private InMemoryCachingHiveMetastore( ExtendedHiveMetastore delegate, ExecutorService executor, boolean metastoreImpersonationEnabled, - OptionalLong expiresAfterWriteMillis, - OptionalLong refreshMills, long maximumSize, boolean partitionVersioningEnabled, - MetastoreCacheScope metastoreCacheScope, double partitionCacheValidationPercentage, int partitionCacheColumnCountLimit, - MetastoreCacheStats metastoreCacheStats) + MetastoreCacheStats metastoreCacheStats, + Optional metastoreCacheSpecProvider, + boolean perTransactionCache) { this.delegate = requireNonNull(delegate, "delegate is null"); requireNonNull(executor, "executor is null"); @@ -185,59 +192,40 @@ private InMemoryCachingHiveMetastore( this.partitionCacheValidationPercentage = partitionCacheValidationPercentage; this.partitionCacheColumnCountLimit = partitionCacheColumnCountLimit; this.metastoreCacheStats = metastoreCacheStats; + this.metastoreCacheSpecProvider = metastoreCacheSpecProvider.orElse(null); - OptionalLong cacheExpiresAfterWriteMillis; - OptionalLong cacheRefreshMills; - long cacheMaxSize; - - OptionalLong partitionCacheExpiresAfterWriteMillis; - OptionalLong partitionCacheRefreshMills; - long partitionCacheMaxSize; - - switch (metastoreCacheScope) { - case PARTITION: - partitionCacheExpiresAfterWriteMillis = expiresAfterWriteMillis; - partitionCacheRefreshMills = refreshMills; - partitionCacheMaxSize = maximumSize; - cacheExpiresAfterWriteMillis = OptionalLong.of(0); - cacheRefreshMills = OptionalLong.of(0); - cacheMaxSize = 0; - break; - - case ALL: - partitionCacheExpiresAfterWriteMillis = expiresAfterWriteMillis; - partitionCacheRefreshMills = refreshMills; - partitionCacheMaxSize = maximumSize; - cacheExpiresAfterWriteMillis = expiresAfterWriteMillis; - cacheRefreshMills = refreshMills; - cacheMaxSize = maximumSize; - break; - - default: - throw new IllegalArgumentException("Unknown metastore-cache-scope: " + metastoreCacheScope); - } - - databaseNamesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadAllDatabases), executor)); + databaseNamesCache = buildCache( + executor, + DATABASE_NAMES, + CacheLoader.from(this::loadAllDatabases), + perTransactionCache, + maximumSize); - databaseCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadDatabase), executor)); + databaseCache = buildCache( + executor, + DATABASE, + CacheLoader.from(this::loadDatabase), + perTransactionCache, + maximumSize); - tableNamesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadAllTables), executor)); + tableNamesCache = buildCache( + executor, + TABLE_NAMES, + CacheLoader.from(this::loadAllTables), + perTransactionCache, + maximumSize); - tableStatisticsCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(new CacheLoader, PartitionStatistics>() - { - @Override - public PartitionStatistics load(KeyAndContext key) - { - return loadTableColumnStatistics(key); - } - }, executor)); + tableStatisticsCache = buildCache( + executor, + TABLE_STATISTICS, + CacheLoader.from(this::loadTableColumnStatistics), + perTransactionCache, + maximumSize); - partitionStatisticsCache = newCacheBuilder(partitionCacheExpiresAfterWriteMillis, partitionCacheRefreshMills, partitionCacheMaxSize) - .build(asyncReloading(new CacheLoader, PartitionStatistics>() + partitionStatisticsCache = buildCache( + executor, + PARTITION_STATISTICS, + new CacheLoader, PartitionStatistics>() { @Override public PartitionStatistics load(KeyAndContext key) @@ -250,27 +238,51 @@ public Map, PartitionStatistics> loadAll(Iterab { return loadPartitionColumnStatistics(keys); } - }, executor)); + }, + perTransactionCache, + maximumSize); - tableCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadTable), executor)); + tableCache = buildCache( + executor, + TABLE, + CacheLoader.from(this::loadTable), + perTransactionCache, + maximumSize); metastoreCacheStats.setTableCache(tableCache); - tableConstraintsCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadTableConstraints), executor)); + tableConstraintsCache = buildCache( + executor, + TABLE_CONSTRAINTS, + CacheLoader.from(this::loadTableConstraints), + perTransactionCache, + maximumSize); - viewNamesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadAllViews), executor)); + viewNamesCache = buildCache( + executor, + VIEW_NAMES, + CacheLoader.from(this::loadAllViews), + perTransactionCache, + maximumSize); - partitionNamesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadPartitionNames), executor)); + partitionNamesCache = buildCache( + executor, + PARTITION_NAMES, + CacheLoader.from(this::loadPartitionNames), + perTransactionCache, + maximumSize); - partitionFilterCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadPartitionNamesByFilter), executor)); + partitionFilterCache = buildCache( + executor, + PARTITION_FILTER, + CacheLoader.from(this::loadPartitionNamesByFilter), + perTransactionCache, + maximumSize); metastoreCacheStats.setPartitionNamesCache(partitionFilterCache); - partitionCache = newCacheBuilder(partitionCacheExpiresAfterWriteMillis, partitionCacheRefreshMills, partitionCacheMaxSize) - .build(asyncReloading(new CacheLoader, Optional>() + partitionCache = buildCache( + executor, + PARTITION, + new CacheLoader, Optional>() { @Override public Optional load(KeyAndContext partitionName) @@ -283,17 +295,31 @@ public Map, Optional> loadAll(Iterab { return loadPartitionsByNames(partitionNames); } - }, executor)); + }, + perTransactionCache, + maximumSize); metastoreCacheStats.setPartitionCache(partitionCache); - tablePrivilegesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadTablePrivileges), executor)); + tablePrivilegesCache = buildCache( + executor, + TABLE_PRIVILEGES, + CacheLoader.from(this::loadTablePrivileges), + perTransactionCache, + maximumSize); - rolesCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadAllRoles), executor)); + rolesCache = buildCache( + executor, + ROLES, + CacheLoader.from(this::loadAllRoles), + perTransactionCache, + maximumSize); - roleGrantsCache = newCacheBuilder(cacheExpiresAfterWriteMillis, cacheRefreshMills, cacheMaxSize) - .build(asyncReloading(CacheLoader.from(this::loadRoleGrants), executor)); + roleGrantsCache = buildCache( + executor, + ROLE_GRANTS, + CacheLoader.from(this::loadRoleGrants), + perTransactionCache, + maximumSize); } @Override @@ -451,6 +477,18 @@ private Map, PartitionStatistics> loadPartition return result.build(); } + @Override + public MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters) + { + try { + return getDelegate().persistTable(metastoreContext, databaseName, tableName, newTable, principalPrivileges, update, additionalParameters); + } + finally { + invalidateTableCache(databaseName, tableName); + invalidateTableCache(newTable.getDatabaseName(), newTable.getTableName()); + } + } + @Override public void updateTableStatistics(MetastoreContext metastoreContext, String databaseName, String tableName, Function update) { @@ -1076,4 +1114,30 @@ private static CacheBuilder newCacheBuilder(OptionalLong expires } return cacheBuilder.maximumSize(maximumSize).recordStats(); } + + private LoadingCache buildCache( + ExecutorService executor, + MetastoreCacheType cacheType, + CacheLoader loader, + boolean isPerTransactionCache, + long maximumSize) + { + if (isPerTransactionCache) { + return newCacheBuilder( + OptionalLong.empty(), + OptionalLong.empty(), + maximumSize) + .build(asyncReloading(loader, executor)); + } + + MetastoreCacheSpec spec = metastoreCacheSpecProvider.getMetastoreCacheSpec(cacheType); + long cacheTtlMillis = spec.getCacheTtlMillis(); + long refreshMillis = spec.getRefreshIntervalMillis(); + + return newCacheBuilder( + OptionalLong.of(cacheTtlMillis), + refreshMillis >= cacheTtlMillis ? OptionalLong.empty() : OptionalLong.of(refreshMillis), + spec.getMaximumSize()) + .build(asyncReloading(loader, executor)); + } } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/IntegerStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/IntegerStatistics.java index 75b24ad200cae..7de33e1470bc5 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/IntegerStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/IntegerStatistics.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.OptionalLong; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java index 29ee7a89a47db..e14ca6eb37c6a 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/InvalidateMetastoreCacheProcedure.java @@ -20,8 +20,8 @@ import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpec.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpec.java new file mode 100644 index 0000000000000..c690d76109e0d --- /dev/null +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpec.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.metastore; + +public class MetastoreCacheSpec +{ + private static final MetastoreCacheSpec DISABLED = new MetastoreCacheSpec(0, 0, 0); + private final long cacheTtlMillis; + private final long refreshIntervalMillis; + private final long maximumSize; + + public static MetastoreCacheSpec disabled() + { + return DISABLED; + } + + public static MetastoreCacheSpec enabled(long cacheTtlMillis, long refreshIntervalMillis, long maximumSize) + { + return new MetastoreCacheSpec(cacheTtlMillis, refreshIntervalMillis, maximumSize); + } + + private MetastoreCacheSpec(long cacheTtlMillis, long refreshIntervalMillis, long maximumSize) + { + this.cacheTtlMillis = cacheTtlMillis; + this.refreshIntervalMillis = refreshIntervalMillis; + this.maximumSize = maximumSize; + } + + public long getCacheTtlMillis() + { + return cacheTtlMillis; + } + + public long getRefreshIntervalMillis() + { + return refreshIntervalMillis; + } + + public long getMaximumSize() + { + return maximumSize; + } +} diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpecProvider.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpecProvider.java new file mode 100644 index 0000000000000..1b645c7870917 --- /dev/null +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreCacheSpecProvider.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.metastore; + +import com.facebook.presto.hive.MetastoreClientConfig; +import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType; +import jakarta.inject.Inject; + +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ALL; +import static java.util.Objects.requireNonNull; + +public class MetastoreCacheSpecProvider +{ + private final MetastoreClientConfig clientConfig; + + @Inject + public MetastoreCacheSpecProvider(MetastoreClientConfig clientConfig) + { + this.clientConfig = requireNonNull(clientConfig, "clientConfig is null"); + } + + public MetastoreCacheSpec getMetastoreCacheSpec(MetastoreCacheType type) + { + boolean enabled = isEnabled(type); + if (!enabled) { + return MetastoreCacheSpec.disabled(); + } + + long cacheTtlMillis = clientConfig.getMetastoreCacheTtlByType().getOrDefault( + type, clientConfig.getDefaultMetastoreCacheTtl()).toMillis(); + long refreshIntervalMillis = clientConfig.getMetastoreCacheRefreshIntervalByType().getOrDefault( + type, clientConfig.getDefaultMetastoreCacheRefreshInterval()).toMillis(); + + return MetastoreCacheSpec.enabled( + cacheTtlMillis, + refreshIntervalMillis, + clientConfig.getMetastoreCacheMaximumSize()); + } + + private boolean isEnabled(MetastoreCacheType type) + { + if (!clientConfig.getEnabledCaches().isEmpty()) { + return clientConfig.getEnabledCaches().contains(type) || clientConfig.getEnabledCaches().contains(ALL); + } + if (!clientConfig.getDisabledCaches().isEmpty()) { + return !(clientConfig.getDisabledCaches().contains(type) || clientConfig.getDisabledCaches().contains(ALL)); + } + + return isEnabledByLegacyMetastoreScope(type); + } + + private boolean isEnabledByLegacyMetastoreScope(MetastoreCacheType type) + { + switch (clientConfig.getMetastoreCacheScope()) { + case ALL: + return true; + case PARTITION: + return type == MetastoreCacheType.PARTITION || type == MetastoreCacheType.PARTITION_STATISTICS; + default: + return false; + } + } +} diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreConfig.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreConfig.java index 058b566682272..f4776d69c77dd 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreConfig.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.hive.metastore; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class MetastoreConfig { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java index 53ec58fc3547a..863522cae2b7a 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/MetastoreUtil.java @@ -69,6 +69,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Longs; import io.airlift.slice.Slice; +import jakarta.annotation.Nullable; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.permission.FsPermission; @@ -80,8 +81,6 @@ import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.ISODateTimeFormat; -import javax.annotation.Nullable; - import java.io.IOException; import java.math.BigInteger; import java.sql.Date; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Partition.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Partition.java index 40163d4bc302e..f2a805da02935 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Partition.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Partition.java @@ -19,8 +19,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Arrays; import java.util.List; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionFilter.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionFilter.java index 94436f62fe65b..e3226b8bca417 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionFilter.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionFilter.java @@ -14,8 +14,7 @@ package com.facebook.presto.hive.metastore; import com.facebook.presto.common.predicate.Domain; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Map; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionStatistics.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionStatistics.java index 69ad53e90a74b..b8fde0d6b640d 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionStatistics.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/PartitionStatistics.java @@ -18,8 +18,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Map; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/RecordingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/RecordingHiveMetastore.java index 993b7da5ef3bf..4c70b9005690c 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/RecordingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/RecordingHiveMetastore.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.metastore; import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.ForRecordingHiveMetastore; @@ -33,12 +34,10 @@ import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.Immutable; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.Immutable; -import javax.inject.Inject; - import java.io.File; import java.io.IOException; import java.util.List; @@ -316,6 +315,13 @@ public MetastoreOperationResult replaceTable(MetastoreContext metastoreContext, return delegate.replaceTable(metastoreContext, databaseName, tableName, newTable, principalPrivileges); } + @Override + public MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters) + { + verifyRecordingMode(); + return delegate.persistTable(metastoreContext, databaseName, tableName, newTable, principalPrivileges, update, additionalParameters); + } + @Override public MetastoreOperationResult renameTable(MetastoreContext metastoreContext, String databaseName, String tableName, String newDatabaseName, String newTableName) { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java index d04d92e2210e9..ea9c61603379f 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SemiTransactionalHiveMetastore.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.metastore; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; @@ -47,13 +48,11 @@ import com.google.common.collect.Lists; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.annotation.concurrent.GuardedBy; - import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SortingColumn.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SortingColumn.java index 717653db1a3ce..cbb62332c8ddc 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SortingColumn.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/SortingColumn.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.PrestoException; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Storage.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Storage.java index bd5262bc36b90..693bc53bfa5fc 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Storage.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Storage.java @@ -17,8 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Map; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/StorageFormat.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/StorageFormat.java index 024f643280367..d05c8e2f73379 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/StorageFormat.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/StorageFormat.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.PrestoException; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Table.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Table.java index 9f8ee10dba1dc..108dc0a0cc01f 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Table.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/Table.java @@ -19,8 +19,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.ArrayList; import java.util.LinkedHashMap; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserDatabaseKey.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserDatabaseKey.java index c4835e81a1240..ca745dd5855fe 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserDatabaseKey.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserDatabaseKey.java @@ -15,8 +15,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserTableKey.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserTableKey.java index 2ddaa5c1cadf5..b4863926411ed 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserTableKey.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/UserTableKey.java @@ -16,8 +16,7 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java index 2b265384ebf7b..ee1879b73e698 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastore.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.metastore.file; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HdfsContext; @@ -59,15 +60,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.ByteStreams; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.io.IOException; import java.io.OutputStream; import java.util.ArrayDeque; @@ -87,6 +86,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.Supplier; import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_DROPPED_DURING_QUERY; @@ -135,6 +135,7 @@ public class FileHiveMetastore protected final HdfsEnvironment hdfsEnvironment; protected final HdfsContext hdfsContext; + protected final FileSystem metadataFileSystem; private final Path catalogDirectory; @@ -478,6 +479,44 @@ public synchronized MetastoreOperationResult replaceTable(MetastoreContext metas return EMPTY_RESULT; } + @Override + public synchronized MetastoreOperationResult persistTable( + MetastoreContext metastoreContext, + String databaseName, + String tableName, + Table newTable, + PrincipalPrivileges principalPrivileges, + Supplier update, + Map additionalParameters) + { + checkArgument(!newTable.getTableType().equals(TEMPORARY_TABLE), "temporary tables must never be stored in the metastore"); + + Table oldTable = getRequiredTable(metastoreContext, databaseName, tableName); + validateReplaceTableType(oldTable, newTable); + if (!oldTable.getDatabaseName().equals(databaseName) || !oldTable.getTableName().equals(tableName)) { + throw new PrestoException(HIVE_METASTORE_ERROR, "Replacement table must have same name"); + } + + Path tableMetadataDirectory = getTableMetadataDirectory(oldTable); + + deleteTablePrivileges(oldTable); + for (Entry> entry : principalPrivileges.getUserPrivileges().asMap().entrySet()) { + setTablePrivileges(metastoreContext, new PrestoPrincipal(USER, entry.getKey()), databaseName, tableName, entry.getValue()); + } + for (Entry> entry : principalPrivileges.getRolePrivileges().asMap().entrySet()) { + setTablePrivileges(metastoreContext, new PrestoPrincipal(ROLE, entry.getKey()), databaseName, tableName, entry.getValue()); + } + PartitionStatistics updatedStatistics = update.get(); + + TableMetadata updatedMetadata = new TableMetadata(newTable) + .withParameters(updateStatisticsParameters(newTable.getParameters(), updatedStatistics.getBasicStatistics())) + .withColumnStatistics(updatedStatistics.getColumnStatistics()); + + writeSchemaFile("table", tableMetadataDirectory, tableCodec, updatedMetadata, true); + + return EMPTY_RESULT; + } + @Override public synchronized MetastoreOperationResult renameTable(MetastoreContext metastoreContext, String databaseName, String tableName, String newDatabaseName, String newTableName) { @@ -1293,7 +1332,6 @@ private List getChildSchemaDirectories(Path metadataDirectory) if (!metadataFileSystem.isDirectory(metadataDirectory)) { return ImmutableList.of(); } - ImmutableList.Builder childSchemaDirectories = ImmutableList.builder(); for (FileStatus child : metadataFileSystem.listStatus(metadataDirectory)) { if (!child.isDirectory()) { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastoreConfig.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastoreConfig.java index 31015aadac281..ee1e628a43a1b 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastoreConfig.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileHiveMetastoreConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class FileHiveMetastoreConfig { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileMetastoreModule.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileMetastoreModule.java index c37a22af47b7f..5ae8fd5cb7fb3 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileMetastoreModule.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/file/FileMetastoreModule.java @@ -18,6 +18,7 @@ import com.facebook.presto.hive.HiveCommonClientConfig; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.google.inject.Binder; import com.google.inject.Scopes; @@ -42,6 +43,7 @@ public void setup(Binder binder) { checkArgument(buildConfigObject(HiveCommonClientConfig.class).getCatalogName() == null, "'hive.metastore.catalog.name' should not be set for file metastore"); configBinder(binder).bindConfig(FileHiveMetastoreConfig.class); + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).annotatedWith(ForCachingHiveMetastore.class).to(FileHiveMetastore.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); newExporter(binder).export(ExtendedHiveMetastore.class) diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/ForGlueHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/ForGlueHiveMetastore.java index a0c0635db36a9..d56e29bcf4b8d 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/ForGlueHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/ForGlueHiveMetastore.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive.metastore.glue; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueCatalogApiStats.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueCatalogApiStats.java index b1709c68465c1..0e276a8234720 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueCatalogApiStats.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueCatalogApiStats.java @@ -13,18 +13,16 @@ */ package com.facebook.presto.hive.metastore.glue; -import com.amazonaws.AmazonWebServiceRequest; -import com.amazonaws.handlers.AsyncHandler; import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeStat; +import com.google.errorprone.annotations.ThreadSafe; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; - import java.util.function.Supplier; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; @ThreadSafe public class GlueCatalogApiStats @@ -54,23 +52,12 @@ public void record(Runnable action) } } - public AsyncHandler metricsAsyncHandler() + public void recordAsync(long executionTimeNanos, boolean failed) { - return new AsyncHandler() { - private final TimeStat.BlockTimer timer = time.time(); - @Override - public void onError(Exception exception) - { - timer.close(); - recordException(exception); - } - - @Override - public void onSuccess(R request, T result) - { - timer.close(); - } - }; + time.add(executionTimeNanos, NANOSECONDS); + if (failed) { + totalFailures.update(1); + } } @Managed diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastore.java index c9c2f502c3592..da980eef676f5 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastore.java @@ -13,51 +13,7 @@ */ package com.facebook.presto.hive.metastore.glue; -import com.amazonaws.AmazonServiceException; -import com.amazonaws.ClientConfiguration; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; -import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; -import com.amazonaws.metrics.RequestMetricCollector; -import com.amazonaws.regions.Region; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.glue.AWSGlueAsync; -import com.amazonaws.services.glue.AWSGlueAsyncClientBuilder; -import com.amazonaws.services.glue.model.AlreadyExistsException; -import com.amazonaws.services.glue.model.BatchCreatePartitionRequest; -import com.amazonaws.services.glue.model.BatchCreatePartitionResult; -import com.amazonaws.services.glue.model.BatchGetPartitionRequest; -import com.amazonaws.services.glue.model.BatchGetPartitionResult; -import com.amazonaws.services.glue.model.CreateDatabaseRequest; -import com.amazonaws.services.glue.model.CreateTableRequest; -import com.amazonaws.services.glue.model.DatabaseInput; -import com.amazonaws.services.glue.model.DeleteDatabaseRequest; -import com.amazonaws.services.glue.model.DeletePartitionRequest; -import com.amazonaws.services.glue.model.DeleteTableRequest; -import com.amazonaws.services.glue.model.EntityNotFoundException; -import com.amazonaws.services.glue.model.ErrorDetail; -import com.amazonaws.services.glue.model.GetDatabaseRequest; -import com.amazonaws.services.glue.model.GetDatabaseResult; -import com.amazonaws.services.glue.model.GetDatabasesRequest; -import com.amazonaws.services.glue.model.GetDatabasesResult; -import com.amazonaws.services.glue.model.GetPartitionRequest; -import com.amazonaws.services.glue.model.GetPartitionResult; -import com.amazonaws.services.glue.model.GetPartitionsRequest; -import com.amazonaws.services.glue.model.GetPartitionsResult; -import com.amazonaws.services.glue.model.GetTableRequest; -import com.amazonaws.services.glue.model.GetTableResult; -import com.amazonaws.services.glue.model.GetTablesRequest; -import com.amazonaws.services.glue.model.GetTablesResult; -import com.amazonaws.services.glue.model.PartitionError; -import com.amazonaws.services.glue.model.PartitionInput; -import com.amazonaws.services.glue.model.PartitionValueList; -import com.amazonaws.services.glue.model.Segment; -import com.amazonaws.services.glue.model.TableInput; -import com.amazonaws.services.glue.model.UpdateDatabaseRequest; -import com.amazonaws.services.glue.model.UpdatePartitionRequest; -import com.amazonaws.services.glue.model.UpdateTableRequest; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HdfsContext; @@ -92,18 +48,69 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.RoleGrant; import com.facebook.presto.spi.statistics.ColumnStatisticType; +import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import io.airlift.units.Duration; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import org.apache.hadoop.fs.Path; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; - -import javax.annotation.Nullable; -import javax.inject.Inject; - +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.awscore.retry.AwsRetryStrategy; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.metrics.MetricPublisher; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.retries.StandardRetryStrategy; +import software.amazon.awssdk.services.glue.GlueAsyncClient; +import software.amazon.awssdk.services.glue.GlueAsyncClientBuilder; +import software.amazon.awssdk.services.glue.model.AlreadyExistsException; +import software.amazon.awssdk.services.glue.model.BatchCreatePartitionRequest; +import software.amazon.awssdk.services.glue.model.BatchCreatePartitionResponse; +import software.amazon.awssdk.services.glue.model.BatchGetPartitionRequest; +import software.amazon.awssdk.services.glue.model.BatchGetPartitionResponse; +import software.amazon.awssdk.services.glue.model.CreateDatabaseRequest; +import software.amazon.awssdk.services.glue.model.CreateTableRequest; +import software.amazon.awssdk.services.glue.model.DatabaseInput; +import software.amazon.awssdk.services.glue.model.DeleteDatabaseRequest; +import software.amazon.awssdk.services.glue.model.DeletePartitionRequest; +import software.amazon.awssdk.services.glue.model.DeleteTableRequest; +import software.amazon.awssdk.services.glue.model.EntityNotFoundException; +import software.amazon.awssdk.services.glue.model.ErrorDetail; +import software.amazon.awssdk.services.glue.model.GetDatabaseRequest; +import software.amazon.awssdk.services.glue.model.GetDatabaseResponse; +import software.amazon.awssdk.services.glue.model.GetDatabasesRequest; +import software.amazon.awssdk.services.glue.model.GetPartitionRequest; +import software.amazon.awssdk.services.glue.model.GetPartitionResponse; +import software.amazon.awssdk.services.glue.model.GetPartitionsRequest; +import software.amazon.awssdk.services.glue.model.GetTableRequest; +import software.amazon.awssdk.services.glue.model.GetTableResponse; +import software.amazon.awssdk.services.glue.model.GetTablesRequest; +import software.amazon.awssdk.services.glue.model.GlueException; +import software.amazon.awssdk.services.glue.model.GlueResponse; +import software.amazon.awssdk.services.glue.model.PartitionError; +import software.amazon.awssdk.services.glue.model.PartitionInput; +import software.amazon.awssdk.services.glue.model.PartitionValueList; +import software.amazon.awssdk.services.glue.model.Segment; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.TableInput; +import software.amazon.awssdk.services.glue.model.UpdateDatabaseRequest; +import software.amazon.awssdk.services.glue.model.UpdatePartitionRequest; +import software.amazon.awssdk.services.glue.model.UpdateTableRequest; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; + +import java.net.URI; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -111,12 +118,16 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.Future; +import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PARTITION_DROPPED_DURING_QUERY; @@ -143,6 +154,7 @@ import static com.google.common.collect.Comparators.lexicographical; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; import static java.util.function.UnaryOperator.identity; import static java.util.stream.Collectors.toMap; @@ -177,7 +189,7 @@ public class GlueHiveMetastore private final GlueMetastoreStats stats = new GlueMetastoreStats(); private final HdfsEnvironment hdfsEnvironment; private final HdfsContext hdfsContext; - private final AWSGlueAsync glueClient; + private final GlueAsyncClient glueClient; private final Optional defaultDir; private final String catalogId; private final int partitionSegments; @@ -191,51 +203,71 @@ public GlueHiveMetastore( { this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.hdfsContext = new HdfsContext(new ConnectorIdentity(DEFAULT_METASTORE_USER, Optional.empty(), Optional.empty())); - this.glueClient = createAsyncGlueClient(requireNonNull(glueConfig, "glueConfig is null"), stats.newRequestMetricsCollector()); + this.glueClient = createAsyncGlueClient(requireNonNull(glueConfig, "glueConfig is null"), stats.newRequestMetricPublisher()); this.defaultDir = glueConfig.getDefaultWarehouseDir(); this.catalogId = glueConfig.getCatalogId().orElse(null); this.partitionSegments = glueConfig.getPartitionSegments(); this.executor = requireNonNull(executor, "executor is null"); } - private static AWSGlueAsync createAsyncGlueClient(GlueHiveMetastoreConfig config, RequestMetricCollector metricsCollector) + private static GlueAsyncClient createAsyncGlueClient(GlueHiveMetastoreConfig config, MetricPublisher metricPublisher) { - ClientConfiguration clientConfig = new ClientConfiguration() - .withMaxConnections(config.getMaxGlueConnections()) - .withMaxErrorRetry(config.getMaxGlueErrorRetries()); - AWSGlueAsyncClientBuilder asyncGlueClientBuilder = AWSGlueAsyncClientBuilder.standard() - .withMetricsCollector(metricsCollector) - .withClientConfiguration(clientConfig); + NettyNioAsyncHttpClient.Builder nettyBuilder = NettyNioAsyncHttpClient.builder() + .maxConcurrency(config.getMaxGlueConnections()); + + StandardRetryStrategy strategy = AwsRetryStrategy.standardRetryStrategy() + .toBuilder() + .maxAttempts(config.getMaxGlueErrorRetries()) + .build(); + + ClientOverrideConfiguration.Builder overrideConfigBuilder = ClientOverrideConfiguration.builder() + .retryStrategy(strategy) + .addMetricPublisher(metricPublisher); + + GlueAsyncClientBuilder glueAsyncClientBuilder = GlueAsyncClient.builder() + .httpClientBuilder(nettyBuilder) + .overrideConfiguration(overrideConfigBuilder.build()); if (config.getGlueEndpointUrl().isPresent()) { checkArgument(config.getGlueRegion().isPresent(), "Glue region must be set when Glue endpoint URL is set"); - asyncGlueClientBuilder.setEndpointConfiguration(new EndpointConfiguration( - config.getGlueEndpointUrl().get(), - config.getGlueRegion().get())); + glueAsyncClientBuilder + .endpointOverride(URI.create(config.getGlueEndpointUrl().get())) + .region(Region.of(config.getGlueRegion().get())); } else if (config.getGlueRegion().isPresent()) { - asyncGlueClientBuilder.setRegion(config.getGlueRegion().get()); - } - else if (config.getPinGlueClientToCurrentRegion()) { - Region currentRegion = Regions.getCurrentRegion(); - if (currentRegion != null) { - asyncGlueClientBuilder.setRegion(currentRegion.getName()); - } + glueAsyncClientBuilder.region(Region.of(config.getGlueRegion().get())); } + AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create(); if (config.getAwsAccessKey().isPresent() && config.getAwsSecretKey().isPresent()) { - AWSCredentialsProvider credentialsProvider = new AWSStaticCredentialsProvider( - new BasicAWSCredentials(config.getAwsAccessKey().get(), config.getAwsSecretKey().get())); - asyncGlueClientBuilder.setCredentials(credentialsProvider); + credentialsProvider = StaticCredentialsProvider.create( + AwsBasicCredentials.create(config.getAwsAccessKey().get(), config.getAwsSecretKey().get())); } else if (config.getIamRole().isPresent()) { - AWSCredentialsProvider credentialsProvider = new STSAssumeRoleSessionCredentialsProvider - .Builder(config.getIamRole().get(), "roleSessionName") + StsClientBuilder stsClientBuilder = StsClient.builder() + .credentialsProvider(DefaultCredentialsProvider.create()); + + if (config.getGlueStsEndpointUrl().isPresent()) { + checkArgument(config.getGlueStsRegion().isPresent(), "Glue STS region must be set when Glue STS endpoint URL is set"); + stsClientBuilder + .endpointOverride(URI.create(config.getGlueStsEndpointUrl().get())) + .region(Region.of(config.getGlueStsRegion().get())); + } + else if (config.getGlueStsRegion().isPresent()) { + stsClientBuilder.region(Region.of(config.getGlueStsRegion().get())); + } + + credentialsProvider = StsAssumeRoleCredentialsProvider.builder() + .refreshRequest(() -> AssumeRoleRequest.builder() + .roleArn(config.getIamRole().get()) + .roleSessionName("presto-session").build()) + .stsClient(stsClientBuilder.build()) .build(); - asyncGlueClientBuilder.setCredentials(credentialsProvider); } - return asyncGlueClientBuilder.build(); + glueAsyncClientBuilder.credentialsProvider(credentialsProvider); + + return glueAsyncClientBuilder.build(); } @Managed @@ -256,36 +288,35 @@ public int getPartitionCommitBatchSize() @Override public Optional getDatabase(MetastoreContext metastoreContext, String databaseName) { - return stats.getGetDatabase().record(() -> { - try { - GetDatabaseResult result = glueClient.getDatabase(new GetDatabaseRequest().withCatalogId(catalogId).withName(databaseName)); - return Optional.of(GlueToPrestoConverter.convertDatabase(result.getDatabase())); - } - catch (EntityNotFoundException e) { - return Optional.empty(); - } - catch (AmazonServiceException e) { - throw new PrestoException(HIVE_METASTORE_ERROR, e); - } - }); + try { + GetDatabaseResponse response = awsSyncRequest(glueClient::getDatabase, + GetDatabaseRequest.builder().catalogId(catalogId).name(databaseName).build(), + stats.getGetDatabase()); + + return Optional.of(GlueToPrestoConverter.convertDatabase(response.database())); + } + catch (EntityNotFoundException e) { + return Optional.empty(); + } + catch (AwsServiceException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } } @Override public List getAllDatabases(MetastoreContext metastoreContext) { try { - List databaseNames = new ArrayList<>(); - GetDatabasesRequest request = new GetDatabasesRequest().withCatalogId(catalogId); - do { - GetDatabasesResult result = stats.getGetDatabases().record(() -> glueClient.getDatabases(request)); - request.setNextToken(result.getNextToken()); - result.getDatabaseList().forEach(database -> databaseNames.add(database.getName())); - } - while (request.getNextToken() != null); + ImmutableList.Builder databaseNames = ImmutableList.builder(); + + awsSyncPaginatedRequest( + glueClient.getDatabasesPaginator(GetDatabasesRequest.builder().catalogId(catalogId).build()), + getDatabasesResponse -> getDatabasesResponse.databaseList().forEach(database -> databaseNames.add(database.name())), + stats.getGetDatabases()); - return databaseNames; + return databaseNames.build(); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -296,29 +327,28 @@ public Optional
getTable(MetastoreContext metastoreContext, String databa return getGlueTable(databaseName, tableName).map(table -> GlueToPrestoConverter.convertTable(table, databaseName)); } - private com.amazonaws.services.glue.model.Table getGlueTableOrElseThrow(String databaseName, String tableName) + private software.amazon.awssdk.services.glue.model.Table getGlueTableOrElseThrow(String databaseName, String tableName) { return getGlueTable(databaseName, tableName) .orElseThrow(() -> new TableNotFoundException(new SchemaTableName(databaseName, tableName))); } - private Optional getGlueTable(String databaseName, String tableName) + private Optional getGlueTable(String databaseName, String tableName) { - return stats.getGetTable().record(() -> { - try { - GetTableResult result = glueClient.getTable(new GetTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withName(tableName)); - return Optional.of(result.getTable()); - } - catch (EntityNotFoundException e) { - return Optional.empty(); - } - catch (AmazonServiceException e) { - throw new PrestoException(HIVE_METASTORE_ERROR, e); - } - }); + try { + GetTableResponse response = awsSyncRequest( + glueClient::getTable, + GetTableRequest.builder().catalogId(catalogId).databaseName(databaseName).name(tableName).build(), + stats.getGetTable()); + + return Optional.of(response.table()); + } + catch (EntityNotFoundException e) { + return Optional.empty(); + } + catch (AwsServiceException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } } @Override @@ -367,17 +397,22 @@ public void updateTableStatistics(MetastoreContext metastoreContext, String data try { TableInput tableInput = GlueInputConverter.convertTable(table); - tableInput.setParameters(updateStatisticsParameters(table.getParameters(), updatedStatistics.getBasicStatistics())); - UpdateTableRequest request = new UpdateTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableInput(tableInput); - stats.getUpdateTable().record(() -> glueClient.updateTable(request)); + final Map statisticsParameters = + updateStatisticsParameters(table.getParameters(), updatedStatistics.getBasicStatistics()); + + awsSyncRequest( + glueClient::updateTable, + UpdateTableRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableInput(tableInput.toBuilder().parameters(statisticsParameters).build()) + .build(), + stats.getUpdateTable()); } catch (EntityNotFoundException e) { throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -399,18 +434,24 @@ public void updatePartitionStatistics(MetastoreContext metastoreContext, String .orElseThrow(() -> new PartitionNotFoundException(new SchemaTableName(databaseName, tableName), partitionValues)); try { PartitionInput partitionInput = GlueInputConverter.convertPartition(partition); - partitionInput.setParameters(updateStatisticsParameters(partition.getParameters(), updatedStatistics.getBasicStatistics())); - stats.getUpdatePartition().record(() -> glueClient.updatePartition(new UpdatePartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionValueList(partition.getValues()) - .withPartitionInput(partitionInput))); + final Map statisticsParameters = + updateStatisticsParameters(partition.getParameters(), updatedStatistics.getBasicStatistics()); + + awsSyncRequest( + glueClient::updatePartition, + UpdatePartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionValueList(partition.getValues()) + .partitionInput(partitionInput.toBuilder().parameters(statisticsParameters).build()) + .build(), + stats.getUpdatePartition()); } catch (EntityNotFoundException e) { throw new PartitionNotFoundException(new SchemaTableName(databaseName, tableName), partitionValues); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -419,22 +460,24 @@ public void updatePartitionStatistics(MetastoreContext metastoreContext, String public Optional> getAllTables(MetastoreContext metastoreContext, String databaseName) { try { - List tableNames = new ArrayList<>(); - GetTablesRequest request = new GetTablesRequest().withCatalogId(catalogId).withDatabaseName(databaseName); - do { - GetTablesResult result = stats.getGetTables().record(() -> glueClient.getTables(request)); - request.setNextToken(result.getNextToken()); - result.getTableList().forEach(table -> tableNames.add(table.getName())); - } - while (request.getNextToken() != null); + ImmutableList.Builder tableNames = ImmutableList.builder(); + + awsSyncPaginatedRequest( + glueClient.getTablesPaginator(GetTablesRequest.builder().catalogId(catalogId).databaseName(databaseName).build()), + getTablesResponse -> { + getTablesResponse.tableList().stream() + .map(software.amazon.awssdk.services.glue.model.Table::name) + .forEach(tableNames::add); + }, + stats.getGetTables()); - return Optional.of(tableNames); + return Optional.of(tableNames.build()); } catch (EntityNotFoundException e) { // database does not exist return Optional.empty(); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -443,25 +486,25 @@ public Optional> getAllTables(MetastoreContext metastoreContext, St public Optional> getAllViews(MetastoreContext metastoreContext, String databaseName) { try { - List views = new ArrayList<>(); - GetTablesRequest request = new GetTablesRequest().withCatalogId(catalogId).withDatabaseName(databaseName); - - do { - GetTablesResult result = stats.getGetTables().record(() -> glueClient.getTables(request)); - request.setNextToken(result.getNextToken()); - result.getTableList().stream() - .filter(table -> VIRTUAL_VIEW.name().equals(table.getTableType())) - .forEach(table -> views.add(table.getName())); - } - while (request.getNextToken() != null); + ImmutableList.Builder viewNames = ImmutableList.builder(); + + awsSyncPaginatedRequest( + glueClient.getTablesPaginator(GetTablesRequest.builder().catalogId(catalogId).databaseName(databaseName).build()), + getTablesResponse -> { + getTablesResponse.tableList().stream() + .filter(table -> VIRTUAL_VIEW.name().equals(table.tableType())) + .map(software.amazon.awssdk.services.glue.model.Table::name) + .forEach(viewNames::add); + }, + stats.getGetTables()); - return Optional.of(views); + return Optional.of(viewNames.build()); } catch (EntityNotFoundException e) { // database does not exist return Optional.empty(); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -478,12 +521,15 @@ public void createDatabase(MetastoreContext metastoreContext, Database database) try { DatabaseInput databaseInput = GlueInputConverter.convertDatabase(database); - stats.getCreateDatabase().record(() -> glueClient.createDatabase(new CreateDatabaseRequest().withCatalogId(catalogId).withDatabaseInput(databaseInput))); + awsSyncRequest( + glueClient::createDatabase, + CreateDatabaseRequest.builder().catalogId(catalogId).databaseInput(databaseInput).build(), + stats.getCreateDatabase()); } catch (AlreadyExistsException e) { throw new SchemaAlreadyExistsException(database.getDatabaseName()); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } @@ -496,12 +542,15 @@ public void createDatabase(MetastoreContext metastoreContext, Database database) public void dropDatabase(MetastoreContext metastoreContext, String databaseName) { try { - stats.getDeleteDatabase().record(() -> glueClient.deleteDatabase(new DeleteDatabaseRequest().withCatalogId(catalogId).withName(databaseName))); + awsSyncRequest( + glueClient::deleteDatabase, + DeleteDatabaseRequest.builder().catalogId(catalogId).name(databaseName).build(), + stats.getDeleteDatabase()); } catch (EntityNotFoundException e) { throw new SchemaNotFoundException(databaseName); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -511,13 +560,18 @@ public void renameDatabase(MetastoreContext metastoreContext, String databaseNam { try { Database database = getDatabase(metastoreContext, databaseName).orElseThrow(() -> new SchemaNotFoundException(databaseName)); - DatabaseInput renamedDatabase = GlueInputConverter.convertDatabase(database).withName(newDatabaseName); - stats.getUpdateDatabase().record(() -> glueClient.updateDatabase(new UpdateDatabaseRequest() - .withCatalogId(catalogId) - .withName(databaseName) - .withDatabaseInput(renamedDatabase))); - } - catch (AmazonServiceException e) { + DatabaseInput renamedDatabase = GlueInputConverter.convertDatabase(database); + + awsSyncRequest( + glueClient::updateDatabase, + UpdateDatabaseRequest.builder() + .catalogId(catalogId) + .name(databaseName) + .databaseInput(renamedDatabase.toBuilder().name(newDatabaseName).build()) + .build(), + stats.getUpdateDatabase()); + } + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -531,10 +585,14 @@ public MetastoreOperationResult createTable(MetastoreContext metastoreContext, T try { TableInput input = GlueInputConverter.convertTable(table); - stats.getCreateTable().record(() -> glueClient.createTable(new CreateTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(table.getDatabaseName()) - .withTableInput(input))); + awsSyncRequest( + glueClient::createTable, + CreateTableRequest.builder() + .catalogId(catalogId) + .databaseName(table.getDatabaseName()) + .tableInput(input) + .build(), + stats.getCreateTable()); } catch (AlreadyExistsException e) { throw new TableAlreadyExistsException(new SchemaTableName(table.getDatabaseName(), table.getTableName())); @@ -542,7 +600,7 @@ public MetastoreOperationResult createTable(MetastoreContext metastoreContext, T catch (EntityNotFoundException e) { throw new SchemaNotFoundException(table.getDatabaseName()); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } @@ -555,12 +613,16 @@ public void dropTable(MetastoreContext metastoreContext, String databaseName, St Table table = getTableOrElseThrow(metastoreContext, databaseName, tableName); try { - stats.getDeleteTable().record(() -> glueClient.deleteTable(new DeleteTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withName(tableName))); - } - catch (AmazonServiceException e) { + awsSyncRequest( + glueClient::deleteTable, + DeleteTableRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .name(tableName) + .build(), + stats.getDeleteTable()); + } + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } @@ -575,17 +637,52 @@ public MetastoreOperationResult replaceTable(MetastoreContext metastoreContext, { try { TableInput newTableInput = GlueInputConverter.convertTable(newTable); - stats.getUpdateTable().record(() -> glueClient.updateTable(new UpdateTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableInput(newTableInput))); + + awsSyncRequest( + glueClient::updateTable, + UpdateTableRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableInput(newTableInput) + .build(), + stats.getUpdateTable()); + + return EMPTY_RESULT; + } + catch (EntityNotFoundException e) { + throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); + } + catch (AwsServiceException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } + } + + public MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters) + { + PartitionStatistics updatedStatistics = update.get(); + if (!updatedStatistics.getColumnStatistics().isEmpty()) { + throw new PrestoException(NOT_SUPPORTED, "Glue metastore does not support column level statistics"); + } + try { + TableInput newTableInput = GlueInputConverter.convertTable(newTable); + final Map statisticsParameters = + updateStatisticsParameters(newTableInput.parameters(), updatedStatistics.getBasicStatistics()); + + awsSyncRequest( + glueClient::updateTable, + UpdateTableRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableInput(newTableInput.toBuilder().parameters(statisticsParameters).build()) + .build(), + stats.getUpdateTable()); return EMPTY_RESULT; } catch (EntityNotFoundException e) { throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -599,36 +696,38 @@ public MetastoreOperationResult renameTable(MetastoreContext metastoreContext, S @Override public MetastoreOperationResult addColumn(MetastoreContext metastoreContext, String databaseName, String tableName, String columnName, HiveType columnType, String columnComment) { - com.amazonaws.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); - ImmutableList.Builder newDataColumns = ImmutableList.builder(); - newDataColumns.addAll(table.getStorageDescriptor().getColumns()); + software.amazon.awssdk.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); + ImmutableList.Builder newDataColumns = ImmutableList.builder(); + newDataColumns.addAll(table.storageDescriptor().columns()); newDataColumns.add(convertColumn(new Column(columnName, columnType, Optional.ofNullable(columnComment), Optional.empty()))); - table.getStorageDescriptor().setColumns(newDataColumns.build()); - replaceGlueTable(databaseName, tableName, table); + StorageDescriptor newStorageDescriptor = table.storageDescriptor().toBuilder().columns(newDataColumns.build()).build(); + replaceGlueTable(databaseName, tableName, table.toBuilder().storageDescriptor(newStorageDescriptor).build()); return EMPTY_RESULT; } @Override public MetastoreOperationResult renameColumn(MetastoreContext metastoreContext, String databaseName, String tableName, String oldColumnName, String newColumnName) { - com.amazonaws.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); - if (table.getPartitionKeys() != null && table.getPartitionKeys().stream().anyMatch(c -> c.getName().equals(oldColumnName))) { + software.amazon.awssdk.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); + if (table.partitionKeys() != null && table.partitionKeys().stream().anyMatch(c -> c.name().equals(oldColumnName))) { throw new PrestoException(NOT_SUPPORTED, "Renaming partition columns is not supported"); } - ImmutableList.Builder newDataColumns = ImmutableList.builder(); - for (com.amazonaws.services.glue.model.Column column : table.getStorageDescriptor().getColumns()) { - if (column.getName().equals(oldColumnName)) { - newDataColumns.add(new com.amazonaws.services.glue.model.Column() - .withName(newColumnName) - .withType(column.getType()) - .withComment(column.getComment())); + ImmutableList.Builder newDataColumns = ImmutableList.builder(); + for (software.amazon.awssdk.services.glue.model.Column column : table.storageDescriptor().columns()) { + if (column.name().equals(oldColumnName)) { + newDataColumns.add(software.amazon.awssdk.services.glue.model.Column.builder() + .name(newColumnName) + .type(column.type()) + .comment(column.comment()) + .build()); } else { newDataColumns.add(column); } } - table.getStorageDescriptor().setColumns(newDataColumns.build()); - replaceGlueTable(databaseName, tableName, table); + + StorageDescriptor newStorageDescriptor = table.storageDescriptor().toBuilder().columns(newDataColumns.build()).build(); + replaceGlueTable(databaseName, tableName, table.toBuilder().storageDescriptor(newStorageDescriptor).build()); return EMPTY_RESULT; } @@ -636,12 +735,12 @@ public MetastoreOperationResult renameColumn(MetastoreContext metastoreContext, public MetastoreOperationResult dropColumn(MetastoreContext metastoreContext, String databaseName, String tableName, String columnName) { verifyCanDropColumn(this, metastoreContext, databaseName, tableName, columnName); - com.amazonaws.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); + software.amazon.awssdk.services.glue.model.Table table = getGlueTableOrElseThrow(databaseName, tableName); - ImmutableList.Builder newDataColumns = ImmutableList.builder(); + ImmutableList.Builder newDataColumns = ImmutableList.builder(); boolean found = false; - for (com.amazonaws.services.glue.model.Column column : table.getStorageDescriptor().getColumns()) { - if (column.getName().equals(columnName)) { + for (software.amazon.awssdk.services.glue.model.Column column : table.storageDescriptor().columns()) { + if (column.name().equals(columnName)) { found = true; } else { @@ -654,24 +753,28 @@ public MetastoreOperationResult dropColumn(MetastoreContext metastoreContext, St throw new ColumnNotFoundException(name, columnName); } - table.getStorageDescriptor().setColumns(newDataColumns.build()); - replaceGlueTable(databaseName, tableName, table); + StorageDescriptor newStorageDescriptor = table.storageDescriptor().toBuilder().columns(newDataColumns.build()).build(); + replaceGlueTable(databaseName, tableName, table.toBuilder().storageDescriptor(newStorageDescriptor).build()); return EMPTY_RESULT; } - private void replaceGlueTable(String databaseName, String tableName, com.amazonaws.services.glue.model.Table newTable) + private void replaceGlueTable(String databaseName, String tableName, software.amazon.awssdk.services.glue.model.Table newTable) { try { - stats.getUpdateTable().record(() -> glueClient.updateTable(new UpdateTableRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableInput(toTableInput(newTable)))); + awsSyncRequest( + glueClient::updateTable, + UpdateTableRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableInput(toTableInput(newTable)) + .build(), + stats.getUpdateTable()); } catch (EntityNotFoundException e) { throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -679,22 +782,25 @@ private void replaceGlueTable(String databaseName, String tableName, com.amazona @Override public Optional getPartition(MetastoreContext metastoreContext, String databaseName, String tableName, List partitionValues) { - return stats.getGetPartition().record(() -> { - try { - GetPartitionResult result = glueClient.getPartition(new GetPartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionValues(partitionValues)); - return Optional.of(new GluePartitionConverter(databaseName, tableName).apply(result.getPartition())); - } - catch (EntityNotFoundException e) { - return Optional.empty(); - } - catch (AmazonServiceException e) { - throw new PrestoException(HIVE_METASTORE_ERROR, e); - } - }); + try { + GetPartitionResponse response = awsSyncRequest( + glueClient::getPartition, + GetPartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionValues(partitionValues) + .build(), + stats.getGetPartition()); + + return Optional.of(new GluePartitionConverter(databaseName, tableName).apply(response.partition())); + } + catch (EntityNotFoundException e) { + return Optional.empty(); + } + catch (AwsServiceException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } } @Override @@ -748,7 +854,7 @@ private List getPartitions(String databaseName, String tableName, Str // Do parallel partition fetch. CompletionService> completionService = new ExecutorCompletionService<>(executor); for (int i = 0; i < partitionSegments; i++) { - Segment segment = new Segment().withSegmentNumber(i).withTotalSegments(partitionSegments); + Segment segment = Segment.builder().segmentNumber(i).totalSegments(partitionSegments).build(); completionService.submit(() -> getPartitions(databaseName, tableName, expression, segment)); } @@ -774,28 +880,30 @@ private List getPartitions(String databaseName, String tableName, Str { try { GluePartitionConverter converter = new GluePartitionConverter(databaseName, tableName); - ArrayList partitions = new ArrayList<>(); - GetPartitionsRequest request = new GetPartitionsRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withExpression(expression) - .withSegment(segment) - .withMaxResults(AWS_GLUE_GET_PARTITIONS_MAX_RESULTS); - - do { - GetPartitionsResult result = stats.getGetPartitions().record(() -> glueClient.getPartitions(request)); - request.setNextToken(result.getNextToken()); - partitions.ensureCapacity(partitions.size() + result.getPartitions().size()); - result.getPartitions().stream() - .map(converter) - .forEach(partitions::add); - } - while (request.getNextToken() != null); - return partitions; + ImmutableList.Builder partitionBuilder = ImmutableList.builder(); + + GetPartitionsRequest partitionsRequest = GetPartitionsRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .expression(expression) + .segment(segment) + .maxResults(AWS_GLUE_GET_PARTITIONS_MAX_RESULTS) + .build(); + + awsSyncPaginatedRequest( + glueClient.getPartitionsPaginator(partitionsRequest), + getPartitionsResponse -> { + getPartitionsResponse.partitions().stream() + .map(converter) + .forEach(partitionBuilder::add); + }, + stats.getGetPartitions()); + + return partitionBuilder.build(); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -841,28 +949,40 @@ public Map> getPartitionsByNames(MetastoreContext me private List batchGetPartition(String databaseName, String tableName, List partitionNames) { try { - List> batchGetPartitionFutures = new ArrayList<>(); + List> batchGetPartitionFutures = new ArrayList<>(); for (List partitionNamesBatch : Lists.partition(partitionNames, BATCH_GET_PARTITION_MAX_PAGE_SIZE)) { - List partitionValuesBatch = mappedCopy(partitionNamesBatch, partitionName -> new PartitionValueList().withValues(toPartitionValues(partitionName))); - batchGetPartitionFutures.add(glueClient.batchGetPartitionAsync(new BatchGetPartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionsToGet(partitionValuesBatch), stats.getBatchGetPartitions().metricsAsyncHandler())); + List partitionValuesBatch = mappedCopy(partitionNamesBatch, partitionName -> PartitionValueList.builder().values(toPartitionValues(partitionName)).build()); + + GlueStatsAsyncHandler asyncHandler = new GlueStatsAsyncHandler(stats.getBatchGetPartitions()); + + batchGetPartitionFutures.add(glueClient.batchGetPartition(BatchGetPartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionsToGet(partitionValuesBatch) + .build()) + .whenCompleteAsync((response, exception) -> { + if (response != null) { + asyncHandler.onSuccess(response); + } + else if (exception != null) { + asyncHandler.onError(exception); + } + })); } GluePartitionConverter converter = new GluePartitionConverter(databaseName, tableName); ImmutableList.Builder resultsBuilder = ImmutableList.builderWithExpectedSize(partitionNames.size()); - for (Future future : batchGetPartitionFutures) { - future.get().getPartitions().stream() + for (Future future : batchGetPartitionFutures) { + future.get().partitions().stream() .map(converter) .forEach(resultsBuilder::add); } return resultsBuilder.build(); } - catch (AmazonServiceException | InterruptedException | ExecutionException e) { + catch (AwsServiceException | InterruptedException | ExecutionException e) { if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } @@ -874,25 +994,37 @@ private List batchGetPartition(String databaseName, String tableName, public MetastoreOperationResult addPartitions(MetastoreContext metastoreContext, String databaseName, String tableName, List partitions) { try { - List> futures = new ArrayList<>(); + List> futures = new ArrayList<>(); for (List partitionBatch : Lists.partition(partitions, BATCH_CREATE_PARTITION_MAX_PAGE_SIZE)) { List partitionInputs = mappedCopy(partitionBatch, GlueInputConverter::convertPartition); - futures.add(glueClient.batchCreatePartitionAsync(new BatchCreatePartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionInputList(partitionInputs), stats.getBatchCreatePartitions().metricsAsyncHandler())); + + GlueStatsAsyncHandler asyncHandler = new GlueStatsAsyncHandler(stats.getBatchCreatePartitions()); + + futures.add(glueClient.batchCreatePartition(BatchCreatePartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionInputList(partitionInputs) + .build()) + .whenCompleteAsync((response, exception) -> { + if (response != null) { + asyncHandler.onSuccess(response); + } + else if (exception != null) { + asyncHandler.onError(exception); + } + })); } - for (Future future : futures) { - BatchCreatePartitionResult result = future.get(); - propagatePartitionErrorToPrestoException(databaseName, tableName, result.getErrors()); + for (Future future : futures) { + BatchCreatePartitionResponse result = future.get(); + propagatePartitionErrorToPrestoException(databaseName, tableName, result.errors()); } return EMPTY_RESULT; } - catch (AmazonServiceException | InterruptedException | ExecutionException e) { + catch (AwsServiceException | InterruptedException | ExecutionException e) { if (e instanceof InterruptedException) { Thread.currentThread().interrupt(); } @@ -903,16 +1035,16 @@ public MetastoreOperationResult addPartitions(MetastoreContext metastoreContext, private static void propagatePartitionErrorToPrestoException(String databaseName, String tableName, List partitionErrors) { if (partitionErrors != null && !partitionErrors.isEmpty()) { - ErrorDetail errorDetail = partitionErrors.get(0).getErrorDetail(); - String glueExceptionCode = errorDetail.getErrorCode(); + ErrorDetail errorDetail = partitionErrors.get(0).errorDetail(); + String glueExceptionCode = errorDetail.errorCode(); switch (glueExceptionCode) { case "AlreadyExistsException": - throw new PrestoException(ALREADY_EXISTS, errorDetail.getErrorMessage()); + throw new PrestoException(ALREADY_EXISTS, errorDetail.errorMessage()); case "EntityNotFoundException": - throw new TableNotFoundException(new SchemaTableName(databaseName, tableName), errorDetail.getErrorMessage()); + throw new TableNotFoundException(new SchemaTableName(databaseName, tableName), errorDetail.errorMessage()); default: - throw new PrestoException(HIVE_METASTORE_ERROR, errorDetail.getErrorCode() + ": " + errorDetail.getErrorMessage()); + throw new PrestoException(HIVE_METASTORE_ERROR, errorDetail.errorCode() + ": " + errorDetail.errorMessage()); } } } @@ -925,13 +1057,17 @@ public void dropPartition(MetastoreContext metastoreContext, String databaseName .orElseThrow(() -> new PartitionNotFoundException(new SchemaTableName(databaseName, tableName), parts)); try { - stats.getDeletePartition().record(() -> glueClient.deletePartition(new DeletePartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionValues(parts))); - } - catch (AmazonServiceException e) { + awsSyncRequest( + glueClient::deletePartition, + DeletePartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionValues(parts) + .build(), + stats.getDeletePartition()); + } + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } @@ -946,19 +1082,24 @@ public MetastoreOperationResult alterPartition(MetastoreContext metastoreContext { try { PartitionInput newPartition = GlueInputConverter.convertPartition(partition); - stats.getUpdatePartition().record(() -> glueClient.updatePartition(new UpdatePartitionRequest() - .withCatalogId(catalogId) - .withDatabaseName(databaseName) - .withTableName(tableName) - .withPartitionInput(newPartition) - .withPartitionValueList(partition.getPartition().getValues()))); + + awsSyncRequest( + glueClient::updatePartition, + UpdatePartitionRequest.builder() + .catalogId(catalogId) + .databaseName(databaseName) + .tableName(tableName) + .partitionInput(newPartition) + .partitionValueList(partition.getPartition().getValues()) + .build(), + stats.getUpdatePartition()); return EMPTY_RESULT; } catch (EntityNotFoundException e) { throw new PartitionNotFoundException(new SchemaTableName(databaseName, tableName), partition.getPartition().getValues()); } - catch (AmazonServiceException e) { + catch (AwsServiceException e) { throw new PrestoException(HIVE_METASTORE_ERROR, e); } } @@ -1047,4 +1188,80 @@ public MetastoreOperationResult addConstraint(MetastoreContext metastoreContext, { throw new PrestoException(NOT_SUPPORTED, "addConstraint is not supported by Glue"); } + + public static T awsSyncRequest( + Function> submission, + R request, + GlueCatalogApiStats stats) + { + requireNonNull(submission, "submission is null"); + requireNonNull(request, "request is null"); + + try { + if (stats != null) { + return stats.record(() -> submission.apply(request).join()); + } + + return submission.apply(request).join(); + } + catch (CompletionException e) { + if (e.getCause() instanceof GlueException) { + throw (GlueException) e.getCause(); + } + throw new PrestoException(HIVE_METASTORE_ERROR, e.getCause()); + } + } + + private static void awsSyncPaginatedRequest( + SdkPublisher paginator, + Consumer resultConsumer, + GlueCatalogApiStats stats) + { + requireNonNull(paginator, "paginator is null"); + requireNonNull(resultConsumer, "resultConsumer is null"); + + // Single join point so exception handling is consistent, and stats (when present) + // cover the full wall-clock time of the paginated request including completion. + Runnable paginationTask = () -> paginator.subscribe(resultConsumer).join(); + + try { + if (stats != null) { + stats.record(() -> { + paginationTask.run(); + return null; + }); + } + else { + paginationTask.run(); + } + } + catch (CompletionException e) { + if (e.getCause() instanceof GlueException) { + throw (GlueException) e.getCause(); + } + throw new PrestoException(HIVE_METASTORE_ERROR, e.getCause()); + } + } + + static class GlueStatsAsyncHandler + { + private final GlueCatalogApiStats stats; + private final Stopwatch stopwatch; + + public GlueStatsAsyncHandler(GlueCatalogApiStats stats) + { + this.stats = requireNonNull(stats, "stats is null"); + this.stopwatch = Stopwatch.createStarted(); + } + + public void onError(Throwable e) + { + stats.recordAsync(stopwatch.elapsed(NANOSECONDS), true); + } + + public void onSuccess(GlueResponse response) + { + stats.recordAsync(stopwatch.elapsed(NANOSECONDS), false); + } + } } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastoreConfig.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastoreConfig.java index 53b3f0e5b0295..06aa90aef674b 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastoreConfig.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueHiveMetastoreConfig.java @@ -16,17 +16,19 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; - -import javax.validation.constraints.Max; -import javax.validation.constraints.Min; +import com.facebook.airlift.configuration.DefunctConfig; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; import java.util.Optional; +@DefunctConfig("hive.metastore.glue.pin-client-to-current-region") public class GlueHiveMetastoreConfig { private Optional glueRegion = Optional.empty(); private Optional glueEndpointUrl = Optional.empty(); - private boolean pinGlueClientToCurrentRegion; + private Optional glueStsRegion = Optional.empty(); + private Optional glueStsEndpointUrl = Optional.empty(); private int maxGlueErrorRetries = 10; private int maxGlueConnections = 50; private Optional defaultWarehouseDir = Optional.empty(); @@ -63,16 +65,29 @@ public GlueHiveMetastoreConfig setGlueEndpointUrl(String glueEndpointUrl) return this; } - public boolean getPinGlueClientToCurrentRegion() + public Optional getGlueStsRegion() + { + return glueStsRegion; + } + + @Config("hive.metastore.glue.sts.region") + @ConfigDescription("AWS STS region for Glue authentication") + public GlueHiveMetastoreConfig setGlueStsRegion(String region) + { + this.glueStsRegion = Optional.ofNullable(region); + return this; + } + + public Optional getGlueStsEndpointUrl() { - return pinGlueClientToCurrentRegion; + return glueStsEndpointUrl; } - @Config("hive.metastore.glue.pin-client-to-current-region") - @ConfigDescription("Should the Glue client be pinned to the current EC2 region") - public GlueHiveMetastoreConfig setPinGlueClientToCurrentRegion(boolean pinGlueClientToCurrentRegion) + @Config("hive.metastore.glue.sts.endpoint-url") + @ConfigDescription("AWS STS endpoint URL for Glue authentication") + public GlueHiveMetastoreConfig setGlueStsEndpointUrl(String glueStsEndpointUrl) { - this.pinGlueClientToCurrentRegion = pinGlueClientToCurrentRegion; + this.glueStsEndpointUrl = Optional.ofNullable(glueStsEndpointUrl); return this; } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreModule.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreModule.java index 0bc9bee6b81fc..607165178c5c6 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreModule.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreModule.java @@ -19,6 +19,7 @@ import com.facebook.presto.hive.HiveCommonClientConfig; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.google.inject.Binder; import com.google.inject.Provides; import com.google.inject.Scopes; @@ -51,6 +52,7 @@ public void setup(Binder binder) checkArgument(buildConfigObject(HiveCommonClientConfig.class).getCatalogName() == null, "'hive.metastore.catalog.name' should not be set for glue metastore"); configBinder(binder).bindConfig(GlueHiveMetastoreConfig.class); binder.bind(GlueHiveMetastore.class).in(Scopes.SINGLETON); + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).annotatedWith(ForCachingHiveMetastore.class).to(GlueHiveMetastore.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); newExporter(binder).export(ExtendedHiveMetastore.class) diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreStats.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreStats.java index 1c7ece62c038b..b0779572f46be 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreStats.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/GlueMetastoreStats.java @@ -13,15 +13,11 @@ */ package com.facebook.presto.hive.metastore.glue; -import com.facebook.airlift.stats.CounterStat; -import com.facebook.airlift.stats.TimeStat; -import com.facebook.presto.hive.aws.AbstractSdkMetricsCollector; -import io.airlift.units.Duration; +import com.facebook.presto.hive.aws.metrics.AwsSdkClientStats; +import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; - -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; +import software.amazon.awssdk.metrics.MetricPublisher; public class GlueMetastoreStats { @@ -42,13 +38,7 @@ public class GlueMetastoreStats private final GlueCatalogApiStats getPartitions = new GlueCatalogApiStats(); private final GlueCatalogApiStats getPartition = new GlueCatalogApiStats(); - // see AWSRequestMetrics - private final CounterStat awsRequestCount = new CounterStat(); - private final CounterStat awsRetryCount = new CounterStat(); - private final CounterStat awsThrottleExceptions = new CounterStat(); - private final TimeStat awsRequestTime = new TimeStat(MILLISECONDS); - private final TimeStat awsClientExecuteTime = new TimeStat(MILLISECONDS); - private final TimeStat awsClientRetryPauseTime = new TimeStat(MILLISECONDS); + private final AwsSdkClientStats awsSdkClientStats = new AwsSdkClientStats(); @Managed @Nested @@ -163,96 +153,14 @@ public GlueCatalogApiStats getGetPartition() } @Managed - @Nested - public CounterStat getAwsRequestCount() - { - return awsRequestCount; - } - - @Managed - @Nested - public CounterStat getAwsRetryCount() + @Flatten + public AwsSdkClientStats getAwsSdkClientStats() { - return awsRetryCount; + return awsSdkClientStats; } - @Managed - @Nested - public CounterStat getAwsThrottleExceptions() - { - return awsThrottleExceptions; - } - - @Managed - @Nested - public TimeStat getAwsRequestTime() - { - return awsRequestTime; - } - - @Managed - @Nested - public TimeStat getAwsClientExecuteTime() + public MetricPublisher newRequestMetricPublisher() { - return awsClientExecuteTime; - } - - @Managed - @Nested - public TimeStat getAwsClientRetryPauseTime() - { - return awsClientRetryPauseTime; - } - - public GlueSdkClientMetricsCollector newRequestMetricsCollector() - { - return new GlueSdkClientMetricsCollector(this); - } - - public static class GlueSdkClientMetricsCollector - extends AbstractSdkMetricsCollector - { - private final GlueMetastoreStats stats; - - public GlueSdkClientMetricsCollector(GlueMetastoreStats stats) - { - this.stats = requireNonNull(stats, "stats is null"); - } - - @Override - protected void recordRequestCount(long count) - { - stats.awsRequestCount.update(count); - } - - @Override - protected void recordRetryCount(long count) - { - stats.awsRetryCount.update(count); - } - - @Override - protected void recordThrottleExceptionCount(long count) - { - stats.awsThrottleExceptions.update(count); - } - - @Override - protected void recordHttpRequestTime(Duration duration) - { - stats.awsRequestTime.add(duration); - } - - @Override - protected void recordClientExecutionTime(Duration duration) - { - stats.awsClientExecuteTime.add(duration); - } - - @Override - protected void recordRetryPauseTime(Duration duration) - { - stats.awsClientRetryPauseTime.add(duration); - } + return awsSdkClientStats.newRequestMetricsPublisher(); } } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueInputConverter.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueInputConverter.java index edf2b7d17fdf5..7c4cf09f3746f 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueInputConverter.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueInputConverter.java @@ -13,12 +13,6 @@ */ package com.facebook.presto.hive.metastore.glue.converter; -import com.amazonaws.services.glue.model.DatabaseInput; -import com.amazonaws.services.glue.model.Order; -import com.amazonaws.services.glue.model.PartitionInput; -import com.amazonaws.services.glue.model.SerDeInfo; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.TableInput; import com.facebook.presto.hive.HiveBucketProperty; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Database; @@ -29,6 +23,12 @@ import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableMap; +import software.amazon.awssdk.services.glue.model.DatabaseInput; +import software.amazon.awssdk.services.glue.model.Order; +import software.amazon.awssdk.services.glue.model.PartitionInput; +import software.amazon.awssdk.services.glue.model.SerDeInfo; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.TableInput; import java.util.EnumSet; import java.util.List; @@ -41,7 +41,6 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static java.util.stream.Collectors.toList; public final class GlueInputConverter { @@ -49,41 +48,42 @@ private GlueInputConverter() {} public static DatabaseInput convertDatabase(Database database) { - DatabaseInput input = new DatabaseInput(); - input.setName(database.getDatabaseName()); - input.setParameters(database.getParameters()); - database.getComment().ifPresent(input::setDescription); - database.getLocation().ifPresent(input::setLocationUri); - return input; + return DatabaseInput.builder() + .name(database.getDatabaseName()) + .parameters(database.getParameters()) + .applyMutation(builder -> database.getComment().ifPresent(builder::description)) + .applyMutation(builder -> database.getLocation().ifPresent(builder::locationUri)) + .build(); } public static TableInput convertTable(Table table) { - TableInput input = new TableInput(); - input.setName(table.getTableName()); - input.setOwner(table.getOwner()); checkArgument(EnumSet.of(MANAGED_TABLE, EXTERNAL_TABLE, VIRTUAL_VIEW).contains(table.getTableType()), "Invalid table type: %s", table.getTableType()); - input.setTableType(table.getTableType().toString()); - input.setStorageDescriptor(convertStorage(table.getStorage(), table.getDataColumns())); - input.setPartitionKeys(table.getPartitionColumns().stream().map(GlueInputConverter::convertColumn).collect(toList())); - input.setParameters(table.getParameters()); - table.getViewOriginalText().ifPresent(input::setViewOriginalText); - table.getViewExpandedText().ifPresent(input::setViewExpandedText); - return input; + + return TableInput.builder() + .name(table.getTableName()) + .owner(table.getOwner()) + .tableType(table.getTableType().toString()) + .storageDescriptor(convertStorage(table.getStorage(), table.getDataColumns())) + .partitionKeys(table.getPartitionColumns().stream().map(GlueInputConverter::convertColumn).collect(toImmutableList())) + .parameters(table.getParameters()) + .applyMutation(builder -> table.getViewOriginalText().ifPresent(builder::viewOriginalText)) + .applyMutation(builder -> table.getViewExpandedText().ifPresent(builder::viewExpandedText)) + .build(); } - public static TableInput toTableInput(com.amazonaws.services.glue.model.Table table) + public static TableInput toTableInput(software.amazon.awssdk.services.glue.model.Table table) { - TableInput input = new TableInput(); - input.setName(table.getName()); - input.setOwner(table.getOwner()); - input.setTableType(table.getTableType()); - input.setStorageDescriptor(table.getStorageDescriptor()); - input.setPartitionKeys(table.getPartitionKeys()); - input.setParameters(table.getParameters()); - input.setViewOriginalText(table.getViewOriginalText()); - input.setViewExpandedText(table.getViewExpandedText()); - return input; + return TableInput.builder() + .name(table.name()) + .owner(table.owner()) + .tableType(table.tableType()) + .storageDescriptor(table.storageDescriptor()) + .partitionKeys(table.partitionKeys()) + .parameters(table.parameters()) + .viewOriginalText(table.viewOriginalText()) + .viewExpandedText(table.viewExpandedText()) + .build(); } public static PartitionInput convertPartition(PartitionWithStatistics partitionWithStatistics) @@ -93,17 +93,17 @@ public static PartitionInput convertPartition(PartitionWithStatistics partitionW if (!statistics.getColumnStatistics().isEmpty()) { throw new PrestoException(NOT_SUPPORTED, "Glue metastore does not support column level statistics"); } - input.setParameters(updateStatisticsParameters(input.getParameters(), statistics.getBasicStatistics())); - return input; + return input.toBuilder().parameters(updateStatisticsParameters(input.parameters(), statistics.getBasicStatistics())) + .build(); } public static PartitionInput convertPartition(Partition partition) { - PartitionInput input = new PartitionInput(); - input.setValues(partition.getValues()); - input.setStorageDescriptor(convertStorage(partition.getStorage(), partition.getColumns())); - input.setParameters(partition.getParameters()); - return input; + return PartitionInput.builder() + .values(partition.getValues()) + .storageDescriptor(convertStorage(partition.getStorage(), partition.getColumns())) + .parameters(partition.getParameters()) + .build(); } private static StorageDescriptor convertStorage(Storage storage, List columns) @@ -111,37 +111,39 @@ private static StorageDescriptor convertStorage(Storage storage, List co if (storage.isSkewed()) { throw new IllegalArgumentException("Writing to skewed table/partition is not supported"); } - SerDeInfo serdeInfo = new SerDeInfo() - .withSerializationLibrary(storage.getStorageFormat().getSerDeNullable()) - .withParameters(storage.getSerdeParameters()); + SerDeInfo serDeInfo = SerDeInfo.builder() + .serializationLibrary(storage.getStorageFormat().getSerDeNullable()) + .parameters(storage.getSerdeParameters()) + .build(); - StorageDescriptor sd = new StorageDescriptor(); - sd.setLocation(storage.getLocation()); - sd.setColumns(columns.stream().map(GlueInputConverter::convertColumn).collect(toList())); - sd.setSerdeInfo(serdeInfo); - sd.setInputFormat(storage.getStorageFormat().getInputFormatNullable()); - sd.setOutputFormat(storage.getStorageFormat().getOutputFormatNullable()); - sd.setParameters(ImmutableMap.of()); + StorageDescriptor.Builder sd = StorageDescriptor.builder() + .location(storage.getLocation()) + .columns(columns.stream().map(GlueInputConverter::convertColumn).collect(toImmutableList())) + .serdeInfo(serDeInfo) + .inputFormat(storage.getStorageFormat().getInputFormatNullable()) + .outputFormat(storage.getStorageFormat().getOutputFormatNullable()) + .parameters(ImmutableMap.of()); Optional bucketProperty = storage.getBucketProperty(); if (bucketProperty.isPresent()) { - sd.setNumberOfBuckets(bucketProperty.get().getBucketCount()); - sd.setBucketColumns(bucketProperty.get().getBucketedBy()); + sd.numberOfBuckets(bucketProperty.get().getBucketCount()); + sd.bucketColumns(bucketProperty.get().getBucketedBy()); if (!bucketProperty.get().getSortedBy().isEmpty()) { - sd.setSortColumns(bucketProperty.get().getSortedBy().stream() - .map(column -> new Order().withColumn(column.getColumnName()).withSortOrder(column.getOrder().getHiveOrder())) + sd.sortColumns(bucketProperty.get().getSortedBy().stream() + .map(column -> Order.builder().column(column.getColumnName()).sortOrder(column.getOrder().getHiveOrder()).build()) .collect(toImmutableList())); } } - return sd; + return sd.build(); } - public static com.amazonaws.services.glue.model.Column convertColumn(Column prestoColumn) + public static software.amazon.awssdk.services.glue.model.Column convertColumn(Column prestoColumn) { - return new com.amazonaws.services.glue.model.Column() - .withName(prestoColumn.getName()) - .withType(prestoColumn.getType().toString()) - .withComment(prestoColumn.getComment().orElse(null)); + return software.amazon.awssdk.services.glue.model.Column.builder() + .name(prestoColumn.getName()) + .type(prestoColumn.getType().toString()) + .comment(prestoColumn.getComment().orElse(null)) + .build(); } } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueToPrestoConverter.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueToPrestoConverter.java index 83de7d0150fc7..947bddbb80043 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueToPrestoConverter.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/glue/converter/GlueToPrestoConverter.java @@ -13,8 +13,6 @@ */ package com.facebook.presto.hive.metastore.glue.converter; -import com.amazonaws.services.glue.model.SerDeInfo; -import com.amazonaws.services.glue.model.StorageDescriptor; import com.facebook.presto.hive.HiveBucketProperty; import com.facebook.presto.hive.HiveStorageFormat; import com.facebook.presto.hive.HiveType; @@ -31,6 +29,8 @@ import com.facebook.presto.spi.security.PrincipalType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import software.amazon.awssdk.services.glue.model.SerDeInfo; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; import java.util.List; import java.util.Locale; @@ -57,47 +57,47 @@ public final class GlueToPrestoConverter private GlueToPrestoConverter() {} - public static Database convertDatabase(com.amazonaws.services.glue.model.Database glueDb) + public static Database convertDatabase(software.amazon.awssdk.services.glue.model.Database glueDb) { return Database.builder() - .setDatabaseName(glueDb.getName()) - .setLocation(Optional.ofNullable(glueDb.getLocationUri())) - .setComment(Optional.ofNullable(glueDb.getDescription())) - .setParameters(convertParameters(glueDb.getParameters())) + .setDatabaseName(glueDb.name()) + .setLocation(Optional.ofNullable(glueDb.locationUri())) + .setComment(Optional.ofNullable(glueDb.description())) + .setParameters(convertParameters(glueDb.parameters())) .setOwnerName(PUBLIC_OWNER) .setOwnerType(PrincipalType.ROLE) .build(); } - public static Table convertTable(com.amazonaws.services.glue.model.Table glueTable, String dbName) + public static Table convertTable(software.amazon.awssdk.services.glue.model.Table glueTable, String dbName) { - Map tableParameters = convertParameters(glueTable.getParameters()); + Map tableParameters = convertParameters(glueTable.parameters()); Table.Builder tableBuilder = Table.builder() .setDatabaseName(dbName) - .setTableName(glueTable.getName()) - .setOwner(nullToEmpty(glueTable.getOwner())) + .setTableName(glueTable.name()) + .setOwner(nullToEmpty(glueTable.owner())) // Athena treats missing table type as EXTERNAL_TABLE. - .setTableType(PrestoTableType.optionalValueOf(glueTable.getTableType()).orElse(EXTERNAL_TABLE)) + .setTableType(PrestoTableType.optionalValueOf(glueTable.tableType()).orElse(EXTERNAL_TABLE)) .setParameters(tableParameters) - .setViewOriginalText(Optional.ofNullable(glueTable.getViewOriginalText())) - .setViewExpandedText(Optional.ofNullable(glueTable.getViewExpandedText())); + .setViewOriginalText(Optional.ofNullable(glueTable.viewOriginalText())) + .setViewExpandedText(Optional.ofNullable(glueTable.viewExpandedText())); - StorageDescriptor sd = glueTable.getStorageDescriptor(); + StorageDescriptor sd = glueTable.storageDescriptor(); if (isIcebergTable(tableParameters) || (sd == null && isDeltaLakeTable(tableParameters))) { // Iceberg and Delta Lake tables do not use the StorageDescriptor field, but we need to return a Table so the caller can check that // the table is an Iceberg/Delta table and decide whether to redirect or fail. tableBuilder.setDataColumns(ImmutableList.of(new Column("dummy", HIVE_INT, Optional.empty(), Optional.empty()))); tableBuilder.getStorageBuilder().setStorageFormat(StorageFormat.fromHiveStorageFormat(HiveStorageFormat.PARQUET)); - tableBuilder.getStorageBuilder().setLocation(sd == null ? "" : sd.getLocation()); + tableBuilder.getStorageBuilder().setLocation(sd == null ? "" : sd.location()); } else { if (sd == null) { - throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, format("Table StorageDescriptor is null for table %s.%s (%s)", dbName, glueTable.getName(), glueTable)); + throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, format("Table StorageDescriptor is null for table %s.%s (%s)", dbName, glueTable.name(), glueTable)); } - tableBuilder.setDataColumns(convertColumns(sd.getColumns())); - if (glueTable.getPartitionKeys() != null) { - tableBuilder.setPartitionColumns(convertColumns(glueTable.getPartitionKeys())); + tableBuilder.setDataColumns(convertColumns(sd.columns())); + if (glueTable.partitionKeys() != null) { + tableBuilder.setPartitionColumns(convertColumns(glueTable.partitionKeys())); } else { tableBuilder.setPartitionColumns(ImmutableList.of()); @@ -109,12 +109,12 @@ public static Table convertTable(com.amazonaws.services.glue.model.Table glueTab return tableBuilder.build(); } - private static Column convertColumn(com.amazonaws.services.glue.model.Column glueColumn) + private static Column convertColumn(software.amazon.awssdk.services.glue.model.Column glueColumn) { - return new Column(glueColumn.getName(), HiveType.valueOf(glueColumn.getType().toLowerCase(Locale.ENGLISH)), Optional.ofNullable(glueColumn.getComment()), Optional.empty()); + return new Column(glueColumn.name(), HiveType.valueOf(glueColumn.type().toLowerCase(Locale.ENGLISH)), Optional.ofNullable(glueColumn.comment()), Optional.empty()); } - private static List convertColumns(List glueColumns) + private static List convertColumns(List glueColumns) { return mappedCopy(glueColumns, GlueToPrestoConverter::convertColumn); } @@ -138,9 +138,9 @@ private static boolean isNullOrEmpty(List list) } public static final class GluePartitionConverter - implements Function + implements Function { - private final Function, List> columnsConverter = memoizeLast(GlueToPrestoConverter::convertColumns); + private final Function, List> columnsConverter = memoizeLast(GlueToPrestoConverter::convertColumns); private final Function, Map> parametersConverter = parametersConverter(); private final StorageConverter storageConverter = new StorageConverter(); private final String databaseName; @@ -153,25 +153,25 @@ public GluePartitionConverter(String databaseName, String tableName) } @Override - public Partition apply(com.amazonaws.services.glue.model.Partition gluePartition) + public Partition apply(software.amazon.awssdk.services.glue.model.Partition gluePartition) { - requireNonNull(gluePartition.getStorageDescriptor(), "Partition StorageDescriptor is null"); - StorageDescriptor sd = gluePartition.getStorageDescriptor(); + requireNonNull(gluePartition.storageDescriptor(), "Partition StorageDescriptor is null"); + StorageDescriptor sd = gluePartition.storageDescriptor(); - if (!databaseName.equals(gluePartition.getDatabaseName())) { - throw new IllegalArgumentException(format("Unexpected databaseName, expected: %s, but found: %s", databaseName, gluePartition.getDatabaseName())); + if (!databaseName.equals(gluePartition.databaseName())) { + throw new IllegalArgumentException(format("Unexpected databaseName, expected: %s, but found: %s", databaseName, gluePartition.databaseName())); } - if (!tableName.equals(gluePartition.getTableName())) { - throw new IllegalArgumentException(format("Unexpected tableName, expected: %s, but found: %s", tableName, gluePartition.getTableName())); + if (!tableName.equals(gluePartition.tableName())) { + throw new IllegalArgumentException(format("Unexpected tableName, expected: %s, but found: %s", tableName, gluePartition.tableName())); } Partition.Builder partitionBuilder = Partition.builder() .setCatalogName(Optional.empty()) .setDatabaseName(databaseName) .setTableName(tableName) - .setValues(gluePartition.getValues()) // No memoization benefit - .setColumns(columnsConverter.apply(sd.getColumns())) - .setParameters(parametersConverter.apply(gluePartition.getParameters())); + .setValues(gluePartition.values()) // No memoization benefit + .setColumns(columnsConverter.apply(sd.columns())) + .setParameters(parametersConverter.apply(gluePartition.parameters())); storageConverter.setConvertedStorage(sd, partitionBuilder.getStorageBuilder()); @@ -182,7 +182,7 @@ public Partition apply(com.amazonaws.services.glue.model.Partition gluePartition private static final class StorageConverter { private final Function, List> bucketColumns = memoizeLast(ImmutableList::copyOf); - private final Function, List> sortColumns = memoizeLast(StorageConverter::createSortingColumns); + private final Function, List> sortColumns = memoizeLast(StorageConverter::createSortingColumns); private final UnaryOperator> bucketProperty = memoizeLast(); private final Function, Map> serdeParametersConverter = parametersConverter(); private final Function, Map> partitionParametersConverter = parametersConverter(); @@ -190,36 +190,36 @@ private static final class StorageConverter public void setConvertedStorage(StorageDescriptor sd, Storage.Builder storageBuilder) { - requireNonNull(sd.getSerdeInfo(), "StorageDescriptor SerDeInfo is null"); - SerDeInfo serdeInfo = sd.getSerdeInfo(); + requireNonNull(sd.serdeInfo(), "StorageDescriptor SerDeInfo is null"); + SerDeInfo serdeInfo = sd.serdeInfo(); - storageBuilder.setLocation(nullToEmpty(sd.getLocation())) + storageBuilder.setLocation(nullToEmpty(sd.location())) .setBucketProperty(createBucketProperty(sd)) - .setSkewed(sd.getSkewedInfo() != null && !isNullOrEmpty(sd.getSkewedInfo().getSkewedColumnNames())) - .setSerdeParameters(serdeParametersConverter.apply(serdeInfo.getParameters())) - .setParameters(partitionParametersConverter.apply(sd.getParameters())) + .setSkewed(sd.skewedInfo() != null && !isNullOrEmpty(sd.skewedInfo().skewedColumnNames())) + .setSerdeParameters(serdeParametersConverter.apply(serdeInfo.parameters())) + .setParameters(partitionParametersConverter.apply(sd.parameters())) .setStorageFormat(storageFormatConverter.createStorageFormat(serdeInfo, sd)); } private Optional createBucketProperty(StorageDescriptor sd) { - if (sd.getNumberOfBuckets() > 0) { - if (isNullOrEmpty(sd.getBucketColumns())) { + if (sd.numberOfBuckets() > 0) { + if (isNullOrEmpty(sd.bucketColumns())) { throw new PrestoException(HIVE_INVALID_METADATA, "Table/partition metadata has 'numBuckets' set, but 'bucketCols' is not set"); } - List bucketColumns = this.bucketColumns.apply(sd.getBucketColumns()); - List sortedBy = this.sortColumns.apply(sd.getSortColumns()); - return bucketProperty.apply(Optional.of(new HiveBucketProperty(bucketColumns, sd.getNumberOfBuckets(), sortedBy, HIVE_COMPATIBLE, Optional.empty()))); + List bucketColumns = this.bucketColumns.apply(sd.bucketColumns()); + List sortedBy = this.sortColumns.apply(sd.sortColumns()); + return bucketProperty.apply(Optional.of(new HiveBucketProperty(bucketColumns, sd.numberOfBuckets(), sortedBy, HIVE_COMPATIBLE, Optional.empty()))); } return Optional.empty(); } - private static List createSortingColumns(List sortColumns) + private static List createSortingColumns(List sortColumns) { if (isNullOrEmpty(sortColumns)) { return ImmutableList.of(); } - return mappedCopy(sortColumns, column -> new SortingColumn(column.getColumn(), Order.fromMetastoreApiOrder(column.getSortOrder(), "unknown"))); + return mappedCopy(sortColumns, column -> new SortingColumn(column.column(), Order.fromMetastoreApiOrder(column.sortOrder(), "unknown"))); } } @@ -234,9 +234,9 @@ private static final class StorageFormatConverter public StorageFormat createStorageFormat(SerDeInfo serdeInfo, StorageDescriptor storageDescriptor) { - String serializationLib = this.serializationLib.apply(serdeInfo.getSerializationLibrary()); - String inputFormat = this.inputFormat.apply(storageDescriptor.getInputFormat()); - String outputFormat = this.outputFormat.apply(storageDescriptor.getOutputFormat()); + String serializationLib = this.serializationLib.apply(serdeInfo.serializationLibrary()); + String inputFormat = this.inputFormat.apply(storageDescriptor.inputFormat()); + String outputFormat = this.outputFormat.apply(storageDescriptor.outputFormat()); if (serializationLib == null && inputFormat == null && outputFormat == null) { return ALL_NULLS; } diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/BridgingHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/BridgingHiveMetastore.java index 61c76f9e4cace..e3787b7ebe91a 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/BridgingHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/BridgingHiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore.thrift; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HiveTableHandle; @@ -44,16 +45,18 @@ import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; +import com.google.common.collect.Maps; +import jakarta.inject.Inject; +import org.apache.hadoop.hive.common.StatsSetupConst; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; -import javax.inject.Inject; - import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import static com.facebook.presto.hive.metastore.MetastoreUtil.getPartitionNamesWithEmptyVersion; @@ -277,6 +280,21 @@ private MetastoreOperationResult alterTable(MetastoreContext metastoreContext, S return delegate.alterTable(metastoreContext, databaseName, tableName, table); } + private MetastoreOperationResult alterTableWithEnvironmentContext(MetastoreContext metastoreContext, String databaseName, String tableName, org.apache.hadoop.hive.metastore.api.Table table, EnvironmentContext environmentContext) + { + return delegate.alterTableWithEnvironmentContext(metastoreContext, databaseName, tableName, table, environmentContext); + } + + @Override + public MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters) + { + checkArgument(!newTable.getTableType().equals(TEMPORARY_TABLE), "temporary tables must never be stored in the metastore"); + Map env = Maps.newHashMapWithExpectedSize(additionalParameters.size() + 1); + env.putAll(additionalParameters); + env.put(StatsSetupConst.DO_NOT_UPDATE_STATS, StatsSetupConst.TRUE); + return alterTableWithEnvironmentContext(metastoreContext, databaseName, tableName, toMetastoreApiTable(newTable, principalPrivileges, metastoreContext.getColumnConverter()), new EnvironmentContext(env)); + } + @Override public Optional getPartition(MetastoreContext metastoreContext, String databaseName, String tableName, List partitionValues) { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastore.java index cf23a1e85aadf..f66484a7bdcba 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore.thrift; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HiveTableHandle; @@ -34,8 +35,8 @@ import com.facebook.presto.spi.security.RoleGrant; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.Partition; import org.apache.hadoop.hive.metastore.api.Table; @@ -62,6 +63,8 @@ public interface HiveMetastore MetastoreOperationResult alterTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table table); + MetastoreOperationResult alterTableWithEnvironmentContext(MetastoreContext metastoreContext, String databaseName, String tableName, Table table, EnvironmentContext environmentContext); + default List getDatabases(MetastoreContext metastoreContext, String pattern) { return getAllDatabases(metastoreContext); diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreApiStats.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreApiStats.java index 88730f0ba0d04..c7e7ede3d7510 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreApiStats.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreApiStats.java @@ -15,14 +15,13 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeStat; +import com.google.errorprone.annotations.ThreadSafe; import org.apache.hadoop.hive.metastore.api.MetaException; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; - import java.util.concurrent.Callable; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClient.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClient.java index 4880eedcdd6ee..caf27aeed8619 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClient.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClient.java @@ -16,6 +16,7 @@ import org.apache.hadoop.hive.metastore.api.CheckLockRequest; import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; @@ -86,6 +87,9 @@ void dropTable(String databaseName, String name, boolean deleteData) void alterTable(String databaseName, String tableName, Table newTable) throws TException; + void alterTableWithEnvironmentContext(String databaseName, String tableName, Table newTable, EnvironmentContext context) + throws TException; + Table getTable(String databaseName, String tableName) throws TException; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClientFactory.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClientFactory.java index 19787c95fbad1..402aabf814ecb 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClientFactory.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/HiveMetastoreClientFactory.java @@ -14,15 +14,15 @@ package com.facebook.presto.hive.metastore.thrift; import com.facebook.airlift.security.pem.PemReader; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.HiveCommonClientConfig; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.authentication.HiveMetastoreAuthentication; import com.facebook.presto.spi.PrestoException; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; +import jakarta.inject.Inject; import org.apache.thrift.transport.TTransportException; -import javax.inject.Inject; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; @@ -77,10 +77,10 @@ public HiveMetastoreClientFactory( public HiveMetastoreClientFactory(MetastoreClientConfig metastoreClientConfig, ThriftHiveMetastoreConfig thriftHiveMetastoreConfig, HiveMetastoreAuthentication metastoreAuthentication, HiveCommonClientConfig hiveCommonClientConfig) { this(buildSslContext(thriftHiveMetastoreConfig.isTlsEnabled(), - Optional.ofNullable(thriftHiveMetastoreConfig.getKeystorePath()), - Optional.ofNullable(thriftHiveMetastoreConfig.getKeystorePassword()), - Optional.ofNullable(thriftHiveMetastoreConfig.getTruststorePath()), - Optional.ofNullable(thriftHiveMetastoreConfig.getTrustStorePassword())), + Optional.ofNullable(thriftHiveMetastoreConfig.getKeystorePath()), + Optional.ofNullable(thriftHiveMetastoreConfig.getKeystorePassword()), + Optional.ofNullable(thriftHiveMetastoreConfig.getTruststorePath()), + Optional.ofNullable(thriftHiveMetastoreConfig.getTrustStorePassword())), Optional.ofNullable(metastoreClientConfig.getMetastoreSocksProxy()), metastoreClientConfig.getMetastoreTimeout(), metastoreAuthentication, hiveCommonClientConfig.getCatalogName()); } @@ -93,6 +93,7 @@ public HiveMetastoreClient create(HostAndPort address, Optional token) /** * Reads the truststore and keystore and returns the SSLContext + * * @param tlsEnabled * @param keystorePath * @param keystorePassword @@ -101,10 +102,10 @@ public HiveMetastoreClient create(HostAndPort address, Optional token) * @return SSLContext */ private static Optional buildSslContext(boolean tlsEnabled, - Optional keystorePath, - Optional keystorePassword, - Optional truststorePath, - Optional trustStorePassword) + Optional keystorePath, + Optional keystorePassword, + Optional truststorePath, + Optional trustStorePassword) { if (!tlsEnabled || (!keystorePath.isPresent() && !truststorePath.isPresent())) { return Optional.empty(); @@ -163,6 +164,7 @@ private static Optional buildSslContext(boolean tlsEnabled, /** * Reads the truststore certificate and returns it + * * @param trustStorePath * @param trustStorePassword * @throws IOException @@ -195,10 +197,12 @@ private static KeyStore getTrustStore(File trustStorePath, Optional trus /** * Validate keystore certificate + * * @param keyStore * @throws GeneralSecurityException */ - private static void validateKeyStoreCertificates(KeyStore keyStore) throws GeneralSecurityException + private static void validateKeyStoreCertificates(KeyStore keyStore) + throws GeneralSecurityException { for (String alias : list(keyStore.aliases())) { if (!keyStore.isKeyEntry(alias)) { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticHiveCluster.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticHiveCluster.java index 86a037f9644f2..139b9d6b771a1 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticHiveCluster.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticHiveCluster.java @@ -14,10 +14,9 @@ package com.facebook.presto.hive.metastore.thrift; import com.google.common.net.HostAndPort; +import jakarta.inject.Inject; import org.apache.thrift.TException; -import javax.inject.Inject; - import java.net.URI; import java.util.ArrayList; import java.util.Collections; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticMetastoreConfig.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticMetastoreConfig.java index f1f2f6d5d4394..94daf6fba5498 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticMetastoreConfig.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/StaticMetastoreConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.net.URI; import java.util.List; diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastore.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastore.java index 24056038648c9..c3bb696979e17 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastore.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastore.java @@ -56,12 +56,15 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.AlreadyExistsException; import org.apache.hadoop.hive.metastore.api.CheckLockRequest; import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; @@ -93,9 +96,6 @@ import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; @@ -1161,6 +1161,35 @@ public MetastoreOperationResult alterTable(MetastoreContext metastoreContext, St } } + @Override + public MetastoreOperationResult alterTableWithEnvironmentContext(MetastoreContext metastoreContext, String databaseName, String tableName, Table table, EnvironmentContext environmentContext) + { + try { + retry() + .stopOn(InvalidOperationException.class, MetaException.class) + .stopOnIllegalExceptions() + .run("alterTableWithEnvironmentContext", stats.getAlterTableWithEnvironmentContext().wrap(() -> + getMetastoreClientThenCall(metastoreContext, client -> { + Optional
source = getTable(metastoreContext, databaseName, tableName); + if (!source.isPresent()) { + throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); + } + client.alterTableWithEnvironmentContext(databaseName, tableName, table, environmentContext); + return null; + }))); + return EMPTY_RESULT; + } + catch (NoSuchObjectException e) { + throw new TableNotFoundException(new SchemaTableName(databaseName, tableName)); + } + catch (TException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, e); + } + catch (Exception e) { + throw propagate(e); + } + } + @Override public Optional> getPartitionNames(MetastoreContext metastoreContext, String databaseName, String tableName) { diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreClient.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreClient.java index f34955d09432e..93ed690cc87a6 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreClient.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreClient.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; import org.apache.hadoop.hive.metastore.api.DropConstraintRequest; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.GetRoleGrantsForPrincipalRequest; import org.apache.hadoop.hive.metastore.api.GetRoleGrantsForPrincipalResponse; @@ -199,6 +200,13 @@ public void alterTable(String databaseName, String tableName, Table newTable) client.alter_table(constructSchemaName(catalogName, databaseName), tableName, newTable); } + @Override + public void alterTableWithEnvironmentContext(String databaseName, String tableName, Table newTable, EnvironmentContext context) + throws TException + { + client.alter_table_with_environment_context(constructSchemaName(catalogName, databaseName), tableName, newTable, context); + } + @Override public Table getTable(String databaseName, String tableName) throws TException diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreStats.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreStats.java index 543c1a2ba27da..aadc7f225b097 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreStats.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftHiveMetastoreStats.java @@ -40,6 +40,7 @@ public class ThriftHiveMetastoreStats private final HiveMetastoreApiStats createTableWithConstraints = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats dropTable = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats alterTable = new HiveMetastoreApiStats(); + private final HiveMetastoreApiStats alterTableWithEnvironmentContext = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats addPartitions = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats dropPartition = new HiveMetastoreApiStats(); private final HiveMetastoreApiStats alterPartition = new HiveMetastoreApiStats(); @@ -216,6 +217,13 @@ public HiveMetastoreApiStats getAlterTable() return alterTable; } + @Managed + @Nested + public HiveMetastoreApiStats getAlterTableWithEnvironmentContext() + { + return alterTableWithEnvironmentContext; + } + @Managed @Nested public HiveMetastoreApiStats getAddPartitions() diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreModule.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreModule.java index 6893f1f1dbd28..094dd9d448c73 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreModule.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreModule.java @@ -19,6 +19,7 @@ import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.RecordingHiveMetastore; import com.facebook.presto.spi.ConnectorId; import com.google.inject.Binder; @@ -69,6 +70,7 @@ protected void setup(Binder binder) .in(Scopes.SINGLETON); } + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); newExporter(binder).export(HiveMetastore.class) .as(generatedNameOf(ThriftHiveMetastore.class, connectorId)); diff --git a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreUtil.java b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreUtil.java index 74858baba829c..4167f8ebeabe4 100644 --- a/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreUtil.java +++ b/presto-hive-metastore/src/main/java/com/facebook/presto/hive/metastore/thrift/ThriftMetastoreUtil.java @@ -43,6 +43,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import com.google.common.primitives.Shorts; +import jakarta.annotation.Nullable; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.BinaryColumnStatsData; import org.apache.hadoop.hive.metastore.api.BooleanColumnStatsData; @@ -64,8 +65,6 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; -import javax.annotation.Nullable; - import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestMetastoreClientConfig.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestMetastoreClientConfig.java index bc0d719e746cd..6f2630b47e9d2 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestMetastoreClientConfig.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestMetastoreClientConfig.java @@ -14,17 +14,21 @@ package com.facebook.presto.hive.metastore; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.MetastoreClientConfig.HiveMetastoreAuthenticationType; import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; import com.google.common.collect.ImmutableMap; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; +import com.google.inject.ConfigurationException; import org.testng.annotations.Test; import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.expectThrows; + public class TestMetastoreClientConfig { @Test @@ -35,8 +39,12 @@ public void testDefaults() .setMetastoreTimeout(new Duration(10, TimeUnit.SECONDS)) .setVerifyChecksum(true) .setRequireHadoopNative(true) - .setMetastoreCacheTtl(new Duration(0, TimeUnit.SECONDS)) - .setMetastoreRefreshInterval(new Duration(0, TimeUnit.SECONDS)) + .setEnabledCaches(null) + .setDisabledCaches(null) + .setDefaultMetastoreCacheTtl(new Duration(0, TimeUnit.SECONDS)) + .setDefaultMetastoreCacheRefreshInterval(new Duration(0, TimeUnit.SECONDS)) + .setMetastoreCacheTtlByType(null) + .setMetastoreCacheRefreshIntervalByType(null) .setMetastoreCacheMaximumSize(10000) .setPerTransactionMetastoreCacheMaximumSize(1000) .setMaxMetastoreRefreshThreads(100) @@ -61,8 +69,12 @@ public void testExplicitPropertyMappings() .put("hive.metastore-timeout", "20s") .put("hive.dfs.verify-checksum", "false") .put("hive.dfs.require-hadoop-native", "false") - .put("hive.metastore-cache-ttl", "2h") - .put("hive.metastore-refresh-interval", "30m") + .put("hive.metastore.cache.enabled-caches", "TABLE,TABLE_NAMES") + .put("hive.metastore.cache.disabled-caches", "TABLE,TABLE_NAMES") + .put("hive.metastore.cache.ttl.default", "2h") + .put("hive.metastore.cache.refresh-interval.default", "30m") + .put("hive.metastore.cache.ttl-by-type", "TABLE:10m") + .put("hive.metastore.cache.refresh-interval-by-type", "TABLE:5m") .put("hive.metastore-cache-maximum-size", "5000") .put("hive.per-transaction-metastore-cache-maximum-size", "500") .put("hive.metastore-refresh-max-threads", "2500") @@ -84,8 +96,12 @@ public void testExplicitPropertyMappings() .setMetastoreTimeout(new Duration(20, TimeUnit.SECONDS)) .setVerifyChecksum(false) .setRequireHadoopNative(false) - .setMetastoreCacheTtl(new Duration(2, TimeUnit.HOURS)) - .setMetastoreRefreshInterval(new Duration(30, TimeUnit.MINUTES)) + .setEnabledCaches("TABLE,TABLE_NAMES") + .setDisabledCaches("TABLE,TABLE_NAMES") + .setDefaultMetastoreCacheTtl(new Duration(2, TimeUnit.HOURS)) + .setDefaultMetastoreCacheRefreshInterval(new Duration(30, TimeUnit.MINUTES)) + .setMetastoreCacheTtlByType("TABLE:10m") + .setMetastoreCacheRefreshIntervalByType("TABLE:5m") .setMetastoreCacheMaximumSize(5000) .setPerTransactionMetastoreCacheMaximumSize(500) .setMaxMetastoreRefreshThreads(2500) @@ -103,4 +119,20 @@ public void testExplicitPropertyMappings() ConfigAssertions.assertFullMapping(properties, expected); } + + @Test + public void testInvalidConfiguration() + { + MetastoreClientConfig config = new MetastoreClientConfig(); + config.setEnabledCaches("TABLE,TABLE_NAMES"); + config.setDisabledCaches("TABLE,TABLE_NAMES"); + + ConfigurationException exception = expectThrows( + ConfigurationException.class, + config::isBothEnabledAndDisabledConfigured); + + assertEquals(exception.getErrorMessages().iterator().next().getMessage(), + "Only one of 'hive.metastore.cache.enabled-caches' or 'hive.metastore.cache.disabled-caches' can be set. " + + "These configs are mutually exclusive."); + } } diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestRecordingHiveMetastore.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestRecordingHiveMetastore.java index d98968c73af62..6424716ce743c 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestRecordingHiveMetastore.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/TestRecordingHiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HiveBasicStatistics; @@ -31,7 +32,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.File; diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/UnimplementedHiveMetastore.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/UnimplementedHiveMetastore.java index 6a6926aa4561b..2e3bbe18952f5 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/UnimplementedHiveMetastore.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/UnimplementedHiveMetastore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.metastore; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.type.Type; import com.facebook.presto.hive.HiveType; @@ -21,13 +22,13 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.RoleGrant; import com.facebook.presto.spi.statistics.ColumnStatisticType; -import io.airlift.units.Duration; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.Supplier; public class UnimplementedHiveMetastore implements ExtendedHiveMetastore @@ -128,6 +129,12 @@ public MetastoreOperationResult replaceTable(MetastoreContext metastoreContext, throw new UnsupportedOperationException(); } + @Override + public MetastoreOperationResult persistTable(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, PrincipalPrivileges principalPrivileges, Supplier update, Map additionalParameters) + { + throw new UnsupportedOperationException(); + } + @Override public MetastoreOperationResult renameTable(MetastoreContext metastoreContext, String databaseName, String tableName, String newDatabaseName, String newTableName) { diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueHiveMetastoreConfig.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueHiveMetastoreConfig.java index fbc22461393ef..cea988de5c96b 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueHiveMetastoreConfig.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueHiveMetastoreConfig.java @@ -30,7 +30,8 @@ public void testDefaults() assertRecordedDefaults(recordDefaults(GlueHiveMetastoreConfig.class) .setGlueRegion(null) .setGlueEndpointUrl(null) - .setPinGlueClientToCurrentRegion(false) + .setGlueStsRegion(null) + .setGlueStsEndpointUrl(null) .setMaxGlueConnections(50) .setMaxGlueErrorRetries(10) .setDefaultWarehouseDir(null) @@ -48,7 +49,8 @@ public void testExplicitPropertyMapping() Map properties = new ImmutableMap.Builder() .put("hive.metastore.glue.region", "us-east-1") .put("hive.metastore.glue.endpoint-url", "http://foo.bar") - .put("hive.metastore.glue.pin-client-to-current-region", "true") + .put("hive.metastore.glue.sts.region", "us-east-1") + .put("hive.metastore.glue.sts.endpoint-url", "http://foo.bar") .put("hive.metastore.glue.max-connections", "10") .put("hive.metastore.glue.max-error-retries", "20") .put("hive.metastore.glue.default-warehouse-dir", "/location") @@ -63,7 +65,8 @@ public void testExplicitPropertyMapping() GlueHiveMetastoreConfig expected = new GlueHiveMetastoreConfig() .setGlueRegion("us-east-1") .setGlueEndpointUrl("http://foo.bar") - .setPinGlueClientToCurrentRegion(true) + .setGlueStsRegion("us-east-1") + .setGlueStsEndpointUrl("http://foo.bar") .setMaxGlueConnections(10) .setMaxGlueErrorRetries(20) .setDefaultWarehouseDir("/location") diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/InMemoryHiveMetastore.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/InMemoryHiveMetastore.java index 81bcf00059b8f..866697dd088df 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/InMemoryHiveMetastore.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/InMemoryHiveMetastore.java @@ -36,16 +36,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.metastore.TableType; import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.Partition; import org.apache.hadoop.hive.metastore.api.PrincipalPrivilegeSet; import org.apache.hadoop.hive.metastore.api.PrincipalType; import org.apache.hadoop.hive.metastore.api.Table; -import javax.annotation.concurrent.GuardedBy; - import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; @@ -284,6 +284,12 @@ public synchronized MetastoreOperationResult alterTable(MetastoreContext metasto return EMPTY_RESULT; } + @Override + public synchronized MetastoreOperationResult alterTableWithEnvironmentContext(MetastoreContext metastoreContext, String databaseName, String tableName, Table newTable, EnvironmentContext environmentContext) + { + return alterTable(metastoreContext, databaseName, tableName, newTable); + } + @Override public synchronized Optional> getAllTables(MetastoreContext metastoreContext, String databaseName) { diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClient.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClient.java index 4ee92ab0e2b47..ffb7a77670d65 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClient.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClient.java @@ -29,6 +29,7 @@ import org.apache.hadoop.hive.metastore.api.CheckLockRequest; import org.apache.hadoop.hive.metastore.api.ColumnStatisticsObj; import org.apache.hadoop.hive.metastore.api.Database; +import org.apache.hadoop.hive.metastore.api.EnvironmentContext; import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.metastore.api.HiveObjectPrivilege; import org.apache.hadoop.hive.metastore.api.HiveObjectRef; @@ -365,6 +366,13 @@ public void alterTable(String databaseName, String tableName, Table newTable) throw new UnsupportedOperationException(); } + @Override + public void alterTableWithEnvironmentContext(String databaseName, String tableName, Table newTable, EnvironmentContext context) + throws TException + { + throw new UnsupportedOperationException(); + } + @Override public int addPartitions(List newPartitions) { diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClientFactory.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClientFactory.java index 196a72018ca72..4e060e6873bcb 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClientFactory.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/MockHiveMetastoreClientFactory.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.hive.metastore.thrift; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.authentication.NoHiveMetastoreAuthentication; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; import org.apache.thrift.transport.TTransportException; import java.util.ArrayList; diff --git a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/TestStaticHiveCluster.java b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/TestStaticHiveCluster.java index f8fbdc75f1f57..8b6a083d099ea 100644 --- a/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/TestStaticHiveCluster.java +++ b/presto-hive-metastore/src/test/java/com/facebook/presto/hive/metastore/thrift/TestStaticHiveCluster.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive.metastore.thrift; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import org.apache.thrift.TException; import org.testng.annotations.Test; diff --git a/presto-hive/pom.xml b/presto-hive/pom.xml index e744a031789f2..3cf76943ab149 100644 --- a/presto-hive/pom.xml +++ b/presto-hive/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hive @@ -15,26 +15,50 @@ ${project.parent.basedir} 5g + 17 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + + + + + software.amazon.awssdk + utils + + + + software.amazon.awssdk + glue + + com.facebook.airlift http-client - com.facebook.drift + com.facebook.airlift.drift drift-codec - com.facebook.drift + com.facebook.airlift.drift drift-codec-utils - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty @@ -100,7 +124,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache provided @@ -160,11 +184,6 @@ test - - com.facebook.drift - drift-api - - com.google.guava guava @@ -176,8 +195,8 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true @@ -192,8 +211,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -211,11 +230,6 @@ aws-java-sdk-core - - com.amazonaws - aws-java-sdk-glue - - com.amazonaws aws-java-sdk-s3 @@ -257,6 +271,11 @@ javax.inject + + jakarta.inject + jakarta.inject-api + + com.fasterxml.jackson.core jackson-core @@ -268,8 +287,8 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api @@ -311,7 +330,7 @@ - io.airlift + com.facebook.airlift units @@ -360,6 +379,12 @@ test + + com.facebook.presto + presto-main-tests + test + + com.facebook.presto presto-hive-metastore @@ -473,7 +498,7 @@ org.objenesis objenesis - 2.6 + 3.4 test @@ -482,6 +507,12 @@ + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test + @@ -518,7 +549,13 @@ org.glassfish.jersey.core:jersey-common:jar org.eclipse.jetty:jetty-server:jar + org.apache.httpcomponents:httpclient + + com.fasterxml.jackson.core:jackson-core + software.amazon.awssdk:glue + software.amazon.awssdk:utils + @@ -549,6 +586,9 @@ **/TestParquetDistributedQueries.java **/TestHive2InsertOverwrite.java **/TestHive3InsertOverwrite.java + **/TestHiveSslWithKeyStore.java + **/TestHiveSslWithTrustStore.java + **/TestHiveSslWithTrustStoreKeyStore.java diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/CachingDirectoryLister.java b/presto-hive/src/main/java/com/facebook/presto/hive/CachingDirectoryLister.java index 7090493afe580..bb86d107efce2 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/CachingDirectoryLister.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/CachingDirectoryLister.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.facebook.presto.hive.metastore.Partition; @@ -24,14 +26,11 @@ import com.google.common.cache.Weigher; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.inject.Inject; import org.apache.hadoop.fs.Path; import org.openjdk.jol.info.ClassLayout; import org.weakref.jmx.Managed; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -142,6 +141,12 @@ public HiveFileInfo next() }; } + public boolean isPathCached(Path path) + { + ValueHolder value = Optional.ofNullable(cache.getIfPresent(path.toString())).orElse(null); + return value != null; + } + public void invalidateDirectoryListCache(Optional directoryPath) { if (directoryPath.isPresent()) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ConcurrentLazyQueue.java b/presto-hive/src/main/java/com/facebook/presto/hive/ConcurrentLazyQueue.java index 05dd95d6d3824..9d7746b5a50d7 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ConcurrentLazyQueue.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ConcurrentLazyQueue.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.Iterator; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java index 7e967e2b97f26..cf3e0f943e93b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/CreateEmptyPartitionProcedure.java @@ -26,9 +26,9 @@ import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slices; +import jakarta.inject.Inject; import org.apache.hadoop.hive.common.FileUtils; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -120,16 +120,16 @@ private void doCreateEmptyPartition(ConnectorSession session, String schema, Str .collect(toImmutableList()); if (metastore.getPartition(new MetastoreContext( - session.getIdentity(), - session.getQueryId(), - session.getClientInfo(), - session.getClientTags(), - session.getSource(), - getMetastoreHeaders(session), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - session.getWarningCollector(), - session.getRuntimeStats()), + session.getIdentity(), + session.getQueryId(), + session.getClientInfo(), + session.getClientTags(), + session.getSource(), + getMetastoreHeaders(session), + false, + HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, + session.getWarningCollector(), + session.getRuntimeStats()), schema, table, partitionStringValues).isPresent()) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java index 983cffbbb3b57..e6a931dcd2922 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryListCacheInvalidationProcedure.java @@ -19,8 +19,8 @@ import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryLister.java b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryLister.java index fa9dfd2dae119..d9d0c141b094f 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryLister.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/DirectoryLister.java @@ -30,4 +30,9 @@ Iterator list( Optional partition, NamenodeStats namenodeStats, HiveDirectoryContext hiveDirectoryContext); + + default boolean isPathCached(Path path) + { + return false; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForCachingDirectoryLister.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForCachingDirectoryLister.java index 94672ba011c2e..014968340b1af 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForCachingDirectoryLister.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForCachingDirectoryLister.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForFileRename.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForFileRename.java index 9c59f672fd861..988b7dec0e247 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForFileRename.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForFileRename.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForHdfs.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForHdfs.java index d0bfa7424abb2..c7d566138f51b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForHdfs.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForHdfs.java @@ -17,7 +17,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveClient.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveClient.java index 6297c8ef5ff17..21329783da407 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveClient.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveClient.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveMetastore.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveMetastore.java index b8a1b33d1a8df..ad698b62a4618 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveMetastore.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForHiveMetastore.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ForZeroRowFileCreator.java b/presto-hive/src/main/java/com/facebook/presto/hive/ForZeroRowFileCreator.java index 8f27319c3c13f..5222f0675044c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ForZeroRowFileCreator.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ForZeroRowFileCreator.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java index 7edaba22e16b7..48613955a6ae4 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/GenericHiveRecordCursorProvider.java @@ -19,14 +19,13 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.RecordReader; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationInitializer.java b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationInitializer.java index 991b364614cd5..f956978787d39 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationInitializer.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HdfsConfigurationInitializer.java @@ -13,19 +13,20 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.hadoop.SocksSocketFactory; import com.facebook.presto.hive.gcs.GcsConfigurationInitializer; import com.facebook.presto.hive.s3.S3ConfigurationUpdater; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.net.HostAndPort; -import io.airlift.units.Duration; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hdfs.client.HdfsClientConfigKeys; import org.apache.hadoop.mapreduce.lib.input.LineRecordReader; import org.apache.hadoop.net.DNSToSwitchMapping; -import org.apache.hadoop.net.SocksSocketFactory; -import javax.inject.Inject; import javax.net.SocketFactory; import java.util.List; @@ -41,9 +42,8 @@ import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_KEY; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_TIMEOUT_KEY; import static org.apache.hadoop.fs.CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY; -import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_CLIENT_READ_SHORTCIRCUIT_KEY; -import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_CLIENT_SOCKET_TIMEOUT_KEY; -import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_DOMAIN_SOCKET_PATH_KEY; +import static org.apache.hadoop.hdfs.client.HdfsClientConfigKeys.DFS_CLIENT_SOCKET_TIMEOUT_KEY; +import static org.apache.hadoop.hdfs.client.HdfsClientConfigKeys.DFS_DOMAIN_SOCKET_PATH_KEY; public class HdfsConfigurationInitializer { @@ -119,7 +119,7 @@ public void updateConfiguration(Configuration config) // only enable short circuit reads if domain socket path is properly configured if (!config.get(DFS_DOMAIN_SOCKET_PATH_KEY, "").trim().isEmpty()) { - config.setBooleanIfUnset(DFS_CLIENT_READ_SHORTCIRCUIT_KEY, true); + config.setBooleanIfUnset(HdfsClientConfigKeys.Read.ShortCircuit.KEY, true); } config.setInt(DFS_CLIENT_SOCKET_TIMEOUT_KEY, toIntExact(dfsTimeout.toMillis())); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveAnalyzeProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveAnalyzeProperties.java index 6b759cac964e2..fbd32604a6070 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveAnalyzeProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveAnalyzeProperties.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java index 6c569bb7ea255..7f56998606165 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientConfig.java @@ -17,26 +17,28 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; import com.facebook.airlift.configuration.LegacyConfig; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MaxDataSize; +import com.facebook.airlift.units.MinDataSize; +import com.facebook.airlift.units.MinDuration; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.hive.s3.S3FileSystemType; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MaxDataSize; -import io.airlift.units.MinDataSize; -import io.airlift.units.MinDuration; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import org.joda.time.DateTimeZone; -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; - import java.util.List; import java.util.TimeZone; import java.util.concurrent.TimeUnit; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.BucketFunctionType.HIVE_COMPATIBLE; import static com.facebook.presto.hive.BucketFunctionType.PRESTO_NATIVE; import static com.facebook.presto.hive.HiveClientConfig.InsertExistingPartitionsBehavior.APPEND; @@ -45,9 +47,6 @@ import static com.facebook.presto.hive.HiveSessionProperties.INSERT_EXISTING_PARTITIONS_BEHAVIOR; import static com.facebook.presto.hive.HiveStorageFormat.ORC; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -108,7 +107,6 @@ public class HiveClientConfig private DataSize textMaxLineLength = new DataSize(100, MEGABYTE); private boolean assumeCanonicalPartitionKeys; - private boolean useOrcColumnNames; private double orcDefaultBloomFilterFpp = 0.05; private boolean rcfileOptimizedWriterEnabled = true; private boolean rcfileWriterValidate; @@ -169,8 +167,6 @@ public class HiveClientConfig private DataSize pageFileStripeMaxSize = new DataSize(24, MEGABYTE); private boolean parquetDereferencePushdownEnabled; - private int maxMetadataUpdaterThreads = 100; - private boolean isPartialAggregationPushdownEnabled; private boolean isPartialAggregationPushdownForVariableLengthDatatypesEnabled; @@ -202,7 +198,7 @@ public class HiveClientConfig private Protocol thriftProtocol = Protocol.BINARY; private DataSize thriftBufferSize = new DataSize(128, BYTE); - private boolean copyOnFirstWriteConfigurationEnabled = true; + private boolean copyOnFirstWriteConfigurationEnabled; private boolean partitionFilteringFromMetastoreEnabled = true; @@ -222,6 +218,9 @@ public class HiveClientConfig private int parquetQuickStatsMaxConcurrentCalls = 500; private int quickStatsMaxConcurrentCalls = 100; private boolean legacyTimestampBucketing; + private boolean optimizeParsingOfPartitionValues; + private int optimizeParsingOfPartitionValuesThreshold = 500; + private boolean symlinkOptimizedReaderEnabled = true; @Min(0) public int getMaxInitialSplits() @@ -680,6 +679,7 @@ public HiveClientConfig setMaxPartitionsPerWriter(int maxPartitionsPerWriter) this.maxPartitionsPerWriter = maxPartitionsPerWriter; return this; } + public int getWriteValidationThreads() { return writeValidationThreads; @@ -719,19 +719,6 @@ public HiveClientConfig setS3FileSystemType(S3FileSystemType s3FileSystemType) return this; } - public boolean isUseOrcColumnNames() - { - return useOrcColumnNames; - } - - @Config("hive.orc.use-column-names") - @ConfigDescription("Access ORC columns using names from the file first, and fallback to Hive schema column names if not found to ensure backward compatibility with old data") - public HiveClientConfig setUseOrcColumnNames(boolean useOrcColumnNames) - { - this.useOrcColumnNames = useOrcColumnNames; - return this; - } - public double getOrcDefaultBloomFilterFpp() { return orcDefaultBloomFilterFpp; @@ -1369,19 +1356,6 @@ public boolean isParquetDereferencePushdownEnabled() return this.parquetDereferencePushdownEnabled; } - @Min(1) - public int getMaxMetadataUpdaterThreads() - { - return maxMetadataUpdaterThreads; - } - - @Config("hive.max-metadata-updater-threads") - public HiveClientConfig setMaxMetadataUpdaterThreads(int maxMetadataUpdaterThreads) - { - this.maxMetadataUpdaterThreads = maxMetadataUpdaterThreads; - return this; - } - @Config("hive.partial_aggregation_pushdown_enabled") @ConfigDescription("enable partial aggregation pushdown") public HiveClientConfig setPartialAggregationPushdownEnabled(boolean partialAggregationPushdownEnabled) @@ -1831,4 +1805,44 @@ public HiveClientConfig setLegacyTimestampBucketing(boolean legacyTimestampBucke this.legacyTimestampBucketing = legacyTimestampBucketing; return this; } + + @Config("hive.optimize-parsing-of-partition-values-enabled") + @ConfigDescription("Enables optimization of parsing partition values when number of candidate partitions is large") + public HiveClientConfig setOptimizeParsingOfPartitionValues(boolean optimizeParsingOfPartitionValues) + { + this.optimizeParsingOfPartitionValues = optimizeParsingOfPartitionValues; + return this; + } + + public boolean isOptimizeParsingOfPartitionValues() + { + return optimizeParsingOfPartitionValues; + } + + @Config("hive.optimize-parsing-of-partition-values-threshold") + @ConfigDescription("Enables optimization of parsing partition values when number of candidate partitions exceed the threshold set here") + public HiveClientConfig setOptimizeParsingOfPartitionValuesThreshold(int optimizeParsingOfPartitionValuesThreshold) + { + this.optimizeParsingOfPartitionValuesThreshold = optimizeParsingOfPartitionValuesThreshold; + return this; + } + + @Min(1) + public int getOptimizeParsingOfPartitionValuesThreshold() + { + return optimizeParsingOfPartitionValuesThreshold; + } + + public boolean isSymlinkOptimizedReaderEnabled() + { + return symlinkOptimizedReaderEnabled; + } + + @Config("hive.experimental.symlink.optimized-reader.enabled") + @ConfigDescription("Experimental: Enable optimized SymlinkTextInputFormat reader") + public HiveClientConfig setSymlinkOptimizedReaderEnabled(boolean symlinkOptimizedReaderEnabled) + { + this.symlinkOptimizedReaderEnabled = symlinkOptimizedReaderEnabled; + return this; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java index ee448345cdfb1..7eb58d9bbd9fe 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveClientModule.java @@ -69,13 +69,11 @@ import com.facebook.presto.parquet.cache.ParquetCacheConfig; import com.facebook.presto.parquet.cache.ParquetFileMetadata; import com.facebook.presto.parquet.cache.ParquetMetadataSource; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; @@ -86,10 +84,9 @@ import com.google.inject.Scopes; import com.google.inject.TypeLiteral; import com.google.inject.multibindings.Multibinder; +import jakarta.inject.Singleton; import org.weakref.jmx.MBeanExporter; -import javax.inject.Singleton; - import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.function.Supplier; @@ -98,7 +95,6 @@ import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; import static com.facebook.airlift.json.smile.SmileCodecBinder.smileCodecBinder; -import static com.facebook.drift.codec.guice.ThriftCodecBinder.thriftCodecBinder; import static com.facebook.presto.orc.StripeMetadataSource.CacheableRowGroupIndices; import static com.facebook.presto.orc.StripeMetadataSource.CacheableSlice; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; @@ -152,9 +148,6 @@ public void configure(Binder binder) binder.bind(HiveWriterStats.class).in(Scopes.SINGLETON); newExporter(binder).export(HiveWriterStats.class).as(generatedNameOf(HiveWriterStats.class, connectorId)); - binder.bind(HiveFileRenamer.class).in(Scopes.SINGLETON); - newExporter(binder).export(HiveFileRenamer.class).as(generatedNameOf(HiveFileRenamer.class, connectorId)); - newSetBinder(binder, EventClient.class).addBinding().to(HiveEventClient.class).in(Scopes.SINGLETON); binder.bind(HivePartitionManager.class).in(Scopes.SINGLETON); newExporter(binder).export(HivePartitionManager.class).withGeneratedName(); @@ -177,11 +170,8 @@ public void configure(Binder binder) binder.bind(ConnectorPageSinkProvider.class).to(HivePageSinkProvider.class).in(Scopes.SINGLETON); binder.bind(ConnectorNodePartitioningProvider.class).to(HiveNodePartitioningProvider.class).in(Scopes.SINGLETON); binder.bind(ConnectorPlanOptimizerProvider.class).to(HivePlanOptimizerProvider.class).in(Scopes.SINGLETON); - binder.bind(ConnectorMetadataUpdaterProvider.class).to(HiveMetadataUpdaterProvider.class).in(Scopes.SINGLETON); - binder.bind(ConnectorTypeSerdeProvider.class).to(HiveConnectorTypeSerdeProvider.class).in(Scopes.SINGLETON); binder.install(new ThriftCodecModule()); binder.install(new DefaultThriftCodecsModule()); - thriftCodecBinder(binder).bindThriftCodec(HiveMetadataUpdateHandle.class); jsonCodecBinder(binder).bindJsonCodec(PartitionUpdate.class); smileCodecBinder(binder).bindSmileCodec(PartitionUpdate.class); @@ -261,14 +251,6 @@ public ExecutorService createCachingHiveMetastoreExecutor(HiveConnectorId hiveCl daemonThreadsNamed("hive-metastore-" + hiveClientId + "-%s")); } - @ForUpdatingHiveMetadata - @Singleton - @Provides - public ExecutorService createUpdatingHiveMetadataExecutor(HiveConnectorId hiveClientId) - { - return newCachedThreadPool(daemonThreadsNamed("hive-metadata-updater-" + hiveClientId + "-%s")); - } - @ForFileRename @Singleton @Provides diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java index da9e8ec98c392..0594f71d3e69c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCoercionPolicy.java @@ -16,13 +16,12 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; +import jakarta.inject.Inject; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; -import javax.inject.Inject; - import java.util.List; import static com.facebook.presto.hive.HiveType.HIVE_BYTE; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java index 1e0bcb904cf07..ae0619790c6a1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveColumnHandle.java @@ -145,6 +145,15 @@ public ColumnMetadata getColumnMetadata(TypeManager typeManager) .build(); } + public ColumnMetadata getColumnMetadata(TypeManager typeManager, String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(typeManager.getType(typeName)) + .setHidden(isHidden()) + .build(); + } + @JsonProperty public Optional getPartialAggregation() { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCompressionCodec.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCompressionCodec.java index 020007c885382..e352603559517 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveCompressionCodec.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveCompressionCodec.java @@ -16,6 +16,7 @@ import com.facebook.presto.orc.metadata.CompressionKind; import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.GzipCodec; +import org.apache.hadoop.io.compress.Lz4Codec; import org.apache.hadoop.io.compress.SnappyCodec; import org.apache.parquet.hadoop.metadata.CompressionCodecName; @@ -25,6 +26,7 @@ import static com.facebook.presto.hive.HiveStorageFormat.DWRF; import static com.facebook.presto.hive.HiveStorageFormat.ORC; import static com.facebook.presto.hive.HiveStorageFormat.PAGEFILE; +import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; import static java.util.Objects.requireNonNull; public enum HiveCompressionCodec @@ -32,12 +34,12 @@ public enum HiveCompressionCodec NONE(null, CompressionKind.NONE, CompressionCodecName.UNCOMPRESSED, f -> true), SNAPPY(SnappyCodec.class, CompressionKind.SNAPPY, CompressionCodecName.SNAPPY, f -> true), GZIP(GzipCodec.class, CompressionKind.ZLIB, CompressionCodecName.GZIP, f -> true), - LZ4(null, CompressionKind.NONE, null, f -> f == PAGEFILE), - ZSTD(null, CompressionKind.ZSTD, null, f -> f == ORC || f == DWRF || f == PAGEFILE); + LZ4(Lz4Codec.class, CompressionKind.LZ4, CompressionCodecName.UNCOMPRESSED, f -> f == PAGEFILE || f == ORC), + ZSTD(null, CompressionKind.ZSTD, CompressionCodecName.ZSTD, f -> f == ORC || f == DWRF || f == PAGEFILE || f == PARQUET); private final Optional> codec; private final CompressionKind orcCompressionKind; - private final Optional parquetCompressionCodec; + private final CompressionCodecName parquetCompressionCodec; private final Predicate supportedStorageFormats; HiveCompressionCodec( @@ -48,7 +50,7 @@ public enum HiveCompressionCodec { this.codec = Optional.ofNullable(codec); this.orcCompressionKind = requireNonNull(orcCompressionKind, "orcCompressionKind is null"); - this.parquetCompressionCodec = Optional.ofNullable(parquetCompressionCodec); + this.parquetCompressionCodec = requireNonNull(parquetCompressionCodec, "parquetCompressionCodec is null"); this.supportedStorageFormats = requireNonNull(supportedStorageFormats, "supportedStorageFormats is null"); } @@ -62,7 +64,7 @@ public CompressionKind getOrcCompressionKind() return orcCompressionKind; } - public Optional getParquetCompressionCodec() + public CompressionCodecName getParquetCompressionCodec() { return parquetCompressionCodec; } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnector.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnector.java index c25ca457d523e..32003811d971e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnector.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnector.java @@ -22,14 +22,12 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorCommitHandle; import com.facebook.presto.spi.connector.ConnectorMetadata; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.session.PropertyMetadata; @@ -73,8 +71,6 @@ public class HiveConnector private final ConnectorAccessControl accessControl; private final ClassLoader classLoader; private final ConnectorPlanOptimizerProvider planOptimizerProvider; - private final ConnectorMetadataUpdaterProvider metadataUpdaterProvider; - private final ConnectorTypeSerdeProvider connectorTypeSerdeProvider; private final HiveTransactionManager transactionManager; @@ -94,8 +90,6 @@ public HiveConnector( List> analyzeProperties, ConnectorAccessControl accessControl, ConnectorPlanOptimizerProvider planOptimizerProvider, - ConnectorMetadataUpdaterProvider metadataUpdaterProvider, - ConnectorTypeSerdeProvider connectorTypeSerdeProvider, ClassLoader classLoader) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); @@ -114,8 +108,6 @@ public HiveConnector( this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.classLoader = requireNonNull(classLoader, "classLoader is null"); this.planOptimizerProvider = requireNonNull(planOptimizerProvider, "planOptimizerProvider is null"); - this.metadataUpdaterProvider = requireNonNull(metadataUpdaterProvider, "metadataUpdaterProvider is null"); - this.connectorTypeSerdeProvider = requireNonNull(connectorTypeSerdeProvider, "connectorTypeSerdeProvider is null"); } @Override @@ -156,18 +148,6 @@ public ConnectorPlanOptimizerProvider getConnectorPlanOptimizerProvider() return planOptimizerProvider; } - @Override - public ConnectorMetadataUpdaterProvider getConnectorMetadataUpdaterProvider() - { - return metadataUpdaterProvider; - } - - @Override - public ConnectorTypeSerdeProvider getConnectorTypeSerdeProvider() - { - return connectorTypeSerdeProvider; - } - @Override public Set getSystemTables() { @@ -204,6 +184,12 @@ public List> getTableProperties() return tableProperties; } + @Override + public List> getMaterializedViewProperties() + { + return tableProperties; + } + @Override public ConnectorAccessControl getAccessControl() { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorFactory.java index eec08e2a792db..a68d371ee0712 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorFactory.java @@ -38,13 +38,11 @@ import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorPageSinkProvider; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorPageSourceProvider; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorSplitManager; @@ -158,8 +156,6 @@ public Connector create(String catalogName, Map config, Connecto ConnectorAccessControl accessControl = new SystemTableAwareAccessControl(injector.getInstance(ConnectorAccessControl.class)); Set procedures = injector.getInstance(Key.get(new TypeLiteral>() {})); ConnectorPlanOptimizerProvider planOptimizerProvider = injector.getInstance(ConnectorPlanOptimizerProvider.class); - ConnectorMetadataUpdaterProvider metadataUpdaterProvider = injector.getInstance(ConnectorMetadataUpdaterProvider.class); - ConnectorTypeSerdeProvider connectorTypeSerdeProvider = injector.getInstance(ConnectorTypeSerdeProvider.class); List> allSessionProperties = new ArrayList<>(hiveSessionProperties.getSessionProperties()); allSessionProperties.addAll(hiveCommonSessionProperties.getSessionProperties()); @@ -180,8 +176,6 @@ public Connector create(String catalogName, Map config, Connecto hiveAnalyzeProperties.getAnalyzeProperties(), accessControl, planOptimizerProvider, - metadataUpdaterProvider, - connectorTypeSerdeProvider, classLoader); } catch (Exception e) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorTypeSerdeProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorTypeSerdeProvider.java deleted file mode 100644 index 2c1127f9051c7..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveConnectorTypeSerdeProvider.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.drift.codec.ThriftCodecManager; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; - -import javax.inject.Inject; - -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; - -public class HiveConnectorTypeSerdeProvider - implements ConnectorTypeSerdeProvider -{ - private final ThriftCodecManager thriftCodecManager; - private final Protocol thriftProtocol; - private final int bufferSize; - - @Inject - public HiveConnectorTypeSerdeProvider(HiveClientConfig hiveClientConfig, ThriftCodecManager thriftCodecManager) - { - this.thriftCodecManager = requireNonNull(thriftCodecManager, "thriftCodecManager is null"); - requireNonNull(hiveClientConfig, "hiveClientConfig is null"); - this.thriftProtocol = hiveClientConfig.getThriftProtocol(); - this.bufferSize = toIntExact(hiveClientConfig.getThriftBufferSize().toBytes()); - } - - @Override - public ConnectorTypeSerde getConnectorMetadataUpdateHandleSerde() - { - return new HiveMetadataUpdateHandleThriftSerde(thriftCodecManager, thriftProtocol, bufferSize); - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveEncryptionInformationProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveEncryptionInformationProvider.java index ab60903a16b76..b3e3ca4637ff4 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveEncryptionInformationProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveEncryptionInformationProvider.java @@ -19,8 +19,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveFileRenamer.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveFileRenamer.java deleted file mode 100644 index 688aaee144fd7..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveFileRenamer.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.SchemaTableName; -import com.google.common.collect.ImmutableList; -import org.weakref.jmx.Managed; - -import javax.annotation.PreDestroy; - -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import static com.google.common.base.Verify.verify; - -public class HiveFileRenamer -{ - private final Map> queryPartitionFileCounterMap = new ConcurrentHashMap<>(); - private final Map> queryHiveMetadataResultMap = new ConcurrentHashMap<>(); - - public List getMetadataUpdateResults(List metadataUpdateRequests, QueryId queryId) - { - ImmutableList.Builder metadataUpdateResults = ImmutableList.builder(); - - for (ConnectorMetadataUpdateHandle connectorMetadataUpdateHandle : metadataUpdateRequests) { - HiveMetadataUpdateHandle request = (HiveMetadataUpdateHandle) connectorMetadataUpdateHandle; - String fileName = getFileName(request, queryId); - metadataUpdateResults.add(new HiveMetadataUpdateHandle(request.getRequestId(), request.getSchemaTableName(), request.getPartitionName(), Optional.of(fileName))); - } - return metadataUpdateResults.build(); - } - - public void cleanup(QueryId queryId) - { - queryPartitionFileCounterMap.remove(queryId); - queryHiveMetadataResultMap.remove(queryId); - } - - private String getFileName(HiveMetadataUpdateHandle request, QueryId queryId) - { - if (!queryPartitionFileCounterMap.containsKey(queryId) || !queryHiveMetadataResultMap.containsKey(queryId)) { - queryPartitionFileCounterMap.putIfAbsent(queryId, new ConcurrentHashMap<>()); - queryHiveMetadataResultMap.putIfAbsent(queryId, new ConcurrentHashMap<>()); - } - - // To keep track of the file counter per query per partition - Map partitionFileCounterMap = queryPartitionFileCounterMap.get(queryId); - - // To keep track of the file name result per query per request - // This is to make sure that request - fileName mapping is 1:1 - Map hiveMetadataResultMap = queryHiveMetadataResultMap.get(queryId); - - // If we have seen this request before then directly return the result. - if (hiveMetadataResultMap.containsKey(request)) { - // We come here if for some reason the worker did not receive the fileName and it retried the request. - return hiveMetadataResultMap.get(request); - } - - HiveMetadataUpdateKey key = new HiveMetadataUpdateKey(request); - // File names start from 0 - partitionFileCounterMap.putIfAbsent(key, new AtomicLong(0)); - - AtomicLong fileCount = partitionFileCounterMap.get(key); - String fileName = Long.valueOf(fileCount.getAndIncrement()).toString(); - - // Store the request - fileName mapping - hiveMetadataResultMap.put(request, fileName); - - return fileName; - } - - @PreDestroy - public void stop() - { - // Mappings should be deleted when query finishes. So verify that map is empty before its closed. - verify(queryPartitionFileCounterMap.isEmpty(), "Query partition file counter map has %s entries left behind", queryPartitionFileCounterMap.size()); - verify(queryHiveMetadataResultMap.isEmpty(), "Query hive metadata result map has %s entries left behind", queryHiveMetadataResultMap.size()); - } - - @Managed - public int getQueryPartitionFileCounterMapSize() - { - return queryPartitionFileCounterMap.size(); - } - - private static class HiveMetadataUpdateKey - { - private final SchemaTableName schemaTableName; - private final Optional partitionName; - - private HiveMetadataUpdateKey(HiveMetadataUpdateHandle hiveMetadataUpdateHandle) - { - this.schemaTableName = hiveMetadataUpdateHandle.getSchemaTableName(); - this.partitionName = hiveMetadataUpdateHandle.getPartitionName(); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - HiveMetadataUpdateKey o = (HiveMetadataUpdateKey) obj; - return schemaTableName.equals(o.schemaTableName) && - partitionName.equals(o.partitionName); - } - - @Override - public int hashCode() - { - return Objects.hash(schemaTableName, partitionName); - } - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHandleResolver.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHandleResolver.java index e726cbe9d5c2a..da668f5a0fd9f 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHandleResolver.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHandleResolver.java @@ -16,7 +16,6 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -74,10 +73,4 @@ public Class getPartitioningHandleClass() { return HivePartitioningHandle.class; } - - @Override - public Class getMetadataUpdateHandleClass() - { - return HiveMetadataUpdateHandle.class; - } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java index 67b4241d25ebd..8c0097ae0bdbf 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveHdfsConfiguration.java @@ -14,10 +14,9 @@ package com.facebook.presto.hive; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import java.net.URI; import java.util.Set; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveInputInfo.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveInputInfo.java index 384cf52ad06ee..6e8a985b6dfe6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveInputInfo.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveInputInfo.java @@ -24,14 +24,17 @@ public class HiveInputInfo // Code that serialize HiveInputInfo into log would often need the ability to limit the length of log entries. // This boolean field allows such code to mark the log entry as length limited. private final boolean truncated; + private final String tableLocation; @JsonCreator public HiveInputInfo( @JsonProperty("partitionIds") List partitionIds, - @JsonProperty("truncated") boolean truncated) + @JsonProperty("truncated") boolean truncated, + @JsonProperty("tableLocation") String tableLocation) { this.partitionIds = partitionIds; this.truncated = truncated; + this.tableLocation = tableLocation; } @JsonProperty @@ -45,4 +48,10 @@ public boolean isTruncated() { return truncated; } + + @JsonProperty + public String getTableLocation() + { + return tableLocation; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveLocationService.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveLocationService.java index 2bddcbf2928da..bc854615d4d30 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveLocationService.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveLocationService.java @@ -20,10 +20,9 @@ import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; +import jakarta.inject.Inject; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.util.Optional; import static com.facebook.presto.hive.HiveErrorCode.HIVE_PATH_ALREADY_EXISTS; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index f20393703e9c3..f9cab97400b3e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -15,8 +15,11 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; +import com.facebook.airlift.log.Logger; import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.Page; import com.facebook.presto.common.Subfield; +import com.facebook.presto.common.block.Block; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.ArrayType; @@ -46,7 +49,6 @@ import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; @@ -65,7 +67,6 @@ import com.facebook.presto.spi.MaterializedViewNotFoundException; import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; @@ -108,7 +109,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; -import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -184,6 +184,7 @@ import static com.facebook.presto.hive.HiveColumnHandle.FILE_SIZE_COLUMN_NAME; import static com.facebook.presto.hive.HiveColumnHandle.PATH_COLUMN_NAME; import static com.facebook.presto.hive.HiveColumnHandle.ROW_ID_COLUMN_NAME; +import static com.facebook.presto.hive.HiveColumnHandle.rowIdColumnHandle; import static com.facebook.presto.hive.HiveColumnHandle.updateRowIdHandle; import static com.facebook.presto.hive.HiveErrorCode.HIVE_COLUMN_ORDER_MISMATCH; import static com.facebook.presto.hive.HiveErrorCode.HIVE_CONCURRENT_MODIFICATION_DETECTED; @@ -233,12 +234,13 @@ import static com.facebook.presto.hive.HiveSessionProperties.isUsePageFileForHiveUnsupportedType; import static com.facebook.presto.hive.HiveSessionProperties.shouldCreateEmptyBucketFilesForTemporaryTable; import static com.facebook.presto.hive.HiveStatisticsUtil.createPartitionStatistics; -import static com.facebook.presto.hive.HiveStatisticsUtil.getColumnStatistics; import static com.facebook.presto.hive.HiveStorageFormat.AVRO; +import static com.facebook.presto.hive.HiveStorageFormat.CSV; import static com.facebook.presto.hive.HiveStorageFormat.DWRF; import static com.facebook.presto.hive.HiveStorageFormat.ORC; import static com.facebook.presto.hive.HiveStorageFormat.PAGEFILE; import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; +import static com.facebook.presto.hive.HiveStorageFormat.TEXTFILE; import static com.facebook.presto.hive.HiveStorageFormat.values; import static com.facebook.presto.hive.HiveTableProperties.AVRO_SCHEMA_URL; import static com.facebook.presto.hive.HiveTableProperties.BUCKETED_BY_PROPERTY; @@ -255,6 +257,8 @@ import static com.facebook.presto.hive.HiveTableProperties.ORC_BLOOM_FILTER_FPP; import static com.facebook.presto.hive.HiveTableProperties.PARTITIONED_BY_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.PREFERRED_ORDERING_COLUMNS; +import static com.facebook.presto.hive.HiveTableProperties.SKIP_FOOTER_LINE_COUNT; +import static com.facebook.presto.hive.HiveTableProperties.SKIP_HEADER_LINE_COUNT; import static com.facebook.presto.hive.HiveTableProperties.SORTED_BY_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.STORAGE_FORMAT_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.getAvroSchemaUrl; @@ -265,6 +269,8 @@ import static com.facebook.presto.hive.HiveTableProperties.getEncryptColumns; import static com.facebook.presto.hive.HiveTableProperties.getEncryptTable; import static com.facebook.presto.hive.HiveTableProperties.getExternalLocation; +import static com.facebook.presto.hive.HiveTableProperties.getFooterSkipCount; +import static com.facebook.presto.hive.HiveTableProperties.getHeaderSkipCount; import static com.facebook.presto.hive.HiveTableProperties.getHiveStorageFormat; import static com.facebook.presto.hive.HiveTableProperties.getOrcBloomFilterColumns; import static com.facebook.presto.hive.HiveTableProperties.getOrcBloomFilterFpp; @@ -281,6 +287,7 @@ import static com.facebook.presto.hive.HiveUtil.encodeViewData; import static com.facebook.presto.hive.HiveUtil.getPartitionKeyColumnHandles; import static com.facebook.presto.hive.HiveUtil.hiveColumnHandles; +import static com.facebook.presto.hive.HiveUtil.parsePartitionValue; import static com.facebook.presto.hive.HiveUtil.translateHiveUnsupportedTypeForTemporaryTable; import static com.facebook.presto.hive.HiveUtil.translateHiveUnsupportedTypesForTemporaryTable; import static com.facebook.presto.hive.HiveUtil.verifyPartitionTypeSupported; @@ -307,17 +314,20 @@ import static com.facebook.presto.hive.metastore.MetastoreUtil.TABLE_COMMENT; import static com.facebook.presto.hive.metastore.MetastoreUtil.buildInitialPrivilegeSet; import static com.facebook.presto.hive.metastore.MetastoreUtil.checkIfNullView; +import static com.facebook.presto.hive.metastore.MetastoreUtil.createPartitionValues; import static com.facebook.presto.hive.metastore.MetastoreUtil.createTableObjectForViewCreation; import static com.facebook.presto.hive.metastore.MetastoreUtil.createViewProperties; import static com.facebook.presto.hive.metastore.MetastoreUtil.extractPartitionValues; import static com.facebook.presto.hive.metastore.MetastoreUtil.getHiveSchema; import static com.facebook.presto.hive.metastore.MetastoreUtil.getMetastoreHeaders; +import static com.facebook.presto.hive.metastore.MetastoreUtil.getPartitionNames; import static com.facebook.presto.hive.metastore.MetastoreUtil.getPartitionNamesWithEmptyVersion; import static com.facebook.presto.hive.metastore.MetastoreUtil.getProtectMode; import static com.facebook.presto.hive.metastore.MetastoreUtil.isDeltaLakeTable; import static com.facebook.presto.hive.metastore.MetastoreUtil.isIcebergTable; import static com.facebook.presto.hive.metastore.MetastoreUtil.isPrestoView; import static com.facebook.presto.hive.metastore.MetastoreUtil.isUserDefinedTypeEncodingEnabled; +import static com.facebook.presto.hive.metastore.MetastoreUtil.makePartName; import static com.facebook.presto.hive.metastore.MetastoreUtil.toPartitionValues; import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyAndPopulateViews; import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyOnline; @@ -386,6 +396,7 @@ public class HiveMetadata implements TransactionalMetadata { + private static final Logger log = Logger.get(HiveMetadata.class); public static final Set RESERVED_ROLES = ImmutableSet.of("all", "default", "none"); public static final String REFERENCED_MATERIALIZED_VIEWS = "referenced_materialized_views"; @@ -407,6 +418,9 @@ public class HiveMetadata private static final String CSV_QUOTE_KEY = OpenCSVSerde.QUOTECHAR; private static final String CSV_ESCAPE_KEY = OpenCSVSerde.ESCAPECHAR; + public static final String SKIP_HEADER_COUNT_KEY = "skip.header.line.count"; + public static final String SKIP_FOOTER_COUNT_KEY = "skip.footer.line.count"; + private static final JsonCodec MATERIALIZED_VIEW_JSON_CODEC = jsonCodec(MaterializedViewDefinition.class); private final boolean allowCorruptWritesForTesting; @@ -432,7 +446,6 @@ public class HiveMetadata private final PartitionObjectBuilder partitionObjectBuilder; private final HiveEncryptionInformationProvider encryptionInformationProvider; private final HivePartitionStats hivePartitionStats; - private final HiveFileRenamer hiveFileRenamer; private final TableWritabilityChecker tableWritabilityChecker; public HiveMetadata( @@ -459,7 +472,6 @@ public HiveMetadata( PartitionObjectBuilder partitionObjectBuilder, HiveEncryptionInformationProvider encryptionInformationProvider, HivePartitionStats hivePartitionStats, - HiveFileRenamer hiveFileRenamer, TableWritabilityChecker tableWritabilityChecker) { this.allowCorruptWritesForTesting = allowCorruptWritesForTesting; @@ -486,7 +498,6 @@ public HiveMetadata( this.partitionObjectBuilder = requireNonNull(partitionObjectBuilder, "partitionObjectBuilder is null"); this.encryptionInformationProvider = requireNonNull(encryptionInformationProvider, "encryptionInformationProvider is null"); this.hivePartitionStats = requireNonNull(hivePartitionStats, "hivePartitionStats is null"); - this.hiveFileRenamer = requireNonNull(hiveFileRenamer, "hiveFileRenamer is null"); this.tableWritabilityChecker = requireNonNull(tableWritabilityChecker, "tableWritabilityChecker is null"); } @@ -587,7 +598,7 @@ private Optional getPropertiesSystemTable(ConnectorSession session, } Map sortedTableParameters = ImmutableSortedMap.copyOf(table.get().getParameters()); List columns = sortedTableParameters.keySet().stream() - .map(key -> ColumnMetadata.builder().setName(key).setType(VARCHAR).build()) + .map(key -> ColumnMetadata.builder().setName(normalizeIdentifier(session, key)).setType(VARCHAR).build()) .collect(toImmutableList()); List types = columns.stream() .map(ColumnMetadata::getType) @@ -620,7 +631,7 @@ private Optional getPartitionsSystemTable(ConnectorSession session, List partitionSystemTableColumns = partitionColumns.stream() .map(column -> ColumnMetadata.builder() - .setName(column.getName()) + .setName(normalizeIdentifier(session, column.getName())) .setType(typeManager.getType(column.getTypeSignature())) .setComment(column.getComment().orElse(null)) .setHidden(column.isHidden()) @@ -674,7 +685,7 @@ private ConnectorTableMetadata getTableMetadata(Optional
table, SchemaTab } if (isIcebergTable(table.get()) || isDeltaLakeTable(table.get())) { - throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, format("Not a Hive table '%s'", tableName)); + throw new UnknownTableTypeException("Not a Hive table: " + tableName); } List> tableConstraints = metastore.getTableConstraints(metastoreContext, tableName.getSchemaName(), tableName.getTableName()); @@ -751,6 +762,12 @@ private ConnectorTableMetadata getTableMetadata(Optional
table, SchemaTab properties.put(AVRO_SCHEMA_URL, avroSchemaUrl); } + // Textfile and CSV specific properties + getSerdeProperty(table.get(), SKIP_HEADER_COUNT_KEY) + .ifPresent(skipHeaderCount -> properties.put(SKIP_HEADER_LINE_COUNT, Integer.valueOf(skipHeaderCount))); + getSerdeProperty(table.get(), SKIP_FOOTER_COUNT_KEY) + .ifPresent(skipFooterCount -> properties.put(SKIP_FOOTER_LINE_COUNT, Integer.valueOf(skipFooterCount))); + // CSV specific property getCsvSerdeProperty(table.get(), CSV_SEPARATOR_KEY) .ifPresent(csvSeparator -> properties.put(CSV_SEPARATOR, csvSeparator)); @@ -805,10 +822,10 @@ public Optional getInfo(ConnectorTableLayoutHandle layoutHandle) HiveTableLayoutHandle tableLayoutHandle = (HiveTableLayoutHandle) layoutHandle; if (tableLayoutHandle.getPartitions().isPresent()) { return Optional.of(new HiveInputInfo( - tableLayoutHandle.getPartitions().get().stream() + tableLayoutHandle.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get().stream() .map(hivePartition -> hivePartition.getPartitionId().getPartitionName()) .collect(toList()), - false)); + false, tableLayoutHandle.getTablePath())); } return Optional.empty(); @@ -862,6 +879,9 @@ public Map> listTableColumns(ConnectorSess catch (TableNotFoundException e) { // table disappeared during listing operation } + catch (UnknownTableTypeException e) { + log.warn(String.format("%s: Unknown table type of table %s", e.getMessage(), tableName)); + } } return columns.build(); } @@ -1062,6 +1082,9 @@ private Table prepareTable(ConnectorSession session, ConnectorTableMetadata tabl else if (tableType.equals(MANAGED_TABLE) || tableType.equals(MATERIALIZED_VIEW)) { LocationHandle locationHandle = locationService.forNewTable(metastore, session, schemaName, tableName, isTempPathRequired(session, bucketProperty, preferredOrderingColumns)); targetPath = locationService.getQueryWriteInfo(locationHandle).getTargetPath(); + if (getFooterSkipCount(tableMetadata.getProperties()).isPresent()) { + throw new PrestoException(NOT_SUPPORTED, format("Cannot create non external table with %s property", SKIP_FOOTER_COUNT_KEY)); + } } else { throw new IllegalStateException(format("%s is not a valid table type to be created.", tableType)); @@ -1289,20 +1312,42 @@ private Map getEmptyTableProperties( tableProperties.put(AVRO_SCHEMA_URL_KEY, validateAndNormalizeAvroSchemaUrl(avroSchemaUrl, hdfsContext)); } + // Textfile and CSV specific properties + Set csvAndTextFile = ImmutableSet.of(TEXTFILE, CSV); + getHeaderSkipCount(tableMetadata.getProperties()).ifPresent(headerSkipCount -> { + if (headerSkipCount > 0) { + checkFormatForProperty(hiveStorageFormat, csvAndTextFile, SKIP_HEADER_LINE_COUNT); + tableProperties.put(SKIP_HEADER_COUNT_KEY, String.valueOf(headerSkipCount)); + } + if (headerSkipCount < 0) { + throw new PrestoException(HIVE_INVALID_METADATA, format("Invalid value for %s property: %s", SKIP_HEADER_LINE_COUNT, headerSkipCount)); + } + }); + + getFooterSkipCount(tableMetadata.getProperties()).ifPresent(footerSkipCount -> { + if (footerSkipCount > 0) { + checkFormatForProperty(hiveStorageFormat, csvAndTextFile, SKIP_FOOTER_LINE_COUNT); + tableProperties.put(SKIP_FOOTER_COUNT_KEY, String.valueOf(footerSkipCount)); + } + if (footerSkipCount < 0) { + throw new PrestoException(HIVE_INVALID_METADATA, format("Invalid value for %s property: %s", SKIP_FOOTER_LINE_COUNT, footerSkipCount)); + } + }); + // CSV specific properties getCsvProperty(tableMetadata.getProperties(), CSV_ESCAPE) .ifPresent(escape -> { - checkFormatForProperty(hiveStorageFormat, HiveStorageFormat.CSV, CSV_ESCAPE); + checkFormatForProperty(hiveStorageFormat, CSV, CSV_ESCAPE); tableProperties.put(CSV_ESCAPE_KEY, escape.toString()); }); getCsvProperty(tableMetadata.getProperties(), CSV_QUOTE) .ifPresent(quote -> { - checkFormatForProperty(hiveStorageFormat, HiveStorageFormat.CSV, CSV_QUOTE); + checkFormatForProperty(hiveStorageFormat, CSV, CSV_QUOTE); tableProperties.put(CSV_QUOTE_KEY, quote.toString()); }); getCsvProperty(tableMetadata.getProperties(), CSV_SEPARATOR) .ifPresent(separator -> { - checkFormatForProperty(hiveStorageFormat, HiveStorageFormat.CSV, CSV_SEPARATOR); + checkFormatForProperty(hiveStorageFormat, CSV, CSV_SEPARATOR); tableProperties.put(CSV_SEPARATOR_KEY, separator.toString()); }); @@ -1322,6 +1367,13 @@ private static void checkFormatForProperty(HiveStorageFormat actualStorageFormat } } + private static void checkFormatForProperty(HiveStorageFormat actualStorageFormat, Set expectedStorageFormats, String propertyName) + { + if (!expectedStorageFormats.contains(actualStorageFormat)) { + throw new PrestoException(INVALID_TABLE_PROPERTY, format("Cannot specify %s table property for storage format: %s", propertyName, actualStorageFormat)); + } + } + private String validateAndNormalizeAvroSchemaUrl(String url, HdfsContext context) { try { @@ -1545,15 +1597,20 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH metastore.setTableStatistics(metastoreContext, table, createPartitionStatistics(session, columnTypes, computedStatisticsMap.get(ImmutableList.of()), timeZone)); } else { + List partitionNames; List> partitionValuesList; if (handle.getAnalyzePartitionValues().isPresent()) { partitionValuesList = handle.getAnalyzePartitionValues().get(); + partitionNames = partitionValuesList.stream() + .map(partitionValues -> makePartName(partitionColumns, partitionValues)) + .collect(toImmutableList()); } else { - partitionValuesList = metastore.getPartitionNames(metastoreContext, handle.getSchemaName(), handle.getTableName()) - .orElseThrow(() -> new TableNotFoundException(((HiveTableHandle) tableHandle).getSchemaTableName())) + partitionNames = getPartitionNames(metastore.getPartitionNames(metastoreContext, handle.getSchemaName(), handle.getTableName()) + .orElseThrow(() -> new TableNotFoundException(tableName))); + partitionValuesList = partitionNames .stream() - .map(partitionNameWithVersion -> MetastoreUtil.toPartitionValues(partitionNameWithVersion.getPartitionName())) + .map(MetastoreUtil::toPartitionValues) .collect(toImmutableList()); } @@ -1565,8 +1622,18 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH Supplier emptyPartitionStatistics = Suppliers.memoize(() -> createEmptyPartitionStatistics(columnTypes, columnStatisticTypes)); int usedComputedStatistics = 0; - for (List partitionValues : partitionValuesList) { - ComputedStatistics collectedStatistics = computedStatisticsMap.get(partitionValues); + List partitionedBy = table.getPartitionColumns().stream() + .map(Column::getName) + .collect(toImmutableList()); + List partitionTypes = partitionedBy.stream() + .map(columnTypes::get) + .collect(toImmutableList()); + for (int i = 0; i < partitionNames.size(); i++) { + String partitionName = partitionNames.get(i); + List partitionValues = partitionValuesList.get(i); + ComputedStatistics collectedStatistics = computedStatisticsMap.containsKey(partitionValues) + ? computedStatisticsMap.get(partitionValues) + : computedStatisticsMap.get(canonicalizePartitionValues(partitionName, partitionValues, partitionTypes)); if (collectedStatistics == null) { partitionStatistics.put(partitionValues, emptyPartitionStatistics.get()); } @@ -1575,11 +1642,46 @@ public void finishStatisticsCollection(ConnectorSession session, ConnectorTableH partitionStatistics.put(partitionValues, createPartitionStatistics(session, columnTypes, collectedStatistics, timeZone)); } } - verify(usedComputedStatistics == computedStatistics.size(), "All computed statistics must be used"); + verify(usedComputedStatistics == computedStatistics.size(), + usedComputedStatistics > computedStatistics.size() ? + "There are multiple variants of the same partition, e.g. p=1, p=01, p=001. All partitions must follow the same key=value representation" : + "All computed statistics must be used"); metastore.setPartitionStatistics(metastoreContext, table, partitionStatistics.build()); } } + private static Map getColumnStatistics(Map, ComputedStatistics> statistics, List partitionValues) + { + return Optional.ofNullable(statistics.get(partitionValues)) + .map(ComputedStatistics::getColumnStatistics) + .orElse(ImmutableMap.of()); + } + + private Map getColumnStatistics( + Map, ComputedStatistics> statistics, + String partitionName, + List partitionValues, + List partitionTypes) + { + Optional> columnStatistics = Optional.ofNullable(statistics.get(partitionValues)) + .map(ComputedStatistics::getColumnStatistics); + return columnStatistics + .orElseGet(() -> getColumnStatistics(statistics, canonicalizePartitionValues(partitionName, partitionValues, partitionTypes))); + } + + private List canonicalizePartitionValues(String partitionName, List partitionValues, List partitionTypes) + { + verify(partitionValues.size() == partitionTypes.size(), "Expected partitionTypes size to be %s but got %s", partitionValues.size(), partitionTypes.size()); + Block[] parsedPartitionValuesBlocks = new Block[partitionValues.size()]; + for (int i = 0; i < partitionValues.size(); i++) { + String partitionValue = partitionValues.get(i); + Type partitionType = partitionTypes.get(i); + parsedPartitionValuesBlocks[i] = parsePartitionValue(partitionName, partitionValue, partitionType, timeZone).asBlock(); + } + + return createPartitionValues(partitionTypes, new Page(parsedPartitionValuesBlocks), 0); + } + @Override public HiveOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout) { @@ -1592,7 +1694,15 @@ public HiveOutputTableHandle beginCreateTable(ConnectorSession session, Connecto if (getAvroSchemaUrl(tableMetadata.getProperties()) != null) { throw new PrestoException(NOT_SUPPORTED, "CREATE TABLE AS not supported when Avro schema url is set"); } + getHeaderSkipCount(tableMetadata.getProperties()).ifPresent(headerSkipCount -> { + if (headerSkipCount > 1) { + throw new PrestoException(NOT_SUPPORTED, format("CREATE TABLE AS not supported when the value of %s property is greater than 1", SKIP_HEADER_COUNT_KEY)); + } + }); + getFooterSkipCount(tableMetadata.getProperties()).ifPresent(footerSkipCount -> { + throw new PrestoException(NOT_SUPPORTED, format("Property %s is not supported with CREATE TABLE AS", SKIP_FOOTER_COUNT_KEY)); + }); HiveStorageFormat tableStorageFormat = getHiveStorageFormat(tableMetadata.getProperties()); List partitionedBy = getPartitionedBy(tableMetadata.getProperties()); Optional bucketProperty = getBucketProperty(tableMetadata.getProperties()); @@ -1756,7 +1866,7 @@ public Optional finishCreateTable(ConnectorSession sess metastore.createTable(session, table, principalPrivileges, Optional.of(writeInfo.getWritePath()), false, tableStatistics, emptyList()); if (handle.getPartitionedBy().isEmpty()) { - return Optional.of(new HiveWrittenPartitions(ImmutableList.of(UNPARTITIONED_ID.getPartitionName()))); + return Optional.of(new HiveOutputMetadata(new HiveOutputInfo(ImmutableList.of(UNPARTITIONED_ID.getPartitionName()), writeInfo.getTargetPath().toString()))); } if (isRespectTableFormat(session)) { @@ -1785,10 +1895,10 @@ public Optional finishCreateTable(ConnectorSession sess partitionStatistics); } - return Optional.of(new HiveWrittenPartitions( + return Optional.of(new HiveOutputMetadata(new HiveOutputInfo( partitionUpdates.stream() .map(PartitionUpdate::getName) - .collect(toList()))); + .collect(toList()), writeInfo.getTargetPath().toString()))); } public static boolean shouldCreateFilesForMissingBuckets(Table table, ConnectorSession session) @@ -1961,6 +2071,15 @@ private HiveInsertTableHandle beginInsertInternal(ConnectorSession session, Conn locationHandle = locationService.forExistingTable(metastore, session, table, tempPathRequired); } + Optional.ofNullable(table.getParameters().get(SKIP_HEADER_COUNT_KEY)).map(Integer::parseInt).ifPresent(headerSkipCount -> { + if (headerSkipCount > 1) { + throw new PrestoException(NOT_SUPPORTED, format("INSERT into %s Hive table with value of %s property greater than 1 is not supported", tableName, SKIP_HEADER_COUNT_KEY)); + } + }); + if (table.getParameters().containsKey(SKIP_FOOTER_COUNT_KEY)) { + throw new PrestoException(NOT_SUPPORTED, format("INSERT into %s Hive table with %s property not supported", tableName, SKIP_FOOTER_COUNT_KEY)); + } + Optional tableEncryptionProperties = getTableEncryptionPropertiesFromHiveProperties(table.getParameters(), tableStorageFormat); HiveStorageFormat partitionStorageFormat = isRespectTableFormat(session) ? tableStorageFormat : getHiveStorageFormat(session); @@ -2110,12 +2229,16 @@ else if (partitionUpdate.getUpdateMode() == APPEND) { throw new PrestoException(HIVE_UNSUPPORTED_ENCRYPTION_OPERATION, "Inserting into an existing partition with encryption enabled is not supported yet"); } // insert into existing partition - List partitionValues = toPartitionValues(partitionUpdate.getName()); + String partitionName = partitionUpdate.getName(); + List partitionValues = toPartitionValues(partitionName); + List partitionTypes = partitionedBy.stream() + .map(columnTypes::get) + .collect(toImmutableList()); PartitionStatistics partitionStatistics = createPartitionStatistics( session, partitionUpdate.getStatistics(), columnTypes, - getColumnStatistics(partitionComputedStatistics, partitionValues), timeZone); + getColumnStatistics(partitionComputedStatistics, partitionName, partitionValues, partitionTypes), timeZone); metastore.finishInsertIntoExistingPartition( session, handle.getSchemaName(), @@ -2171,11 +2294,16 @@ else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode throw new PrestoException(HIVE_CONCURRENT_MODIFICATION_DETECTED, "Partition format changed during insert"); } + String partitionName = partitionUpdate.getName(); + List partitionValues = partition.getValues(); + List partitionTypes = partitionedBy.stream() + .map(columnTypes::get) + .collect(toImmutableList()); PartitionStatistics partitionStatistics = createPartitionStatistics( session, partitionUpdate.getStatistics(), columnTypes, - getColumnStatistics(partitionComputedStatistics, partition.getValues()), + getColumnStatistics(partitionComputedStatistics, partitionName, partitionValues, partitionTypes), timeZone); // New partition or overwriting existing partition by staging and moving the new partition @@ -2196,11 +2324,11 @@ else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode } } - return Optional.of(new HiveWrittenPartitions( + return Optional.of(new HiveOutputMetadata(new HiveOutputInfo( partitionUpdates.stream() .map(PartitionUpdate::getName) .map(name -> name.isEmpty() ? UNPARTITIONED_ID.getPartitionName() : name) - .collect(toList()))); + .collect(toList()), table.getStorage().getLocation()))); } /** @@ -2208,8 +2336,8 @@ else if (partitionUpdate.getUpdateMode() == NEW || partitionUpdate.getUpdateMode * This is required when we are overwriting the partitions by directly writing the new * files to the existing directory, where files written by older queries may be present too. * - * @param session the ConnectorSession object - * @param partitionPath the path of the partition from where the older files are to be deleted + * @param session the ConnectorSession object + * @param partitionPath the path of the partition from where the older files are to be deleted */ private void removeNonCurrentQueryFiles(ConnectorSession session, Path partitionPath) { @@ -2405,9 +2533,6 @@ public MaterializedViewStatus getMaterializedViewStatus(ConnectorSession session Map> viewToBasePartitionMap = getViewToBasePartitionMap(materializedViewTable, baseTables, directColumnMappings); MaterializedDataPredicates materializedDataPredicates = getMaterializedDataPredicates(metastore, metastoreContext, typeManager, materializedViewTable, timeZone); - if (materializedDataPredicates.getPredicateDisjuncts().isEmpty()) { - return new MaterializedViewStatus(NOT_MATERIALIZED); - } // Partitions to keep track of for materialized view freshness are the partitions of every base table // that are not available/updated to the materialized view yet. @@ -2430,13 +2555,15 @@ public MaterializedViewStatus getMaterializedViewStatus(ConnectorSession session for (MaterializedDataPredicates dataPredicates : partitionsFromBaseTables.values()) { if (!dataPredicates.getPredicateDisjuncts().isEmpty()) { - missingPartitions += dataPredicates.getPredicateDisjuncts().stream() + missingPartitions += (int) dataPredicates.getPredicateDisjuncts().stream() .filter(baseQueryDomain::overlaps) - .mapToInt(tupleDomain -> tupleDomain.getDomains().isPresent() ? tupleDomain.getDomains().get().size() : 0) - .sum(); + .count(); } } + if (materializedDataPredicates.getPredicateDisjuncts().isEmpty()) { + return new MaterializedViewStatus(NOT_MATERIALIZED, partitionsFromBaseTables); + } if (missingPartitions > HiveSessionProperties.getMaterializedViewMissingPartitionsThreshold(session)) { return new MaterializedViewStatus(TOO_MANY_PARTITIONS_MISSING, partitionsFromBaseTables); } @@ -2461,6 +2588,7 @@ public void createMaterializedView(ConnectorSession session, ConnectorTableMetad viewDefinition.getTable(), viewDefinition.getBaseTables(), viewDefinition.getOwner(), + viewDefinition.getSecurityMode(), viewDefinition.getColumnMappings(), viewDefinition.getBaseTablesOnOuterJoinSide(), Optional.of(getPartitionedBy(viewMetadata.getProperties()))); @@ -2547,9 +2675,9 @@ public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, Connecto } @Override - public ColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + public Optional getDeleteRowIdColumn(ConnectorSession session, ConnectorTableHandle tableHandle) { - return updateRowIdHandle(); + return Optional.of(updateRowIdHandle()); } @Override @@ -2579,13 +2707,13 @@ public OptionalLong metadataDelete(ConnectorSession session, ConnectorTableHandl private List getOrComputePartitions(HiveTableLayoutHandle layoutHandle, ConnectorSession session, ConnectorTableHandle tableHandle) { if (layoutHandle.getPartitions().isPresent()) { - return layoutHandle.getPartitions().get(); + return layoutHandle.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get(); } else { TupleDomain partitionColumnPredicate = layoutHandle.getPartitionColumnPredicate(); Predicate> predicate = convertToPredicate(partitionColumnPredicate); - List tableLayoutResults = getTableLayouts(session, tableHandle, new Constraint<>(partitionColumnPredicate, predicate), Optional.empty()); - return ((HiveTableLayoutHandle) Iterables.getOnlyElement(tableLayoutResults).getTableLayout().getHandle()).getPartitions().get(); + ConnectorTableLayoutResult tableLayoutResult = getTableLayoutForConstraint(session, tableHandle, new Constraint<>(partitionColumnPredicate, predicate), Optional.empty()); + return ((HiveTableLayoutHandle) tableLayoutResult.getTableLayout().getHandle()).getPartitions().map(PartitionSet::getFullyLoadedPartitions).get(); } } @@ -2662,7 +2790,11 @@ private String createTableLayoutString( } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle tableHandle, + Constraint constraint, + Optional> desiredColumns) { HiveTableHandle handle = (HiveTableHandle) tableHandle; HivePartitionResult hivePartitionResult; @@ -2690,7 +2822,7 @@ public List getTableLayouts(ConnectorSession session String layoutString = createTableLayoutString(session, handle.getSchemaTableName(), hivePartitionResult.getBucketHandle(), hivePartitionResult.getBucketFilter(), TRUE_CONSTANT, domainPredicate); Optional> requestedColumns = desiredColumns.map(columns -> columns.stream().map(column -> (HiveColumnHandle) column).collect(toImmutableSet())); - return ImmutableList.of(new ConnectorTableLayoutResult( + return new ConnectorTableLayoutResult( getTableLayout( session, new HiveTableLayoutHandle.Builder() @@ -2713,7 +2845,7 @@ public List getTableLayouts(ConnectorSession session .setAppendRowNumberEnabled(false) .setHiveTableHandle(handle) .build()), - hivePartitionResult.getUnenforcedConstraint())); + hivePartitionResult.getUnenforcedConstraint()); } private static Subfield toSubfield(ColumnHandle columnHandle) @@ -2745,7 +2877,7 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa { HiveTableLayoutHandle hiveLayoutHandle = (HiveTableLayoutHandle) layoutHandle; List partitionColumns = ImmutableList.copyOf(hiveLayoutHandle.getPartitionColumns()); - List partitions = hiveLayoutHandle.getPartitions().get(); + List partitions = hiveLayoutHandle.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get(); Optional discretePredicates = getDiscretePredicates(partitionColumns, partitions); @@ -2810,7 +2942,8 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa // capture subfields from domainPredicate to add to remainingPredicate // so those filters don't get lost Map columnTypes = hiveColumnHandles(table).stream() - .collect(toImmutableMap(HiveColumnHandle::getName, columnHandle -> columnHandle.getColumnMetadata(typeManager).getType())); + .collect(toImmutableMap(HiveColumnHandle::getName, columnHandle -> columnHandle.getColumnMetadata(typeManager, + normalizeIdentifier(session, columnHandle.getName())).getType())); subfieldPredicate = getSubfieldPredicate(session, hiveLayoutHandle, columnTypes, functionResolution, rowExpressionService); } @@ -2859,7 +2992,8 @@ && isOrderBasedExecutionEnabled(session)) { streamPartitionColumns, discretePredicates, localPropertyBuilder.build(), - Optional.of(combinedRemainingPredicate)); + Optional.of(combinedRemainingPredicate), + Optional.of(rowIdColumnHandle())); } @Override @@ -2967,6 +3101,7 @@ public ConnectorTableLayoutHandle getAlternativeLayoutHandle(ConnectorSession se .setBucketHandle(Optional.of(updatedBucketHandle)) .build(); } + @Override public ConnectorPartitioningHandle getPartitioningHandleForExchange(ConnectorSession session, int partitionCount, List partitionTypes) { @@ -3308,18 +3443,6 @@ public CompletableFuture commitPageSinkAsync(ConnectorSession session, Con getPartitionUpdates(session, fragments))); } - @Override - public List getMetadataUpdateResults(List metadataUpdateRequests, QueryId queryId) - { - return hiveFileRenamer.getMetadataUpdateResults(metadataUpdateRequests, queryId); - } - - @Override - public void doMetadataUpdateCleanup(QueryId queryId) - { - hiveFileRenamer.cleanup(queryId); - } - private List buildGrants(ConnectorSession session, SchemaTableName tableName, PrestoPrincipal principal) { ImmutableList.Builder result = ImmutableList.builder(); @@ -3617,7 +3740,7 @@ else if (column.isHidden()) { private static void validateCsvColumns(ConnectorTableMetadata tableMetadata) { - if (getHiveStorageFormat(tableMetadata.getProperties()) != HiveStorageFormat.CSV) { + if (getHiveStorageFormat(tableMetadata.getProperties()) != CSV) { return; } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java index 9d75a447c8a8c..78de6facb91f6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataFactory.java @@ -26,10 +26,9 @@ import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.relation.RowExpressionService; import com.google.common.util.concurrent.ListeningExecutorService; +import jakarta.inject.Inject; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.util.function.Supplier; import static java.util.Objects.requireNonNull; @@ -68,7 +67,6 @@ public class HiveMetadataFactory private final PartitionObjectBuilder partitionObjectBuilder; private final HiveEncryptionInformationProvider encryptionInformationProvider; private final HivePartitionStats hivePartitionStats; - private final HiveFileRenamer hiveFileRenamer; private final ColumnConverterProvider columnConverterProvider; private final QuickStatsProvider quickStatsProvider; private final TableWritabilityChecker tableWritabilityChecker; @@ -97,7 +95,6 @@ public HiveMetadataFactory( PartitionObjectBuilder partitionObjectBuilder, HiveEncryptionInformationProvider encryptionInformationProvider, HivePartitionStats hivePartitionStats, - HiveFileRenamer hiveFileRenamer, ColumnConverterProvider columnConverterProvider, QuickStatsProvider quickStatsProvider, TableWritabilityChecker tableWritabilityChecker) @@ -132,7 +129,6 @@ public HiveMetadataFactory( partitionObjectBuilder, encryptionInformationProvider, hivePartitionStats, - hiveFileRenamer, columnConverterProvider, quickStatsProvider, tableWritabilityChecker); @@ -168,7 +164,6 @@ public HiveMetadataFactory( PartitionObjectBuilder partitionObjectBuilder, HiveEncryptionInformationProvider encryptionInformationProvider, HivePartitionStats hivePartitionStats, - HiveFileRenamer hiveFileRenamer, ColumnConverterProvider columnConverterProvider, QuickStatsProvider quickStatsProvider, TableWritabilityChecker tableWritabilityChecker) @@ -202,7 +197,6 @@ public HiveMetadataFactory( this.partitionObjectBuilder = requireNonNull(partitionObjectBuilder, "partitionObjectBuilder is null"); this.encryptionInformationProvider = requireNonNull(encryptionInformationProvider, "encryptionInformationProvider is null"); this.hivePartitionStats = requireNonNull(hivePartitionStats, "hivePartitionStats is null"); - this.hiveFileRenamer = requireNonNull(hiveFileRenamer, "hiveFileRenamer is null"); this.columnConverterProvider = requireNonNull(columnConverterProvider, "columnConverterProvider is null"); this.quickStatsProvider = requireNonNull(quickStatsProvider, "quickStatsProvider is null"); this.tableWritabilityChecker = requireNonNull(tableWritabilityChecker, "tableWritabilityChecker is null"); @@ -251,7 +245,6 @@ public HiveMetadata get() partitionObjectBuilder, encryptionInformationProvider, hivePartitionStats, - hiveFileRenamer, tableWritabilityChecker); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandle.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandle.java deleted file mode 100644 index 386bddeadbee5..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandle.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.drift.annotations.ThriftConstructor; -import com.facebook.drift.annotations.ThriftField; -import com.facebook.drift.annotations.ThriftStruct; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.SchemaTableName; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Objects; -import java.util.Optional; -import java.util.UUID; - -import static java.util.Objects.requireNonNull; - -@ThriftStruct -public class HiveMetadataUpdateHandle - implements ConnectorMetadataUpdateHandle -{ - private final UUID requestId; - private final SchemaTableName schemaTableName; - - // partitionName will be null for un-partitioned tables - private final Optional partitionName; - - // fileName will be null when this class is used to represent metadata request - private final Optional fileName; - - @JsonCreator - @ThriftConstructor - public HiveMetadataUpdateHandle( - @JsonProperty("requestId") UUID requestId, - @JsonProperty("schemaTableName") SchemaTableName schemaTableName, - @JsonProperty("partitionName") Optional partitionName, - @JsonProperty("fileName") Optional fileName) - { - this.requestId = requireNonNull(requestId, "requestId is null"); - this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); - this.partitionName = requireNonNull(partitionName, "partitionName is null"); - this.fileName = requireNonNull(fileName, "fileName is null"); - } - - @JsonProperty - @ThriftField(1) - public UUID getRequestId() - { - return requestId; - } - - @JsonProperty - @ThriftField(2) - public SchemaTableName getSchemaTableName() - { - return schemaTableName; - } - - @JsonProperty - @ThriftField(3) - public Optional getPartitionName() - { - return partitionName; - } - - @JsonProperty("fileName") - @ThriftField(value = 4, name = "fileName") - public Optional getMetadataUpdate() - { - return fileName; - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - HiveMetadataUpdateHandle o = (HiveMetadataUpdateHandle) obj; - return requestId.equals(o.requestId) && - schemaTableName.equals(o.schemaTableName) && - partitionName.equals(o.partitionName) && - fileName.equals(o.fileName); - } - - @Override - public int hashCode() - { - return Objects.hash(requestId, schemaTableName, partitionName, fileName); - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandleThriftSerde.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandleThriftSerde.java deleted file mode 100644 index 471f76abe86f3..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdateHandleThriftSerde.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.airlift.http.client.thrift.ThriftProtocolException; -import com.facebook.airlift.http.client.thrift.ThriftProtocolUtils; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.ThriftCodecManager; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import io.airlift.slice.DynamicSliceOutput; -import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; - -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -public class HiveMetadataUpdateHandleThriftSerde - implements ConnectorTypeSerde -{ - private final ThriftCodecManager thriftCodecManager; - private final Protocol thriftProtocol; - private final int bufferSize; - - public HiveMetadataUpdateHandleThriftSerde( - ThriftCodecManager thriftCodecManager, - Protocol thriftProtocol, - int bufferSize) - { - this.thriftCodecManager = requireNonNull(thriftCodecManager, "thriftCodecManager is null"); - this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null"); - this.bufferSize = bufferSize; - } - - @Override - public byte[] serialize(ConnectorMetadataUpdateHandle value) - { - ThriftCodec codec = thriftCodecManager.getCodec(value.getClass()); - SliceOutput dynamicSliceOutput = new DynamicSliceOutput(bufferSize); - try { - ThriftProtocolUtils.write(value, codec, thriftProtocol, dynamicSliceOutput); - return dynamicSliceOutput.slice().getBytes(); - } - catch (ThriftProtocolException e) { - throw new IllegalArgumentException(format("%s could not be converted to Thrift", value.getClass().getName()), e); - } - } - - @Override - public ConnectorMetadataUpdateHandle deserialize(Class connectorTypeClass, byte[] bytes) - { - try { - ThriftCodec codec = thriftCodecManager.getCodec(connectorTypeClass); - return ThriftProtocolUtils.read(codec, thriftProtocol, Slices.wrappedBuffer(bytes).getInput()); - } - catch (ThriftProtocolException e) { - throw new IllegalArgumentException(format("Invalid Thrift bytes for %s", connectorTypeClass), e); - } - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdater.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdater.java deleted file mode 100644 index 35a43c1f2f872..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdater.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.airlift.concurrent.MoreFutures; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.SchemaTableName; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.SettableFuture; - -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Queue; -import java.util.UUID; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.Executor; - -import static java.util.Objects.requireNonNull; - -// TODO: Revisit and make this class more robust -public class HiveMetadataUpdater - implements ConnectorMetadataUpdater -{ - private final Executor boundedExecutor; - - // Stores writerIndex <-> requestId mapping - private final Map writerRequestMap = new ConcurrentHashMap<>(); - - // Stores requestId <-> fileNameFuture mapping - private final Map> requestFutureMap = new ConcurrentHashMap<>(); - - // Queue of pending requests - private final Queue hiveMetadataRequestQueue = new ConcurrentLinkedQueue<>(); - - HiveMetadataUpdater(Executor boundedExecutor) - { - this.boundedExecutor = requireNonNull(boundedExecutor, "boundedExecutor is null"); - } - - @Override - public List getPendingMetadataUpdateRequests() - { - ImmutableList.Builder result = ImmutableList.builder(); - for (HiveMetadataUpdateHandle request : hiveMetadataRequestQueue) { - result.add(request); - } - return result.build(); - } - - @Override - public void setMetadataUpdateResults(List results) - { - boundedExecutor.execute(() -> updateResultAsync(results)); - } - - private void updateResultAsync(List results) - { - for (ConnectorMetadataUpdateHandle connectorMetadataUpdateHandle : results) { - HiveMetadataUpdateHandle updateResult = (HiveMetadataUpdateHandle) connectorMetadataUpdateHandle; - UUID requestId = updateResult.getRequestId(); - - if (!requestFutureMap.containsKey(requestId)) { - continue; - } - - Optional fileName = updateResult.getMetadataUpdate(); - if (fileName.isPresent()) { - // remove the request from queue - hiveMetadataRequestQueue.removeIf(metadataUpdateRequest -> metadataUpdateRequest.getRequestId().equals(requestId)); - - // Set the fileName future - requestFutureMap.get(requestId).set(fileName.get()); - } - } - } - - public void addMetadataUpdateRequest(String schemaName, String tableName, Optional partitionName, int writerIndex) - { - UUID requestId = UUID.randomUUID(); - requestFutureMap.put(requestId, SettableFuture.create()); - writerRequestMap.put(writerIndex, requestId); - - // create a request and add it to the queue - hiveMetadataRequestQueue.add(new HiveMetadataUpdateHandle(requestId, new SchemaTableName(schemaName, tableName), partitionName, Optional.empty())); - } - - public void removeResultFuture(int writerIndex) - { - UUID requestId = writerRequestMap.get(writerIndex); - requestFutureMap.remove(requestId); - writerRequestMap.remove(writerIndex); - } - - public CompletableFuture getMetadataResult(int writerIndex) - { - UUID requestId = writerRequestMap.get(writerIndex); - return MoreFutures.toCompletableFuture(requestFutureMap.get(requestId)); - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdaterProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdaterProvider.java deleted file mode 100644 index 5ce8ddc9f72e6..0000000000000 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadataUpdaterProvider.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.airlift.concurrent.BoundedExecutor; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; - -import javax.annotation.PreDestroy; -import javax.inject.Inject; - -import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; - -import static java.util.Objects.requireNonNull; - -public class HiveMetadataUpdaterProvider - implements ConnectorMetadataUpdaterProvider -{ - private final ExecutorService executorService; - private final Executor boundedExecutor; - - @Inject - public HiveMetadataUpdaterProvider( - @ForUpdatingHiveMetadata ExecutorService executorService, - HiveClientConfig hiveClientConfig) - { - this.executorService = requireNonNull(executorService, "executorService is null"); - int maxMetadataUpdaterThreads = requireNonNull(hiveClientConfig, "hiveClientConfig is null").getMaxMetadataUpdaterThreads(); - this.boundedExecutor = new BoundedExecutor(executorService, maxMetadataUpdaterThreads); - } - - @Override - public ConnectorMetadataUpdater getMetadataUpdater() - { - return new HiveMetadataUpdater(boundedExecutor); - } - - @PreDestroy - public void stop() - { - executorService.shutdownNow(); - } -} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSink.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSink.java index 77e8a163f51ef..f551fda1a21d0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSink.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSink.java @@ -22,7 +22,6 @@ import com.facebook.presto.common.block.IntArrayBlockBuilder; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; -import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PageIndexer; @@ -33,13 +32,10 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.SettableFuture; import io.airlift.slice.Slice; import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; -import org.apache.hadoop.fs.Path; -import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -49,24 +45,18 @@ import java.util.OptionalInt; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executors; -import static com.facebook.airlift.concurrent.MoreFutures.addSuccessCallback; -import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.hive.HiveBucketFunction.createHiveCompatibleBucketFunction; import static com.facebook.presto.hive.HiveBucketFunction.createPrestoNativeBucketFunction; -import static com.facebook.presto.hive.HiveErrorCode.HIVE_FILESYSTEM_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_TOO_MANY_OPEN_PARTITIONS; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static com.facebook.presto.hive.HiveSessionProperties.isFileRenamingEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isLegacyTimestampBucketing; import static com.facebook.presto.hive.HiveSessionProperties.isOptimizedPartitionUpdateSerializationEnabled; import static com.facebook.presto.hive.HiveUtil.serializeZstdCompressed; -import static com.facebook.presto.hive.PartitionUpdate.FileWriteInfo; import static com.facebook.presto.hive.PartitionUpdate.mergePartitionUpdates; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -105,7 +95,6 @@ public class HivePageSink private final List writers = new ArrayList<>(); private final ConnectorSession session; - private final HiveMetadataUpdater hiveMetadataUpdater; private final boolean fileRenamingEnabled; private long writtenBytes; @@ -127,8 +116,7 @@ public HivePageSink( ListeningExecutorService writeVerificationExecutor, JsonCodec partitionUpdateCodec, SmileCodec partitionUpdateSmileCodec, - ConnectorSession session, - HiveMetadataUpdater hiveMetadataUpdater) + ConnectorSession session) { this.writerFactory = requireNonNull(writerFactory, "writerFactory is null"); @@ -199,7 +187,6 @@ public HivePageSink( } this.session = requireNonNull(session, "session is null"); - this.hiveMetadataUpdater = requireNonNull(hiveMetadataUpdater, "hiveMetadataUpdater is null"); this.fileRenamingEnabled = isFileRenamingEnabled(session); } @@ -281,20 +268,6 @@ private ListenableFuture> doFinish() .mapToLong(HiveWriter::getValidationCpuNanos) .sum(); - if (waitForFileRenaming && verificationTasks.isEmpty()) { - // Use CopyOnWriteArrayList to prevent race condition when callbacks try to add partitionUpdates to this list - List partitionUpdatesWithRenamedFileNames = new CopyOnWriteArrayList<>(); - List> futures = new ArrayList<>(); - for (int i = 0; i < writers.size(); i++) { - int writerIndex = i; - ListenableFuture fileNameFuture = toListenableFuture(hiveMetadataUpdater.getMetadataResult(writerIndex)); - SettableFuture renamingFuture = SettableFuture.create(); - futures.add(renamingFuture); - addSuccessCallback(fileNameFuture, obj -> renameFiles((String) obj, writerIndex, renamingFuture, partitionUpdatesWithRenamedFileNames)); - } - return Futures.transform(Futures.allAsList(futures), input -> partitionUpdatesWithRenamedFileNames, directExecutor()); - } - if (verificationTasks.isEmpty()) { return Futures.immediateFuture(serializedPartitionUpdates); } @@ -414,62 +387,6 @@ private void writePage(Page page) } } - private void sendMetadataUpdateRequest(Optional partitionName, int writerIndex, boolean writeTempData) - { - // Bucketed tables already have unique bucket number as part of fileName. So no need to rename. - if (writeTempData || !fileRenamingEnabled || bucketFunction != null) { - return; - } - hiveMetadataUpdater.addMetadataUpdateRequest(schemaName, tableName, partitionName, writerIndex); - waitForFileRenaming = true; - } - - private void renameFiles(String fileName, int writerIndex, SettableFuture renamingFuture, List partitionUpdatesWithRenamedFileNames) - { - HdfsContext context = new HdfsContext( - session, - schemaName, - tableName, - writerFactory.getLocationHandle().getTargetPath().toString(), - writerFactory.isCreateTable()); - HiveWriter writer = writers.get(writerIndex); - PartitionUpdate partitionUpdate = writer.getPartitionUpdate(); - - // Check that only one file is written by a writer - checkArgument(partitionUpdate.getFileWriteInfos().size() == 1, "HiveWriter wrote data to more than one file"); - - FileWriteInfo fileWriteInfo = partitionUpdate.getFileWriteInfos().get(0); - Path fromPath = new Path(partitionUpdate.getWritePath(), fileWriteInfo.getWriteFileName()); - Path toPath = new Path(partitionUpdate.getWritePath(), fileName); - try { - ExtendedFileSystem fileSystem = hdfsEnvironment.getFileSystem(context, fromPath); - ListenableFuture asyncFuture = fileSystem.renameFileAsync(fromPath, toPath); - addSuccessCallback(asyncFuture, () -> updateFileInfo(partitionUpdatesWithRenamedFileNames, renamingFuture, partitionUpdate, fileName, fileWriteInfo, writerIndex)); - } - catch (IOException e) { - throw new PrestoException(HIVE_FILESYSTEM_ERROR, format("Error renaming file. fromPath: %s toPath: %s", fromPath, toPath), e); - } - } - - private void updateFileInfo(List partitionUpdatesWithRenamedFileNames, SettableFuture renamingFuture, PartitionUpdate partitionUpdate, String fileName, FileWriteInfo fileWriteInfo, int writerIndex) - { - // Update the file info in partitionUpdate with new filename - FileWriteInfo fileInfoWithRenamedFileName = new FileWriteInfo(fileName, fileName, fileWriteInfo.getFileSize()); - PartitionUpdate partitionUpdateWithRenamedFileName = new PartitionUpdate(partitionUpdate.getName(), - partitionUpdate.getUpdateMode(), - partitionUpdate.getWritePath(), - partitionUpdate.getTargetPath(), - ImmutableList.of(fileInfoWithRenamedFileName), - partitionUpdate.getRowCount(), - partitionUpdate.getInMemoryDataSizeInBytes(), - partitionUpdate.getOnDiskDataSizeInBytes(), - true); - partitionUpdatesWithRenamedFileNames.add(wrappedBuffer(partitionUpdateCodec.toJsonBytes(partitionUpdateWithRenamedFileName))); - - hiveMetadataUpdater.removeResultFuture(writerIndex); - renamingFuture.set(null); - } - private int[] getWriterIndexes(Page page) { Page partitionColumns = extractColumns(page, partitionColumnsInputIndex); @@ -497,9 +414,6 @@ private int[] getWriterIndexes(Page page) } HiveWriter writer = writerFactory.createWriter(partitionColumns, position, bucketNumber); writers.set(writerIndex, writer); - - // Send metadata update request if needed - sendMetadataUpdateRequest(writer.getPartitionName(), writerIndex, writer.isWriteTempData()); } verify(writers.size() == pagePartitioner.getMaxIndex() + 1); verify(!writers.contains(null)); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSinkProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSinkProvider.java index 01800fbed2e13..07d37e4671adf 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSinkProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSinkProvider.java @@ -16,6 +16,7 @@ import com.facebook.airlift.event.client.EventClient; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HivePageSinkMetadataProvider; @@ -29,19 +30,15 @@ import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.PageSorter; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.DataSize; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -49,7 +46,6 @@ import static com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore.memoizeMetastore; import static com.facebook.presto.hive.metastore.MetastoreUtil.getMetastoreHeaders; import static com.facebook.presto.hive.metastore.MetastoreUtil.isUserDefinedTypeEncodingEnabled; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newFixedThreadPool; @@ -133,21 +129,17 @@ public HivePageSinkProvider( public ConnectorPageSink createPageSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorOutputTableHandle tableHandle, PageSinkContext pageSinkContext) { HiveOutputTableHandle handle = (HiveOutputTableHandle) tableHandle; - Optional hiveMetadataUpdater = pageSinkContext.getMetadataUpdater(); - checkArgument(hiveMetadataUpdater.isPresent(), "Metadata Updater for HivePageSink is null"); - return createPageSink(handle, true, session, (HiveMetadataUpdater) hiveMetadataUpdater.get(), pageSinkContext.isCommitRequired(), handle.getAdditionalTableParameters()); + return createPageSink(handle, true, session, pageSinkContext.isCommitRequired(), handle.getAdditionalTableParameters()); } @Override public ConnectorPageSink createPageSink(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorInsertTableHandle tableHandle, PageSinkContext pageSinkContext) { HiveInsertTableHandle handle = (HiveInsertTableHandle) tableHandle; - Optional hiveMetadataUpdater = pageSinkContext.getMetadataUpdater(); - checkArgument(hiveMetadataUpdater.isPresent(), "Metadata Updater for HivePageSink is null"); - return createPageSink(handle, false, session, (HiveMetadataUpdater) hiveMetadataUpdater.get(), pageSinkContext.isCommitRequired(), ImmutableMap.of()); + return createPageSink(handle, false, session, pageSinkContext.isCommitRequired(), ImmutableMap.of()); } - private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, HiveMetadataUpdater hiveMetadataUpdater, boolean commitRequired, Map additionalTableParameters) + private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean isCreateTable, ConnectorSession session, boolean commitRequired, Map additionalTableParameters) { OptionalInt bucketCount = OptionalInt.empty(); List sortedBy; @@ -208,7 +200,6 @@ private ConnectorPageSink createPageSink(HiveWritableTableHandle handle, boolean writeVerificationExecutor, partitionUpdateCodec, partitionUpdateSmileCodec, - session, - hiveMetadataUpdater); + session); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java index 63c78ded9b53f..2a8ab38f2e422 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSource.java @@ -27,11 +27,10 @@ import com.google.common.annotations.VisibleForTesting; import io.airlift.slice.Slices; import it.unimi.dsi.fastutil.ints.IntArrayList; +import jakarta.annotation.Nullable; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.joda.time.DateTimeZone; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Paths; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java index 85d00b8b9571c..1d8b557cf7dc0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePageSourceProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.Subfield.NestedField; @@ -45,13 +46,11 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.util.HashSet; import java.util.List; import java.util.Map; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionKey.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionKey.java index 6fbe2472da620..fd8564919d30e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionKey.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionKey.java @@ -15,10 +15,9 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Objects; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java index f6170becee95c..0829404178f31 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HivePartitionManager.java @@ -36,6 +36,9 @@ import com.facebook.presto.spi.TableNotFoundException; import com.google.common.base.Predicates; import com.google.common.base.VerifyException; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; @@ -43,12 +46,11 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; +import jakarta.inject.Inject; import org.joda.time.DateTimeZone; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -69,8 +71,10 @@ import static com.facebook.presto.hive.HiveErrorCode.HIVE_EXCEEDED_PARTITION_LIMIT; import static com.facebook.presto.hive.HiveSessionProperties.getMaxBucketsForGroupedExecution; import static com.facebook.presto.hive.HiveSessionProperties.getMinBucketCountToNotIgnoreTableBucketing; +import static com.facebook.presto.hive.HiveSessionProperties.getOptimizeParsingOfPartitionValuesThreshold; import static com.facebook.presto.hive.HiveSessionProperties.isLegacyTimestampBucketing; import static com.facebook.presto.hive.HiveSessionProperties.isOfflineDataDebugModeEnabled; +import static com.facebook.presto.hive.HiveSessionProperties.isOptimizeParsingOfPartitionValues; import static com.facebook.presto.hive.HiveSessionProperties.isParallelParsingOfPartitionValuesEnabled; import static com.facebook.presto.hive.HiveSessionProperties.shouldIgnoreTableBucketing; import static com.facebook.presto.hive.HiveUtil.getPartitionKeyColumnHandles; @@ -86,6 +90,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Predicates.not; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static java.lang.String.format; @@ -180,7 +185,7 @@ public List getPartitionsList( ConcurrentLinkedQueue result = new ConcurrentLinkedQueue<>(); List> futures = new ArrayList<>(); try { - partitionNameBatches.forEach(batch -> futures.add(executorService.submit(() -> result.addAll(getPartitionListFromPartitionNames(batch, tableName, partitionColumns, partitionTypes, constraint))))); + partitionNameBatches.forEach(batch -> futures.add(executorService.submit(() -> result.addAll(getPartitionListFromPartitionNames(batch, tableName, partitionColumns, partitionTypes, constraint, session))))); Futures.transform(Futures.allAsList(futures), input -> result, directExecutor()).get(); return Arrays.asList(result.toArray(new HivePartition[0])); } @@ -188,7 +193,7 @@ public List getPartitionsList( log.error(e, "Parallel parsing of partition values failed"); } } - return getPartitionListFromPartitionNames(partitionNames, tableName, partitionColumns, partitionTypes, constraint); + return getPartitionListFromPartitionNames(partitionNames, tableName, partitionColumns, partitionTypes, constraint, session); } } @@ -197,14 +202,61 @@ private List getPartitionListFromPartitionNames( SchemaTableName tableName, List partitionColumns, List partitionTypes, - Constraint constraint) + Constraint constraint, + ConnectorSession session) { - return partitionNames.stream() - // Apply extra filters which could not be done by getFilteredPartitionNames - .map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, constraint)) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()); + if (isOptimizeParsingOfPartitionValues(session) && partitionNames.size() >= getOptimizeParsingOfPartitionValuesThreshold(session)) { + List partitionList = partitionNames.stream() + .map(partitionNameWithVersion -> parsePartition(tableName, partitionNameWithVersion, partitionColumns, partitionTypes, timeZone)) + .collect(toImmutableList()); + + Map domains = constraint.getSummary().getDomains().get(); + List usedPartitionColumnsInDomain = partitionColumns.stream().filter(x -> domains.containsKey(x)).collect(toImmutableList()); + partitionList = partitionList.stream().filter( + partition -> { + for (HiveColumnHandle column : usedPartitionColumnsInDomain) { + NullableValue value = partition.getKeys().get(column); + Domain allowedDomain = domains.get(column); + if (allowedDomain != null && !allowedDomain.includesNullableValue(value.getValue())) { + return false; + } + } + return true; + }).collect(toImmutableList()); + + if (!constraint.predicate().isPresent()) { + return partitionList; + } + Optional> predicateInputs = constraint.getPredicateInputs(); + if (!predicateInputs.isPresent()) { + return partitionList.stream().filter(x -> constraint.predicate().get().test(x.getKeys())).collect(toList()); + } + List usedPartitionColumnsInPredicate = partitionColumns.stream().filter(x -> predicateInputs.get().contains(x)).collect(toList()); + if (usedPartitionColumnsInPredicate.size() == partitionColumns.size()) { + return partitionList.stream().filter(x -> constraint.predicate().get().test(x.getKeys())).collect(toList()); + } + + ImmutableList.Builder resultBuilder = ImmutableList.builder(); + LoadingCache, Boolean> cacheResult = CacheBuilder.newBuilder() + .maximumSize(10_000) + .build(CacheLoader.from(cacheKey -> constraint.predicate().get().test(cacheKey))); + for (HivePartition partition : partitionList) { + Map filteredMap = partition.getKeys().entrySet().stream().filter(x -> usedPartitionColumnsInPredicate.contains(x.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + if (cacheResult.getUnchecked(filteredMap)) { + resultBuilder.add(partition); + } + } + return resultBuilder.build(); + } + else { + return partitionNames.stream() + // Apply extra filters which could not be done by getFilteredPartitionNames + .map(partitionName -> parseValuesAndFilterPartition(tableName, partitionName, partitionColumns, partitionTypes, constraint)) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableList()); + } } private Map createPartitionPredicates( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java index 12249c8dd5186..4879d5cb0d4ad 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSessionProperties.java @@ -13,17 +13,16 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.inject.Inject; import org.apache.parquet.column.ParquetProperties; -import javax.inject.Inject; - import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -60,7 +59,7 @@ public final class HiveSessionProperties private static final String ORC_OPTIMIZED_WRITER_COMPRESSION_LEVEL = "orc_optimized_writer_compression_level"; private static final String PAGEFILE_WRITER_MAX_STRIPE_SIZE = "pagefile_writer_max_stripe_size"; public static final String HIVE_STORAGE_FORMAT = "hive_storage_format"; - private static final String COMPRESSION_CODEC = "compression_codec"; + static final String COMPRESSION_CODEC = "compression_codec"; private static final String ORC_COMPRESSION_CODEC = "orc_compression_codec"; public static final String RESPECT_TABLE_FORMAT = "respect_table_format"; private static final String CREATE_EMPTY_BUCKET_FILES = "create_empty_bucket_files"; @@ -70,6 +69,7 @@ public final class HiveSessionProperties private static final String PARQUET_WRITER_VERSION = "parquet_writer_version"; private static final String MAX_SPLIT_SIZE = "max_split_size"; private static final String MAX_INITIAL_SPLIT_SIZE = "max_initial_split_size"; + private static final String SYMLINK_OPTIMIZED_READER_ENABLED = "symlink_optimized_reader_enabled"; public static final String RCFILE_OPTIMIZED_WRITER_ENABLED = "rcfile_optimized_writer_enabled"; private static final String RCFILE_OPTIMIZED_WRITER_VALIDATE = "rcfile_optimized_writer_validate"; private static final String SORTED_WRITING_ENABLED = "sorted_writing_enabled"; @@ -134,6 +134,8 @@ public final class HiveSessionProperties public static final String DYNAMIC_SPLIT_SIZES_ENABLED = "dynamic_split_sizes_enabled"; public static final String SKIP_EMPTY_FILES = "skip_empty_files"; public static final String LEGACY_TIMESTAMP_BUCKETING = "legacy_timestamp_bucketing"; + public static final String OPTIMIZE_PARSING_OF_PARTITION_VALUES = "optimize_parsing_of_partition_values"; + public static final String OPTIMIZE_PARSING_OF_PARTITION_VALUES_THRESHOLD = "optimize_parsing_of_partition_values_threshold"; public static final String NATIVE_STATS_BASED_FILTER_REORDER_DISABLED = "native_stats_based_filter_reorder_disabled"; @@ -625,6 +627,11 @@ public HiveSessionProperties(HiveClientConfig hiveClientConfig, OrcFileWriterCon "Use quick stats to resolve stats", hiveClientConfig.isQuickStatsEnabled(), false), + booleanProperty( + SYMLINK_OPTIMIZED_READER_ENABLED, + "Experimental: Enable optimized SymlinkTextInputFormat reader", + hiveClientConfig.isSymlinkOptimizedReaderEnabled(), + false), new PropertyMetadata<>( QUICK_STATS_INLINE_BUILD_TIMEOUT, "Duration that the first query that initiated a quick stats call should wait before failing and returning EMPTY stats. " + @@ -656,6 +663,15 @@ public HiveSessionProperties(HiveClientConfig hiveClientConfig, OrcFileWriterCon "Use legacy timestamp bucketing algorithm (which is not Hive compatible) for table bucketed by timestamp type.", hiveClientConfig.isLegacyTimestampBucketing(), false), + booleanProperty( + OPTIMIZE_PARSING_OF_PARTITION_VALUES, + "Optimize partition values parsing when number of candidates are large", + hiveClientConfig.isOptimizeParsingOfPartitionValues(), + false), + integerProperty(OPTIMIZE_PARSING_OF_PARTITION_VALUES_THRESHOLD, + "When OPTIMIZE_PARSING_OF_PARTITION_VALUES is set to true, enable this optimizations when number of partitions exceed the threshold here", + hiveClientConfig.getOptimizeParsingOfPartitionValuesThreshold(), + false), booleanProperty( NATIVE_STATS_BASED_FILTER_REORDER_DISABLED, "Native Execution only. Disable stats based filter reordering.", @@ -1148,4 +1164,19 @@ public static boolean isLegacyTimestampBucketing(ConnectorSession session) { return session.getProperty(LEGACY_TIMESTAMP_BUCKETING, Boolean.class); } + + public static boolean isOptimizeParsingOfPartitionValues(ConnectorSession session) + { + return session.getProperty(OPTIMIZE_PARSING_OF_PARTITION_VALUES, Boolean.class); + } + + public static int getOptimizeParsingOfPartitionValuesThreshold(ConnectorSession session) + { + return session.getProperty(OPTIMIZE_PARSING_OF_PARTITION_VALUES_THRESHOLD, Integer.class); + } + + public static boolean isSymlinkOptimizedReaderEnabled(ConnectorSession session) + { + return session.getProperty(SYMLINK_OPTIMIZED_READER_ENABLED, Boolean.class); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java index 62e620bee3e2f..b7a34e746a50b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.predicate.Range; @@ -57,7 +58,7 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hudi.hadoop.HoodieParquetInputFormat; @@ -65,8 +66,6 @@ import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -260,6 +259,7 @@ public ConnectorSplitSource getSplits( // get partitions List partitions = layout.getPartitions() + .map(PartitionSet::getFullyLoadedPartitions) .orElseThrow(() -> new PrestoException(GENERIC_INTERNAL_ERROR, "Layout does not contain partitions")); // short circuit if we don't have any partitions diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitSource.java index 98843199442fb..6264ea5eab69e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveSplitSource.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.InternalHiveSplit.InternalHiveBlock; import com.facebook.presto.hive.util.AsyncQueue; import com.facebook.presto.hive.util.AsyncQueue.BorrowResult; @@ -30,9 +31,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.FileNotFoundException; import java.util.ArrayList; @@ -52,6 +51,7 @@ import static com.facebook.airlift.concurrent.MoreFutures.failedFuture; import static com.facebook.airlift.concurrent.MoreFutures.toCompletableFuture; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.hive.HiveCommonSessionProperties.getAffinitySchedulingFileSectionSize; import static com.facebook.presto.hive.HiveErrorCode.HIVE_EXCEEDED_SPLIT_BUFFERING_LIMIT; import static com.facebook.presto.hive.HiveErrorCode.HIVE_FILE_NOT_FOUND; @@ -70,7 +70,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.String.format; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveStagingFileCommitter.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveStagingFileCommitter.java index ecc203a7ebde7..e53042ee7cef9 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveStagingFileCommitter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveStagingFileCommitter.java @@ -17,11 +17,10 @@ import com.facebook.presto.spi.ConnectorSession; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableLayoutHandle.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableLayoutHandle.java index 9b882b7685078..7ae9306aec376 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableLayoutHandle.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableLayoutHandle.java @@ -65,7 +65,6 @@ public class HiveTableLayoutHandle private final boolean footerStatsUnreliable; // coordinator-only properties - private final Optional> partitions; private final Optional hiveTableHandle; /** @@ -143,7 +142,7 @@ protected HiveTableLayoutHandle( remainingPredicate, pushdownFilterEnabled, partitionColumnPredicate, - partitions); + partitions.map(PartitionSet::new)); this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); this.tablePath = requireNonNull(tablePath, "tablePath is null"); @@ -165,7 +164,6 @@ else if (predicateColumns.values().stream().anyMatch(column -> isRowIdColumnHand this.appendRowId = false; } this.appendRowNumberEnabled = appendRowNumberEnabled; - this.partitions = requireNonNull(partitions, "partitions is null"); this.footerStatsUnreliable = footerStatsUnreliable; this.hiveTableHandle = requireNonNull(hiveTableHandle, "hiveTableHandle is null"); } @@ -300,7 +298,7 @@ private TupleDomain getConstraint(PlanCanonicalizationStrategy can // Constants are only removed from point checks, and not range checks. Example: // `x = 1` is equivalent to `x = 1000` // `x > 1` is NOT equivalent to `x > 1000` - TupleDomain constraint = createPredicate(ImmutableList.copyOf(getPartitionColumns()), partitions.get()); + TupleDomain constraint = createPredicate(ImmutableList.copyOf(getPartitionColumns()), getPartitions().map(PartitionSet::getFullyLoadedPartitions).get()); constraint = getDomainPredicate() .transform(subfield -> subfield.getPath().isEmpty() ? subfield.getRootName() : null) .transform(getPredicateColumns()::get) @@ -363,7 +361,7 @@ public Builder builder() .setRequestedColumns(getRequestedColumns()) .setPartialAggregationsPushedDown(isPartialAggregationsPushedDown()) .setAppendRowNumberEnabled(isAppendRowNumberEnabled()) - .setPartitions(getPartitions()) + .setPartitions(getPartitions().map(PartitionSet::getFullyLoadedPartitions)) .setFooterStatsUnreliable(isFooterStatsUnreliable()) .setHiveTableHandle(getHiveTableHandle()); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableProperties.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableProperties.java index 88b5d9bb88bf6..5c7eb8153a034 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableProperties.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableProperties.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -58,6 +57,8 @@ public class HiveTableProperties public static final String CSV_SEPARATOR = "csv_separator"; public static final String CSV_QUOTE = "csv_quote"; public static final String CSV_ESCAPE = "csv_escape"; + public static final String SKIP_HEADER_LINE_COUNT = "skip_header_line_count"; + public static final String SKIP_FOOTER_LINE_COUNT = "skip_footer_line_count"; private final List> tableProperties; @@ -156,6 +157,8 @@ public HiveTableProperties(TypeManager typeManager, HiveClientConfig config) stringProperty(CSV_SEPARATOR, "CSV separator character", null, false), stringProperty(CSV_QUOTE, "CSV quote character", null, false), stringProperty(CSV_ESCAPE, "CSV escape character", null, false), + integerProperty(SKIP_HEADER_LINE_COUNT, "Number of header lines", null, false), + integerProperty(SKIP_FOOTER_LINE_COUNT, "Number of footer lines", null, false), new PropertyMetadata<>( ENCRYPT_COLUMNS, "List of key references and columns being encrypted. Example: ARRAY['key1:col1,col2', 'key2:col3,col4']", @@ -291,4 +294,14 @@ public static ColumnEncryptionInformation getEncryptColumns(Map return tableProperties.containsKey(ENCRYPT_COLUMNS) ? (ColumnEncryptionInformation) tableProperties.get(ENCRYPT_COLUMNS) : ColumnEncryptionInformation.fromMap(ImmutableMap.of()); } + + public static Optional getHeaderSkipCount(Map tableProperties) + { + return Optional.ofNullable((Integer) tableProperties.get(SKIP_HEADER_LINE_COUNT)); + } + + public static Optional getFooterSkipCount(Map tableProperties) + { + return Optional.ofNullable((Integer) tableProperties.get(SKIP_FOOTER_LINE_COUNT)); + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableWritabilityChecker.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableWritabilityChecker.java index c83b20b12e3c3..9c6b1fec98190 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableWritabilityChecker.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTableWritabilityChecker.java @@ -16,8 +16,7 @@ import com.facebook.presto.hive.metastore.PrestoTableType; import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.spi.PrestoException; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTransactionManager.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTransactionManager.java index 6b2c20871154d..90adfd6eb0144 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveTransactionManager.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveTransactionManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java index b9c79fb6b4370..1f255ec03e01b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveUtil.java @@ -28,7 +28,9 @@ import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.hadoop.TextLineLengthLimitExceededException; import com.facebook.presto.hive.avro.PrestoAvroSerDe; +import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.facebook.presto.hive.metastore.Column; +import com.facebook.presto.hive.metastore.Partition; import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.hive.pagefile.PageInputFormat; @@ -51,6 +53,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceUtf8; import io.airlift.slice.Slices; +import jakarta.annotation.Nullable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -89,12 +92,12 @@ import org.joda.time.format.DateTimePrinter; import org.joda.time.format.ISODateTimeFormat; -import javax.annotation.Nullable; - +import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; import java.io.UncheckedIOException; import java.lang.annotation.Annotation; import java.lang.reflect.Field; @@ -105,6 +108,8 @@ import java.time.ZoneId; import java.util.Arrays; import java.util.Base64; +import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -145,12 +150,15 @@ import static com.facebook.presto.hive.HiveColumnHandle.rowIdColumnHandle; import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; import static com.facebook.presto.hive.HiveErrorCode.HIVE_CANNOT_OPEN_SPLIT; +import static com.facebook.presto.hive.HiveErrorCode.HIVE_FILE_NOT_FOUND; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_METADATA; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_PARTITION_VALUE; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_VIEW_DATA; import static com.facebook.presto.hive.HiveErrorCode.HIVE_SERDE_NOT_FOUND; import static com.facebook.presto.hive.HiveErrorCode.HIVE_TABLE_BUCKETING_IS_IGNORED; import static com.facebook.presto.hive.HiveErrorCode.HIVE_UNSUPPORTED_FORMAT; +import static com.facebook.presto.hive.HiveSessionProperties.isUseListDirectoryCache; +import static com.facebook.presto.hive.HiveStorageFormat.TEXTFILE; import static com.facebook.presto.hive.metastore.MetastoreUtil.HIVE_DEFAULT_DYNAMIC_PARTITION; import static com.facebook.presto.hive.metastore.MetastoreUtil.checkCondition; import static com.facebook.presto.hive.metastore.MetastoreUtil.getMetastoreHeaders; @@ -165,6 +173,7 @@ import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Lists.newArrayList; import static com.google.common.collect.Lists.transform; +import static com.google.common.io.CharStreams.readLines; import static java.lang.Byte.parseByte; import static java.lang.Double.parseDouble; import static java.lang.Float.floatToRawIntBits; @@ -178,6 +187,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; +import static org.apache.hadoop.fs.Path.getPathWithoutSchemeAndAuthority; import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.FILE_INPUT_FORMAT; import static org.apache.hadoop.hive.serde.serdeConstants.COLLECTION_DELIM; import static org.apache.hadoop.hive.serde.serdeConstants.DECIMAL_TYPE_NAME; @@ -263,7 +273,7 @@ private HiveUtil() // Only propagate serialization schema configs by default Predicate schemaFilter = schemaProperty -> schemaProperty.startsWith("serialization."); - InputFormat inputFormat = getInputFormat(configuration, getInputFormatName(schema), true); + InputFormat inputFormat = getInputFormat(configuration, getInputFormatName(schema), getDeserializerClassName(schema), true); JobConf jobConf = toJobConf(configuration); FileSplit fileSplit = new FileSplit(path, start, length, (String[]) null); if (!customSplitInfo.isEmpty()) { @@ -346,15 +356,39 @@ public static Optional getCompressionCodec(TextInputFormat inp return Optional.ofNullable(compressionCodecFactory.getCodec(file)); } - public static InputFormat getInputFormat(Configuration configuration, String inputFormatName, boolean symlinkTarget) + public static InputFormat getInputFormat(Configuration configuration, String inputFormatName, String serDe, boolean symlinkTarget) { try { JobConf jobConf = toJobConf(configuration); Class> inputFormatClass = getInputFormatClass(jobConf, inputFormatName); if (symlinkTarget && (inputFormatClass == SymlinkTextInputFormat.class)) { - // symlink targets are always TextInputFormat - inputFormatClass = TextInputFormat.class; + if (serDe == null) { + throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, "Missing SerDe for SymlinkTextInputFormat"); + } + + /* + * https://github.com/apache/hive/blob/b240eb3266d4736424678d6c71c3c6f6a6fdbf38/ql/src/java/org/apache/hadoop/hive/ql/io/SymlinkTextInputFormat.java#L47-L52 + * According to Hive implementation of SymlinkInputFormat, The target input data should be in TextInputFormat. + * + * But Delta Lake provides an integration with Presto using Symlink Tables with target input data as MapredParquetInputFormat. + * https://docs.delta.io/latest/presto-integration.html + * + * To comply with Hive implementation, we will keep the default value here as TextInputFormat unless serde is not LazySimpleSerDe + */ + if (serDe.equals(TEXTFILE.getSerDe())) { + inputFormatClass = TextInputFormat.class; + return ReflectionUtils.newInstance(inputFormatClass, jobConf); + } + + for (HiveStorageFormat hiveStorageFormat : HiveStorageFormat.values()) { + if (serDe.equals(hiveStorageFormat.getSerDe())) { + inputFormatClass = getInputFormatClass(jobConf, hiveStorageFormat.getInputFormat()); + return ReflectionUtils.newInstance(inputFormatClass, jobConf); + } + } + + throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, format("Unsupported SerDe for SymlinkTextInputFormat: %s", serDe)); } return ReflectionUtils.newInstance(inputFormatClass, jobConf); @@ -395,7 +429,7 @@ static boolean shouldUseRecordReaderFromInputFormat(Configuration configuration, return false; } - InputFormat inputFormat = HiveUtil.getInputFormat(configuration, storage.getStorageFormat().getInputFormat(), false); + InputFormat inputFormat = HiveUtil.getInputFormat(configuration, storage.getStorageFormat().getInputFormat(), storage.getStorageFormat().getSerDe(), false); return Arrays.stream(inputFormat.getClass().getAnnotations()) .map(Annotation::annotationType) .map(Class::getSimpleName) @@ -1311,4 +1345,98 @@ public static Map buildDirectoryContextProperties(ConnectorSessi } return directoryContextProperties.build(); } + + public static List readSymlinkPaths(ExtendedFileSystem fileSystem, Iterator manifestFileInfos) + throws IOException + { + ImmutableList.Builder targets = ImmutableList.builder(); + while (manifestFileInfos.hasNext()) { + HiveFileInfo symlink = manifestFileInfos.next(); + + try (BufferedReader reader = new BufferedReader(new InputStreamReader(fileSystem.open(new Path(symlink.getPath())), UTF_8))) { + readLines(reader).stream() + .map(Path::new) + .forEach(targets::add); + } + } + return targets.build(); + } + + public static List getTargetPathsHiveFileInfos( + Path path, + Optional partition, + Path targetParent, + List currentTargetPaths, + HiveDirectoryContext hiveDirectoryContext, + ExtendedFileSystem targetFilesystem, + DirectoryLister directoryLister, + Table table, + NamenodeStats namenodeStats, + ConnectorSession session) + { + boolean parentPathCached = directoryLister.isPathCached(targetParent); + + Map targetParentHiveFileInfos = new HashMap<>(getTargetParentHiveFileInfoMap( + partition, + targetParent, + hiveDirectoryContext, + targetFilesystem, + directoryLister, + table, + namenodeStats)); + + // If caching is enabled and the parent path was cached, we verify that all target paths exist in the listing. + // If any target path is missing (likely due to stale cache), we invalidate the cache for that directory + // and re-fetch the listing to ensure we don't miss any files. + if (parentPathCached && isUseListDirectoryCache(session)) { + boolean allPathsExist = currentTargetPaths.stream() + .map(Path::getPathWithoutSchemeAndAuthority) + .map(Path::toString) + .allMatch(targetParentHiveFileInfos::containsKey); + + if (!allPathsExist) { + ((CachingDirectoryLister) directoryLister).invalidateDirectoryListCache(Optional.of(targetParent.toString())); + + targetParentHiveFileInfos.clear(); + targetParentHiveFileInfos.putAll(getTargetParentHiveFileInfoMap( + partition, + targetParent, + hiveDirectoryContext, + targetFilesystem, + directoryLister, + table, + namenodeStats)); + } + } + + return currentTargetPaths.stream().map(targetPath -> { + HiveFileInfo hiveFileInfo = targetParentHiveFileInfos.get(getPathWithoutSchemeAndAuthority(targetPath).toString()); + + if (hiveFileInfo == null) { + throw new PrestoException(HIVE_FILE_NOT_FOUND, String.format("Invalid path in Symlink manifest file %s: %s does not exist", path, targetPath)); + } + + return hiveFileInfo; + }).collect(toImmutableList()); + } + + private static Map getTargetParentHiveFileInfoMap( + Optional partition, + Path targetParent, + HiveDirectoryContext hiveDirectoryContext, + ExtendedFileSystem targetFilesystem, + DirectoryLister directoryLister, + Table table, + NamenodeStats namenodeStats) + { + Map targetParentHiveFileInfos = new HashMap<>(); + Iterator hiveFileInfoIterator = directoryLister.list(targetFilesystem, table, targetParent, partition, namenodeStats, hiveDirectoryContext); + + // We will use the path without the scheme and authority since the manifest file may contain entries both with and without them + hiveFileInfoIterator.forEachRemaining(hiveFileInfo -> targetParentHiveFileInfos.put( + getPathWithoutSchemeAndAuthority(new Path(hiveFileInfo.getPath())).toString(), + hiveFileInfo)); + + return targetParentHiveFileInfos; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java index 9d0ddf2581960..84f4c75bf4345 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriteUtils.java @@ -45,6 +45,7 @@ import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.HadoopExtendedFileSystem; import org.apache.hadoop.fs.Path; @@ -53,10 +54,13 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.metastore.ProtectMode; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; +import org.apache.hadoop.hive.ql.exec.TextRecordWriter; +import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; import org.apache.hadoop.hive.ql.io.HiveOutputFormat; import org.apache.hadoop.hive.ql.io.RCFile; import org.apache.hadoop.hive.ql.io.RCFileOutputFormat; import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat; +import org.apache.hadoop.hive.serde.serdeConstants; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.Serializer; import org.apache.hadoop.hive.serde2.io.DateWritable; @@ -75,6 +79,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.BinaryComparable; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.ByteWritable; import org.apache.hadoop.io.BytesWritable; @@ -90,6 +95,7 @@ import org.apache.hive.common.util.ReflectionUtil; import java.io.IOException; +import java.io.OutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -117,6 +123,7 @@ import static com.facebook.presto.hive.metastore.MetastoreUtil.verifyOnline; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static java.lang.Float.intBitsToFloat; +import static java.lang.Integer.parseInt; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -124,6 +131,7 @@ import static java.util.stream.Collectors.toList; import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.COMPRESSRESULT; import static org.apache.hadoop.hive.metastore.api.hive_metastoreConstants.META_TABLE_COLUMNS; +import static org.apache.hadoop.hive.ql.exec.Utilities.createCompressedStream; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector; import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaBooleanObjectInspector; @@ -159,6 +167,11 @@ private HiveWriteUtils() } public static RecordWriter createRecordWriter(Path target, JobConf conf, Properties properties, String outputFormatName, ConnectorSession session) + { + return createRecordWriter(target, conf, properties, outputFormatName, session, Optional.empty()); + } + + public static RecordWriter createRecordWriter(Path target, JobConf conf, Properties properties, String outputFormatName, ConnectorSession session, Optional textCSVHeaderWriter) { try { boolean compress = HiveConf.getBoolVar(conf, COMPRESSRESULT); @@ -168,6 +181,9 @@ public static RecordWriter createRecordWriter(Path target, JobConf conf, Propert if (outputFormatName.equals(MapredParquetOutputFormat.class.getName())) { return createParquetWriter(target, conf, properties, compress, session); } + if (outputFormatName.equals(HiveIgnoreKeyTextOutputFormat.class.getName())) { + return createTextCsvFileWriter(target, conf, properties, compress, textCSVHeaderWriter); + } Object writer = Class.forName(outputFormatName).getConstructor().newInstance(); return ((HiveOutputFormat) writer).getHiveRecordWriter(conf, target, Text.class, compress, properties, Reporter.NULL); } @@ -218,6 +234,63 @@ public void close(boolean abort) }; } + private static RecordWriter createTextCsvFileWriter(Path target, JobConf conf, Properties properties, boolean compress, Optional textCSVHeaderWriter) + throws IOException + { + String rowSeparatorString = properties.getProperty(serdeConstants.LINE_DELIM, "\n"); + + int rowSeparatorByte; + try { + rowSeparatorByte = Byte.parseByte(rowSeparatorString); + } + catch (NumberFormatException e) { + rowSeparatorByte = rowSeparatorString.charAt(0); + } + + FSDataOutputStream output = target.getFileSystem(conf).create(target, Reporter.NULL); + OutputStream compressedOutput = createCompressedStream(conf, output, compress); + TextRecordWriter writer = new TextRecordWriter(); + writer.initialize(compressedOutput, conf); + Optional skipHeaderLine = Optional.ofNullable(properties.getProperty("skip.header.line.count")); + if (skipHeaderLine.isPresent()) { + if (parseInt(skipHeaderLine.get()) == 1) { + textCSVHeaderWriter + .orElseThrow(() -> new IllegalArgumentException("TextHeaderWriter must not be empty when skip.header.line.count is set to 1")) + .write(compressedOutput, rowSeparatorByte); + } + } + int finalRowSeparatorByte = rowSeparatorByte; + return new ExtendedRecordWriter() + { + private long length; + + @Override + public long getWrittenBytes() + { + return length; + } + + @Override + public void write(Writable value) + throws IOException + { + BinaryComparable binary = (BinaryComparable) value; + compressedOutput.write(binary.getBytes(), 0, binary.getLength()); + compressedOutput.write(finalRowSeparatorByte); + } + + @Override + public void close(boolean abort) + throws IOException + { + writer.close(); + if (!abort) { + length = target.getFileSystem(conf).getFileStatus(target).getLen(); + } + } + }; + } + public static Serializer initializeSerializer(Configuration conf, Properties properties, String serializerName) { try { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java index 4ac892cb21678..8606fcf4cbb72 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveWriterFactory.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.airlift.event.client.EventClient; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; @@ -39,7 +40,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.common.FileUtils; import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveZeroRowFileCreator.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveZeroRowFileCreator.java index 081551c7010f7..faaeb2cb9c500 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveZeroRowFileCreator.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveZeroRowFileCreator.java @@ -24,13 +24,12 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import io.airlift.slice.Slices; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.mapred.JobConf; -import javax.inject.Inject; - import java.io.IOException; import java.io.UncheckedIOException; import java.nio.file.Paths; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/InternalHiveSplit.java b/presto-hive/src/main/java/com/facebook/presto/hive/InternalHiveSplit.java index 2cbec5c1d85fc..2ecf44843cb17 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/InternalHiveSplit.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/InternalHiveSplit.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.hive.HiveSplit.BucketConversion; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; @@ -20,8 +21,6 @@ import com.google.common.collect.ImmutableMap; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.List; import java.util.Map; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ManifestPartitionLoader.java b/presto-hive/src/main/java/com/facebook/presto/hive/ManifestPartitionLoader.java index b904e21a83abe..c4614a6f95b1d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ManifestPartitionLoader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ManifestPartitionLoader.java @@ -155,6 +155,7 @@ private InternalHiveSplitFactory createInternalHiveSplitFactory( String partitionName = partition.getHivePartition().getPartitionId().getPartitionName(); Storage storage = partition.getPartition().map(Partition::getStorage).orElse(table.getStorage()); String inputFormatName = storage.getStorageFormat().getInputFormat(); + String serDe = storage.getStorageFormat().getSerDe(); int partitionDataColumnCount = partition.getPartition() .map(p -> p.getColumns().size()) .orElseGet(table.getDataColumns()::size); @@ -162,7 +163,7 @@ private InternalHiveSplitFactory createInternalHiveSplitFactory( String location = getPartitionLocation(table, partition.getPartition()); Path path = new Path(location); Configuration configuration = hdfsEnvironment.getConfiguration(hdfsContext, path); - InputFormat inputFormat = getInputFormat(configuration, inputFormatName, false); + InputFormat inputFormat = getInputFormat(configuration, inputFormatName, serDe, false); ExtendedFileSystem fileSystem = hdfsEnvironment.getFileSystem(hdfsContext, path); return new InternalHiveSplitFactory( diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterConfig.java index df2ae97535845..51c90630d12d3 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterConfig.java @@ -15,13 +15,12 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.orc.DefaultOrcWriterFlushPolicy; import com.facebook.presto.orc.OrcWriterOptions; import com.facebook.presto.orc.metadata.DwrfStripeCacheMode; import com.facebook.presto.orc.writer.StreamLayoutFactory; -import io.airlift.units.DataSize; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.util.OptionalInt; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterFactory.java index ba8f130569839..89d36f3cc44a3 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/OrcFileWriterFactory.java @@ -40,6 +40,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat; @@ -49,8 +50,6 @@ import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; -import javax.inject.Inject; - import java.io.IOException; import java.net.InetAddress; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/ParquetFileWriterConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/ParquetFileWriterConfig.java index bef328563f843..541fdd3add710 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/ParquetFileWriterConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/ParquetFileWriterConfig.java @@ -14,11 +14,11 @@ package com.facebook.presto.hive; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.parquet.writer.ParquetWriterOptions; -import io.airlift.units.DataSize; import org.apache.parquet.hadoop.ParquetWriter; -import static io.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static org.apache.parquet.column.ParquetProperties.WriterVersion; public class ParquetFileWriterConfig diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java index 2d1a11d7e32fb..b453930934ea9 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/RcFileFileWriterFactory.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.RCFileOutputFormat; @@ -34,8 +35,6 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.io.OutputStream; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/RebindSafeMBeanServer.java b/presto-hive/src/main/java/com/facebook/presto/hive/RebindSafeMBeanServer.java index 21abfc5d29566..c1af771e83253 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/RebindSafeMBeanServer.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/RebindSafeMBeanServer.java @@ -14,8 +14,8 @@ package com.facebook.presto.hive; import com.facebook.airlift.log.Logger; +import com.google.errorprone.annotations.ThreadSafe; -import javax.annotation.concurrent.ThreadSafe; import javax.management.Attribute; import javax.management.AttributeList; import javax.management.AttributeNotFoundException; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/RecordFileWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/RecordFileWriter.java index 476b6e186ce4e..04edc2c71ffa1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/RecordFileWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/RecordFileWriter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.Type; @@ -23,9 +24,9 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; +import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.Serializer; import org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe; @@ -100,11 +101,18 @@ public RecordFileWriter( serDe = OptimizedLazyBinaryColumnarSerde.class.getName(); } serializer = initializeSerializer(conf, schema, serDe); - recordWriter = createRecordWriter(path, conf, schema, storageFormat.getOutputFormat(), session); List objectInspectors = getRowColumnInspectors(fileColumnTypes); tableInspector = getStandardStructObjectInspector(fileColumnNames, objectInspectors); + if (storageFormat.getOutputFormat().equals(HiveIgnoreKeyTextOutputFormat.class.getName())) { + Optional textHeaderWriter = Optional.of(new TextCSVHeaderWriter(serializer, typeManager, session, fileColumnNames)); + recordWriter = createRecordWriter(path, conf, schema, storageFormat.getOutputFormat(), session, textHeaderWriter); + } + else { + recordWriter = createRecordWriter(path, conf, schema, storageFormat.getOutputFormat(), session, Optional.empty()); + } + // reorder (and possibly reduce) struct fields to match input structFields = ImmutableList.copyOf(inputColumnNames.stream() .map(tableInspector::getStructFieldRef) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriter.java index 2ceac1b992813..4c74193f71b43 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriter.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.io.DataSink; @@ -29,7 +30,6 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; import com.google.common.io.Closer; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.openjdk.jol.info.ClassLayout; @@ -46,12 +46,12 @@ import java.util.function.Consumer; import java.util.stream.IntStream; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_DATA_ERROR; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.min; import static java.util.Comparator.comparing; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriterFactory.java index 84dc4fef5c0b3..dfb7511bcc6c5 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/SortingFileWriterFactory.java @@ -14,13 +14,13 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PageSorter; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/StoragePartitionLoader.java b/presto-hive/src/main/java/com/facebook/presto/hive/StoragePartitionLoader.java index 110b7cb415505..8ba0c27a18efe 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/StoragePartitionLoader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/StoragePartitionLoader.java @@ -24,11 +24,11 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; import com.google.common.collect.ListMultimap; -import com.google.common.io.CharStreams; import com.google.common.util.concurrent.ListenableFuture; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; @@ -40,10 +40,7 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.TextInputFormat; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStreamReader; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Comparator; import java.util.Deque; @@ -54,6 +51,7 @@ import java.util.OptionalInt; import java.util.Properties; import java.util.function.IntPredicate; +import java.util.stream.Collectors; import static com.facebook.presto.hive.HiveBucketing.getVirtualBucketNumber; import static com.facebook.presto.hive.HiveColumnHandle.pathColumnHandle; @@ -68,12 +66,15 @@ import static com.facebook.presto.hive.HiveSessionProperties.isFileSplittable; import static com.facebook.presto.hive.HiveSessionProperties.isOrderBasedExecutionEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isSkipEmptyFilesEnabled; +import static com.facebook.presto.hive.HiveSessionProperties.isSymlinkOptimizedReaderEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isUseListDirectoryCache; import static com.facebook.presto.hive.HiveUtil.buildDirectoryContextProperties; import static com.facebook.presto.hive.HiveUtil.getFooterCount; import static com.facebook.presto.hive.HiveUtil.getHeaderCount; import static com.facebook.presto.hive.HiveUtil.getInputFormat; +import static com.facebook.presto.hive.HiveUtil.getTargetPathsHiveFileInfos; import static com.facebook.presto.hive.HiveUtil.isHudiParquetInputFormat; +import static com.facebook.presto.hive.HiveUtil.readSymlinkPaths; import static com.facebook.presto.hive.HiveUtil.shouldUseFileSplitsFromInputFormat; import static com.facebook.presto.hive.HiveWriterFactory.getBucketNumber; import static com.facebook.presto.hive.NestedDirectoryPolicy.FAIL; @@ -143,7 +144,11 @@ public StoragePartitionLoader( if (!isNullOrEmpty(table.getStorage().getLocation())) { Configuration configuration = hdfsEnvironment.getConfiguration(hdfsContext, new Path(table.getStorage().getLocation())); try { - InputFormat inputFormat = getInputFormat(configuration, table.getStorage().getStorageFormat().getInputFormat(), false); + InputFormat inputFormat = getInputFormat( + configuration, + table.getStorage().getStorageFormat().getInputFormat(), + table.getStorage().getStorageFormat().getSerDe(), + false); if (isHudiParquetInputFormat(inputFormat)) { directoryListerOverride = Optional.of(new HudiDirectoryLister(configuration, session, table)); } @@ -159,7 +164,8 @@ public StoragePartitionLoader( this.directoryLister = directoryListerOverride.orElseGet(() -> requireNonNull(directoryLister, "directoryLister is null")); } - private ListenableFuture handleSymlinkTextInputFormat(ExtendedFileSystem fs, + private ListenableFuture handleSymlinkTextInputFormat( + ExtendedFileSystem fs, Path path, InputFormat inputFormat, boolean s3SelectPushdownEnabled, @@ -169,30 +175,166 @@ private ListenableFuture handleSymlinkTextInputFormat(ExtendedFileSystem fs, int partitionDataColumnCount, boolean stopped, HivePartitionMetadata partition, - HiveSplitSource hiveSplitSource) + HiveSplitSource hiveSplitSource, + Configuration configuration, + boolean splittable) throws IOException { if (tableBucketInfo.isPresent()) { throw new PrestoException(NOT_SUPPORTED, "Bucketed table in SymlinkTextInputFormat is not yet supported"); } - // TODO: This should use an iterator like the HiveFileIterator + List targetPaths = getTargetPathsFromSymlink(fs, path, partition.getPartition()); + + if (isSymlinkOptimizedReaderEnabled(session)) { + Map> parentToTargets = targetPaths.stream().collect(Collectors.groupingBy(Path::getParent)); + + InputFormat targetInputFormat = getInputFormat( + configuration, + storage.getStorageFormat().getInputFormat(), + storage.getStorageFormat().getSerDe(), + true); + + HiveDirectoryContext hiveDirectoryContext = new HiveDirectoryContext( + IGNORED, + isUseListDirectoryCache(session), + isSkipEmptyFilesEnabled(session), + hdfsContext.getIdentity(), + buildDirectoryContextProperties(session), + session.getRuntimeStats()); + + for (Map.Entry> entry : parentToTargets.entrySet()) { + Iterator symlinkIterator = getSymlinkIterator( + path, + s3SelectPushdownEnabled, + storage, + partitionKeys, + partitionName, + partitionDataColumnCount, + partition, + splittable, + entry.getKey(), + entry.getValue(), + targetInputFormat, + hiveDirectoryContext); + + fileIterators.addLast(symlinkIterator); + } + return COMPLETED_FUTURE; + } + + return getSymlinkSplits( + path, + inputFormat, + s3SelectPushdownEnabled, + storage, + partitionKeys, + partitionName, + partitionDataColumnCount, + stopped, + partition, + hiveSplitSource, + targetPaths); + } + + @VisibleForTesting + Iterator getSymlinkIterator( + Path path, + boolean s3SelectPushdownEnabled, + Storage storage, + List partitionKeys, + String partitionName, + int partitionDataColumnCount, + HivePartitionMetadata partition, + boolean splittable, + Path targetParent, + List currentTargetPaths, + InputFormat targetInputFormat, + HiveDirectoryContext hiveDirectoryContext) + throws IOException + { + ExtendedFileSystem targetFilesystem = hdfsEnvironment.getFileSystem(hdfsContext, targetParent); + + List targetPathsHiveFileInfos = getTargetPathsHiveFileInfos( + path, + partition.getPartition(), + targetParent, + currentTargetPaths, + hiveDirectoryContext, + targetFilesystem, + directoryLister, + table, + namenodeStats, + session); + + InternalHiveSplitFactory splitFactory = getHiveSplitFactory( + targetFilesystem, + targetInputFormat, + s3SelectPushdownEnabled, + storage, + targetParent.toUri().toString(), + partitionName, + partitionKeys, + partitionDataColumnCount, + partition, + Optional.empty()); + + return targetPathsHiveFileInfos.stream() + .map(hiveFileInfo -> splitFactory.createInternalHiveSplit(hiveFileInfo, splittable)) + .filter(Optional::isPresent) + .map(Optional::get) + .iterator(); + } + + private ListenableFuture getSymlinkSplits( + Path path, + InputFormat inputFormat, + boolean s3SelectPushdownEnabled, + Storage storage, + List partitionKeys, + String partitionName, + int partitionDataColumnCount, + boolean stopped, + HivePartitionMetadata partition, + HiveSplitSource hiveSplitSource, + List targetPaths) + throws IOException + { ListenableFuture lastResult = COMPLETED_FUTURE; - for (Path targetPath : getTargetPathsFromSymlink(fs, path, partition.getPartition())) { + for (Path targetPath : targetPaths) { // The input should be in TextInputFormat. TextInputFormat targetInputFormat = new TextInputFormat(); // the splits must be generated using the file system for the target path // get the configuration for the target path -- it may be a different hdfs instance ExtendedFileSystem targetFilesystem = hdfsEnvironment.getFileSystem(hdfsContext, targetPath); - JobConf targetJob = toJobConf(targetFilesystem.getConf()); + + Configuration targetConfiguration = targetFilesystem.getConf(); + if (targetConfiguration instanceof HiveCachingHdfsConfiguration.CachingJobConf) { + targetConfiguration = ((HiveCachingHdfsConfiguration.CachingJobConf) targetConfiguration).getConfig(); + } + if (targetConfiguration instanceof CopyOnFirstWriteConfiguration) { + targetConfiguration = ((CopyOnFirstWriteConfiguration) targetConfiguration).getConfig(); + } + + JobConf targetJob = toJobConf(targetConfiguration); + targetJob.setInputFormat(TextInputFormat.class); targetInputFormat.configure(targetJob); targetJob.set(SPLIT_MINSIZE, Long.toString(getMaxSplitSize(session).toBytes())); FileInputFormat.setInputPaths(targetJob, targetPath); InputSplit[] targetSplits = targetInputFormat.getSplits(targetJob, 0); - InternalHiveSplitFactory splitFactory = getHiveSplitFactory(fs, inputFormat, s3SelectPushdownEnabled, storage, path.toUri().toString(), partitionName, - partitionKeys, partitionDataColumnCount, partition, Optional.empty()); + InternalHiveSplitFactory splitFactory = getHiveSplitFactory( + targetFilesystem, + inputFormat, + s3SelectPushdownEnabled, + storage, + path.toUri().toString(), + partitionName, + partitionKeys, + partitionDataColumnCount, + partition, + Optional.empty()); lastResult = addSplitsToSource(targetSplits, splitFactory, hiveSplitSource, stopped); if (stopped) { return COMPLETED_FUTURE; @@ -263,6 +405,7 @@ public ListenableFuture loadPartition(HivePartitionMetadata partition, HiveSp Storage storage = partition.getPartition().map(Partition::getStorage).orElse(table.getStorage()); Properties schema = getPartitionSchema(table, partition.getPartition()); String inputFormatName = storage.getStorageFormat().getInputFormat(); + String serDe = storage.getStorageFormat().getSerDe(); int partitionDataColumnCount = partition.getPartition() .map(p -> p.getColumns().size()) .orElseGet(table.getDataColumns()::size); @@ -284,13 +427,22 @@ public ListenableFuture loadPartition(HivePartitionMetadata partition, HiveSp configuration = ((CopyOnFirstWriteConfiguration) configuration).getConfig(); } } - InputFormat inputFormat = getInputFormat(configuration, inputFormatName, false); + InputFormat inputFormat = getInputFormat(configuration, inputFormatName, serDe, false); ExtendedFileSystem fs = hdfsEnvironment.getFileSystem(hdfsContext.getIdentity().getUser(), path, configuration); boolean s3SelectPushdownEnabled = shouldEnablePushdownForTable(session, table, path.toString(), partition.getPartition()); + // Streaming aggregation works at the granularity of individual files + // Partial aggregation pushdown works at the granularity of individual files + // therefore we must not split files when either is enabled. + // Skip header / footer lines are not splittable except for a special case when skip.header.line.count=1 + boolean splittable = isFileSplittable(session) && + !isOrderBasedExecutionEnabled(session) && + !partialAggregationsPushedDown && + getFooterCount(schema) == 0 && getHeaderCount(schema) <= 1; + if (inputFormat instanceof SymlinkTextInputFormat) { return handleSymlinkTextInputFormat(fs, path, inputFormat, s3SelectPushdownEnabled, storage, partitionKeys, partitionName, - partitionDataColumnCount, stopped, partition, hiveSplitSource); + partitionDataColumnCount, stopped, partition, hiveSplitSource, configuration, splittable); } Optional bucketConversion = Optional.empty(); @@ -326,15 +478,6 @@ public ListenableFuture loadPartition(HivePartitionMetadata partition, HiveSp return handleGetSplitsFromInputFormat(configuration, path, schema, inputFormat, stopped, hiveSplitSource, splitFactory); } - // Streaming aggregation works at the granularity of individual files - // Partial aggregation pushdown works at the granularity of individual files - // therefore we must not split files when either is enabled. - // Skip header / footer lines are not splittable except for a special case when skip.header.line.count=1 - boolean splittable = isFileSplittable(session) && - !isOrderBasedExecutionEnabled(session) && - !partialAggregationsPushedDown && - getFooterCount(schema) == 0 && getHeaderCount(schema) <= 1; - // Bucketed partitions are fully loaded immediately since all files must be loaded to determine the file to bucket mapping if (tableBucketInfo.isPresent()) { if (tableBucketInfo.get().isVirtuallyBucketed()) { @@ -584,7 +727,6 @@ private List getVirtuallyBucketedSplits(Path path, ExtendedFi private List getTargetPathsFromSymlink(ExtendedFileSystem fileSystem, Path symlinkDir, Optional partition) { try { - List targets = new ArrayList<>(); HiveDirectoryContext hiveDirectoryContext = new HiveDirectoryContext( IGNORED, isUseListDirectoryCache(session), @@ -592,16 +734,8 @@ private List getTargetPathsFromSymlink(ExtendedFileSystem fileSystem, Path hdfsContext.getIdentity(), buildDirectoryContextProperties(session), session.getRuntimeStats()); - List manifestFileInfos = ImmutableList.copyOf(directoryLister.list(fileSystem, table, symlinkDir, partition, namenodeStats, hiveDirectoryContext)); - - for (HiveFileInfo symlink : manifestFileInfos) { - try (BufferedReader reader = new BufferedReader(new InputStreamReader(fileSystem.open(new Path(symlink.getPath())), StandardCharsets.UTF_8))) { - CharStreams.readLines(reader).stream() - .map(Path::new) - .forEach(targets::add); - } - } - return targets; + Iterator manifestFileInfos = directoryLister.list(fileSystem, table, symlinkDir, partition, namenodeStats, hiveDirectoryContext); + return readSymlinkPaths(fileSystem, manifestFileInfos); } catch (IOException e) { throw new PrestoException(HIVE_BAD_DATA, "Error parsing symlinks from: " + symlinkDir, e); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java b/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java index 3bdbb0f0f7b8a..7f248470483e3 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/SyncPartitionMetadataProcedure.java @@ -29,11 +29,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; import javax.inject.Provider; import java.io.IOException; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/TextCSVHeaderWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/TextCSVHeaderWriter.java new file mode 100644 index 0000000000000..32f010c7ce105 --- /dev/null +++ b/presto-hive/src/main/java/com/facebook/presto/hive/TextCSVHeaderWriter.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.Serializer; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; +import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.io.Text; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import java.util.stream.IntStream; + +import static com.facebook.presto.hive.HiveWriteUtils.getRowColumnInspector; +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TextCSVHeaderWriter +{ + private final Serializer serializer; + private final Type headerType; + private final List fileColumnNames; + public TextCSVHeaderWriter(Serializer serializer, TypeManager typeManager, ConnectorSession session, List fileColumnNames) + { + this.serializer = serializer; + this.fileColumnNames = fileColumnNames; + this.headerType = HiveType.valueOf("string").getType(typeManager); + } + + public void write(OutputStream compressedOutput, int rowSeparator) + throws IOException + { + try { + ObjectInspector stringObjectInspector = getRowColumnInspector(headerType); + List headers = fileColumnNames.stream().map(Text::new).collect(toImmutableList()); + List inspectors = IntStream.range(0, fileColumnNames.size()).mapToObj(ignored -> stringObjectInspector).collect(toImmutableList()); + StandardStructObjectInspector headerStructObjectInspectors = ObjectInspectorFactory.getStandardStructObjectInspector(fileColumnNames, inspectors); + BinaryComparable binary = (BinaryComparable) serializer.serialize(headers, headerStructObjectInspectors); + compressedOutput.write(binary.getBytes(), 0, binary.getLength()); + compressedOutput.write(rowSeparator); + } + catch (SerDeException e) { + throw new IOException(e); + } + } +} diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java b/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java index ff5fcc933e3c0..7ce318af36cf6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/WriteCompletedEvent.java @@ -16,9 +16,8 @@ import com.facebook.airlift.event.client.EventField; import com.facebook.airlift.event.client.EventField.EventFieldMapping; import com.facebook.airlift.event.client.EventType; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.time.Instant; import java.util.Map; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java index 32e1fca17ba91..239d89b94e17d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/AuthenticationModules.java @@ -21,8 +21,7 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Singleton; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static com.facebook.presto.hive.authentication.KerberosHadoopAuthentication.createKerberosHadoopAuthentication; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/CachingKerberosHadoopAuthentication.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/CachingKerberosHadoopAuthentication.java index 424383c197ad4..baec2af035515 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/CachingKerberosHadoopAuthentication.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/CachingKerberosHadoopAuthentication.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.hive.authentication; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.hadoop.security.UserGroupInformation; -import javax.annotation.concurrent.GuardedBy; import javax.security.auth.Subject; import javax.security.auth.kerberos.KerberosTicket; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/DirectHdfsAuthentication.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/DirectHdfsAuthentication.java index 25f5a3c3efa87..6aec90dc9952d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/DirectHdfsAuthentication.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/DirectHdfsAuthentication.java @@ -14,8 +14,7 @@ package com.facebook.presto.hive.authentication; import com.facebook.presto.hive.ForHdfs; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.hive.authentication.UserGroupInformationUtils.executeActionInDoAs; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java index 02444edb614b3..a29f27afff4e0 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/HdfsKerberosConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class HdfsKerberosConfig { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/ImpersonatingHdfsAuthentication.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/ImpersonatingHdfsAuthentication.java index af4d1bbd3e16f..cb60faac2f980 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/ImpersonatingHdfsAuthentication.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/ImpersonatingHdfsAuthentication.java @@ -14,10 +14,9 @@ package com.facebook.presto.hive.authentication; import com.facebook.presto.hive.ForHdfs; +import jakarta.inject.Inject; import org.apache.hadoop.security.UserGroupInformation; -import javax.inject.Inject; - import static com.facebook.presto.hive.authentication.UserGroupInformationUtils.executeActionInDoAs; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java index 5d3d2840f11fb..6b75d8a08495a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/KerberosHiveMetastoreAuthentication.java @@ -17,6 +17,7 @@ import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.hadoop.hive.metastore.security.DelegationTokenIdentifier; import org.apache.hadoop.hive.thrift.client.TUGIAssumingTransport; import org.apache.hadoop.security.SaslRpcServer; @@ -27,7 +28,6 @@ import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; -import javax.inject.Inject; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java index a78c835071ca1..06216824a5944 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/authentication/MetastoreKerberosConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class MetastoreKerberosConfig { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/cache/HiveCachingHdfsConfiguration.java b/presto-hive/src/main/java/com/facebook/presto/hive/cache/HiveCachingHdfsConfiguration.java index 937d2fa6887c3..7ceec635f68c3 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/cache/HiveCachingHdfsConfiguration.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/cache/HiveCachingHdfsConfiguration.java @@ -24,12 +24,11 @@ import com.facebook.presto.hive.WrapperJobConf; import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.facebook.presto.spi.PrestoException; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.io.IOException; import java.net.URI; import java.util.function.BiFunction; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfig.java index f4c2f3df498e4..4c66f49e87600 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfig.java @@ -15,7 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; public class HiveGcsConfig { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfigurationInitializer.java b/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfigurationInitializer.java index ce75fee27b5ff..4285c5197ee96 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfigurationInitializer.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/gcs/HiveGcsConfigurationInitializer.java @@ -14,10 +14,9 @@ package com.facebook.presto.hive.gcs; import com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import static com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystemBase.AUTHENTICATION_PREFIX; import static com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystemConfiguration.AUTH_SERVICE_ACCOUNT_ENABLE; import static com.google.cloud.hadoop.util.AccessTokenProviderClassFromConfigFactory.ACCESS_TOKEN_PROVIDER_IMPL_SUFFIX; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfAggregatedPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfAggregatedPageSourceFactory.java index 19ba823ef895a..723759a0b162e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfAggregatedPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfAggregatedPageSourceFactory.java @@ -19,7 +19,6 @@ import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveAggregatedPageSourceFactory; -import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HiveFileContext; import com.facebook.presto.hive.HiveFileSplit; @@ -30,13 +29,13 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.StandardFunctionResolution; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import java.util.List; import java.util.Optional; +import static com.facebook.presto.hive.HiveCommonSessionProperties.isUseOrcColumnNames; import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; import static com.facebook.presto.hive.orc.OrcAggregatedPageSourceFactory.createOrcPageSource; import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION; @@ -48,7 +47,6 @@ public class DwrfAggregatedPageSourceFactory { private final TypeManager typeManager; private final StandardFunctionResolution functionResolution; - private final boolean useOrcColumnNames; private final HdfsEnvironment hdfsEnvironment; private final FileFormatDataSourceStats stats; private final OrcFileTailSource orcFileTailSource; @@ -58,26 +56,6 @@ public class DwrfAggregatedPageSourceFactory public DwrfAggregatedPageSourceFactory( TypeManager typeManager, StandardFunctionResolution functionResolution, - HiveClientConfig config, - HdfsEnvironment hdfsEnvironment, - FileFormatDataSourceStats stats, - OrcFileTailSource orcFileTailSource, - StripeMetadataSourceFactory stripeMetadataSourceFactory) - { - this( - typeManager, - functionResolution, - requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), - hdfsEnvironment, - stats, - orcFileTailSource, - stripeMetadataSourceFactory); - } - - public DwrfAggregatedPageSourceFactory( - TypeManager typeManager, - StandardFunctionResolution functionResolution, - boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats, OrcFileTailSource orcFileTailSource, @@ -85,7 +63,6 @@ public DwrfAggregatedPageSourceFactory( { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); - this.useOrcColumnNames = useOrcColumnNames; this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.stats = requireNonNull(stats, "stats is null"); this.orcFileTailSource = requireNonNull(orcFileTailSource, "orcFileTailCache is null"); @@ -117,7 +94,7 @@ public Optional createPageSource( configuration, fileSplit, columns, - useOrcColumnNames, + isUseOrcColumnNames(session), typeManager, functionResolution, stats, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfBatchPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfBatchPageSourceFactory.java index 7babcc020c82a..894acab035b44 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfBatchPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfBatchPageSourceFactory.java @@ -35,11 +35,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.function.StandardFunctionResolution; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.util.List; import java.util.Map; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfSelectivePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfSelectivePageSourceFactory.java index 124efccd8f3f3..045793f40320a 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfSelectivePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/DwrfSelectivePageSourceFactory.java @@ -38,11 +38,10 @@ import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.RowExpressionService; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.util.List; import java.util.Map; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java index 9d122d93c6332..9053233d6ac69 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/HdfsOrcDataSource.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.orc.AbstractOrcDataSource; import com.facebook.presto.orc.OrcDataSourceId; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.FSDataInputStream; import java.io.IOException; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcAggregatedPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcAggregatedPageSourceFactory.java index ab1533eac7d7f..cd6313d484755 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcAggregatedPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcAggregatedPageSourceFactory.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.hive.EncryptionInformation; import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveAggregatedPageSourceFactory; -import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.HiveColumnHandle; import com.facebook.presto.hive.HiveFileContext; import com.facebook.presto.hive.HiveFileSplit; @@ -35,13 +35,11 @@ import com.facebook.presto.spi.FixedPageSource; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.orc.OrcSerde; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; @@ -50,6 +48,7 @@ import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcMaxReadBlockSize; import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold; import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled; +import static com.facebook.presto.hive.HiveCommonSessionProperties.isUseOrcColumnNames; import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles; import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource; import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcReader; @@ -63,7 +62,6 @@ public class OrcAggregatedPageSourceFactory { private final TypeManager typeManager; private final StandardFunctionResolution functionResolution; - private final boolean useOrcColumnNames; private final HdfsEnvironment hdfsEnvironment; private final FileFormatDataSourceStats stats; private final OrcFileTailSource orcFileTailSource; @@ -73,26 +71,6 @@ public class OrcAggregatedPageSourceFactory public OrcAggregatedPageSourceFactory( TypeManager typeManager, StandardFunctionResolution functionResolution, - HiveClientConfig config, - HdfsEnvironment hdfsEnvironment, - FileFormatDataSourceStats stats, - OrcFileTailSource orcFileTailSource, - StripeMetadataSourceFactory stripeMetadataSourceFactory) - { - this( - typeManager, - functionResolution, - requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), - hdfsEnvironment, - stats, - orcFileTailSource, - stripeMetadataSourceFactory); - } - - public OrcAggregatedPageSourceFactory( - TypeManager typeManager, - StandardFunctionResolution functionResolution, - boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats, OrcFileTailSource orcFileTailSource, @@ -100,7 +78,6 @@ public OrcAggregatedPageSourceFactory( { this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); - this.useOrcColumnNames = useOrcColumnNames; this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.stats = requireNonNull(stats, "stats is null"); this.orcFileTailSource = requireNonNull(orcFileTailSource, "orcFileTailCache is null"); @@ -133,7 +110,7 @@ public Optional createPageSource( configuration, fileSplit, columns, - useOrcColumnNames, + isUseOrcColumnNames(session), typeManager, functionResolution, stats, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSource.java index 9573c25b3cb40..7fd753773077d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSource.java @@ -32,12 +32,12 @@ import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Booleans; import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; @@ -72,15 +72,16 @@ public class OrcBatchPageSource private final RuntimeStats runtimeStats; - private final boolean[] isRowNumberList; + private final OptionalInt rowNumberColumnIndex; private final RowIDCoercer coercer; /** * @param columns an ordered list of the fields to read - * @param isRowNumberList list of indices of columns. If true, then the column then the column - * at the same position in {@code columns} is a row number. If false, it isn't. - * This should have the same length as {@code columns}. + * @param rowNumberColumnIndex specifies the index of the row number column. Its value should + * be less than the length of {@code columns}. Set to OptionalInt.empty() if no row number + * column is present. + * * #throws IllegalArgumentException if columns and isRowNumberList do not have the same size */ // TODO(elharo) HiveColumnHandle should know whether it's a row number or not. Alternatively, @@ -93,8 +94,7 @@ public OrcBatchPageSource( OrcAggregatedMemoryContext systemMemoryContext, FileFormatDataSourceStats stats, RuntimeStats runtimeStats, - // TODO avoid conversion; just pass a boolean array here - List isRowNumberList, + OptionalInt rowNumberColumnIndex, byte[] rowIDPartitionComponent, String rowGroupId) { @@ -105,9 +105,9 @@ public OrcBatchPageSource( this.stats = requireNonNull(stats, "stats is null"); this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null"); - requireNonNull(isRowNumberList, "isRowNumberList is null"); - checkArgument(isRowNumberList.size() == numColumns, "row number list size %s does not match columns size %s", isRowNumberList.size(), columns.size()); - this.isRowNumberList = Booleans.toArray(isRowNumberList); + checkArgument(rowNumberColumnIndex.isEmpty() || + (rowNumberColumnIndex.getAsInt() >= 0 && rowNumberColumnIndex.getAsInt() < numColumns), "row number column index is incorrect"); + this.rowNumberColumnIndex = rowNumberColumnIndex; this.coercer = new RowIDCoercer(rowIDPartitionComponent, rowGroupId); this.constantBlocks = new Block[numColumns]; @@ -264,7 +264,7 @@ protected void closeWithSuppression(Throwable throwable) private boolean isRowPositionColumn(int column) { - return isRowNumberList[column]; + return rowNumberColumnIndex.isPresent() && rowNumberColumnIndex.getAsInt() == column; } private boolean isRowIDColumn(int column) diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSourceFactory.java index 0e3aef9254f71..35669ab368b07 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcBatchPageSourceFactory.java @@ -45,17 +45,17 @@ import com.facebook.presto.spi.function.StandardFunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.orc.OrcSerde; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcMaxMergeDistance; @@ -63,6 +63,7 @@ import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold; import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled; import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled; +import static com.facebook.presto.hive.HiveCommonSessionProperties.isUseOrcColumnNames; import static com.facebook.presto.hive.HiveUtil.checkRowIDPartitionComponent; import static com.facebook.presto.hive.HiveUtil.getPhysicalHiveColumnHandles; import static com.facebook.presto.hive.orc.OrcPageSourceFactoryUtils.getOrcDataSource; @@ -72,14 +73,12 @@ import static com.facebook.presto.orc.OrcEncoding.ORC; import static com.facebook.presto.orc.OrcReader.INITIAL_BATCH_SIZE; import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Collections.nCopies; import static java.util.Objects.requireNonNull; public class OrcBatchPageSourceFactory implements HiveBatchPageSourceFactory { private final TypeManager typeManager; - private final boolean useOrcColumnNames; private final HdfsEnvironment hdfsEnvironment; private final FileFormatDataSourceStats stats; private final int domainCompactionThreshold; @@ -98,7 +97,6 @@ public OrcBatchPageSourceFactory( { this( typeManager, - requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), hdfsEnvironment, stats, config.getDomainCompactionThreshold(), @@ -108,7 +106,6 @@ public OrcBatchPageSourceFactory( public OrcBatchPageSourceFactory( TypeManager typeManager, - boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats, int domainCompactionThreshold, @@ -116,7 +113,6 @@ public OrcBatchPageSourceFactory( StripeMetadataSourceFactory stripeMetadataSourceFactory) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); - this.useOrcColumnNames = useOrcColumnNames; this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.stats = requireNonNull(stats, "stats is null"); this.domainCompactionThreshold = domainCompactionThreshold; @@ -154,7 +150,7 @@ public Optional createPageSource( configuration, fileSplit, columns, - useOrcColumnNames, + isUseOrcColumnNames(session), effectivePredicate, hiveStorageTimeZone, typeManager, @@ -245,7 +241,6 @@ public static ConnectorPageSource createOrcPageSource( String rowGroupID = path.getName(); // none of the columns are row numbers - List isRowNumberList = nCopies(physicalColumns.size(), false); return new OrcBatchPageSource( recordReader, reader.getOrcDataSource(), @@ -254,7 +249,7 @@ public static ConnectorPageSource createOrcPageSource( systemMemoryUsage, stats, hiveFileContext.getStats(), - isRowNumberList, + OptionalInt.empty(), partitionID, rowGroupID); } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactoryUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactoryUtils.java index 5051cf97ca8ab..bbe56caa1c3d6 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactoryUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcPageSourceFactoryUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.orc; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.EncryptionInformation; import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsEnvironment; @@ -33,7 +34,6 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.google.common.util.concurrent.UncheckedExecutionException; -import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcSelectivePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcSelectivePageSourceFactory.java index 0fd0fc6244471..5fcdc9c89ad7c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcSelectivePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/OrcSelectivePageSourceFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.orc; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.function.SqlFunctionProperties; @@ -68,15 +69,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.io.orc.OrcSerde; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -105,6 +104,7 @@ import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcTinyStripeThreshold; import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcBloomFiltersEnabled; import static com.facebook.presto.hive.HiveCommonSessionProperties.isOrcZstdJniDecompressionEnabled; +import static com.facebook.presto.hive.HiveCommonSessionProperties.isUseOrcColumnNames; import static com.facebook.presto.hive.HiveErrorCode.HIVE_INVALID_BUCKET_FILES; import static com.facebook.presto.hive.HiveSessionProperties.isAdaptiveFilterReorderingEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isLegacyTimestampBucketing; @@ -133,7 +133,6 @@ public class OrcSelectivePageSourceFactory private final TypeManager typeManager; private final StandardFunctionResolution functionResolution; private final RowExpressionService rowExpressionService; - private final boolean useOrcColumnNames; private final HdfsEnvironment hdfsEnvironment; private final FileFormatDataSourceStats stats; private final int domainCompactionThreshold; @@ -157,7 +156,6 @@ public OrcSelectivePageSourceFactory( typeManager, functionResolution, rowExpressionService, - requireNonNull(config, "hiveClientConfig is null").isUseOrcColumnNames(), hdfsEnvironment, stats, config.getDomainCompactionThreshold(), @@ -170,7 +168,6 @@ public OrcSelectivePageSourceFactory( TypeManager typeManager, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, - boolean useOrcColumnNames, HdfsEnvironment hdfsEnvironment, FileFormatDataSourceStats stats, int domainCompactionThreshold, @@ -181,7 +178,6 @@ public OrcSelectivePageSourceFactory( this.typeManager = requireNonNull(typeManager, "typeManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); - this.useOrcColumnNames = useOrcColumnNames; this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.stats = requireNonNull(stats, "stats is null"); this.domainCompactionThreshold = domainCompactionThreshold; @@ -231,7 +227,7 @@ public Optional createPageSource( outputColumns, domainPredicate, remainingPredicate, - useOrcColumnNames, + isUseOrcColumnNames(session), hiveStorageTimeZone, typeManager, functionResolution, diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/orc/TupleDomainFilterCache.java b/presto-hive/src/main/java/com/facebook/presto/hive/orc/TupleDomainFilterCache.java index 2878071127c4a..665960d80abde 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/orc/TupleDomainFilterCache.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/orc/TupleDomainFilterCache.java @@ -19,11 +19,10 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import jakarta.annotation.Nullable; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.Nullable; - import static com.facebook.presto.common.predicate.TupleDomainFilterUtils.toFilter; import static java.lang.System.identityHashCode; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFilePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFilePageSourceFactory.java index 57eeb18676ae8..aac36611af58e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFilePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFilePageSourceFactory.java @@ -26,13 +26,12 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.FileNotFoundException; import java.io.IOException; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriter.java index 5d270bbd99835..76f676d1f1770 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriter.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.hive.pagefile; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.io.DataSink; import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.hive.HiveFileWriter; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.page.PagesSerde; -import io.airlift.units.DataSize; import org.openjdk.jol.info.ClassLayout; import java.io.IOException; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriterFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriterFactory.java index 8381270ce2cc3..1ab143216fc32 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriterFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageFileWriterFactory.java @@ -36,12 +36,11 @@ import io.airlift.compress.lz4.Lz4Decompressor; import io.airlift.compress.snappy.SnappyCompressor; import io.airlift.compress.snappy.SnappyDecompressor; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageWriter.java index f5387f67fbb91..807685aac053b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/pagefile/PageWriter.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.hive.pagefile; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.io.DataOutput; import com.facebook.presto.common.io.DataSink; import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.spi.page.PageDataOutput; import com.facebook.presto.spi.page.SerializedPage; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jol.info.ClassLayout; import java.io.Closeable; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetAggregatedPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetAggregatedPageSourceFactory.java index 05b3814e04a96..0a1a274d01d75 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetAggregatedPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetAggregatedPageSourceFactory.java @@ -28,14 +28,13 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; import org.apache.parquet.crypto.InternalFileDecryptor; import org.apache.parquet.hadoop.metadata.ParquetMetadata; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java index 1e9259fb59e8b..b7196dec561ea 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetPageSourceFactory.java @@ -45,6 +45,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; @@ -64,8 +65,6 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -89,6 +88,7 @@ import static com.facebook.presto.common.type.StandardTypes.ROW; import static com.facebook.presto.common.type.StandardTypes.SMALLINT; import static com.facebook.presto.common.type.StandardTypes.TIMESTAMP; +import static com.facebook.presto.common.type.StandardTypes.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.StandardTypes.TINYINT; import static com.facebook.presto.common.type.StandardTypes.VARBINARY; import static com.facebook.presto.common.type.StandardTypes.VARCHAR; @@ -451,7 +451,8 @@ public static boolean checkSchemaMatch(org.apache.parquet.schema.Type parquetTyp PrimitiveTypeName parquetTypeName = parquetType.asPrimitiveType().getPrimitiveTypeName(); switch (parquetTypeName) { case INT64: - return prestoType.equals(BIGINT) || prestoType.equals(DECIMAL) || prestoType.equals(TIMESTAMP) || prestoType.equals(StandardTypes.REAL) || prestoType.equals(StandardTypes.DOUBLE); + return prestoType.equals(BIGINT) || prestoType.equals(DECIMAL) || prestoType.equals(TIMESTAMP) || prestoType.equals(StandardTypes.REAL) || prestoType.equals(StandardTypes.DOUBLE) + || prestoType.equals(TIMESTAMP_WITH_TIME_ZONE); case INT32: return prestoType.equals(INTEGER) || prestoType.equals(BIGINT) || prestoType.equals(SMALLINT) || prestoType.equals(DATE) || prestoType.equals(DECIMAL) || prestoType.equals(TINYINT) || prestoType.equals(REAL) || prestoType.equals(StandardTypes.DOUBLE); @@ -464,7 +465,7 @@ public static boolean checkSchemaMatch(org.apache.parquet.schema.Type parquetTyp case BINARY: return prestoType.equals(VARBINARY) || prestoType.equals(VARCHAR) || prestoType.startsWith(CHAR) || prestoType.equals(DECIMAL); case INT96: - return prestoType.equals(TIMESTAMP); + return prestoType.equals(TIMESTAMP) || prestoType.equals(TIMESTAMP_WITH_TIME_ZONE); case FIXED_LEN_BYTE_ARRAY: return prestoType.equals(DECIMAL); default: diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetSelectivePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetSelectivePageSourceFactory.java index 08922058d363b..450e80d3f3e83 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetSelectivePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/parquet/ParquetSelectivePageSourceFactory.java @@ -28,11 +28,10 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.relation.RowExpression; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.util.List; import java.util.Map; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSource.java b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSource.java index a37427cc52fac..ae0899dea1f61 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSource.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.rcfile; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; @@ -28,7 +29,6 @@ import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.io.IOException; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java index e9106bd542fc8..4b1e18e4647dc 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/rcfile/RcFilePageSourceFactory.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.hive.rcfile; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; @@ -38,8 +40,7 @@ import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.Path; @@ -47,8 +48,6 @@ import org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.FileNotFoundException; import java.io.IOException; import java.util.Arrays; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/HiveS3Config.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/HiveS3Config.java index 5c835615cf57e..a49850a4c692b 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/HiveS3Config.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/HiveS3Config.java @@ -16,19 +16,18 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDataSize; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.StandardSystemProperty; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDataSize; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.concurrent.TimeUnit; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class HiveS3Config { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ClientFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ClientFactory.java index 37102d6be6e67..8008d36a2db32 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ClientFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ClientFactory.java @@ -26,12 +26,11 @@ import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3Builder; import com.amazonaws.services.s3.AmazonS3Client; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.HiveClientConfig; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.apache.hadoop.conf.Configuration; -import javax.annotation.concurrent.GuardedBy; - import java.net.URI; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ConfigurationUpdater.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ConfigurationUpdater.java index d68a21dd58045..5323f39fdb2cc 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ConfigurationUpdater.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3ConfigurationUpdater.java @@ -13,12 +13,11 @@ */ package com.facebook.presto.hive.s3; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import java.io.File; public class PrestoS3ConfigurationUpdater diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystem.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystem.java index bc71a852de713..f402478b1d7b7 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystem.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystem.java @@ -54,6 +54,8 @@ import com.amazonaws.services.s3.transfer.TransferManagerBuilder; import com.amazonaws.services.s3.transfer.Upload; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.AbstractSequentialIterator; @@ -61,8 +63,6 @@ import com.google.common.collect.Iterators; import com.google.common.io.Closer; import com.google.common.net.MediaType; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.hadoop.conf.Configurable; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.BlockLocation; @@ -105,6 +105,7 @@ import static com.amazonaws.services.s3.Headers.UNENCRYPTED_CONTENT_LENGTH; import static com.amazonaws.services.s3.model.StorageClass.DeepArchive; import static com.amazonaws.services.s3.model.StorageClass.Glacier; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.RetryDriver.retry; import static com.facebook.presto.hive.s3.S3ConfigurationUpdater.S3_ACCESS_KEY; import static com.facebook.presto.hive.s3.S3ConfigurationUpdater.S3_ACL_TYPE; @@ -146,7 +147,6 @@ import static com.google.common.base.Throwables.throwIfUnchecked; import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.toArray; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.max; import static java.lang.Math.toIntExact; import static java.lang.String.format; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemMetricCollector.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemMetricCollector.java index 97c5fe0c220ba..8b0ca6e75c631 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemMetricCollector.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemMetricCollector.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.hive.s3; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.aws.AbstractSdkMetricsCollector; -import io.airlift.units.Duration; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemStats.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemStats.java index 880182aba154f..ef429fc973993 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemStats.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/PrestoS3FileSystemStats.java @@ -16,7 +16,7 @@ import com.amazonaws.AbortedException; import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeStat; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/AWSS3SecurityMappingConfigurationProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/AWSS3SecurityMappingConfigurationProvider.java index a09e391377376..7bff58e38dd08 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/AWSS3SecurityMappingConfigurationProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/AWSS3SecurityMappingConfigurationProvider.java @@ -19,10 +19,9 @@ import com.facebook.presto.hive.aws.security.AWSSecurityMappings; import com.facebook.presto.hive.aws.security.AWSSecurityMappingsSupplier; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; -import javax.inject.Inject; - import java.net.URI; import java.util.Set; import java.util.function.Supplier; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/ForAWSS3DynamicConfigurationProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/ForAWSS3DynamicConfigurationProvider.java index 446032ade96c0..05311a04ac87e 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/ForAWSS3DynamicConfigurationProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3/security/ForAWSS3DynamicConfigurationProvider.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive.s3.security; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectLineRecordReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectLineRecordReader.java index 24d394b2d132d..f271566e193b1 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectLineRecordReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectLineRecordReader.java @@ -20,6 +20,7 @@ import com.amazonaws.services.s3.model.OutputSerialization; import com.amazonaws.services.s3.model.ScanRange; import com.amazonaws.services.s3.model.SelectObjectContentRequest; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.HiveClientConfig; import com.facebook.presto.hive.s3.HiveS3Config; import com.facebook.presto.hive.s3.PrestoS3ClientFactory; @@ -28,7 +29,6 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closer; -import io.airlift.units.Duration; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectRecordCursorProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectRecordCursorProvider.java index 4552e71944b71..1a3497fcf44e4 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectRecordCursorProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/s3select/S3SelectRecordCursorProvider.java @@ -26,12 +26,11 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.RecordCursor; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.joda.time.DateTimeZone; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java index 59c2fe40a9a3a..3ccb4db413a69 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/LegacyAccessControl.java @@ -30,8 +30,7 @@ import com.facebook.presto.spi.security.ViewExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -102,6 +101,11 @@ public Set filterSchemas(ConnectorTransactionHandle transactionHandle, C return schemaNames; } + @Override + public void checkCanShowCreateTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + } + @Override public void checkCanCreateTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -157,6 +161,17 @@ public Set filterTables(ConnectorTransactionHandle transactionH return tableNames; } + @Override + public void checkCanShowColumnsMetadata(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + } + + @Override + public List filterColumns(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return columns; + } + @Override public void checkCanAddColumn(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -186,6 +201,11 @@ public void checkCanSelectFromColumns(ConnectorTransactionHandle transactionHand { } + @Override + public void checkCanCallProcedure(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName procedureName) + { + } + @Override public void checkCanInsertIntoTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -276,6 +296,16 @@ public void checkCanShowRoleGrants(ConnectorTransactionHandle transactionHandle, { } + @Override + public void checkCanDropBranch(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + } + + @Override + public void checkCanDropTag(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + } + @Override public void checkCanDropConstraint(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SecurityConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SecurityConfig.java index 9e6ed10b97f61..0f4bbd16bfbf7 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SecurityConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SecurityConfig.java @@ -14,12 +14,12 @@ package com.facebook.presto.hive.security; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class SecurityConfig { private String securitySystem = "legacy"; + private boolean restrictProcedureCall = true; @NotNull public String getSecuritySystem() @@ -33,4 +33,16 @@ public SecurityConfig setSecuritySystem(String securitySystem) this.securitySystem = securitySystem; return this; } + + public boolean isRestrictProcedureCall() + { + return restrictProcedureCall; + } + + @Config("hive.restrict-procedure-call") + public SecurityConfig setRestrictProcedureCall(boolean restrictProcedureCall) + { + this.restrictProcedureCall = restrictProcedureCall; + return this; + } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java index 5f225ba0a094c..b230ffc0d8249 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SqlStandardAccessControl.java @@ -34,8 +34,7 @@ import com.facebook.presto.spi.security.ViewExpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -57,17 +56,20 @@ import static com.facebook.presto.hive.metastore.thrift.ThriftMetastoreUtil.listEnabledTablePrivileges; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddConstraint; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCallProcedure; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateRole; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropBranch; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropConstraint; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropRole; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTag; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; import static com.facebook.presto.spi.security.AccessDeniedException.denyGrantRoles; import static com.facebook.presto.spi.security.AccessDeniedException.denyGrantTablePrivilege; @@ -82,6 +84,8 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denySetCatalogSessionProperty; import static com.facebook.presto.spi.security.AccessDeniedException.denySetRole; import static com.facebook.presto.spi.security.AccessDeniedException.denySetTableProperties; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowColumnsMetadata; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyShowRoles; import static com.facebook.presto.spi.security.AccessDeniedException.denyTruncateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyUpdateTableColumns; @@ -100,30 +104,23 @@ public class SqlStandardAccessControl private final String connectorId; private final HiveTransactionManager hiveTransactionManager; + private final boolean restrictProcedureCall; @Inject public SqlStandardAccessControl( HiveConnectorId connectorId, - HiveTransactionManager hiveTransactionManager) + HiveTransactionManager hiveTransactionManager, + SecurityConfig securityConfig) { this.connectorId = requireNonNull(connectorId, "connectorId is null").toString(); this.hiveTransactionManager = requireNonNull(hiveTransactionManager, "hiveTransactionManager is null"); + this.restrictProcedureCall = requireNonNull(securityConfig, "securityConfig is null").isRestrictProcedureCall(); } @Override public void checkCanCreateSchema(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, String schemaName) { - // TODO: Refactor code to inject metastore headers using AccessControlContext instead of empty() - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isAdmin(transaction, identity, metastoreContext)) { denyCreateSchema(schemaName); } @@ -132,16 +129,7 @@ public void checkCanCreateSchema(ConnectorTransactionHandle transaction, Connect @Override public void checkCanDropSchema(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, String schemaName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isDatabaseOwner(transaction, identity, metastoreContext, schemaName)) { denyDropSchema(schemaName); } @@ -150,16 +138,7 @@ public void checkCanDropSchema(ConnectorTransactionHandle transaction, Connector @Override public void checkCanRenameSchema(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, String schemaName, String newSchemaName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isDatabaseOwner(transaction, identity, metastoreContext, schemaName)) { denyRenameSchema(schemaName, newSchemaName); } @@ -176,19 +155,20 @@ public Set filterSchemas(ConnectorTransactionHandle transactionHandle, C return schemaNames; } + @Override + public void checkCanShowCreateTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + MetastoreContext metastoreContext = createMetastoreContext(identity, context); + + if (!checkTablePermission(transactionHandle, identity, metastoreContext, tableName, SELECT, true)) { + denyShowCreateTable(tableName.toString()); + } + } + @Override public void checkCanCreateTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isDatabaseOwner(transaction, identity, metastoreContext, tableName.getSchemaName())) { denyCreateTable(tableName.toString()); } @@ -197,16 +177,7 @@ public void checkCanCreateTable(ConnectorTransactionHandle transaction, Connecto @Override public void checkCanSetTableProperties(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, Map properties) { - MetastoreContext metastoreContext = new MetastoreContext(identity, - context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transactionHandle, identity, metastoreContext, tableName)) { denySetTableProperties(tableName.toString()); } @@ -215,16 +186,7 @@ public void checkCanSetTableProperties(ConnectorTransactionHandle transactionHan @Override public void checkCanDropTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyDropTable(tableName.toString()); } @@ -233,16 +195,7 @@ public void checkCanDropTable(ConnectorTransactionHandle transaction, ConnectorI @Override public void checkCanRenameTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, SchemaTableName newTableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyRenameTable(tableName.toString(), newTableName.toString()); } @@ -259,19 +212,31 @@ public Set filterTables(ConnectorTransactionHandle transactionH return tableNames; } + @Override + public void checkCanShowColumnsMetadata(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + MetastoreContext metastoreContext = createMetastoreContext(identity, context); + + if (!hasAnyTablePermission(transactionHandle, identity, metastoreContext, tableName)) { + denyShowColumnsMetadata(tableName.toString()); + } + } + + @Override + public List filterColumns(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + MetastoreContext metastoreContext = createMetastoreContext(identity, context); + + if (!hasAnyTablePermission(transactionHandle, identity, metastoreContext, tableName)) { + return ImmutableList.of(); + } + return columns; + } + @Override public void checkCanAddColumn(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyAddColumn(tableName.toString()); } @@ -279,6 +244,15 @@ public void checkCanAddColumn(ConnectorTransactionHandle transaction, ConnectorI @Override public void checkCanDropColumn(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + MetastoreContext metastoreContext = createMetastoreContext(identity, context); + if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { + denyDropColumn(tableName.toString()); + } + } + + @Override + public void checkCanDropBranch(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { MetastoreContext metastoreContext = new MetastoreContext( identity, context.getQueryId().getId(), @@ -291,12 +265,12 @@ public void checkCanDropColumn(ConnectorTransactionHandle transaction, Connector context.getWarningCollector(), context.getRuntimeStats()); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { - denyDropColumn(tableName.toString()); + denyDropBranch(tableName.toString()); } } @Override - public void checkCanDropConstraint(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + public void checkCanDropTag(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { MetastoreContext metastoreContext = new MetastoreContext( identity, context.getQueryId().getId(), @@ -308,6 +282,15 @@ public void checkCanDropConstraint(ConnectorTransactionHandle transaction, Conne HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, context.getWarningCollector(), context.getRuntimeStats()); + if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { + denyDropTag(tableName.toString()); + } + } + + @Override + public void checkCanDropConstraint(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyDropConstraint(tableName.toString()); } @@ -316,16 +299,7 @@ public void checkCanDropConstraint(ConnectorTransactionHandle transaction, Conne @Override public void checkCanAddConstraint(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyAddConstraint(tableName.toString()); } @@ -334,16 +308,7 @@ public void checkCanAddConstraint(ConnectorTransactionHandle transaction, Connec @Override public void checkCanRenameColumn(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, tableName)) { denyRenameColumn(tableName.toString()); } @@ -352,34 +317,24 @@ public void checkCanRenameColumn(ConnectorTransactionHandle transaction, Connect @Override public void checkCanSelectFromColumns(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, Set columnOrSubfieldNames) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); // TODO: Implement column level access control + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, SELECT, false)) { denySelectTable(tableName.toString()); } } + @Override + public void checkCanCallProcedure(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName procedureName) + { + if (restrictProcedureCall) { + denyCallProcedure(procedureName.toString()); + } + } + @Override public void checkCanInsertIntoTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, INSERT, false)) { denyInsertTable(tableName.toString()); } @@ -388,16 +343,7 @@ public void checkCanInsertIntoTable(ConnectorTransactionHandle transaction, Conn @Override public void checkCanDeleteFromTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, DELETE, false)) { denyDeleteTable(tableName.toString()); } @@ -406,16 +352,7 @@ public void checkCanDeleteFromTable(ConnectorTransactionHandle transaction, Conn @Override public void checkCanTruncateTable(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, DELETE, false)) { denyTruncateTable(tableName.toString()); } @@ -424,16 +361,7 @@ public void checkCanTruncateTable(ConnectorTransactionHandle transaction, Connec @Override public void checkCanUpdateTableColumns(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, Set updatedColumns) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, UPDATE, false)) { denyUpdateTableColumns(tableName.toString(), updatedColumns); } @@ -442,16 +370,7 @@ public void checkCanUpdateTableColumns(ConnectorTransactionHandle transaction, C @Override public void checkCanCreateView(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName viewName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isDatabaseOwner(transaction, identity, metastoreContext, viewName.getSchemaName())) { denyCreateView(viewName.toString()); } @@ -460,16 +379,7 @@ public void checkCanCreateView(ConnectorTransactionHandle transaction, Connector @Override public void checkCanRenameView(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName viewName, SchemaTableName newViewName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, viewName)) { denyRenameView(viewName.toString(), newViewName.toString()); } @@ -478,16 +388,7 @@ public void checkCanRenameView(ConnectorTransactionHandle transaction, Connector @Override public void checkCanDropView(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName viewName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isTableOwner(transaction, identity, metastoreContext, viewName)) { denyDropView(viewName.toString()); } @@ -497,16 +398,7 @@ public void checkCanDropView(ConnectorTransactionHandle transaction, ConnectorId public void checkCanCreateViewWithSelectFromColumns(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, Set columnNames) { checkCanSelectFromColumns(transaction, identity, context, tableName, columnNames.stream().map(column -> new Subfield(column)).collect(toImmutableSet())); - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); // TODO implement column level access control + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!checkTablePermission(transaction, identity, metastoreContext, tableName, SELECT, true)) { denyCreateViewWithSelect(tableName.toString(), identity); } @@ -515,16 +407,7 @@ public void checkCanCreateViewWithSelectFromColumns(ConnectorTransactionHandle t @Override public void checkCanSetCatalogSessionProperty(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, String propertyName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isAdmin(transaction, identity, metastoreContext)) { denySetCatalogSessionProperty(connectorId, propertyName); } @@ -533,16 +416,7 @@ public void checkCanSetCatalogSessionProperty(ConnectorTransactionHandle transac @Override public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, Privilege privilege, SchemaTableName tableName, PrestoPrincipal grantee, boolean withGrantOption) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (isTableOwner(transaction, identity, metastoreContext, tableName)) { return; } @@ -555,16 +429,7 @@ public void checkCanGrantTablePrivilege(ConnectorTransactionHandle transaction, @Override public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transaction, ConnectorIdentity identity, AccessControlContext context, Privilege privilege, SchemaTableName tableName, PrestoPrincipal revokee, boolean grantOptionFor) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (isTableOwner(transaction, identity, metastoreContext, tableName)) { return; } @@ -581,16 +446,7 @@ public void checkCanCreateRole(ConnectorTransactionHandle transactionHandle, Con if (grantor.isPresent()) { throw new AccessDeniedException("Hive Connector does not support WITH ADMIN statement"); } - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isAdmin(transactionHandle, identity, metastoreContext)) { denyCreateRole(role); } @@ -599,16 +455,7 @@ public void checkCanCreateRole(ConnectorTransactionHandle transactionHandle, Con @Override public void checkCanDropRole(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, String role) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isAdmin(transactionHandle, identity, metastoreContext)) { denyDropRole(role); } @@ -621,16 +468,7 @@ public void checkCanGrantRoles(ConnectorTransactionHandle transactionHandle, Con if (grantor.isPresent()) { throw new AccessDeniedException("Hive Connector does not support GRANTED BY statement"); } - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!hasAdminOptionForRoles(transactionHandle, identity, metastoreContext, roles)) { denyGrantRoles(roles, grantees); } @@ -643,16 +481,7 @@ public void checkCanRevokeRoles(ConnectorTransactionHandle transactionHandle, Co if (grantor.isPresent()) { throw new AccessDeniedException("Hive Connector does not support GRANTED BY statement"); } - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!hasAdminOptionForRoles(transactionHandle, identity, metastoreContext, roles)) { denyRevokeRoles(roles, grantees); } @@ -663,16 +492,7 @@ public void checkCanSetRole(ConnectorTransactionHandle transaction, ConnectorIde { Optional metastoreOptional = getMetastore(transaction); metastoreOptional.ifPresent(metastore -> { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isRoleApplicable(metastore, identity, new PrestoPrincipal(USER, identity.getUser()), metastoreContext, role)) { denySetRole(role); } @@ -682,16 +502,7 @@ public void checkCanSetRole(ConnectorTransactionHandle transaction, ConnectorIde @Override public void checkCanShowRoles(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, String catalogName) { - MetastoreContext metastoreContext = new MetastoreContext( - identity, context.getQueryId().getId(), - context.getClientInfo(), - context.getClientTags(), - context.getSource(), - Optional.empty(), - false, - HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, - context.getWarningCollector(), - context.getRuntimeStats()); + MetastoreContext metastoreContext = createMetastoreContext(identity, context); if (!isAdmin(transactionHandle, identity, metastoreContext)) { denyShowRoles(catalogName); } @@ -719,6 +530,22 @@ public Map getColumnMasks(ConnectorTransactionHa return ImmutableMap.of(); } + private static MetastoreContext createMetastoreContext(ConnectorIdentity identity, AccessControlContext context) + { + // TODO: Refactor code to inject metastore headers using AccessControlContext instead of empty() + return new MetastoreContext( + identity, + context.getQueryId().getId(), + context.getClientInfo(), + context.getClientTags(), + context.getSource(), + Optional.empty(), + false, + HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, + context.getWarningCollector(), + context.getRuntimeStats()); + } + private boolean isAdmin(ConnectorTransactionHandle transaction, ConnectorIdentity identity, MetastoreContext metastoreContext) { return getMetastore(transaction) @@ -823,6 +650,26 @@ private boolean hasAdminOptionForRoles(ConnectorTransactionHandle transaction, C .orElse(false); } + private boolean hasAnyTablePermission(ConnectorTransactionHandle transaction, ConnectorIdentity identity, MetastoreContext metastoreContext, SchemaTableName tableName) + { + if (isAdmin(transaction, identity, metastoreContext)) { + return true; + } + + if (tableName.equals(ROLES)) { + return false; + } + + if (INFORMATION_SCHEMA_NAME.equals(tableName.getSchemaName())) { + return true; + } + + return getMetastore(transaction) + .map(metastore -> listEnabledTablePrivileges(metastore, tableName.getSchemaName(), tableName.getTableName(), identity, metastoreContext) + .anyMatch(privilegeInfo -> true)) + .orElse(false); + } + private Optional getMetastore(ConnectorTransactionHandle transaction) { TransactionalMetadata metadata = hiveTransactionManager.get(transaction); diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/SystemTableAwareAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/SystemTableAwareAccessControl.java index 74a433e5d6356..f22bb65c01c30 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/SystemTableAwareAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/SystemTableAwareAccessControl.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.Subfield; import com.facebook.presto.plugin.base.security.ForwardingConnectorAccessControl; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -25,12 +26,14 @@ import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import static com.facebook.presto.hive.HiveMetadata.getSourceTableNameFromSystemTable; import static com.facebook.presto.spi.security.AccessDeniedException.denySelectTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowColumnsMetadata; import static java.util.Objects.requireNonNull; public class SystemTableAwareAccessControl @@ -114,6 +117,34 @@ public Set filterTables(ConnectorTransactionHandle transactionH return delegate.filterTables(transactionHandle, identity, context, tableNames); } + @Override + public void checkCanShowColumnsMetadata(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + Optional sourceTableName = getSourceTableNameFromSystemTable(tableName); + if (sourceTableName.isPresent()) { + try { + checkCanShowColumnsMetadata(transactionHandle, identity, context, sourceTableName.get()); + return; + } + catch (AccessDeniedException e) { + denyShowColumnsMetadata(tableName.toString()); + } + } + + delegate.checkCanShowColumnsMetadata(transactionHandle, identity, context, tableName); + } + + @Override + public List filterColumns(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + Optional sourceTableName = getSourceTableNameFromSystemTable(tableName); + if (sourceTableName.isPresent()) { + return filterColumns(transactionHandle, identity, context, sourceTableName.get(), columns); + } + + return delegate.filterColumns(transactionHandle, identity, context, tableName, columns); + } + @Override public void checkCanAddColumn(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -149,6 +180,12 @@ public void checkCanSelectFromColumns(ConnectorTransactionHandle transactionHand delegate.checkCanSelectFromColumns(transactionHandle, identity, context, tableName, columnOrSubfieldNames); } + @Override + public void checkCanCallProcedure(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName procedureName) + { + delegate.checkCanCallProcedure(transactionHandle, identity, context, procedureName); + } + @Override public void checkCanInsertIntoTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -257,6 +294,18 @@ public void checkCanShowRoleGrants(ConnectorTransactionHandle transactionHandle, delegate.checkCanShowRoleGrants(transactionHandle, identity, context, catalogName); } + @Override + public void checkCanDropBranch(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + delegate.checkCanDropBranch(transactionHandle, identity, context, tableName); + } + + @Override + public void checkCanDropTag(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + delegate.checkCanDropTag(transactionHandle, identity, context, tableName); + } + @Override public void checkCanDropConstraint(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/ForRangerInfo.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/ForRangerInfo.java index 080910adc927f..0835fb4cbb13d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/ForRangerInfo.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/ForRangerInfo.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.hive.security.ranger; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControl.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControl.java index 29f5548553d15..bad44d4b92704 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControl.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControl.java @@ -17,6 +17,8 @@ import com.facebook.airlift.http.client.Request; import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.common.Subfield; +import com.facebook.presto.hive.security.SecurityConfig; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorAccessControl; @@ -50,6 +52,7 @@ import static com.facebook.presto.hive.security.ranger.RangerBasedAccessControlConfig.RANGER_REST_USER_GROUP_URL; import static com.facebook.presto.hive.security.ranger.RangerBasedAccessControlConfig.RANGER_REST_USER_ROLES_URL; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCallProcedure; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; @@ -63,6 +66,9 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyRenameColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyRenameTable; import static com.facebook.presto.spi.security.AccessDeniedException.denySelectColumns; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowColumnsMetadata; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowCreateTable; +import static com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES; import static com.google.common.base.Suppliers.memoizeWithExpiration; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; @@ -77,7 +83,8 @@ public class RangerBasedAccessControl implements ConnectorAccessControl { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper() + .configure(FAIL_ON_UNKNOWN_PROPERTIES, false); private static final JsonCodec USER_INFO_CODEC = jsonCodec(Users.class); private static final JsonCodec> ROLES_INFO_CODEC = listJsonCodec(String.class); @@ -86,14 +93,17 @@ public class RangerBasedAccessControl private final Supplier>> userGroupsMapping; private final Supplier servicePolicies; private final HttpClient httpClient; + private final boolean restrictProcedureCall; @Inject - public RangerBasedAccessControl(RangerBasedAccessControlConfig config, @ForRangerInfo HttpClient httpClient) + public RangerBasedAccessControl(RangerBasedAccessControlConfig config, SecurityConfig securityConfig, @ForRangerInfo HttpClient httpClient) { requireNonNull(config, "config is null"); requireNonNull(config.getRangerHttpEndPoint(), "Ranger service http end point is null"); requireNonNull(config.getRangerHiveServiceName(), "Ranger hive service name is null"); + requireNonNull(securityConfig, "securityConfig is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.restrictProcedureCall = securityConfig.isRestrictProcedureCall(); try { servicePolicies = memoizeWithExpiration( @@ -130,7 +140,7 @@ private ServicePolicies getHiveServicePolicies(RangerBasedAccessControlConfig co return OBJECT_MAPPER.readValue(httpClient.execute(request, createStringResponseHandler()).getBody(), ServicePolicies.class); } catch (IOException e) { - throw new PrestoException(HIVE_RANGER_SERVER_ERROR, format("Unable to fetch policies from %s hive service end point", config.getRangerHiveServiceName())); + throw new PrestoException(HIVE_RANGER_SERVER_ERROR, format("Unable to fetch policies from %s hive service end point", config.getRangerHiveServiceName()), e); } } @@ -277,6 +287,20 @@ public Set filterSchemas(ConnectorTransactionHandle transactionHandle, C return allowedSchemas; } + /** + * Check if identity is allowed to execute SHOW CREATE TABLE or SHOW CREATE VIEW. + * + * @throws AccessDeniedException if not allowed + */ + @Override + public void checkCanShowCreateTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + if (!checkAccess(identity, tableName, null, HiveAccessType.SELECT)) { + denyShowCreateTable(tableName.getTableName(), format("Access denied - User [ %s ] does not have [SELECT] " + + "privilege on [ %s/%s ] ", identity.getUser(), tableName.getSchemaName(), tableName.getTableName())); + } + } + /** * Check if identity is allowed to create the specified table in this catalog. * @@ -309,6 +333,33 @@ public Set filterTables(ConnectorTransactionHandle transactionH return allowedTables; } + /** + * Check if identity is allowed to show columns of tables by executing SHOW COLUMNS, DESCRIBE etc. + *

+ * NOTE: This method is only present to give users an error message when listing is not allowed. + * The {@link #filterColumns} method must filter all results for unauthorized users, + * since there are multiple ways to list columns. + * + * @throws AccessDeniedException if not allowed + */ + @Override + public void checkCanShowColumnsMetadata(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + if (!checkAccess(identity, tableName, null, HiveAccessType.SELECT)) { + denyShowColumnsMetadata(tableName.getTableName(), format("Access denied - User [ %s ] does not have [SELECT] " + + "privilege on [ %s/%s ] ", identity.getUser(), tableName.getSchemaName(), tableName.getTableName())); + } + } + + /** + * Filter the list of columns to those visible to the identity. + */ + @Override + public List filterColumns(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName, List columns) + { + return columns; + } + /** * Check if identity is allowed to add columns to the specified table in this catalog. * @@ -371,6 +422,14 @@ public void checkCanSelectFromColumns(ConnectorTransactionHandle transactionHand } } + @Override + public void checkCanCallProcedure(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName procedureName) + { + if (restrictProcedureCall) { + denyCallProcedure(procedureName.toString()); + } + } + /** * Check if identity is allowed to drop the specified table in this catalog. * diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControlConfig.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControlConfig.java index 36c83ff498061..bd9dd7dc16c6d 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControlConfig.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasedAccessControlConfig.java @@ -15,10 +15,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigSecuritySensitive; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasicAuthHttpRequestFilter.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasicAuthHttpRequestFilter.java index a25994fca0eea..0900067555013 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasicAuthHttpRequestFilter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/RangerBasicAuthHttpRequestFilter.java @@ -16,8 +16,7 @@ import com.facebook.airlift.http.client.BasicAuthRequestFilter; import com.facebook.airlift.http.client.HttpRequestFilter; import com.facebook.airlift.http.client.Request; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/Users.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/Users.java index 45f85d8f141c1..83c08815ee832 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/Users.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/Users.java @@ -16,8 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/VXUser.java b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/VXUser.java index c3a16c80cd4b6..9b235ec8d0e33 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/VXUser.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/security/ranger/VXUser.java @@ -17,8 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/QuickStatsProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/QuickStatsProvider.java index fff00dd53f654..02f0aa44d7c54 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/statistics/QuickStatsProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/statistics/QuickStatsProvider.java @@ -32,8 +32,10 @@ import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.Partition; import com.facebook.presto.hive.metastore.PartitionStatistics; +import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.google.common.base.Stopwatch; import com.google.common.cache.Cache; @@ -41,6 +43,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat; +import org.apache.hadoop.mapred.InputFormat; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; @@ -61,8 +65,10 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.hive.HiveErrorCode.HIVE_BAD_DATA; import static com.facebook.presto.hive.HivePartition.UNPARTITIONED_ID; import static com.facebook.presto.hive.HiveSessionProperties.getQuickStatsBackgroundBuildTimeout; import static com.facebook.presto.hive.HiveSessionProperties.getQuickStatsInlineBuildTimeout; @@ -70,6 +76,9 @@ import static com.facebook.presto.hive.HiveSessionProperties.isSkipEmptyFilesEnabled; import static com.facebook.presto.hive.HiveSessionProperties.isUseListDirectoryCache; import static com.facebook.presto.hive.HiveUtil.buildDirectoryContextProperties; +import static com.facebook.presto.hive.HiveUtil.getInputFormat; +import static com.facebook.presto.hive.HiveUtil.getTargetPathsHiveFileInfos; +import static com.facebook.presto.hive.HiveUtil.readSymlinkPaths; import static com.facebook.presto.hive.NestedDirectoryPolicy.IGNORED; import static com.facebook.presto.hive.NestedDirectoryPolicy.RECURSE; import static com.facebook.presto.hive.metastore.PartitionStatistics.empty; @@ -323,15 +332,18 @@ private PartitionStatistics buildQuickStats(String partitionKey, String partitio Table resolvedTable = metastore.getTable(metastoreContext, table.getSchemaName(), table.getTableName()).get(); Optional partition; Path path; + StorageFormat storageFormat; if (UNPARTITIONED_ID.getPartitionName().equals(partitionId)) { partition = Optional.empty(); path = new Path(resolvedTable.getStorage().getLocation()); + storageFormat = resolvedTable.getStorage().getStorageFormat(); } else { partition = metastore.getPartitionsByNames(metastoreContext, table.getSchemaName(), table.getTableName(), ImmutableList.of(new PartitionNameWithVersion(partitionId, Optional.empty()))).get(partitionId); checkState(partition.isPresent(), "getPartitionsByNames returned no partitions for partition with name [%s]", partitionId); path = new Path(partition.get().getStorage().getLocation()); + storageFormat = partition.get().getStorage().getStorageFormat(); } HdfsContext hdfsContext = new HdfsContext(session, table.getSchemaName(), table.getTableName(), partitionId, false); @@ -347,6 +359,37 @@ private PartitionStatistics buildQuickStats(String partitionKey, String partitio Iterator fileList = directoryLister.list(fs, resolvedTable, path, partition, nameNodeStats, hiveDirectoryContext); + InputFormat inputFormat = getInputFormat(hdfsEnvironment.getConfiguration(hdfsContext, path), storageFormat.getInputFormat(), storageFormat.getSerDe(), false); + if (inputFormat instanceof SymlinkTextInputFormat) { + // For symlinks, follow the paths in the manifest file and create a new iterator of the target files + try { + List targetPaths = readSymlinkPaths(fs, fileList); + + Map> parentToTargets = targetPaths.stream().collect(Collectors.groupingBy(Path::getParent)); + + ImmutableList.Builder targetFileInfoList = ImmutableList.builder(); + + for (Map.Entry> entry : parentToTargets.entrySet()) { + targetFileInfoList.addAll(getTargetPathsHiveFileInfos( + path, + partition, + entry.getKey(), + entry.getValue(), + hiveDirectoryContext, + fs, + directoryLister, + resolvedTable, + nameNodeStats, + session)); + } + + fileList = targetFileInfoList.build().iterator(); + } + catch (IOException e) { + throw new PrestoException(HIVE_BAD_DATA, "Error parsing symlinks", e); + } + } + PartitionQuickStats partitionQuickStats = PartitionQuickStats.EMPTY; Stopwatch buildStopwatch = Stopwatch.createStarted(); // Build quick stats one by one from statsBuilderStrategies. Do this until we get a non-empty PartitionQuickStats diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/AsyncQueue.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/AsyncQueue.java index c02417b01c003..0d114d590cf14 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/AsyncQueue.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/AsyncQueue.java @@ -17,9 +17,8 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayDeque; import java.util.ArrayList; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java index cdd38ffdda171..5bacde49d52d8 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/ConfigurationUtils.java @@ -120,7 +120,7 @@ private static void setCompressionProperties(Configuration config, HiveCompressi config.unset(FileOutputFormat.COMPRESS_CODEC); } // For Parquet - compression.getParquetCompressionCodec().ifPresent(codec -> config.set(ParquetOutputFormat.COMPRESSION, codec.name())); + config.set(ParquetOutputFormat.COMPRESSION, compression.getParquetCompressionCodec().name()); // For SequenceFile config.set(FileOutputFormat.COMPRESS_TYPE, BLOCK.toString()); // For PageFile diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/InternalHiveSplitFactory.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/InternalHiveSplitFactory.java index 67a04948f7508..24424a95362d3 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/InternalHiveSplitFactory.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/InternalHiveSplitFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.hive.BlockLocation; import com.facebook.presto.hive.EncryptionInformation; @@ -24,7 +25,6 @@ import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.mapred.FileSplit; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/SizeBasedSplitWeightProvider.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/SizeBasedSplitWeightProvider.java index 97f135689b762..84cc35e37a565 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/SizeBasedSplitWeightProvider.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/SizeBasedSplitWeightProvider.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.HiveSplitWeightProvider; import com.facebook.presto.spi.SplitWeight; -import io.airlift.units.DataSize; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/SortBuffer.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/SortBuffer.java index 8095ffe9e82df..2cf344f57b219 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/SortBuffer.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/SortBuffer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -21,7 +22,6 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.PageSorter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jol.info.ClassLayout; import java.util.ArrayList; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileReader.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileReader.java index f1a1c8f2cc5ea..3039318e43feb 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileReader.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileReader.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.block.Block; @@ -28,7 +29,6 @@ import com.facebook.presto.orc.cache.StorageOrcFileTailSource; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.AbstractIterator; -import io.airlift.units.DataSize; import java.io.IOException; import java.io.InterruptedIOException; @@ -36,11 +36,11 @@ import java.util.List; import java.util.Map; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_DATA_ERROR; import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION; import static com.facebook.presto.orc.OrcEncoding.ORC; import static com.facebook.presto.orc.OrcReader.INITIAL_BATCH_SIZE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static org.joda.time.DateTimeZone.UTC; diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileWriter.java b/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileWriter.java index ae9a7071b8e44..ed57ac9dd2a8c 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileWriter.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/util/TempFileWriter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.NotSupportedException; import com.facebook.presto.common.Page; import com.facebook.presto.common.io.DataSink; @@ -23,7 +24,6 @@ import com.facebook.presto.orc.OrcWriterOptions; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import java.io.Closeable; import java.io.IOException; @@ -32,14 +32,14 @@ import java.util.Optional; import java.util.stream.IntStream; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.orc.DwrfEncryptionProvider.NO_ENCRYPTION; import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS; import static com.facebook.presto.orc.OrcEncoding.ORC; import static com.facebook.presto.orc.metadata.CompressionKind.LZ4; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static org.joda.time.DateTimeZone.UTC; public class TempFileWriter diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractHiveSslTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractHiveSslTest.java index 84f58ad904d3a..84db4a79b163f 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractHiveSslTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractHiveSslTest.java @@ -26,6 +26,7 @@ import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.hive.containers.HiveHadoopContainer.HIVE3_IMAGE; +import static com.facebook.presto.tests.SslKeystoreManager.initializeKeystoreAndTruststore; import static com.facebook.presto.tests.sql.TestTable.randomTableSuffix; import static java.lang.String.format; @@ -40,6 +41,7 @@ public abstract class AbstractHiveSslTest AbstractHiveSslTest(Map sslConfig) { + initializeKeystoreAndTruststore(); this.sslConfig = sslConfig; } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestCteExecution.java similarity index 92% rename from presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java rename to presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestCteExecution.java index 1e4f13b963905..587150e09eca5 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestCteExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestCteExecution.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Optional; +import java.util.UUID; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -46,20 +47,19 @@ import static java.lang.String.format; import static org.testng.Assert.assertTrue; -@Test(singleThreaded = true) -public class TestCteExecution +public abstract class AbstractTestCteExecution extends AbstractTestQueryFramework { private static final Pattern CTE_INFO_MATCHER = Pattern.compile("CTEInfo.*"); - @Override - protected QueryRunner createQueryRunner() + protected QueryRunner createQueryRunner(boolean singleNode) throws Exception { return HiveQueryRunner.createQueryRunner( ImmutableList.of(ORDERS, CUSTOMER, LINE_ITEM, PART_SUPPLIER, NATION, REGION, PART, SUPPLIER), ImmutableMap.of( - "query.cte-partitioning-provider-catalog", "hive"), + "query.cte-partitioning-provider-catalog", "hive", + "single-node-execution-enabled", "" + singleNode), "sql-standard", ImmutableMap.of("hive.pushdown-filter-enabled", "true", "hive.enable-parquet-dereference-pushdown", "true", @@ -465,6 +465,7 @@ public void testPersistentCteWithVarbinary() QueryRunner queryRunner = getQueryRunner(); verifyResults(queryRunner, testQuery, ImmutableList.of(generateMaterializedCTEInformation("dataset", 1, false, true))); } + @Test public void testComplexRefinedCtesOutsideScope() { @@ -802,41 +803,45 @@ public void testComplexQuery3() public void testSimplePersistentCteForCtasQueries() { QueryRunner queryRunner = getQueryRunner(); + String persistentTableName = generateRandomTableName("persistent_table"); + String nonPersistentTableName = generateRandomTableName("non_persistent_table"); try { // Create tables with Ctas Session materializedSession = getMaterializedSession(); - String testQuery = "CREATE TABLE persistent_table as (WITH temp as (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp t1 )"; + String testQuery = format("CREATE TABLE %s as (WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1 )", persistentTableName); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true))); queryRunner.execute(materializedSession, testQuery); queryRunner.execute(getSession(), - "CREATE TABLE non_persistent_table as (WITH temp as (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp t1) "); + format("CREATE TABLE %s as (WITH temp as (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp t1) ", nonPersistentTableName)); // Compare contents with a select compareResults(queryRunner.execute(getSession(), - "SELECT * FROM persistent_table"), + "SELECT * FROM " + persistentTableName), queryRunner.execute(getSession(), - "SELECT * FROM non_persistent_table")); + "SELECT * FROM " + nonPersistentTableName)); } finally { // drop tables queryRunner.execute(getSession(), - "DROP TABLE persistent_table"); + "DROP TABLE " + persistentTableName); queryRunner.execute(getSession(), - "DROP TABLE non_persistent_table"); + "DROP TABLE " + nonPersistentTableName); } } @Test public void testComplexPersistentCteForCtasQueries() { + String persistentTableName = generateRandomTableName("persistent_table"); + String nonPersistentTableName = generateRandomTableName("non_persistent_table"); QueryRunner queryRunner = getQueryRunner(); try { // Create tables with Ctas Session materializedSession = getMaterializedSession(); - String testQuery = "CREATE TABLE persistent_table as ( " + + String testQuery = format("CREATE TABLE %s as ( " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -856,8 +861,8 @@ public void testComplexPersistentCteForCtasQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name)"; + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name)", persistentTableName); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("supplier_region", 1, false, true), generateMaterializedCTEInformation("supplier_parts", 1, false, true), @@ -866,7 +871,7 @@ public void testComplexPersistentCteForCtasQueries() queryRunner.execute(materializedSession, testQuery); queryRunner.execute(getSession(), - "CREATE TABLE non_persistent_table as ( " + + format("CREATE TABLE %s as ( " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -886,67 +891,73 @@ public void testComplexPersistentCteForCtasQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name)"); + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name)", nonPersistentTableName)); // Compare contents with a select compareResults(queryRunner.execute(getSession(), - "SELECT * FROM persistent_table"), + "SELECT * FROM " + persistentTableName), queryRunner.execute(getSession(), - "SELECT * FROM non_persistent_table")); + "SELECT * FROM " + nonPersistentTableName)); } finally { // drop tables queryRunner.execute(getSession(), - "DROP TABLE persistent_table"); + "DROP TABLE " + persistentTableName); queryRunner.execute(getSession(), - "DROP TABLE non_persistent_table"); + "DROP TABLE " + nonPersistentTableName); } } @Test public void testSimplePersistentCteForInsertQueries() { + String persistentTableName = generateRandomTableName("persistent_table"); + String nonPersistentTableName = generateRandomTableName("non_persistent_table"); + QueryRunner queryRunner = getQueryRunner(); try { // Create tables without data queryRunner.execute(getSession(), - "CREATE TABLE persistent_table (orderkey BIGINT)"); + format("CREATE TABLE %s (orderkey BIGINT)", persistentTableName)); queryRunner.execute(getSession(), - "CREATE TABLE non_persistent_table (orderkey BIGINT)"); + format("CREATE TABLE %s (orderkey BIGINT)", nonPersistentTableName)); // Insert data into tables using CTEs Session materializedSession = getMaterializedSession(); - String testQuery = "INSERT INTO persistent_table " + + String testQuery = format("INSERT INTO %s " + "WITH temp AS (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp"; + "SELECT * FROM temp", persistentTableName); queryRunner.execute(materializedSession, testQuery); queryRunner.execute(getSession(), - "INSERT INTO non_persistent_table " + + format("INSERT INTO %s " + "WITH temp AS (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp"); + "SELECT * FROM temp", nonPersistentTableName)); // Compare contents with a select compareResults(queryRunner.execute(getSession(), - "SELECT * FROM persistent_table"), + "SELECT * FROM " + persistentTableName), queryRunner.execute(getSession(), - "SELECT * FROM non_persistent_table")); + "SELECT * FROM " + nonPersistentTableName)); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true))); } finally { // drop tables queryRunner.execute(getSession(), - "DROP TABLE persistent_table"); + "DROP TABLE " + persistentTableName); queryRunner.execute(getSession(), - "DROP TABLE non_persistent_table"); + "DROP TABLE " + nonPersistentTableName); } } @Test public void testComplexPersistentCteForInsertQueries() { + String persistentTableName = generateRandomTableName("persistent_table"); + String nonPersistentTableName = generateRandomTableName("non_persistent_table"); + QueryRunner queryRunner = getQueryRunner(); // Create tables without data // Create tables @@ -957,13 +968,13 @@ public void testComplexPersistentCteForInsertQueries() "nation_comment VARCHAR, region_comment VARCHAR)"; queryRunner.execute(getSession(), - "CREATE TABLE persistent_table" + createTableBase); + "CREATE TABLE " + persistentTableName + " " + createTableBase); queryRunner.execute(getSession(), - "CREATE TABLE non_persistent_table" + createTableBase); + "CREATE TABLE " + nonPersistentTableName + " " + createTableBase); Session materializedSession = getMaterializedSession(); - String testQuery = "INSERT INTO persistent_table " + + String testQuery = format("INSERT INTO %s " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -983,12 +994,12 @@ public void testComplexPersistentCteForInsertQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name"; + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name", persistentTableName); queryRunner.execute(materializedSession, testQuery); queryRunner.execute(getSession(), - "INSERT INTO non_persistent_table " + + format("INSERT INTO %s " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -1008,14 +1019,14 @@ public void testComplexPersistentCteForInsertQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name"); + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name", nonPersistentTableName)); // Compare contents with a select compareResults(queryRunner.execute(getSession(), - "SELECT * FROM persistent_table"), + "SELECT * FROM " + persistentTableName), queryRunner.execute(getSession(), - "SELECT * FROM non_persistent_table")); + "SELECT * FROM " + nonPersistentTableName)); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("supplier_region", 1, false, true), generateMaterializedCTEInformation("supplier_parts", 1, false, true), @@ -1025,48 +1036,54 @@ public void testComplexPersistentCteForInsertQueries() finally { // drop tables queryRunner.execute(getSession(), - "DROP TABLE persistent_table"); + "DROP TABLE " + persistentTableName); queryRunner.execute(getSession(), - "DROP TABLE non_persistent_table"); + "DROP TABLE " + nonPersistentTableName); } } @Test public void testSimplePersistentCteForViewQueries() { + String persistentViewName = generateRandomTableName("persistent_view"); + String nonPersistentViewName = generateRandomTableName("non_persistent_view"); + QueryRunner queryRunner = getQueryRunner(); try { // Create views Session materializedSession = getMaterializedSession(); queryRunner.execute(materializedSession, - "CREATE VIEW persistent_view AS WITH temp AS (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp"); + format("CREATE VIEW %s AS WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp", persistentViewName)); queryRunner.execute(getSession(), - "CREATE VIEW non_persistent_view AS WITH temp AS (SELECT orderkey FROM ORDERS) " + - "SELECT * FROM temp"); + format("CREATE VIEW %s AS WITH temp AS (SELECT orderkey FROM ORDERS) " + + "SELECT * FROM temp", nonPersistentViewName)); // Compare contents of views with a select - String testQuery = "SELECT * FROM persistent_view"; + String testQuery = "SELECT * FROM " + persistentViewName; compareResults(queryRunner.execute(getMaterializedSession(), testQuery), - queryRunner.execute(getSession(), "SELECT * FROM non_persistent_view")); + queryRunner.execute(getSession(), "SELECT * FROM " + nonPersistentViewName)); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("temp", 1, false, true))); } finally { // Drop views - queryRunner.execute(getSession(), "DROP VIEW persistent_view"); - queryRunner.execute(getSession(), "DROP VIEW non_persistent_view"); + queryRunner.execute(getSession(), "DROP VIEW " + persistentViewName); + queryRunner.execute(getSession(), "DROP VIEW " + nonPersistentViewName); } } @Test public void testComplexPersistentCteForViewQueries() { + String persistentViewName = generateRandomTableName("persistent_view"); + String nonPersistentViewName = generateRandomTableName("non_persistent_view"); + QueryRunner queryRunner = getQueryRunner(); try { // Create Views Session materializedSession = getMaterializedSession(); queryRunner.execute(materializedSession, - "CREATE View persistent_view as " + + format("CREATE View %s as " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -1086,10 +1103,10 @@ public void testComplexPersistentCteForViewQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name"); + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name", persistentViewName)); queryRunner.execute(getSession(), - "CREATE View non_persistent_view as " + + format("CREATE View %s as " + "WITH supplier_region AS (" + " SELECT s.suppkey, s.name AS supplier_name, n.name AS nation_name, r.name AS region_name " + " FROM SUPPLIER s " + @@ -1109,15 +1126,15 @@ public void testComplexPersistentCteForViewQueries() " JOIN NATION n ON pi.nation_name = n.name " + " JOIN REGION r ON pi.region_name = r.name) " + "SELECT * FROM full_supplier_part_info " + - "WHERE part_type LIKE '%BRASS' " + - "ORDER BY region_name, supplier_name"); + "WHERE part_type LIKE '%%BRASS' " + + "ORDER BY region_name, supplier_name", nonPersistentViewName)); // Compare contents with a select - String testQuery = "SELECT * FROM persistent_view"; + String testQuery = "SELECT * FROM " + persistentViewName; compareResults(queryRunner.execute(getMaterializedSession(), testQuery), queryRunner.execute(getSession(), - "SELECT * FROM non_persistent_view")); + "SELECT * FROM " + nonPersistentViewName)); verifyCTEExplainPlan(materializedSession, testQuery, ImmutableList.of(generateMaterializedCTEInformation("supplier_region", 1, false, true), generateMaterializedCTEInformation("supplier_parts", 1, false, true), @@ -1127,9 +1144,9 @@ public void testComplexPersistentCteForViewQueries() finally { // drop views queryRunner.execute(getSession(), - "DROP View persistent_view"); + "DROP View " + persistentViewName); queryRunner.execute(getSession(), - "DROP View non_persistent_view"); + "DROP View " + nonPersistentViewName); } } @@ -1269,6 +1286,7 @@ protected Session getSession() .setSystemProperty(CTE_MATERIALIZATION_STRATEGY, "NONE") .build(); } + protected Session getMaterializedSession() { return Session.builder(super.getSession()) @@ -1283,4 +1301,9 @@ private CTEInformation generateMaterializedCTEInformation(String name, int frequ { return new CTEInformation(name, name, frequency, isView, isMaterialized); } + + private String generateRandomTableName(String prefix) + { + return prefix + "_" + UUID.randomUUID().toString().replace("-", ""); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java index 9035d5b5de72d..9d8dc8d05985d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveClient.java @@ -15,6 +15,8 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.common.Page; @@ -27,6 +29,7 @@ import com.facebook.presto.common.predicate.Range; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.predicate.ValueSet; +import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.NamedTypeSignature; @@ -41,7 +44,6 @@ import com.facebook.presto.hive.LocationService.WriteInfo; import com.facebook.presto.hive.authentication.NoHdfsAuthentication; import com.facebook.presto.hive.datasink.OutputStreamDataSinkFactory; -import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HiveColumnStatistics; @@ -49,6 +51,7 @@ import com.facebook.presto.hive.metastore.HivePrivilegeInfo; import com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.Partition; import com.facebook.presto.hive.metastore.PartitionStatistics; @@ -139,8 +142,6 @@ import com.google.common.net.HostAndPort; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -188,6 +189,8 @@ import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual; import static com.facebook.airlift.testing.Assertions.assertInstanceOf; import static com.facebook.airlift.testing.Assertions.assertLessThanOrEqual; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.common.predicate.TupleDomain.withColumnDomains; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -281,6 +284,7 @@ import static com.facebook.presto.hive.HiveType.toHiveType; import static com.facebook.presto.hive.HiveUtil.columnExtraInfo; import static com.facebook.presto.hive.LocationHandle.WriteMode.STAGE_AND_MOVE_TO_TARGET_DIRECTORY; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ALL; import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createBinaryColumnStatistics; import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createBooleanColumnStatistics; import static com.facebook.presto.hive.metastore.HiveColumnStatistics.createDateColumnStatistics; @@ -324,8 +328,6 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator; import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.lang.Float.floatToRawIntBits; import static java.lang.Math.toIntExact; import static java.lang.String.format; @@ -361,7 +363,7 @@ public abstract class AbstractTestHiveClient protected static final String TEST_SERVER_VERSION = "test_version"; protected static final Executor EXECUTOR = Executors.newFixedThreadPool(5); - protected static final PageSinkContext TEST_HIVE_PAGE_SINK_CONTEXT = PageSinkContext.builder().setCommitRequired(false).setConnectorMetadataUpdater(new HiveMetadataUpdater(EXECUTOR)).build(); + protected static final PageSinkContext TEST_HIVE_PAGE_SINK_CONTEXT = PageSinkContext.builder().setCommitRequired(false).build(); private static final Type ARRAY_TYPE = arrayType(createUnboundedVarcharType()); private static final Type MAP_TYPE = mapType(createUnboundedVarcharType(), BIGINT); @@ -385,6 +387,12 @@ public abstract class AbstractTestHiveClient .add(ColumnMetadata.builder().setName("t_row").setType(ROW_TYPE).build()) .build(); + private static final List CREATE_TABLE_COLUMNS_FOR_DROP = ImmutableList.builder() + .add(ColumnMetadata.builder().setName("id").setType(BIGINT).build()) + .add(ColumnMetadata.builder().setName("t_string").setType(createUnboundedVarcharType()).build()) + .add(ColumnMetadata.builder().setName("t_double").setType(DOUBLE).build()) + .build(); + private static final MaterializedResult CREATE_TABLE_DATA = MaterializedResult.resultBuilder(SESSION, BIGINT, createUnboundedVarcharType(), TINYINT, SMALLINT, INTEGER, BIGINT, REAL, DOUBLE, BOOLEAN, ARRAY_TYPE, MAP_TYPE, ROW_TYPE) .row(1L, "hello", (byte) 45, (short) 345, 234, 123L, -754.1985f, 43.5, true, ImmutableList.of("apple", "banana"), ImmutableMap.of("one", 1L, "two", 2L), ImmutableList.of("true", 1L, true)) @@ -978,6 +986,12 @@ protected final void setup(String host, int port, String databaseName, String ti HiveClientConfig hiveClientConfig = getHiveClientConfig(); CacheConfig cacheConfig = getCacheConfig(); MetastoreClientConfig metastoreClientConfig = getMetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(Duration.valueOf("1m")); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(Duration.valueOf("15s")); + metastoreClientConfig.setMetastoreCacheMaximumSize(10000); + metastoreClientConfig.setEnabledCaches(ALL.name()); + ThriftHiveMetastoreConfig thriftHiveMetastoreConfig = getThriftHiveMetastoreConfig(); hiveClientConfig.setTimeZone(timeZone); String proxy = System.getProperty("hive.metastore.thrift.client.socks-proxy"); @@ -991,14 +1005,12 @@ protected final void setup(String host, int port, String databaseName, String ti new BridgingHiveMetastore(new ThriftHiveMetastore(hiveCluster, metastoreClientConfig, hdfsEnvironment), new HivePartitionMutator()), executor, false, - Duration.valueOf("1m"), - Duration.valueOf("15s"), 10000, false, - MetastoreCacheScope.ALL, 0.0, metastoreClientConfig.getPartitionCacheColumnCountLimit(), - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); setup(databaseName, hiveClientConfig, cacheConfig, metastoreClientConfig, metastore); } @@ -1044,7 +1056,6 @@ protected final void setup(String databaseName, HiveClientConfig hiveClientConfi new HivePartitionObjectBuilder(), new HiveEncryptionInformationProvider(ImmutableList.of()), new HivePartitionStats(), - new HiveFileRenamer(), DEFAULT_COLUMN_CONVERTER_PROVIDER, new QuickStatsProvider(metastoreClient, HDFS_ENVIRONMENT, DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(false)); @@ -1245,6 +1256,12 @@ public RuntimeStats getRuntimeStats() return session.getRuntimeStats(); } + @Override + public Optional getQueryType() + { + return session.getQueryType(); + } + @Override public ConnectorSession forConnectorId(ConnectorId connectorId) { @@ -1646,7 +1663,7 @@ protected void assertExpectedTableLayoutHandle(ConnectorTableLayoutHandle actual assertInstanceOf(expectedTableLayoutHandle, HiveTableLayoutHandle.class); HiveTableLayoutHandle actual = (HiveTableLayoutHandle) actualTableLayoutHandle; HiveTableLayoutHandle expected = (HiveTableLayoutHandle) expectedTableLayoutHandle; - assertExpectedPartitions(actual.getPartitions().get(), expected.getPartitions().get()); + assertExpectedPartitions(actual.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get(), expected.getPartitions().get()); } protected void assertExpectedPartitions(List actualPartitions, Iterable expectedPartitions) @@ -2329,7 +2346,7 @@ private void doTestBucketedTableEvolution(HiveStorageFormat storageFormat, Schem Optional.empty()).getLayout().getHandle(); } else { - layoutHandle = getOnlyElement(metadata.getTableLayouts(session, tableHandle, new Constraint<>(TupleDomain.fromFixedValues(ImmutableMap.of(bucketColumnHandle(), singleBucket))), Optional.empty())).getTableLayout().getHandle(); + layoutHandle = metadata.getTableLayoutForConstraint(session, tableHandle, new Constraint<>(TupleDomain.fromFixedValues(ImmutableMap.of(bucketColumnHandle(), singleBucket))), Optional.empty()).getTableLayout().getHandle(); } result = readTable( @@ -2423,9 +2440,9 @@ private static void assertBucketTableEvolutionResult(MaterializedResult result, private void assertTableIsBucketed(Transaction transaction, ConnectorTableHandle tableHandle) { - // the bucketed test tables should have exactly 32 splits + // the bucketed test tables should have ~32 splits List splits = getAllSplits(transaction, tableHandle, TupleDomain.all()); - assertEquals(splits.size(), 32); + assertThat(splits.size()).as("splits.size()").isBetween(31, 32); // verify all paths are unique Set paths = new HashSet<>(); @@ -2687,8 +2704,8 @@ protected ConnectorTableLayout getTableLayout(ConnectorSession session, Connecto Optional.empty()).getLayout(); } - List tableLayoutResults = metadata.getTableLayouts(session, tableHandle, constraint, Optional.empty()); - return getOnlyElement(tableLayoutResults).getTableLayout(); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(session, tableHandle, constraint, Optional.empty()); + return tableLayoutResult.getTableLayout(); } @Test @@ -2748,7 +2765,7 @@ private void checkSupportedStorageFormat(HiveStorageFormat storageFormat) } } - @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Error opening Hive split .*SequenceFile.*EOFException") + @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Error opening Hive split .*SequenceFile") public void testEmptySequenceFile() throws Exception { @@ -2894,7 +2911,7 @@ public void testTableCreation() SchemaTableName temporaryCreateTableForPageSinkCommit = temporaryTable("create_table_page_sink_commit"); try { doCreateTable(temporaryCreateTable, storageFormat, TEST_HIVE_PAGE_SINK_CONTEXT); - doCreateTable(temporaryCreateTableForPageSinkCommit, storageFormat, PageSinkContext.builder().setCommitRequired(true).setConnectorMetadataUpdater(new HiveMetadataUpdater(EXECUTOR)).build()); + doCreateTable(temporaryCreateTableForPageSinkCommit, storageFormat, PageSinkContext.builder().setCommitRequired(true).build()); } finally { dropTable(temporaryCreateTable); @@ -3227,7 +3244,7 @@ public void testInsert() SchemaTableName temporaryInsertTableForPageSinkCommit = temporaryTable("insert_table_page_sink_commit"); try { doInsert(storageFormat, temporaryInsertTable, TEST_HIVE_PAGE_SINK_CONTEXT); - doInsert(storageFormat, temporaryInsertTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).setConnectorMetadataUpdater(new HiveMetadataUpdater(EXECUTOR)).build()); + doInsert(storageFormat, temporaryInsertTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).build()); } finally { dropTable(temporaryInsertTable); @@ -3245,7 +3262,7 @@ public void testInsertIntoNewPartition() SchemaTableName temporaryInsertIntoNewPartitionTableForPageSinkCommit = temporaryTable("insert_new_partitioned_page_sink_commit"); try { doInsertIntoNewPartition(storageFormat, temporaryInsertIntoNewPartitionTable, TEST_HIVE_PAGE_SINK_CONTEXT); - doInsertIntoNewPartition(storageFormat, temporaryInsertIntoNewPartitionTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).setConnectorMetadataUpdater(new HiveMetadataUpdater(EXECUTOR)).build()); + doInsertIntoNewPartition(storageFormat, temporaryInsertIntoNewPartitionTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).build()); } finally { dropTable(temporaryInsertIntoNewPartitionTable); @@ -3263,7 +3280,7 @@ public void testInsertIntoExistingPartition() SchemaTableName temporaryInsertIntoExistingPartitionTableForPageSinkCommit = temporaryTable("insert_existing_partitioned_page_sink_commit"); try { doInsertIntoExistingPartition(storageFormat, temporaryInsertIntoExistingPartitionTable, TEST_HIVE_PAGE_SINK_CONTEXT); - doInsertIntoExistingPartition(storageFormat, temporaryInsertIntoExistingPartitionTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).setConnectorMetadataUpdater(new HiveMetadataUpdater(EXECUTOR)).build()); + doInsertIntoExistingPartition(storageFormat, temporaryInsertIntoExistingPartitionTableForPageSinkCommit, PageSinkContext.builder().setCommitRequired(true).build()); } finally { dropTable(temporaryInsertIntoExistingPartitionTable); @@ -3968,14 +3985,14 @@ public void testDropColumn() { SchemaTableName tableName = temporaryTable("test_drop_column"); try { - doCreateEmptyTable(tableName, ORC, CREATE_TABLE_COLUMNS); + doCreateEmptyTable(tableName, ORC, CREATE_TABLE_COLUMNS_FOR_DROP); ExtendedHiveMetastore metastoreClient = getMetastoreClient(); - metastoreClient.dropColumn(METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName(), CREATE_TABLE_COLUMNS.get(0).getName()); + metastoreClient.dropColumn(METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName(), CREATE_TABLE_COLUMNS_FOR_DROP.get(0).getName()); Optional

table = metastoreClient.getTable(METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName()); assertTrue(table.isPresent()); List columns = table.get().getDataColumns(); - assertEquals(columns.get(0).getName(), CREATE_TABLE_COLUMNS.get(1).getName()); - assertFalse(columns.stream().map(Column::getName).anyMatch(colName -> colName.equals(CREATE_TABLE_COLUMNS.get(0).getName()))); + assertEquals(columns.get(0).getName(), CREATE_TABLE_COLUMNS_FOR_DROP.get(1).getName()); + assertFalse(columns.stream().map(Column::getName).anyMatch(colName -> colName.equals(CREATE_TABLE_COLUMNS_FOR_DROP.get(0).getName()))); } finally { dropTable(tableName); @@ -5352,6 +5369,7 @@ protected static List getAllSplits(ConnectorSplitSource splitSou protected List getAllPartitions(ConnectorTableLayoutHandle layoutHandle) { return ((HiveTableLayoutHandle) layoutHandle).getPartitions() + .map(PartitionSet::getFullyLoadedPartitions) .orElseThrow(() -> new AssertionError("layout has no partitions")); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileSystem.java b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileSystem.java index e27b042b04a66..f2636f98ace9a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileSystem.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/AbstractTestHiveFileSystem.java @@ -26,6 +26,7 @@ import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HivePartitionMutator; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.MetastoreOperationResult; import com.facebook.presto.hive.metastore.PrincipalPrivileges; @@ -225,7 +226,6 @@ protected void setup(String host, int port, String databaseName, BiFunction tableLayoutResults = metadata.getTableLayouts(session, hiveTableHandle, Constraint.alwaysTrue(), Optional.empty()); - HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) getOnlyElement(tableLayoutResults).getTableLayout().getHandle(); - assertEquals(layoutHandle.getPartitions().get().size(), 1); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(session, hiveTableHandle, Constraint.alwaysTrue(), Optional.empty()); + HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) tableLayoutResult.getTableLayout().getHandle(); + assertEquals(layoutHandle.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get().size(), 1); ConnectorSplitSource splitSource = splitManager.getSplits(transaction.getTransactionHandle(), session, layoutHandle, SPLIT_SCHEDULING_CONTEXT); ConnectorSplit split = getOnlyElement(getAllSplits(splitSource)); @@ -509,7 +509,7 @@ public static class TestingHiveMetastore public TestingHiveMetastore(ExtendedHiveMetastore delegate, ExecutorService executor, MetastoreClientConfig metastoreClientConfig, Path basePath, HdfsEnvironment hdfsEnvironment) { - super(delegate, executor, NOOP_METASTORE_CACHE_STATS, metastoreClientConfig); + super(delegate, executor, NOOP_METASTORE_CACHE_STATS, metastoreClientConfig, new MetastoreCacheSpecProvider(metastoreClientConfig)); this.basePath = basePath; this.hdfsEnvironment = hdfsEnvironment; } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveFileSystemTestUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveFileSystemTestUtils.java index 0d509285ed16a..41f5000b446be 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveFileSystemTestUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveFileSystemTestUtils.java @@ -50,7 +50,6 @@ import static com.facebook.presto.testing.MaterializedResult.materializeSourceDataStream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; public class HiveFileSystemTestUtils { @@ -74,8 +73,8 @@ public static MaterializedResult readTable(SchemaTableName tableName, ConnectorTableHandle table = getTableHandle(metadata, tableName, session); List columnHandles = ImmutableList.copyOf(metadata.getColumnHandles(session, table).values()); - List tableLayoutResults = metadata.getTableLayouts(session, table, Constraint.alwaysTrue(), Optional.empty()); - HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) getOnlyElement(tableLayoutResults).getTableLayout().getHandle(); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(session, table, Constraint.alwaysTrue(), Optional.empty()); + HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) tableLayoutResult.getTableLayout().getHandle(); TableHandle tableHandle = new TableHandle(new ConnectorId(tableName.getSchemaName()), table, transaction.getTransactionHandle(), Optional.of(layoutHandle)); metadata.beginQuery(session); @@ -134,8 +133,8 @@ public static MaterializedResult filterTable(SchemaTableName tableName, session = newSession(config); ConnectorTableHandle table = getTableHandle(metadata, tableName, session); - List tableLayoutResults = metadata.getTableLayouts(session, table, Constraint.alwaysTrue(), Optional.empty()); - HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) getOnlyElement(tableLayoutResults).getTableLayout().getHandle(); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(session, table, Constraint.alwaysTrue(), Optional.empty()); + HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) tableLayoutResult.getTableLayout().getHandle(); TableHandle tableHandle = new TableHandle(new ConnectorId(tableName.getSchemaName()), table, transaction.getTransactionHandle(), Optional.of(layoutHandle)); metadata.beginQuery(session); @@ -190,8 +189,8 @@ public static int getSplitsCount(SchemaTableName tableName, session = newSession(config); ConnectorTableHandle table = getTableHandle(metadata, tableName, session); - List tableLayoutResults = metadata.getTableLayouts(session, table, Constraint.alwaysTrue(), Optional.empty()); - HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) getOnlyElement(tableLayoutResults).getTableLayout().getHandle(); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(session, table, Constraint.alwaysTrue(), Optional.empty()); + HiveTableLayoutHandle layoutHandle = (HiveTableLayoutHandle) tableLayoutResult.getTableLayout().getHandle(); TableHandle tableHandle = new TableHandle(new ConnectorId(tableName.getSchemaName()), table, transaction.getTransactionHandle(), Optional.of(layoutHandle)); metadata.beginQuery(session); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java index 37b3b35139a27..14d00eee6dc81 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/HiveTestUtils.java @@ -17,7 +17,9 @@ import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.log.Logger; import com.facebook.presto.PagesIndexPageSorter; +import com.facebook.presto.Session; import com.facebook.presto.cache.CacheConfig; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; @@ -51,7 +53,9 @@ import com.facebook.presto.hive.s3.PrestoS3ConfigurationUpdater; import com.facebook.presto.hive.s3select.S3SelectRecordCursorProvider; import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.orc.StorageStripeMetadataSource; import com.facebook.presto.orc.StripeMetadataSourceFactory; @@ -59,7 +63,9 @@ import com.facebook.presto.parquet.cache.MetadataReader; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; import com.facebook.presto.spi.relation.DeterminismEvaluator; @@ -75,7 +81,9 @@ import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.relational.RowExpressionDomainTranslator; import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.TestingConnectorSession; +import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; @@ -87,13 +95,17 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.airlift.json.smile.SmileCodec.smileCodec; import static com.facebook.presto.common.type.Decimals.encodeScaledValue; import static com.facebook.presto.hive.HiveDwrfEncryptionProvider.NO_ENCRYPTION; +import static com.facebook.presto.hive.HiveQueryRunner.TPCH_SCHEMA; +import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static java.lang.String.format; import static java.util.stream.Collectors.toList; +import static org.testng.Assert.assertTrue; public final class HiveTestUtils { @@ -188,8 +200,8 @@ public static Set getDefaultHiveAggregatedPageS FileFormatDataSourceStats stats = new FileFormatDataSourceStats(); HdfsEnvironment testHdfsEnvironment = createTestHdfsEnvironment(hiveClientConfig, metastoreClientConfig); return ImmutableSet.builder() - .add(new OrcAggregatedPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, FUNCTION_RESOLUTION, hiveClientConfig, testHdfsEnvironment, stats, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))) - .add(new DwrfAggregatedPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, FUNCTION_RESOLUTION, hiveClientConfig, testHdfsEnvironment, stats, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))) + .add(new OrcAggregatedPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, FUNCTION_RESOLUTION, testHdfsEnvironment, stats, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))) + .add(new DwrfAggregatedPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, FUNCTION_RESOLUTION, testHdfsEnvironment, stats, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))) .add(new ParquetAggregatedPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, FUNCTION_RESOLUTION, testHdfsEnvironment, stats, new MetadataReader())) .build(); } @@ -374,4 +386,20 @@ public static List> getAllSessionProperties(HiveClientConfig allSessionProperties.addAll(hiveCommonSessionProperties.getSessionProperties()); return allSessionProperties; } + + public static Object getHiveTableProperty(QueryRunner queryRunner, Session session, String tableName, Function propertyGetter) + { + Metadata metadata = ((DistributedQueryRunner) queryRunner).getCoordinator().getMetadata(); + + return transaction(queryRunner.getTransactionManager(), queryRunner.getAccessControl()) + .readOnly() + .execute(session, transactionSession -> { + Optional tableHandle = metadata.getMetadataResolver(transactionSession).getTableHandle(new QualifiedObjectName("hive", TPCH_SCHEMA, tableName)); + assertTrue(tableHandle.isPresent()); + + TableLayout layout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) + .getLayout(); + return propertyGetter.apply((HiveTableLayoutHandle) layout.getNewTableHandle().getLayout().get()); + }); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java index 5633478e9b10b..0bc9692b58223 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestBackgroundHiveSplitLoader.java @@ -14,6 +14,8 @@ package com.facebook.presto.hive; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType; import com.facebook.presto.hive.HiveBucketing.HiveBucketFilter; @@ -33,8 +35,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.BlockLocation; import org.apache.hadoop.fs.FSDataInputStream; @@ -67,6 +67,9 @@ import java.util.stream.Collectors; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.hive.BucketFunctionType.HIVE_COMPATIBLE; @@ -86,9 +89,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; @@ -123,7 +123,7 @@ public class TestBackgroundHiveSplitLoader private static final Optional BUCKET_PROPERTY = Optional.of( new HiveBucketProperty(ImmutableList.of("col1"), BUCKET_COUNT, ImmutableList.of(), HIVE_COMPATIBLE, Optional.empty())); - private static final Table SIMPLE_TABLE = table(ImmutableList.of(), Optional.empty()); + public static final Table SIMPLE_TABLE = table(ImmutableList.of(), Optional.empty()); private static final Table PARTITIONED_TABLE = table(PARTITION_COLUMNS, BUCKET_PROPERTY); @Test @@ -541,7 +541,7 @@ private static BackgroundHiveSplitLoader backgroundHiveSplitLoader( false); } - private static List samplePartitionMetadatas() + public static List samplePartitionMetadatas() { return ImmutableList.of( new HivePartitionMetadata( diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedCteExecution.java new file mode 100644 index 0000000000000..cd6f3d4f8a79a --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedCteExecution.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.testing.QueryRunner; +import org.testng.annotations.Test; + +@Test(singleThreaded = true) +public class TestDistributedCteExecution + extends AbstractTestCteExecution +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(false); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedQueriesSingleNode.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedQueriesSingleNode.java index bfaec5f22852d..e3328a940c71b 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedQueriesSingleNode.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestDistributedQueriesSingleNode.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; import com.google.common.collect.ImmutableMap; @@ -32,11 +33,13 @@ protected QueryRunner createQueryRunner() { ImmutableMap.Builder coordinatorProperties = ImmutableMap.builder(); coordinatorProperties.put("single-node-execution-enabled", "true"); - return HiveQueryRunner.createQueryRunner( + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( getTables(), ImmutableMap.of(), coordinatorProperties.build(), Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveAnalyze.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveAnalyze.java new file mode 100644 index 0000000000000..ee4f0564597b8 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveAnalyze.java @@ -0,0 +1,126 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; + +@Test +public class TestHiveAnalyze + extends AbstractTestQueryFramework +{ + private static final String CATALOG = "hive"; + private static final String SCHEMA = "test_analyze_schema"; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder().setCatalog(CATALOG).setSchema(SCHEMA).setTimeZoneKey(TimeZoneKey.UTC_KEY).build(); + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(ImmutableMap.builder().build()).build(); + + queryRunner.installPlugin(new HivePlugin(CATALOG)); + Path catalogDirectory = queryRunner.getCoordinator().getDataDirectory().resolve("hive_data").getParent().resolve("catalog"); + Map properties = ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) + .put("hive.allow-drop-table", "true") + .put("hive.non-managed-table-writes-enabled", "true") + .put("hive.parquet.use-column-names", "true") + .build(); + + queryRunner.createCatalog(CATALOG, CATALOG, properties); + queryRunner.execute(format("CREATE SCHEMA %s.%s", CATALOG, SCHEMA)); + + return queryRunner; + } + + @Test + public void testAnalyzePartitionedTableWithNonCanonicalValues() + throws IOException + { + String tableName = "test_analyze_table_canonicalization"; + assertUpdate(format("CREATE TABLE %s (a_varchar varchar, month varchar) WITH (partitioned_by = ARRAY['month'], external_location='%s')", tableName, com.google.common.io.Files.createTempDir().getPath())); + + assertUpdate(format("INSERT INTO %s VALUES ('A', '01'), ('B', '01'), ('C', '02'), ('D', '03')", tableName), 4); + + String tableLocation = (String) computeActual(format("SELECT DISTINCT regexp_replace(\"$path\", '/[^/]*/[^/]*$', '') FROM %s", tableName)).getOnlyValue(); + + String externalTableName = "external_" + tableName; + assertUpdate(format( + "CREATE TABLE %s (a_varchar varchar, month integer) WITH (partitioned_by = ARRAY['month'], external_location='%s')", externalTableName, tableLocation)); + + assertUpdate(format("CALL system.sync_partition_metadata('%s', '%s', 'ADD')", SCHEMA, externalTableName)); + assertQuery(format("SELECT * FROM \"%s$partitions\"", externalTableName), "SELECT * FROM VALUES 1, 2, 3"); + assertUpdate(format("ANALYZE %s", externalTableName), 4); + assertQuery(format("SHOW STATS FOR %s", externalTableName), + "SELECT * FROM VALUES " + + "('a_varchar', 4.0, 2.0, 0.0, null, null, null, null), " + + "('month', null, 3.0, 0.0, null, 1, 3, null), " + + "(null, null, null, null, 4.0, null, null, null)"); + + assertUpdate(format("INSERT INTO %s VALUES ('E', '04')", tableName), 1); + assertUpdate(format("CALL system.sync_partition_metadata('%s', '%s', 'ADD')", SCHEMA, externalTableName)); + assertQuery(format("SELECT * FROM \"%s$partitions\"", externalTableName), "SELECT * FROM VALUES 1, 2, 3, 4"); + assertUpdate(format("ANALYZE %s WITH (partitions = ARRAY[ARRAY['04']])", externalTableName), 1); + assertQuery(format("SHOW STATS FOR %s", externalTableName), + "SELECT * FROM VALUES " + + "('a_varchar', 5.0, 2.0, 0.0, null, null, null, null), " + + "('month', null, 4.0, 0.0, null, 1, 4, null), " + + "(null, null, null, null, 5.0, null, null, null)"); + // TODO fix selective ANALYZE for table with non-canonical partition values + assertQueryFails(format("ANALYZE %s WITH (partitions = ARRAY[ARRAY['4']])", externalTableName), + format("Partition no longer exists: %s.%s/month=4", SCHEMA, externalTableName)); + + assertUpdate(format("DROP TABLE %s", tableName)); + assertUpdate(format("DROP TABLE %s", externalTableName)); + } + + @Test + public void testAnalyzePartitionedTableWithNonCanonicalValuesUnsupportedScenario() + throws IOException + { + String tableName = "test_analyze_table_canonicalization_unsupported"; + assertUpdate(format("CREATE TABLE %s (a_varchar varchar, month varchar) WITH (partitioned_by = ARRAY['month'], external_location='%s')", tableName, com.google.common.io.Files.createTempDir().getPath())); + + assertUpdate(format("INSERT INTO %s VALUES ('A', '1'), ('B', '01'), ('C', '001'), ('D', '02')", tableName), 4); + + String tableLocation = (String) computeActual(format("SELECT DISTINCT regexp_replace(\"$path\", '/[^/]*/[^/]*$', '') FROM %s", tableName)).getOnlyValue(); + + String externalTableName = "external_" + tableName; + assertUpdate(format( + "CREATE TABLE %s (a_varchar varchar, month integer) WITH (partitioned_by = ARRAY['month'], external_location='%s')", externalTableName, tableLocation)); + + assertUpdate(format("CALL system.sync_partition_metadata('%s', '%s', 'ADD')", SCHEMA, externalTableName)); + assertQuery(format("SELECT * FROM \"%s$partitions\"", externalTableName), "SELECT * FROM VALUES 1, 1, 1, 2"); + + assertQueryFails(format("ANALYZE %s", externalTableName), + "There are multiple variants of the same partition, e.g. p=1, p=01, p=001. All partitions must follow the same key=value representation"); + + assertUpdate(format("DROP TABLE %s", tableName)); + assertUpdate(format("DROP TABLE %s", externalTableName)); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveBucketedTablesWithRowId.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveBucketedTablesWithRowId.java new file mode 100644 index 0000000000000..c32fa22ac1116 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveBucketedTablesWithRowId.java @@ -0,0 +1,240 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.ORDERS; + +@Test(singleThreaded = true) +public class TestHiveBucketedTablesWithRowId + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.createQueryRunner( + ImmutableList.of(ORDERS, CUSTOMER), + ImmutableMap.of(), + Optional.empty()); + } + + @BeforeClass + public void setUp() + { + // Create bucketed customer table + assertUpdate("CREATE TABLE customer_bucketed WITH " + + "(bucketed_by = ARRAY['custkey'], bucket_count = 13) " + + "AS SELECT * FROM customer", 1500); + + // Create bucketed orders table + assertUpdate("CREATE TABLE orders_bucketed WITH " + + "(bucketed_by = ARRAY['orderkey'], bucket_count = 11) " + + "AS SELECT * FROM orders", 15000); + + // Verify tables are created + assertQuery("SELECT count(*) FROM customer_bucketed", "SELECT count(*) FROM customer"); + assertQuery("SELECT count(*) FROM orders_bucketed", "SELECT count(*) FROM orders"); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + try { + assertUpdate("DROP TABLE IF EXISTS customer_bucketed"); + assertUpdate("DROP TABLE IF EXISTS orders_bucketed"); + } + catch (Exception e) { + // Ignore cleanup errors + } + } + + @Test + public void testRowIdWithBucketColumn() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + + // Test basic query with both $row_id and $bucket + String sql = "SELECT \"$row_id\", \"$bucket\", custkey, name " + + "FROM customer_bucketed " + + "WHERE \"$bucket\" = 5"; + + assertPlan(session, sql, anyTree( + project(filter(tableScan("customer_bucketed"))))); + + // Test aggregation grouping by both $row_id and $bucket + sql = "SELECT \"$row_id\", \"$bucket\", COUNT(*) " + + "FROM customer_bucketed " + + "GROUP BY \"$row_id\", \"$bucket\""; + + assertPlan(session, sql, anyTree( + aggregation(ImmutableMap.of(), + project(tableScan("customer_bucketed"))))); + + // Test join between bucketed tables using both $row_id and $bucket + sql = "SELECT c.\"$row_id\" AS customer_row_id, " + + "c.\"$bucket\" AS customer_bucket, " + + "o.\"$row_id\" AS order_row_id, " + + "o.\"$bucket\" AS order_bucket, " + + "c.name, o.orderkey " + + "FROM customer_bucketed c " + + "JOIN orders_bucketed o " + + "ON c.custkey = o.custkey " + + "WHERE c.\"$bucket\" IN (1, 3, 5) " + + "AND o.\"$bucket\" IN (2, 4, 6)"; + + assertPlan(session, sql, anyTree( + join( + project(filter(tableScan("customer_bucketed"))), + exchange(anyTree(tableScan("orders_bucketed")))))); + } + + @Test + public void testRowIdUniquePropertyWithBucketing() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + + // Test unique grouping by $row_id with bucket filtering + String sql = "SELECT " + + "customer_row_id, " + + "ARBITRARY(name) AS customer_name, " + + "ARBITRARY(bucket_num) AS customer_bucket, " + + "ARRAY_AGG(orderkey) AS orders_info " + + "FROM (" + + " SELECT " + + " c.\"$row_id\" AS customer_row_id, " + + " c.\"$bucket\" AS bucket_num, " + + " c.name, " + + " o.orderkey " + + " FROM customer_bucketed c " + + " LEFT JOIN orders_bucketed o " + + " ON c.custkey = o.custkey " + + " AND o.orderstatus IN ('O', 'F') " + + " WHERE c.\"$bucket\" < 5 " + + " AND c.nationkey IN (1, 2, 3) " + + ") " + + "GROUP BY customer_row_id"; + + assertPlan(session, sql, anyTree( + aggregation(ImmutableMap.of(), + join( + anyTree(tableScan("customer_bucketed")), + anyTree(tableScan("orders_bucketed")))))); + } + + @Test + public void testRowIdAndBucketInComplexQuery() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + + // Complex query with both row_id and bucket columns + String sql = "SELECT " + + " unique_id, " + + " bucket_group, " + + " COUNT(*) AS order_count, " + + " AVG(totalprice) AS avg_price " + + "FROM (" + + " SELECT " + + " c.\"$row_id\" AS unique_id, " + + " CASE " + + " WHEN c.\"$bucket\" < 5 THEN 'low' " + + " WHEN c.\"$bucket\" < 10 THEN 'medium' " + + " ELSE 'high' " + + " END AS bucket_group, " + + " o.totalprice " + + " FROM customer_bucketed c " + + " JOIN orders_bucketed o " + + " ON c.custkey = o.custkey " + + " WHERE o.\"$bucket\" % 2 = 0 " + + ") t " + + "GROUP BY unique_id, bucket_group"; + + assertPlan(session, sql, anyTree( + aggregation(ImmutableMap.of(), + project(project(join( + project(tableScan("customer_bucketed")), + anyTree(tableScan("orders_bucketed")))))))); + } + + @Test + public void testDistinctRowIdWithBucketFilter() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + + // Test DISTINCT with row_id and bucket filtering + String sql = "SELECT " + + " DISTINCT c.\"$row_id\" AS unique_id, " + + " c.\"$bucket\" AS bucket_num, " + + " c.name " + + "FROM customer_bucketed c " + + "WHERE c.\"$bucket\" BETWEEN 3 AND 8 " + + " AND c.nationkey = 1"; + + assertPlan(session, sql, anyTree( + aggregation(ImmutableMap.of(), + project(filter(tableScan("customer_bucketed")))))); + } + + @Test + public void testRowIdJoinOnBucketColumn() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + + // Test joining on bucket column while selecting row_id + String sql = "SELECT " + + " c.\"$row_id\" AS customer_row_id, " + + " o.\"$row_id\" AS order_row_id, " + + " c.\"$bucket\" AS shared_bucket, " + + " c.name, " + + " o.orderkey " + + "FROM customer_bucketed c " + + "JOIN orders_bucketed o " + + " ON c.\"$bucket\" = o.\"$bucket\" " + + "WHERE c.\"$bucket\" < 5"; + + assertPlan(session, sql, anyTree( + join( + exchange(anyTree(tableScan("customer_bucketed"))), + exchange(anyTree(tableScan("orders_bucketed")))))); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java index e51a67fd43502..1d09a4993dff6 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveClientConfig.java @@ -14,14 +14,14 @@ package com.facebook.presto.hive; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; +import com.facebook.airlift.units.Duration; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.hive.HiveClientConfig.HdfsAuthenticationType; import com.facebook.presto.hive.s3.S3FileSystemType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.time.ZoneId; @@ -29,6 +29,9 @@ import java.util.TimeZone; import java.util.concurrent.TimeUnit; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.BucketFunctionType.HIVE_COMPATIBLE; import static com.facebook.presto.hive.BucketFunctionType.PRESTO_NATIVE; import static com.facebook.presto.hive.HiveClientConfig.InsertExistingPartitionsBehavior.APPEND; @@ -37,9 +40,6 @@ import static com.facebook.presto.hive.HiveStorageFormat.DWRF; import static com.facebook.presto.hive.HiveStorageFormat.ORC; import static com.facebook.presto.hive.TestHiveUtil.nonDefaultTimeZone; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class TestHiveClientConfig { @@ -83,7 +83,6 @@ public void testDefaults() .setMaxPartitionsPerWriter(100) .setWriteValidationThreads(16) .setTextMaxLineLength(new DataSize(100, Unit.MEGABYTE)) - .setUseOrcColumnNames(false) .setAssumeCanonicalPartitionKeys(false) .setOrcDefaultBloomFilterFpp(0.05) .setRcfileOptimizedWriterEnabled(true) @@ -129,7 +128,6 @@ public void testDefaults() .setBucketFunctionTypeForCteMaterialization(PRESTO_NATIVE) .setParquetDereferencePushdownEnabled(false) .setIgnoreUnreadablePartition(false) - .setMaxMetadataUpdaterThreads(100) .setPartialAggregationPushdownEnabled(false) .setPartialAggregationPushdownForVariableLengthDatatypesEnabled(false) .setFileRenamingEnabled(false) @@ -152,7 +150,7 @@ public void testDefaults() .setHudiTablesUseMergedView(null) .setThriftProtocol(Protocol.BINARY) .setThriftBufferSize(new DataSize(128, BYTE)) - .setCopyOnFirstWriteConfigurationEnabled(true) + .setCopyOnFirstWriteConfigurationEnabled(false) .setPartitionFilteringFromMetastoreEnabled(true) .setParallelParsingOfPartitionValuesEnabled(false) .setMaxParallelParsingConcurrency(100) @@ -166,7 +164,10 @@ public void testDefaults() .setMaxConcurrentParquetQuickStatsCalls(500) .setCteVirtualBucketCount(128) .setSkipEmptyFilesEnabled(false) - .setLegacyTimestampBucketing(false)); + .setOptimizeParsingOfPartitionValues(false) + .setOptimizeParsingOfPartitionValuesThreshold(500) + .setLegacyTimestampBucketing(false) + .setSymlinkOptimizedReaderEnabled(true)); } @Test @@ -209,7 +210,6 @@ public void testExplicitPropertyMappings() .put("hive.max-concurrent-zero-row-file-creations", "100") .put("hive.assume-canonical-partition-keys", "true") .put("hive.text.max-line-length", "13MB") - .put("hive.orc.use-column-names", "true") .put("hive.orc.default-bloom-filter-fpp", "0.96") .put("hive.rcfile-optimized-writer.enabled", "false") .put("hive.rcfile.writer.validate", "true") @@ -255,7 +255,6 @@ public void testExplicitPropertyMappings() .put("hive.bucket-function-type-for-cte-materialization", "HIVE_COMPATIBLE") .put("hive.enable-parquet-dereference-pushdown", "true") .put("hive.ignore-unreadable-partition", "true") - .put("hive.max-metadata-updater-threads", "1000") .put("hive.partial_aggregation_pushdown_enabled", "true") .put("hive.partial_aggregation_pushdown_for_variable_length_datatypes_enabled", "true") .put("hive.file_renaming_enabled", "true") @@ -278,7 +277,7 @@ public void testExplicitPropertyMappings() .put("hive.hudi-tables-use-merged-view", "default.user") .put("hive.internal-communication.thrift-transport-protocol", "COMPACT") .put("hive.internal-communication.thrift-transport-buffer-size", "256B") - .put("hive.copy-on-first-write-configuration-enabled", "false") + .put("hive.copy-on-first-write-configuration-enabled", "true") .put("hive.partition-filtering-from-metastore-enabled", "false") .put("hive.parallel-parsing-of-partition-values-enabled", "true") .put("hive.max-parallel-parsing-concurrency", "200") @@ -292,7 +291,10 @@ public void testExplicitPropertyMappings() .put("hive.quick-stats.max-concurrent-calls", "101") .put("hive.cte-virtual-bucket-count", "256") .put("hive.skip-empty-files", "true") + .put("hive.optimize-parsing-of-partition-values-enabled", "true") + .put("hive.optimize-parsing-of-partition-values-threshold", "100") .put("hive.legacy-timestamp-bucketing", "true") + .put("hive.experimental.symlink.optimized-reader.enabled", "false") .build(); HiveClientConfig expected = new HiveClientConfig() @@ -330,7 +332,6 @@ public void testExplicitPropertyMappings() .setDomainSocketPath("/foo") .setS3FileSystemType(S3FileSystemType.EMRFS) .setTextMaxLineLength(new DataSize(13, Unit.MEGABYTE)) - .setUseOrcColumnNames(true) .setAssumeCanonicalPartitionKeys(true) .setOrcDefaultBloomFilterFpp(0.96) .setRcfileOptimizedWriterEnabled(false) @@ -377,7 +378,6 @@ public void testExplicitPropertyMappings() .setBucketFunctionTypeForCteMaterialization(HIVE_COMPATIBLE) .setParquetDereferencePushdownEnabled(true) .setIgnoreUnreadablePartition(true) - .setMaxMetadataUpdaterThreads(1000) .setPartialAggregationPushdownEnabled(true) .setPartialAggregationPushdownForVariableLengthDatatypesEnabled(true) .setFileRenamingEnabled(true) @@ -400,7 +400,7 @@ public void testExplicitPropertyMappings() .setHudiTablesUseMergedView("default.user") .setThriftProtocol(Protocol.COMPACT) .setThriftBufferSize(new DataSize(256, BYTE)) - .setCopyOnFirstWriteConfigurationEnabled(false) + .setCopyOnFirstWriteConfigurationEnabled(true) .setPartitionFilteringFromMetastoreEnabled(false) .setParallelParsingOfPartitionValuesEnabled(true) .setMaxParallelParsingConcurrency(200) @@ -414,7 +414,10 @@ public void testExplicitPropertyMappings() .setMaxConcurrentQuickStatsCalls(101) .setSkipEmptyFilesEnabled(true) .setCteVirtualBucketCount(256) - .setLegacyTimestampBucketing(true); + .setOptimizeParsingOfPartitionValues(true) + .setOptimizeParsingOfPartitionValuesThreshold(100) + .setLegacyTimestampBucketing(true) + .setSymlinkOptimizedReaderEnabled(false); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveCommitHandleOutput.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveCommitHandleOutput.java index ead56db6ac6c1..30260dfcf0622 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveCommitHandleOutput.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveCommitHandleOutput.java @@ -206,7 +206,31 @@ public void testCommitOutputForPartitions() assertEquals(handle.getSerializedCommitOutputForRead(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)), serializedCommitOutput); assertTrue(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)).isEmpty()); + // Add the same partition with the same location, the metastore will generate same commit output. + hiveMeta = getHiveMetadata(metastore, hiveClientConfig, listeningExecutor); + hiveMeta.getMetastore().addPartition( + connectorSession, + TEST_SCHEMA, + TEST_TABLE, + "random_table_path", + false, + createPartition(partitionName, "location1"), + new Path("/" + TEST_TABLE), + PartitionStatistics.empty()); + handle = hiveMeta.commit(); + + assertEquals(handle.getSerializedCommitOutputForRead(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)), ""); + assertFalse(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)).isEmpty()); + assertEquals(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)), serializedCommitOutput); + // Add the same partition with different location, it should trigger the metastore to generate different commit output. + // Wait for 1000ms to make sure the commit time changes + try { + Thread.sleep(1000); + } + catch (InterruptedException e) { + // ignored + } hiveMeta = getHiveMetadata(metastore, hiveClientConfig, listeningExecutor); hiveMeta.getMetastore().addPartition( connectorSession, @@ -221,7 +245,7 @@ public void testCommitOutputForPartitions() assertEquals(handle.getSerializedCommitOutputForRead(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)), ""); assertFalse(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)).isEmpty()); - assertEquals(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE)), serializedCommitOutput); + assertTrue(Long.parseLong(handle.getSerializedCommitOutputForWrite(new SchemaTableName(TEST_SCHEMA, TEST_TABLE))) > Long.parseLong(serializedCommitOutput)); } private HiveMetadata getHiveMetadata(TestingExtendedHiveMetastore metastore, HiveClientConfig hiveClientConfig, ListeningExecutorService listeningExecutor) @@ -257,7 +281,6 @@ private HiveMetadata getHiveMetadata(TestingExtendedHiveMetastore metastore, Hiv new HivePartitionObjectBuilder(), new HiveEncryptionInformationProvider(ImmutableList.of()), new HivePartitionStats(), - new HiveFileRenamer(), HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, new QuickStatsProvider(metastore, HDFS_ENVIRONMENT, DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(false)); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedNanQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedNanQueries.java index fa2738c63fd72..fc74dbb2f4433 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedNanQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedNanQueries.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestNanQueries; import com.google.common.collect.ImmutableList; @@ -28,6 +29,9 @@ public class TestHiveDistributedNanQueries protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner(ImmutableList.of(), ImmutableMap.of("use-new-nan-definition", "true"), ImmutableMap.of(), Optional.empty()); + QueryRunner queryRunner = + HiveQueryRunner.createQueryRunner(ImmutableList.of(), ImmutableMap.of("use-new-nan-definition", "true"), ImmutableMap.of(), Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java index 89c4dbf955c96..ae1ce36d23ecc 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueries.java @@ -14,23 +14,40 @@ package com.facebook.presto.hive; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.hive.TestHiveEventListenerPlugin.TestingHiveEventListener; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Set; import static com.facebook.presto.SystemSessionProperties.CTE_MATERIALIZATION_STRATEGY; import static com.facebook.presto.SystemSessionProperties.CTE_PARTITIONING_PROVIDER_CATALOG; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_ENABLED; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED; import static com.facebook.presto.SystemSessionProperties.VERBOSE_OPTIMIZER_INFO_ENABLED; +import static com.facebook.presto.hive.HiveCommonSessionProperties.ORC_USE_COLUMN_NAMES; +import static com.facebook.presto.hive.HiveTestUtils.getHiveTableProperty; import static com.facebook.presto.sql.tree.ExplainType.Type.LOGICAL; +import static com.facebook.presto.tests.QueryAssertions.assertEqualsIgnoreOrder; import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.airlift.tpch.TpchTable.getTables; import static java.lang.String.format; import static java.util.stream.Collectors.joining; @@ -46,7 +63,9 @@ public class TestHiveDistributedQueries protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner(getTables()); + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner(getTables()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Override @@ -116,5 +135,398 @@ public void testTrackMaterializedCTEs() checkCTEInfo(explain, "tbl2", 1, false, true); } + @Test + public void testPushdownSubfieldForMapFunctionsInLambda() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "true") + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "true") + .setSystemProperty(PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "false") + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "false") + .setSystemProperty(PUSHDOWN_SUBFIELDS_FROM_LAMBDA_ENABLED, "false") + .build(); + try { + getQueryRunner().execute( + "CREATE TABLE test_pushdown_subfields AS\n" + + "SELECT * FROM (\n" + + " VALUES \n" + + " (3, '2025-01-08', \n" + + " ARRAY[\n" + + " MAP(ARRAY[-2, 1], ARRAY[0.34, 0.92]),\n" + + " MAP(ARRAY[3, 4], ARRAY[0.12, 0.88])\n" + + " ],\n" + + " ARRAY[\n" + + " MAP(ARRAY['a', 'b'], ARRAY[0.56, 0.44]),\n" + + " MAP(ARRAY['c', 'd'], ARRAY[0.90, 0.10])\n" + + " ]\n" + + " ),\n" + + " (1, '2025-01-02', \n" + + " ARRAY[\n" + + " MAP(ARRAY[1, 2], ARRAY[0.23, 0.45]),\n" + + " MAP(ARRAY[5, 6], ARRAY[0.67, 0.89])\n" + + " ],\n" + + " ARRAY[\n" + + " MAP(ARRAY['x', 'y'], ARRAY[0.78, 0.22]),\n" + + " MAP(ARRAY['z', 'w'], ARRAY[0.11, 0.99])\n" + + " ]\n" + + " ),\n" + + " (7, '2025-01-17', \n" + + " ARRAY[\n" + + " MAP(ARRAY[-1, 0], ARRAY[0.60, 0.70]),\n" + + " MAP(ARRAY[2, 3], ARRAY[0.21, 0.79])\n" + + " ],\n" + + " ARRAY[\n" + + " MAP(ARRAY['m', 'n'], ARRAY[0.43, 0.57]),\n" + + " MAP(ARRAY['o', 'p'], ARRAY[0.25, 0.75])\n" + + " ]\n" + + " ),\n" + + " (2, '2025-01-06', \n" + + " ARRAY[\n" + + " MAP(ARRAY[4, 5], ARRAY[0.75, 0.32]),\n" + + " MAP(ARRAY[6, 7], ARRAY[0.19, 0.46])\n" + + " ],\n" + + " ARRAY[\n" + + " MAP(ARRAY['q', 'r'], ARRAY[0.98, 0.02]),\n" + + " MAP(ARRAY['s', 't'], ARRAY[0.49, 0.51])\n" + + " ]\n" + + " ),\n" + + " (5, '2025-01-14', \n" + + " ARRAY[\n" + + " MAP(ARRAY[8, 9], ARRAY[0.88, 0.99]),\n" + + " MAP(ARRAY[10, 11], ARRAY[0.00, 0.33])\n" + + " ],\n" + + " ARRAY[\n" + + " MAP(ARRAY['u', 'v'], ARRAY[0.66, 0.34]),\n" + + " MAP(ARRAY['w', 'x'], ARRAY[0.17, 0.83])\n" + + " ]\n" + + " )\n" + + ") t(id, ds, array_of_maps_int, array_of_maps_str)"); + + @Language("SQL") String sql = "select transform(array_of_maps_int, item -> map_filter(item, (k, v) -> k = 1)), " + + "transform(array_of_maps_str, item -> map_filter(item, (k, v) -> k = 'x')) from test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> contains(array[-2, 1, 0], k))),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> contains(array['a', 'x'], k)))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_subset(item, array[-2, 1, 0])),\n" + + " transform(array_of_maps_str, item -> map_subset(item, array['a', 'x']))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> k in (-2, 1, 0))),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> k in ('a', 'x')))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> k = 1)),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> k = 'a'))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> contains(array[-2, 1, id], k))),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> contains(array['a', 'x'], k)))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_subset(item, array[-2, 1, id])),\n" + + " transform(array_of_maps_str, item -> map_subset(item, array['a', 'x']))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> k in (-2, 1, null))),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> k in ('a', 'x')))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT \n" + + " transform(array_of_maps_int, item -> map_filter(item, (k, v) -> k = id)),\n" + + " transform(array_of_maps_str, item -> map_filter(item, (k, v) -> k = 'a'))\n" + + "FROM test_pushdown_subfields"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + } + finally { + getQueryRunner().execute("DROP TABLE IF EXISTS test_pushdown_subfields"); + } + } + + @Test + public void testPushdownSubfieldForMapFunctions() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "true") + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "false") + .setSystemProperty(PUSHDOWN_SUBFIELDS_ENABLED, "false") + .build(); + try { + getQueryRunner().execute( + "CREATE TABLE test_pushdown_subfields_map_functions AS\n" + + "SELECT * FROM (\n" + + " VALUES \n" + + " (3, '2025-01-08', MAP(ARRAY[2, 1], ARRAY[0.34, 0.92]), MAP(ARRAY['a', 'b'], ARRAY[0.12, 0.88])),\n" + + " (1, '2025-01-02', MAP(ARRAY[1, 3], ARRAY[0.23, 0.5]), MAP(ARRAY['x', 'y'], ARRAY[0.45, 0.55])),\n" + + " (7, '2025-01-17', MAP(ARRAY[6, 8], ARRAY[0.60, 0.70]), MAP(ARRAY['m', 'n'], ARRAY[0.21, 0.79])),\n" + + " (2, '2025-01-06', MAP(ARRAY[2, 3, 5, 7], ARRAY[0.75, 0.32, 0.19, 0.46]), MAP(ARRAY['p', 'q', 'r'], ARRAY[0.11, 0.22, 0.67])),\n" + + " (5, '2025-01-14', MAP(ARRAY[8, 4, 6], ARRAY[0.88, 0.99, 0.00]), MAP(ARRAY['s', 't', 'u'], ARRAY[0.33, 0.44, 0.23])),\n" + + " (4, '2025-01-12', MAP(ARRAY[7, 3, 2], ARRAY[0.33, 0.44, 0.55]), MAP(ARRAY['v', 'w'], ARRAY[0.66, 0.34])),\n" + + " (8, '2025-01-20', MAP(ARRAY[1, 7, 6], ARRAY[0.35, 0.45, 0.55]), MAP(ARRAY['i', 'j', 'k'], ARRAY[0.78, 0.89, 0.12])),\n" + + " (6, '2025-01-16', MAP(ARRAY[9, 1, 3], ARRAY[0.30, 0.40, 0.50]), MAP(ARRAY['c', 'd'], ARRAY[0.90, 0.10])),\n" + + " (2, '2025-01-05', MAP(ARRAY[3, 4], ARRAY[0.98, 0.21]), MAP(ARRAY['e', 'f'], ARRAY[0.56, 0.44])),\n" + + " (1, '2025-01-04', MAP(ARRAY[1, 2], ARRAY[0.45, 0.67]), MAP(ARRAY['g', 'h'], ARRAY[0.23, 0.77]))\n" + + ") AS t(id, ds, feature, extra_feature)"); + + @Language("SQL") String sql = "select map_filter(feature, (k, v) -> k in (-2, 1, 0)), map_filter(extra_feature, (k, v) -> k in ('a', 'x')) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_filter(feature, (k, v) -> contains(array[-2, 1, 0], k)), map_filter(extra_feature, (k, v) -> contains(array['a', 'x'], k)) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_filter(feature, (k, v) -> k = 0), map_filter(extra_feature, (k, v) -> k = 'a') from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_subset(feature, array[-2, 1, 0]), map_subset(extra_feature, array['a', 'x']) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + + sql = "select map_filter(feature, (k, v) -> k in (-2, 1, id)), map_filter(extra_feature, (k, v) -> k in ('a', 'x')) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_filter(feature, (k, v) -> contains(array[-2, 1, id], k)), map_filter(extra_feature, (k, v) -> contains(array['a', 'x', null], k)) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_filter(feature, (k, v) -> k = id), map_filter(extra_feature, (k, v) -> k = 'a') from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "select map_subset(feature, array[id]), map_subset(extra_feature, array['a', null]) from test_pushdown_subfields_map_functions"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + } + finally { + getQueryRunner().execute("DROP TABLE IF EXISTS test_pushdown_subfields_map_functions"); + } + } + + @Test + public void testOrcUseColumnNames() throws IOException, URISyntaxException + { + File externalTableDataDirectory = Files.createTempDir(); + String externalTableDataLocationUri = externalTableDataDirectory.toURI().toString(); + + try { + @Language("SQL") String createManagedTableSql = "" + + "create table test_orc_use_column_names (\n" + + " \"c1\" int,\n" + + " \"c2\" varchar\n" + + ")\n" + + "WITH (\n" + + " format = 'ORC'\n" + + ")"; + + assertUpdate(createManagedTableSql); + assertUpdate(format("insert into test_orc_use_column_names values (1, 'one')"), 1); + String tablePath = (String) getHiveTableProperty(getQueryRunner(), getSession(), "test_orc_use_column_names", (HiveTableLayoutHandle table) -> table.getTablePath()); + File managedTableDataDirectory = new File(new URI(tablePath).getRawPath()); + + assertTrue(managedTableDataDirectory.isDirectory(), "Source managed table data directory does not exist: " + managedTableDataDirectory); + File[] orcFiles = managedTableDataDirectory.listFiles(file -> file.isFile() + && !file.getName().contains(".")); + + if (orcFiles != null) { + for (File orcFile : orcFiles) { + File destinationFile = new File(externalTableDataDirectory, orcFile.getName()); + Files.copy(orcFile, destinationFile); + } + } + else { + throw new IllegalStateException("No ORC files found in managed table data directory: " + managedTableDataDirectory); + } + + @Language("SQL") String createMisMatchingExternalTableSql = format("" + + "CREATE TABLE test_orc_use_column_names_mismatching_ext (\n" + + " \"c2\" varchar,\n" + + " \"c1\" int\n" + + ")\n" + + "WITH (\n" + + " external_location = '%s',\n" + + " format = 'ORC'\n" + + ")", + externalTableDataLocationUri); + assertUpdate(createMisMatchingExternalTableSql); + + assertQueryFails(getSession(), + "select * from test_orc_use_column_names_mismatching_ext", + ".*java.io.IOException: Malformed ORC file. Can not read SQL type varchar from ORC stream .c1 of type INT.*"); + + Session useOrcColumnNamesSession = Session.builder(getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty("hive", ORC_USE_COLUMN_NAMES, "true").build(); + assertQuerySucceeds(useOrcColumnNamesSession, "select * from test_orc_use_column_names_mismatching_ext"); + + @Language("SQL") String createDifferentColumnNameExternalTableSql = format("" + + "CREATE TABLE test_orc_use_column_names_different_column_name_ext (\n" + + " \"c1\" int,\n" + + " \"c3\" varchar\n" + + ")\n" + + "WITH (\n" + + " external_location = '%s',\n" + + " format = 'ORC'\n" + + ")", + externalTableDataLocationUri); + assertUpdate(createDifferentColumnNameExternalTableSql); + + assertQuery(useOrcColumnNamesSession, "select * from test_orc_use_column_names_different_column_name_ext", "VALUES (1, NULL)"); + + @Language("SQL") String createAdditionalColumnExternalTableSql = format("" + + "CREATE TABLE test_orc_use_column_names_additional_column_ext (\n" + + " \"c1\" int,\n" + + " \"c2\" varchar,\n" + + " \"c3\" varchar\n" + + ")\n" + + "WITH (\n" + + " external_location = '%s',\n" + + " format = 'ORC'\n" + + ")", + externalTableDataLocationUri); + assertUpdate(createAdditionalColumnExternalTableSql); + + assertQuery(useOrcColumnNamesSession, "select * from test_orc_use_column_names_additional_column_ext", "VALUES (1, 'one', NULL)"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS test_orc_use_column_names"); + assertUpdate("DROP TABLE IF EXISTS test_orc_use_column_names_mismatching_ext"); + assertUpdate("DROP TABLE IF EXISTS test_orc_use_column_names_different_column_name_ext"); + assertUpdate("DROP TABLE IF EXISTS test_orc_use_column_names_additional_column_ext"); + + if (externalTableDataDirectory != null) { + deleteRecursively(externalTableDataDirectory.toPath(), ALLOW_INSECURE); + } + } + } + + @Test + public void testCombineMultipleApproxDistinctSameType() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "false") + .build(); + + // two distinct expressions of same type + assertQueryWithSameQueryRunner(enabled, "SELECT approx_distinct(name), approx_distinct(comment) FROM nation", disabled); + // three distinct expressions of same type + assertQueryWithSameQueryRunner(enabled, "SELECT approx_distinct(name), approx_distinct(comment), approx_distinct(CAST(nationkey AS VARCHAR)) FROM nation", disabled); + // multiple distinct expressions with GROUP BY + assertQueryWithSameQueryRunner(enabled, "SELECT regionkey, approx_distinct(name), approx_distinct(comment) FROM nation GROUP BY regionkey", disabled); + + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name), approx_distinct(comment), count(*) FROM nation", + disabled); + + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name), approx_distinct(comment), approx_distinct(nationkey), approx_distinct(regionkey) FROM nation", + disabled); + + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name, 0.01), approx_distinct(comment, 0.01) FROM nation", + disabled); + + // different error tolerances should NOT be merged - these should still produce correct results + // but won't benefit from the optimization (each approx_distinct runs separately) + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name, 0.01), approx_distinct(comment, 0.02) FROM nation", + disabled); + + // mix of default and explicit error tolerance should NOT be merged + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name), approx_distinct(comment, 0.01) FROM nation", + disabled); + + // three columns with different error tolerances - none should be merged + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name, 0.01), approx_distinct(comment, 0.02), approx_distinct(CAST(nationkey AS VARCHAR), 0.05) FROM nation", + disabled); + + // same error tolerance with GROUP BY - should be merged + assertQueryWithSameQueryRunner(enabled, + "SELECT regionkey, approx_distinct(name, 0.01), approx_distinct(comment, 0.01) FROM nation GROUP BY regionkey", + disabled); + + // different error tolerances with GROUP BY - should NOT be merged + assertQueryWithSameQueryRunner(enabled, + "SELECT regionkey, approx_distinct(name, 0.01), approx_distinct(comment, 0.05) FROM nation GROUP BY regionkey", + disabled); + + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(name), approx_distinct(comment) FROM nation WHERE regionkey > 1", + disabled); + + assertQueryWithSameQueryRunner(enabled, + "SELECT regionkey, approx_distinct(name), approx_distinct(comment) FROM nation GROUP BY regionkey HAVING count(*) > 3", + disabled); + + MaterializedResult resultDisabled = computeActual(disabled, "SELECT approx_distinct(name), approx_distinct(comment) FROM nation"); + MaterializedResult resultEnabled = computeActual(enabled, "SELECT approx_distinct(name), approx_distinct(comment) FROM nation"); + assertEqualsIgnoreOrder(resultDisabled, resultEnabled); + + computeActual(enabled, "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority) FROM orders"); + computeActual(enabled, "SELECT orderstatus, approx_distinct(orderpriority), approx_distinct(clerk) FROM orders GROUP BY orderstatus"); + } + + @Test + public void testCombineApproxDistinctWithOtherAggregations() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "false") + .build(); + + // approx_distinct with sum + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), sum(totalprice) FROM orders", + disabled); + + // approx_distinct with count + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), count(*) FROM orders", + disabled); + + // approx_distinct with count distinct + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), count(DISTINCT custkey) FROM orders", + disabled); + + // approx_distinct with multiple other aggregations + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), sum(totalprice), count(*), avg(totalprice) FROM orders", + disabled); + + // approx_distinct with GROUP BY and multiple aggregations + assertQueryWithSameQueryRunner(enabled, + "SELECT custkey, approx_distinct(orderstatus), approx_distinct(orderpriority), sum(totalprice), count(*) FROM orders WHERE custkey < 100 GROUP BY custkey", + disabled); + + // approx_distinct with min and max + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), min(totalprice), max(totalprice), avg(totalprice) FROM orders", + disabled); + + // approx_distinct with max_by + assertQueryWithSameQueryRunner(enabled, + "SELECT custkey, approx_distinct(orderstatus), approx_distinct(orderpriority), max_by(orderkey, totalprice), sum(totalprice) FROM orders WHERE custkey < 100 GROUP BY custkey", + disabled); + + // approx_distinct with multiple count distincts + assertQueryWithSameQueryRunner(enabled, + "SELECT approx_distinct(orderstatus), approx_distinct(orderpriority), count(DISTINCT custkey), count(DISTINCT orderdate) FROM orders", + disabled); + + // approx_distinct with HAVING clause + assertQueryWithSameQueryRunner(enabled, + "SELECT custkey, approx_distinct(orderstatus), approx_distinct(orderpriority), sum(totalprice) FROM orders WHERE custkey < 100 GROUP BY custkey HAVING sum(totalprice) > 10000", + disabled); + } + // Hive specific tests should normally go in TestHiveIntegrationSmokeTest } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithExchangeMaterialization.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithExchangeMaterialization.java index 226bbf3888b53..a878243e32b91 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithExchangeMaterialization.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithExchangeMaterialization.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.presto.Session; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; @@ -40,7 +41,9 @@ public class TestHiveDistributedQueriesWithExchangeMaterialization protected QueryRunner createQueryRunner() throws Exception { - return createMaterializingQueryRunner(getTables()); + QueryRunner queryRunner = createMaterializingQueryRunner(getTables()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Test diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithOptimizedRepartitioning.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithOptimizedRepartitioning.java index d0bf8f6499428..2477f9c31cf46 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithOptimizedRepartitioning.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithOptimizedRepartitioning.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; import com.google.common.collect.ImmutableMap; @@ -30,13 +31,15 @@ public class TestHiveDistributedQueriesWithOptimizedRepartitioning protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner( + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( getTables(), ImmutableMap.of( "experimental.optimized-repartitioning", "true", // Use small SerializedPages to force flushing "driver.max-page-partitioning-buffer-size", "10000B"), Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithThriftRpc.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithThriftRpc.java index 6f877f35a4ce3..1bbc9590dcd50 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithThriftRpc.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveDistributedQueriesWithThriftRpc.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; import com.google.common.collect.ImmutableMap; @@ -30,13 +31,15 @@ public class TestHiveDistributedQueriesWithThriftRpc protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner( + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( getTables(), ImmutableMap.of( "internal-communication.task-communication-protocol", "THRIFT", "internal-communication.server-info-communication-protocol", "THRIFT"), ImmutableMap.of(), Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileBasedSecurity.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileBasedSecurity.java index d895de3dd5855..7e768787f9f8b 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileBasedSecurity.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileBasedSecurity.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.presto.Session; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.testing.QueryRunner; import com.google.common.collect.ImmutableList; @@ -25,8 +26,11 @@ import java.util.Optional; import static com.facebook.presto.hive.HiveQueryRunner.createQueryRunner; +import static com.facebook.presto.hive.HiveSessionProperties.USE_LIST_DIRECTORY_CACHE; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static io.airlift.tpch.TpchTable.NATION; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.ThrowingRunnable; public class TestHiveFileBasedSecurity { @@ -61,6 +65,31 @@ public void testNonAdminCannotRead() queryRunner.execute(bob, "SELECT * FROM nation"); } + @Test + public void testCallProcedures() + { + Session admin = Session.builder(getSession("hive")) + .setConnectionProperty(new ConnectorId("hive"), USE_LIST_DIRECTORY_CACHE, "true") + .build(); + queryRunner.execute(admin, "call hive.system.invalidate_directory_list_cache()"); + + Session alice = Session.builder(getSession("alice")) + .setConnectionProperty(new ConnectorId("hive"), USE_LIST_DIRECTORY_CACHE, "true") + .build(); + queryRunner.execute(alice, "call hive.system.invalidate_directory_list_cache()"); + + Session bob = Session.builder(getSession("bob")) + .setConnectionProperty(new ConnectorId("hive"), USE_LIST_DIRECTORY_CACHE, "true") + .build(); + queryRunner.execute(bob, "call hive.system.invalidate_directory_list_cache()"); + + Session joe = Session.builder(getSession("joe")) + .setConnectionProperty(new ConnectorId("hive"), USE_LIST_DIRECTORY_CACHE, "true") + .build(); + assertDenied(() -> queryRunner.execute(joe, "call hive.system.invalidate_directory_list_cache()"), + "Access Denied: Cannot call procedure system.invalidate_directory_list_cache"); + } + private Session getSession(String user) { return testSessionBuilder() @@ -68,4 +97,11 @@ private Session getSession(String user) .setSchema(queryRunner.getDefaultSession().getSchema().get()) .setIdentity(new Identity(user, Optional.empty())).build(); } + + private static void assertDenied(ThrowingRunnable runnable, String message) + { + assertThatThrownBy(runnable::run) + .isInstanceOf(RuntimeException.class) + .hasMessageMatching(message); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java index 8c117033a686b..f9fed4fdf2024 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileFormats.java @@ -339,7 +339,7 @@ public void testOrc(int rowCount) assertThatFileFormat(ORC) .withColumns(TEST_COLUMNS) .withRowsCount(rowCount) - .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); + .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); } @Test(dataProvider = "rowCount") @@ -362,7 +362,7 @@ public void testOrcOptimizedWriter(int rowCount) .withSession(session) .withFileWriterFactory(new OrcFileWriterFactory(HDFS_ENVIRONMENT, new OutputStreamDataSinkFactory(), FUNCTION_AND_TYPE_MANAGER, new NodeVersion("test"), HIVE_STORAGE_TIME_ZONE, STATS, new OrcFileWriterConfig(), NO_ENCRYPTION)) .isReadableByRecordCursor(new GenericHiveRecordCursorProvider(HDFS_ENVIRONMENT)) - .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); + .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); } @Test(dataProvider = "rowCount") @@ -394,14 +394,15 @@ public void testOrcUseColumnNames(int rowCount) { TestingConnectorSession session = new TestingConnectorSession(getAllSessionProperties( new HiveClientConfig(), - new HiveCommonClientConfig())); + new HiveCommonClientConfig() + .setUseOrcColumnNames(true))); assertThatFileFormat(ORC) .withWriteColumns(TEST_COLUMNS) .withRowsCount(rowCount) .withReadColumns(Lists.reverse(TEST_COLUMNS)) .withSession(session) - .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, true, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); + .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); } @Test(dataProvider = "rowCount") @@ -419,7 +420,7 @@ public void testOrcUseColumnNamesCompatibility(int rowCount) .withRowsCount(rowCount) .withReadColumns(TEST_COLUMNS) .withSession(session) - .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, true, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); + .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); } private static List getHiveColumnNameColumns() @@ -628,7 +629,7 @@ public void testTruncateVarcharColumn() assertThatFileFormat(ORC) .withWriteColumns(ImmutableList.of(writeColumn)) .withReadColumns(ImmutableList.of(readColumn)) - .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); + .isReadableByPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource()))); assertThatFileFormat(PARQUET) .withWriteColumns(ImmutableList.of(writeColumn)) @@ -676,7 +677,7 @@ public void testFailForLongVarcharPartitionColumn() assertThatFileFormat(ORC) .withColumns(columns) - .isFailingForPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, false, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource())), expectedErrorCode, expectedMessage); + .isFailingForPageSource(new OrcBatchPageSourceFactory(FUNCTION_AND_TYPE_MANAGER, HDFS_ENVIRONMENT, STATS, 100, new StorageOrcFileTailSource(), StripeMetadataSourceFactory.of(new StorageStripeMetadataSource())), expectedErrorCode, expectedMessage); assertThatFileFormat(PARQUET) .withColumns(columns) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileRenamer.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileRenamer.java deleted file mode 100644 index fe86044638acd..0000000000000 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveFileRenamer.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.SchemaTableName; -import com.google.common.collect.ImmutableList; -import org.testng.annotations.Test; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.stream.Collectors; - -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_HIVE_METADATA_UPDATE_REQUEST; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_PARTITION_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_REQUEST_ID; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_SCHEMA_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_SCHEMA_TABLE_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_TABLE_NAME; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertTrue; - -@Test -public class TestHiveFileRenamer -{ - private static final QueryId TEST_QUERY_ID = new QueryId("test"); - private static final int REQUEST_COUNT = 10; - private static final int PARTITION_COUNT = 10; - private static final int TABLE_COUNT = 10; - private static final int THREAD_COUNT = 100; - private static final int THREAD_POOL_SIZE = 10; - - @Test - public void testHiveFileRenamer() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - List requests = ImmutableList.of(TEST_HIVE_METADATA_UPDATE_REQUEST); - List results = hiveFileRenamer.getMetadataUpdateResults(requests, TEST_QUERY_ID); - - // Assert # of requests is equal to # of results - assertEquals(requests.size(), results.size()); - - HiveMetadataUpdateHandle result = (HiveMetadataUpdateHandle) results.get(0); - - assertEquals(result.getRequestId(), TEST_REQUEST_ID); - assertEquals(result.getSchemaTableName(), TEST_SCHEMA_TABLE_NAME); - assertEquals(result.getPartitionName(), Optional.of(TEST_PARTITION_NAME)); - - // Assert file name returned is "1" - assertEquals(result.getMetadataUpdate(), Optional.of("0")); - } - - @Test - public void testFileNamesForSinglePartition() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - List requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, TEST_TABLE_NAME, TEST_PARTITION_NAME); - List fileNames = getFileNames(hiveFileRenamer, requests); - List aggregatedFileNames = new ArrayList<>(fileNames); - - assertTrue(areFileNamesIncreasingSequentially(fileNames)); - - // Send more requests - requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, TEST_TABLE_NAME, TEST_PARTITION_NAME); - fileNames = getFileNames(hiveFileRenamer, requests); - aggregatedFileNames.addAll(fileNames); - - // Assert that the # of filenames is equal to # of requests - assertEquals(fileNames.size(), REQUEST_COUNT); - - // Assert that the file names are continuous increasing numbers - assertTrue(areFileNamesIncreasingSequentially(aggregatedFileNames)); - } - - @Test - public void testFileNamesForMultiplePartitions() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - for (int partitionNumber = 1; partitionNumber <= PARTITION_COUNT; partitionNumber++) { - List requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, TEST_TABLE_NAME, "partition_" + partitionNumber); - List fileNames = getFileNames(hiveFileRenamer, requests); - - // Assert that the # of filenames is equal to # of requests - assertEquals(fileNames.size(), REQUEST_COUNT); - - assertTrue(areFileNamesIncreasingSequentially(fileNames)); - } - } - - @Test - public void testFileNamesForMultipleTables() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - for (int tableNumber = 1; tableNumber <= TABLE_COUNT; tableNumber++) { - List requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, "table_" + tableNumber, TEST_PARTITION_NAME); - List fileNames = getFileNames(hiveFileRenamer, requests); - - // Assert that the # of filenames is equal to # of requests - assertEquals(fileNames.size(), REQUEST_COUNT); - - assertTrue(areFileNamesIncreasingSequentially(fileNames)); - } - } - - @Test - public void testFileNameResultCache() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - List requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, TEST_TABLE_NAME, TEST_PARTITION_NAME); - - // Get the file names 1st time - List fileNames = getFileNames(hiveFileRenamer, requests); - - // Get the file names 2nd time, for the same set of requests. This should be served from cache. - List fileNamesList = getFileNames(hiveFileRenamer, requests); - - // Assert that the file name is same for a given request. This is to imitate retries from workers. - assertEquals(fileNames, fileNamesList); - assertEquals(fileNames, getFileNames(hiveFileRenamer, requests)); - } - - @Test - public void testMultiThreadedRequests() - throws InterruptedException - { - ExecutorService service = Executors.newFixedThreadPool(THREAD_POOL_SIZE); - CountDownLatch latch = new CountDownLatch(THREAD_COUNT); - - List fileNames = new CopyOnWriteArrayList<>(); - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - - // Spawn THREAD_COUNT threads. And each thread will send REQUEST_COUNT requests to HiveFileRenamer - for (int i = 0; i < THREAD_COUNT; i++) { - service.execute(() -> { - List requests = createHiveMetadataUpdateRequests(TEST_SCHEMA_NAME, TEST_TABLE_NAME, TEST_PARTITION_NAME); - fileNames.addAll(getFileNames(hiveFileRenamer, requests)); - latch.countDown(); - }); - } - - // wait for all threads to finish - latch.await(); - - // Assert the # of filenames - assertEquals(fileNames.size(), THREAD_COUNT * REQUEST_COUNT); - - // Assert that the filenames are an increasing sequence - assertTrue(areFileNamesIncreasingSequentially(fileNames)); - } - - @Test - public void testCleanup() - { - HiveFileRenamer hiveFileRenamer = new HiveFileRenamer(); - List requests = ImmutableList.of(TEST_HIVE_METADATA_UPDATE_REQUEST); - List results = hiveFileRenamer.getMetadataUpdateResults(requests, TEST_QUERY_ID); - assertEquals(results.size(), 1); - HiveMetadataUpdateHandle result = (HiveMetadataUpdateHandle) results.get(0); - assertEquals(result.getMetadataUpdate(), Optional.of("0")); - - hiveFileRenamer.cleanup(TEST_QUERY_ID); - - requests = ImmutableList.of(TEST_HIVE_METADATA_UPDATE_REQUEST); - results = hiveFileRenamer.getMetadataUpdateResults(requests, TEST_QUERY_ID); - assertEquals(results.size(), 1); - result = (HiveMetadataUpdateHandle) results.get(0); - assertEquals(result.getMetadataUpdate(), Optional.of("0")); - } - - private List getFileNames(HiveFileRenamer hiveFileRenamer, List requests) - { - List results = hiveFileRenamer.getMetadataUpdateResults(requests, TEST_QUERY_ID); - return results.stream() - .map(result -> { - Optional fileName = ((HiveMetadataUpdateHandle) result).getMetadataUpdate(); - assertTrue(fileName.isPresent()); - return fileName.get(); - }) - .collect(Collectors.toList()); - } - - private List createHiveMetadataUpdateRequests(String schemaName, String tableName, String partitionName) - { - List requests = new ArrayList<>(); - for (int i = 1; i <= REQUEST_COUNT; i++) { - requests.add(new HiveMetadataUpdateHandle(UUID.randomUUID(), new SchemaTableName(schemaName, tableName), Optional.of(partitionName), Optional.empty())); - } - return requests; - } - - private boolean areFileNamesIncreasingSequentially(List fileNames) - { - // Sort the filenames - fileNames.sort(Comparator.comparingInt(Integer::valueOf)); - - long start = 0; - - for (String fileName : fileNames) { - if (!fileName.equals(String.valueOf(start))) { - return false; - } - start++; - } - return true; - } -} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java index 0ca908038c861..fe8e7f6a225e6 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveIntegrationSmokeTest.java @@ -23,12 +23,10 @@ import com.facebook.presto.hive.HiveClientConfig.InsertExistingPartitionsBehavior; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.CatalogSchemaTableName; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; -import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; import com.facebook.presto.spi.plan.MarkDistinctNode; @@ -68,17 +66,15 @@ import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Function; import java.util.stream.LongStream; +import java.util.stream.Stream; import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SystemSessionProperties.COLOCATED_JOIN; @@ -89,6 +85,8 @@ import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.SystemSessionProperties.LOG_INVOKED_FUNCTION_NAMES_ENABLED; +import static com.facebook.presto.SystemSessionProperties.MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED; +import static com.facebook.presto.SystemSessionProperties.MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED; import static com.facebook.presto.SystemSessionProperties.PARTIAL_MERGE_PUSHDOWN_STRATEGY; import static com.facebook.presto.SystemSessionProperties.PARTITIONING_PROVIDER_CATALOG; import static com.facebook.presto.common.predicate.Marker.Bound.EXACTLY; @@ -111,6 +109,7 @@ import static com.facebook.presto.hive.HiveQueryRunner.TPCH_SCHEMA; import static com.facebook.presto.hive.HiveQueryRunner.createBucketedSession; import static com.facebook.presto.hive.HiveQueryRunner.createMaterializeExchangesSession; +import static com.facebook.presto.hive.HiveSessionProperties.COMPRESSION_CODEC; import static com.facebook.presto.hive.HiveSessionProperties.FILE_RENAMING_ENABLED; import static com.facebook.presto.hive.HiveSessionProperties.MANIFEST_VERIFICATION_ENABLED; import static com.facebook.presto.hive.HiveSessionProperties.OPTIMIZED_PARTITION_UPDATE_SERIALIZATION_ENABLED; @@ -132,6 +131,7 @@ import static com.facebook.presto.hive.HiveTableProperties.PARTITIONED_BY_PROPERTY; import static com.facebook.presto.hive.HiveTableProperties.STORAGE_FORMAT_PROPERTY; import static com.facebook.presto.hive.HiveTestUtils.FUNCTION_AND_TYPE_MANAGER; +import static com.facebook.presto.hive.HiveTestUtils.getHiveTableProperty; import static com.facebook.presto.hive.HiveUtil.columnExtraInfo; import static com.facebook.presto.spi.security.SelectedRole.Type.ROLE; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; @@ -142,6 +142,8 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_MATERIALIZED; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textLogicalPlan; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; +import static com.facebook.presto.testing.TestingAccessControlManager.privilege; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.facebook.presto.tests.QueryAssertions.assertEqualsIgnoreOrder; @@ -161,6 +163,8 @@ import static io.airlift.tpch.TpchTable.PART_SUPPLIER; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; import static org.assertj.core.api.Assertions.assertThat; @@ -207,12 +211,16 @@ protected TestHiveIntegrationSmokeTest( protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner(ORDERS, CUSTOMER, LINE_ITEM, PART_SUPPLIER, NATION); + return HiveQueryRunner.createQueryRunner(ImmutableList.of(ORDERS, CUSTOMER, LINE_ITEM, PART_SUPPLIER, NATION), + ImmutableMap.of(), + "sql-standard", + ImmutableMap.of("hive.restrict-procedure-call", "false"), + Optional.empty()); } private List getPartitions(HiveTableLayoutHandle tableLayoutHandle) { - return tableLayoutHandle.getPartitions().get(); + return tableLayoutHandle.getPartitions().map(PartitionSet::getFullyLoadedPartitions).get(); } @Test @@ -351,6 +359,43 @@ public void testIOExplain() assertUpdate("DROP TABLE test_orders"); } + @Test + public void testIOExplainWithTemporalTypes() + { + computeActual("CREATE TABLE test_temporal_io " + + "WITH (partitioned_by = ARRAY['dt', 'ts']) " + + "AS SELECT orderkey, " + + "CAST('2020-03-25' AS DATE) AS dt, " + + "CAST('2020-01-15 10:30:45.000' AS TIMESTAMP) AS ts " + + "FROM orders WHERE orderkey = 1"); + + try { + MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) SELECT * FROM test_temporal_io " + + "WHERE dt = DATE '2020-03-25' " + + "AND ts = TIMESTAMP '2020-01-15 10:30:45.000'"); + IOPlan ioPlan = jsonCodec(IOPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())); + assertEquals(ioPlan.getInputTableColumnInfos().size(), 1); + TableColumnInfo tableInfo = ioPlan.getInputTableColumnInfos().iterator().next(); + + Optional dtConstraint = tableInfo.getColumnConstraints().stream() + .filter(c -> c.getColumnName().equals("dt")) + .findFirst(); + assertTrue(dtConstraint.isPresent(), "Expected date column constraint"); + assertEquals(dtConstraint.get().getDomain().getRanges().iterator().next().getLow().getValue().get(), "2020-03-25"); + + Optional tsConstraint = tableInfo.getColumnConstraints().stream() + .filter(c -> c.getColumnName().equals("ts")) + .findFirst(); + assertTrue(tsConstraint.isPresent(), "Expected timestamp column constraint"); + String tsValue = tsConstraint.get().getDomain().getRanges().iterator().next().getLow().getValue().get(); + assertTrue(tsValue.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}$"), + "Timestamp should be formatted as yyyy-MM-dd HH:mm:ss.SSS but was: " + tsValue); + } + finally { + assertUpdate("DROP TABLE test_temporal_io"); + } + } + @Test public void testReadNoColumns() { @@ -1905,7 +1950,7 @@ public void testShowColumnsFromPartitions() assertQuery( getSession(), "SHOW COLUMNS FROM \"" + tableName + "$partitions\"", - "VALUES ('part1', 'bigint', '', ''), ('part2', 'varchar', '', '')"); + "VALUES ('part1', 'bigint', '', '', 19L, null, null), ('part2', 'varchar', '', '', null, null, 2147483647L)"); assertQueryFails( getSession(), @@ -2050,10 +2095,12 @@ public void testMetadataDelete() assertEquals(e.getMessage(), "This connector only supports delete where one or more partitions are deleted entirely"); } + // Test successful metadata delete on partition columns assertUpdate("DELETE FROM test_metadata_delete WHERE LINE_STATUS='O'"); assertQuery("SELECT * from test_metadata_delete", "SELECT orderkey, linenumber, linestatus FROM lineitem WHERE linestatus<>'O' and linenumber<>3"); + // Test delete on non-partition column - should fail try { getQueryRunner().execute("DELETE FROM test_metadata_delete WHERE ORDER_KEY=1"); fail("expected exception"); @@ -2064,6 +2111,17 @@ public void testMetadataDelete() assertQuery("SELECT * from test_metadata_delete", "SELECT orderkey, linenumber, linestatus FROM lineitem WHERE linestatus<>'O' and linenumber<>3"); + // Test delete with partition column AND RAND() - should fail because RAND() requires row-level filtering + try { + getQueryRunner().execute("DELETE FROM test_metadata_delete WHERE LINE_STATUS='F' AND rand() <= 0.1"); + fail("expected exception"); + } + catch (RuntimeException e) { + assertEquals(e.getMessage(), "This connector only supports delete where one or more partitions are deleted entirely"); + } + + assertQuery("SELECT * from test_metadata_delete", "SELECT orderkey, linenumber, linestatus FROM lineitem WHERE linestatus<>'O' and linenumber<>3"); + assertUpdate("DROP TABLE test_metadata_delete"); assertFalse(getQueryRunner().tableExists(getSession(), "test_metadata_delete")); @@ -2083,31 +2141,14 @@ private TableMetadata getTableMetadata(String catalog, String schema, String tab }); } - private Object getHiveTableProperty(String tableName, Function propertyGetter) - { - Session session = getSession(); - Metadata metadata = ((DistributedQueryRunner) getQueryRunner()).getCoordinator().getMetadata(); - - return transaction(getQueryRunner().getTransactionManager(), getQueryRunner().getAccessControl()) - .readOnly() - .execute(session, transactionSession -> { - Optional tableHandle = metadata.getMetadataResolver(transactionSession).getTableHandle(new QualifiedObjectName(catalog, TPCH_SCHEMA, tableName)); - assertTrue(tableHandle.isPresent()); - - TableLayout layout = metadata.getLayout(transactionSession, tableHandle.get(), Constraint.alwaysTrue(), Optional.empty()) - .getLayout(); - return propertyGetter.apply((HiveTableLayoutHandle) layout.getNewTableHandle().getLayout().get()); - }); - } - private List getPartitions(String tableName) { - return (List) getHiveTableProperty(tableName, (HiveTableLayoutHandle table) -> getPartitions(table)); + return (List) getHiveTableProperty(getQueryRunner(), getSession(), tableName, (HiveTableLayoutHandle table) -> getPartitions(table)); } private int getBucketCount(String tableName) { - return (int) getHiveTableProperty(tableName, (HiveTableLayoutHandle table) -> table.getBucketHandle().get().getTableBucketCount()); + return (int) getHiveTableProperty(getQueryRunner(), getSession(), tableName, (HiveTableLayoutHandle table) -> table.getBucketHandle().get().getTableBucketCount()); } @Test @@ -2120,15 +2161,15 @@ public void testShowColumnsPartitionKey() MaterializedResult actual = computeActual("SHOW COLUMNS FROM test_show_columns_partition_key"); Type unboundedVarchar = canonicalizeType(VARCHAR); - MaterializedResult expected = resultBuilder(getSession(), unboundedVarchar, unboundedVarchar, unboundedVarchar, unboundedVarchar) - .row("grape", canonicalizeTypeName("bigint"), "", "") - .row("orange", canonicalizeTypeName("bigint"), "", "") - .row("pear", canonicalizeTypeName("varchar(65535)"), "", "") - .row("mango", canonicalizeTypeName("integer"), "", "") - .row("lychee", canonicalizeTypeName("smallint"), "", "") - .row("kiwi", canonicalizeTypeName("tinyint"), "", "") - .row("apple", canonicalizeTypeName("varchar"), "partition key", "") - .row("pineapple", canonicalizeTypeName("varchar(65535)"), "partition key", "") + MaterializedResult expected = resultBuilder(getSession(), unboundedVarchar, unboundedVarchar, unboundedVarchar, unboundedVarchar, BIGINT, BIGINT, BIGINT) + .row("grape", canonicalizeTypeName("bigint"), "", "", 19L, null, null) + .row("orange", canonicalizeTypeName("bigint"), "", "", 19L, null, null) + .row("pear", canonicalizeTypeName("varchar(65535)"), "", "", null, null, 65535L) + .row("mango", canonicalizeTypeName("integer"), "", "", 10L, null, null) + .row("lychee", canonicalizeTypeName("smallint"), "", "", 5L, null, null) + .row("kiwi", canonicalizeTypeName("tinyint"), "", "", 3L, null, null) + .row("apple", canonicalizeTypeName("varchar"), "partition key", "", null, null, 2147483647L) + .row("pineapple", canonicalizeTypeName("varchar(65535)"), "partition key", "", null, null, 65535L) .build(); assertEquals(actual, expected); } @@ -2630,96 +2671,6 @@ public void testPreferManifestsToListFilesForUnPartitionedTable() } } - @Test - public void testFileRenamingForPartitionedTable() - { - try { - // Create partitioned table - assertUpdate( - Session.builder(getSession()) - .setCatalogSessionProperty(catalog, FILE_RENAMING_ENABLED, "true") - .setSystemProperty("scale_writers", "false") - .setSystemProperty("writer_min_size", "1MB") - .setSystemProperty("task_writer_count", "1") - .build(), - "CREATE TABLE partitioned_ordering_table (orderkey, custkey, totalprice, orderdate, orderpriority, clerk, shippriority, comment, orderstatus)\n" + - "WITH (partitioned_by = ARRAY['orderstatus'], preferred_ordering_columns = ARRAY['orderkey']) AS\n" + - "SELECT orderkey, custkey, totalprice, orderdate, orderpriority, clerk, shippriority, comment, orderstatus FROM tpch.tiny.orders", - (long) computeActual("SELECT count(*) FROM tpch.tiny.orders").getOnlyValue()); - - // Collect all file names - Map> partitionFileNamesMap = new HashMap<>(); - MaterializedResult partitionedResults = computeActual("SELECT DISTINCT \"$path\" FROM partitioned_ordering_table"); - for (int i = 0; i < partitionedResults.getRowCount(); i++) { - MaterializedRow row = partitionedResults.getMaterializedRows().get(i); - Path pathName = new Path((String) row.getField(0)); - String partitionName = pathName.getParent().toString(); - String fileName = pathName.getName(); - partitionFileNamesMap.putIfAbsent(partitionName, new ArrayList<>()); - partitionFileNamesMap.get(partitionName).add(Integer.valueOf(fileName)); - } - - // Assert that file names are a continuous increasing sequence for all partitions - for (String partitionName : partitionFileNamesMap.keySet()) { - List partitionedTableFileNames = partitionFileNamesMap.get(partitionName); - assertTrue(partitionedTableFileNames.size() > 0); - assertTrue(isIncreasingSequence(partitionedTableFileNames)); - } - } - finally { - assertUpdate("DROP TABLE IF EXISTS partitioned_ordering_table"); - } - } - - @Test - public void testFileRenamingForUnpartitionedTable() - { - try { - // Create un-partitioned table - assertUpdate( - Session.builder(getSession()) - .setCatalogSessionProperty(catalog, FILE_RENAMING_ENABLED, "true") - .setSystemProperty("scale_writers", "false") - .setSystemProperty("writer_min_size", "1MB") - .setSystemProperty("task_writer_count", "1") - .build(), - "CREATE TABLE unpartitioned_ordering_table AS SELECT * FROM tpch.tiny.orders", - (long) computeActual("SELECT count(*) FROM tpch.tiny.orders").getOnlyValue()); - - // Collect file names of the table - List fileNames = new ArrayList<>(); - MaterializedResult results = computeActual("SELECT DISTINCT \"$path\" FROM unpartitioned_ordering_table"); - for (int i = 0; i < results.getRowCount(); i++) { - MaterializedRow row = results.getMaterializedRows().get(i); - String pathName = (String) row.getField(0); - String fileName = new Path(pathName).getName(); - fileNames.add(Integer.valueOf(fileName)); - } - - assertTrue(fileNames.size() > 0); - - // Assert that file names are continuous increasing sequence - assertTrue(isIncreasingSequence(fileNames)); - } - finally { - assertUpdate("DROP TABLE IF EXISTS unpartitioned_ordering_table"); - } - } - - boolean isIncreasingSequence(List fileNames) - { - Collections.sort(fileNames); - - int i = 0; - for (int fileName : fileNames) { - if (i != fileName) { - return false; - } - i++; - } - return true; - } - @Test public void testShowCreateTable() { @@ -2781,6 +2732,7 @@ public void testShowCreateTable() actualResult = computeActual("SHOW CREATE TABLE \"test_show_create_table'2\""); assertEquals(getOnlyElement(actualResult.getOnlyColumnAsSet()), createTableSql); } + @Test public void testShowCreateSchema() { @@ -3218,7 +3170,7 @@ public void testAddColumn() assertUpdate("ALTER TABLE test_add_column ADD COLUMN b bigint COMMENT 'test comment BBB'"); assertQueryFails("ALTER TABLE test_add_column ADD COLUMN a varchar", ".* Column 'a' already exists"); assertQueryFails("ALTER TABLE test_add_column ADD COLUMN c bad_type", ".* Unknown type 'bad_type' for column 'c'"); - assertQuery("SHOW COLUMNS FROM test_add_column", "VALUES ('a', 'bigint', '', 'test comment AAA'), ('b', 'bigint', '', 'test comment BBB')"); + assertQuery("SHOW COLUMNS FROM test_add_column", "VALUES ('a', 'bigint', '', 'test comment AAA', 19, NULL, NULL), ('b', 'bigint', '', 'test comment BBB', 19, NULL, NULL)"); assertUpdate("DROP TABLE test_add_column"); } @@ -4442,6 +4394,48 @@ public void testInvalidPartitionValue() "\\QHive partition keys can only contain printable ASCII characters (0x20 - 0x7E). Invalid value: E2 98 83\\E"); } + @Test + public void testShowColumnMetadata() + { + String tableName = "test_show_column_table"; + + @Language("SQL") String createTable = "CREATE TABLE " + tableName + " (a bigint, b varchar, c double)"; + + Session testSession = testSessionBuilder() + .setIdentity(new Identity("test_access_owner", Optional.empty())) + .setCatalog(getSession().getCatalog().get()) + .setSchema(getSession().getSchema().get()) + .build(); + + assertUpdate(createTable); + + // verify showing columns over a table requires SELECT privileges for the table + assertAccessAllowed("SHOW COLUMNS FROM " + tableName); + assertAccessDenied(testSession, + "SHOW COLUMNS FROM " + tableName, + "Cannot show columns of table .*." + tableName + ".*", + privilege(tableName, SELECT_COLUMN)); + + @Language("SQL") String getColumnsSql = "" + + "SELECT lower(column_name) " + + "FROM information_schema.columns " + + "WHERE table_name = '" + tableName + "'"; + assertEquals(computeActual(getColumnsSql).getOnlyColumnAsSet(), ImmutableSet.of("a", "b", "c")); + + // verify with no SELECT privileges on table, querying information_schema will return empty columns + executeExclusively(() -> { + try { + getQueryRunner().getAccessControl().deny(privilege(tableName, SELECT_COLUMN)); + assertQueryReturnsEmptyResult(testSession, getColumnsSql); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + }); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testCurrentUserInView() { @@ -5333,17 +5327,6 @@ public void testPageFileFormatSmallSplitSize() assertUpdate("DROP TABLE test_pagefile_small_split"); } - @Test - public void testPageFileCompression() - { - for (HiveCompressionCodec compression : HiveCompressionCodec.values()) { - if (!compression.isSupportedStorageFormat(PAGEFILE)) { - continue; - } - testPageFileCompression(compression.name()); - } - } - @Test public void testPartialAggregatePushdownORC() { @@ -5769,31 +5752,35 @@ public void testParquetSelectivePageSourceFails() assertQueryFails(parquetFilterPushdownSession, "SELECT a FROM test_parquet_filter_pushdoown WHERE b = false", "Parquet reader doesn't support filter pushdown yet"); } - private void testPageFileCompression(String compression) + @DataProvider(name = "testFormatAndCompressionCodecs") + public Object[][] compressionCodecs() { - Session testSession = Session.builder(getQueryRunner().getDefaultSession()) - .setCatalogSessionProperty(catalog, "compression_codec", compression) - .setCatalogSessionProperty(catalog, "pagefile_writer_max_stripe_size", "100B") - .setCatalogSessionProperty(catalog, "max_split_size", "1kB") - .setCatalogSessionProperty(catalog, "max_initial_split_size", "1kB") - .build(); - - assertUpdate( - testSession, - "CREATE TABLE test_pagefile_compression\n" + - "WITH (\n" + - "format = 'PAGEFILE'\n" + - ") AS\n" + - "SELECT\n" + - "*\n" + - "FROM tpch.orders", - "SELECT count(*) FROM orders"); - - assertQuery(testSession, "SELECT count(*) FROM test_pagefile_compression", "SELECT count(*) FROM orders"); - - assertQuery(testSession, "SELECT sum(custkey) FROM test_pagefile_compression", "SELECT sum(custkey) FROM orders"); + return Stream.of(PARQUET, ORC, PAGEFILE) + .flatMap(format -> Arrays.stream(HiveCompressionCodec.values()) + .map(codec -> new Object[] {codec, format})) + .toArray(Object[][]::new); + } - assertUpdate("DROP TABLE test_pagefile_compression"); + @Test(dataProvider = "testFormatAndCompressionCodecs") + public void testFormatAndCompressionCodecs(HiveCompressionCodec codec, HiveStorageFormat format) + { + String tableName = "test_" + format.name().toLowerCase(ROOT) + "_compression_codec_" + codec.name().toLowerCase(ROOT); + Session session = Session.builder(getSession()) + .setCatalogSessionProperty("hive", COMPRESSION_CODEC, codec.name()).build(); + if (codec.isSupportedStorageFormat(format == PARQUET ? HiveStorageFormat.PARQUET : HiveStorageFormat.ORC)) { + assertUpdate(session, + format("CREATE TABLE %s WITH (format = '%s') AS SELECT * FROM orders", + tableName, format.name()), + "SELECT count(*) FROM orders"); + assertQuery(format("SELECT count(*) FROM %s", tableName), "SELECT count(*) FROM orders"); + assertQuery(format("SELECT sum(custkey) FROM %s", tableName), "SELECT sum(custkey) FROM orders"); + assertQuerySucceeds(format("DROP TABLE %s", tableName)); + } + else { + assertQueryFails(session, format("CREATE TABLE %s WITH (format = '%s') AS SELECT * FROM orders", + tableName, format.name()), + format("%s compression is not supported with %s", codec, format)); + } } private static Consumer assertTableWriterMergeNodeIsPresent() @@ -6123,11 +6110,130 @@ public void testRefreshMaterializedView() "SELECT COUNT(*) FROM ( " + expectedInsertQuery + " )", false, true); + refreshSql = "REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE regionkey = 1 OR nationkey = 24"; + expectedInsertQuery = "SELECT nation.name AS nationname, customer.custkey, customer.name AS customername, UPPER(customer.mktsegment) AS marketsegment, customer.nationkey, regionkey " + + "FROM test_nation_base_5 nation JOIN test_customer_base_5 customer ON (nation.nationkey = customer.nationkey) " + + "WHERE regionkey = 1 OR customer.nationkey = 24"; + QueryAssertions.assertQuery( + queryRunner, + session, + refreshSql, + queryRunner, + "SELECT COUNT(*) FROM ( " + expectedInsertQuery + " )", + false, true); + + // Test InPredicate + refreshSql = "REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE regionkey IN (1, 2)"; + expectedInsertQuery = "SELECT nation.name AS nationname, customer.custkey, customer.name AS customername, UPPER(customer.mktsegment) AS marketsegment, customer.nationkey, regionkey " + + "FROM test_nation_base_5 nation JOIN test_customer_base_5 customer ON (nation.nationkey = customer.nationkey) " + + "WHERE regionkey IN (1, 2)"; + QueryAssertions.assertQuery( + queryRunner, + session, + refreshSql, + queryRunner, + "SELECT COUNT(*) FROM ( " + expectedInsertQuery + " )", + false, true); + + refreshSql = "REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE nationkey IN (1, 5, 10) AND regionkey = 1"; + expectedInsertQuery = "SELECT nation.name AS nationname, customer.custkey, customer.name AS customername, UPPER(customer.mktsegment) AS marketsegment, customer.nationkey, regionkey " + + "FROM test_nation_base_5 nation JOIN test_customer_base_5 customer ON (nation.nationkey = customer.nationkey) " + + "WHERE nation.nationkey IN (1, 5, 10) AND regionkey = 1"; + QueryAssertions.assertQuery( + queryRunner, + session, + refreshSql, + queryRunner, + "SELECT COUNT(*) FROM ( " + expectedInsertQuery + " )", + false, true); + // Test invalid predicates assertQueryFails("REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE nationname = 'UNITED STATES'", ".*Refresh materialized view by column nationname is not supported.*"); - assertQueryFails("REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE regionkey = 1 OR nationkey = 24", ".*Only logical AND is supported in WHERE clause.*"); - assertQueryFails("REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE regionkey + nationkey = 25", ".*Only columns specified on literals are supported in WHERE clause.*"); - assertQueryFails("REFRESH MATERIALIZED VIEW test_customer_view_5", ".*mismatched input ''\\. Expecting: '\\.', 'WHERE'.*"); + assertQueryFails("REFRESH MATERIALIZED VIEW test_customer_view_5 WHERE regionkey + nationkey = 25", ".*Only column references are supported on the left side of comparison expressions in WHERE clause.*"); + } + + @Test + public void testAutoRefreshMaterializedViewFailsWithoutFlag() + { + computeActual("CREATE TABLE test_orders_no_flag WITH (partitioned_by = ARRAY['orderstatus']) " + + "AS SELECT orderkey, totalprice, orderstatus FROM orders WHERE orderkey < 100"); + computeActual( + "CREATE MATERIALIZED VIEW test_orders_no_flag_view WITH (partitioned_by = ARRAY['orderstatus']" + retentionDays(30) + ") " + + "AS SELECT SUM(totalprice) AS total, orderstatus FROM test_orders_no_flag GROUP BY orderstatus"); + + assertQueryFails( + "REFRESH MATERIALIZED VIEW test_orders_no_flag_view", + ".*misses too many partitions or is never refreshed and may incur high cost.*"); + + computeActual("DROP MATERIALIZED VIEW test_orders_no_flag_view"); + computeActual("DROP TABLE test_orders_no_flag"); + } + + @Test + public void testMaterializedViewBaseTableDropped() + { + Session stitchingEnabledSession = Session.builder(getSession()) + .setSystemProperty(MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED, "true") + .build(); + Session stichingDisabledSession = Session.builder(getSession()) + .setSystemProperty(MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED, "false") + .build(); + + assertUpdate("CREATE TABLE drop_table_test (id BIGINT, partkey VARCHAR) " + + "WITH (partitioned_by=ARRAY['partkey'])"); + assertUpdate("INSERT INTO drop_table_test VALUES (1, 'p1'), (2, 'p2')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_drop " + + "with (partitioned_by=ARRAY['partkey']) " + + "AS SELECT id, partkey FROM drop_table_test"); + + assertUpdate(stitchingEnabledSession, "REFRESH MATERIALIZED VIEW mv_drop", 2); + + assertUpdate("DROP TABLE drop_table_test"); + + assertQueryFails(stitchingEnabledSession, "SELECT COUNT(*) FROM mv_drop", + ".*Table .* not found.*"); + assertQueryFails(stichingDisabledSession, "SELECT * FROM mv_drop ORDER BY id", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_drop", + ".*Table .* not found.*"); + + computeActual("DROP MATERIALIZED VIEW mv_drop"); + } + + @Test + public void testMaterializedViewBaseTableRenamed() + { + Session stitchingEnabledSession = Session.builder(getSession()) + .setSystemProperty(MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED, "true") + .build(); + Session stichingDisabledSession = Session.builder(getSession()) + .setSystemProperty(MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED, "false") + .build(); + + assertUpdate("CREATE TABLE rename_table_test (id BIGINT, partkey VARCHAR) " + + "WITH (partitioned_by=ARRAY['partkey'])"); + assertUpdate("INSERT INTO rename_table_test VALUES (1, 'p1'), (2, 'p2')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_rename_test " + + "with (partitioned_by=ARRAY['partkey']) " + + "AS SELECT id, partkey FROM rename_table_test"); + + assertUpdate(stitchingEnabledSession, "REFRESH MATERIALIZED VIEW mv_rename_test", 2); + + assertUpdate("ALTER TABLE rename_table_test RENAME TO rename_table_test_new"); + + assertQueryFails(stitchingEnabledSession, "SELECT COUNT(*) FROM mv_rename_test", + ".*Table .* not found.*"); + assertQueryFails(stichingDisabledSession, "SELECT * FROM mv_rename_test ORDER BY id", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_rename_test", + ".*Table .* not found.*"); + + computeActual("DROP TABLE rename_table_test_new"); + computeActual("DROP MATERIALIZED VIEW mv_rename_test"); } @Test @@ -6192,32 +6298,6 @@ public void testInvokedFunctionNamesLog() assertEqualsNoOrder(queryInfo.getScalarFunctions(), ImmutableList.of("presto.default.transform", "presto.default.abs")); } - @Test - public void testGroupByLimitPartitionKeys() - { - Session prefilter = Session.builder(getSession()) - .setSystemProperty("prefilter_for_groupby_limit", "true") - .build(); - - @Language("SQL") String createTable = "" + - "CREATE TABLE test_create_partitioned_table_as " + - "WITH (" + - "partitioned_by = ARRAY[ 'orderstatus' ]" + - ") " + - "AS " + - "SELECT custkey, orderkey, orderstatus FROM tpch.tiny.orders"; - - assertUpdate(prefilter, createTable, 15000); - prefilter = Session.builder(prefilter) - .setSystemProperty("prefilter_for_groupby_limit", "true") - .build(); - - MaterializedResult plan = computeActual(prefilter, "explain(type distributed) select count(custkey), orderstatus from test_create_partitioned_table_as group by orderstatus limit 1000"); - assertFalse(((String) plan.getOnlyValue()).toUpperCase().indexOf("MAP_AGG") >= 0); - plan = computeActual(prefilter, "explain(type distributed) select count(custkey), orderkey from test_create_partitioned_table_as group by orderkey limit 1000"); - assertTrue(((String) plan.getOnlyValue()).toUpperCase().indexOf("MAP_AGG") >= 0); - } - @Test public void testJoinPrefilterPartitionKeys() { @@ -6875,6 +6955,141 @@ public void testAlterColumnNotNull() assertUpdate(getSession(), dropTableStmt); } + private void testCreateTableWithHeaderAndFooter(String format) + { + String name = format.toLowerCase(ENGLISH); + String catalog = getSession().getCatalog().get(); + String schema = getSession().getSchema().get(); + @Language("SQL") String createTableSql = + format("CREATE TABLE %s.%s.%s_table_skip_header (\n" + + " \"name\" varchar\n" + + ")\n" + + "WITH (\n" + + " format = '%s',\n" + + " skip_header_line_count = 1\n" + + ")", + catalog, schema, name, format); + + assertUpdate(createTableSql); + MaterializedResult actual = + computeActual(format("SHOW CREATE TABLE %s_table_skip_header", name)); + assertEquals(actual.getOnlyValue(), createTableSql); + assertUpdate(format("DROP TABLE %s_table_skip_header", name)); + + @Language("SQL") String createFooter = + format("CREATE TABLE %s.%s.%s_table_skip_footer (\n" + + " \"name\" varchar\n" + + ")\n" + + "WITH (\n" + + " format = '%s',\n" + + " skip_footer_line_count = 1\n" + + ")", + catalog, schema, name, format); + + assertThatThrownBy(() -> assertUpdate(createFooter)) + .hasMessageContaining("Cannot create non external table with skip.footer.line.count property"); + + @Language("SQL") String createHeaderFooter = + format("CREATE TABLE %s.%s.%s_table_skip_header_footer (\n" + + " \"name\" varchar\n" + + ")\n" + + "WITH (\n" + + " format = '%s',\n" + + " skip_footer_line_count = 1,\n" + + " skip_header_line_count = 1\n" + + ")", + catalog, schema, name, format); + + assertThatThrownBy(() -> assertUpdate(createHeaderFooter)) + .hasMessageContaining("Cannot create non external table with skip.footer.line.count property"); + + createTableSql = + format("CREATE TABLE %s.%s.%s_table_skip_header " + + "WITH (\n" + + " format = '%s',\n" + + " skip_header_line_count = 1\n" + + ") AS SELECT CAST(1 AS VARCHAR) AS col_name1, CAST(2 AS VARCHAR) AS col_name2", + catalog, schema, name, format); + + assertUpdate(createTableSql, 1); + assertUpdate( + format("INSERT INTO %s.%s.%s_table_skip_header VALUES('3', '4')", + catalog, schema, name), + 1); + + MaterializedResult materializedRows = + computeActual(format("SELECT * FROM %s_table_skip_header", name)); + + assertEqualsIgnoreOrder( + materializedRows, + resultBuilder(getSession(), VARCHAR, VARCHAR) + .row("1", "2") + .row("3", "4") + .build() + .getMaterializedRows()); + + assertUpdate(format("DROP TABLE %s_table_skip_header", name)); + } + + @Test + public void testCreateTableWithHeaderAndFooterForCsv() + { + testCreateTableWithHeaderAndFooter("CSV"); + } + @Test + public void testInsertTableWithHeaderAndFooterForCsv() + { + String catalog = getSession().getCatalog().get(); + String schema = getSession().getSchema().get(); + + @Language("SQL") String createHeader = + format("CREATE TABLE %s.%s.csv_table_skip_header (\n" + + " name VARCHAR\n" + + ")\n" + + "WITH (\n" + + " format = 'CSV',\n" + + " skip_header_line_count = 2\n" + + ")", + catalog, schema); + + assertUpdate(createHeader); + + assertThatThrownBy(() -> + assertUpdate(format( + "INSERT INTO %s.%s.csv_table_skip_header VALUES ('name')", + catalog, schema))) + .hasMessageMatching("INSERT into .* skip.header.line.count property greater than 1 is not supported"); + + assertUpdate("DROP TABLE csv_table_skip_header"); + + createHeader = + format("CREATE TABLE %s.%s.csv_table_skip_header (\n" + + " name VARCHAR\n" + + ")\n" + + "WITH (\n" + + " format = 'CSV',\n" + + " skip_header_line_count = 1\n" + + ")", + catalog, schema); + + assertUpdate(createHeader); + + assertUpdate(format( + "INSERT INTO %s.%s.csv_table_skip_header VALUES ('name')", catalog, schema), 1); + + MaterializedResult materializedRows = + computeActual(format("SELECT * FROM %s.%s.csv_table_skip_header", catalog, schema)); + + assertEqualsIgnoreOrder( + materializedRows, + resultBuilder(getSession(), VARCHAR) + .row("name") + .build() + .getMaterializedRows()); + + assertUpdate("DROP TABLE csv_table_skip_header"); + } + protected String retentionDays(int days) { return ""; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java index c1d9c11660993..3fec92f66e088 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveLogicalPlanner.java @@ -84,6 +84,9 @@ import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_METADATA_QUERIES_IGNORE_STATS; import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_DEREFERENCE_ENABLED; import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_ENABLED; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_FOR_CARDINALITY; +import static com.facebook.presto.SystemSessionProperties.PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS; +import static com.facebook.presto.SystemSessionProperties.UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.predicate.Domain.create; import static com.facebook.presto.common.predicate.Domain.multipleValues; @@ -129,6 +132,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; @@ -166,6 +170,13 @@ protected QueryRunner createQueryRunner() Optional.empty()); } + @Override + protected QueryRunner createExpectedQueryRunner() + throws Exception + { + return getQueryRunner(); + } + @Test public void testMetadataQueryOptimizationWithLimit() { @@ -1366,6 +1377,25 @@ public void testPushdownSubfields() assertPushdownSubfields("SELECT x.a FROM test_pushdown_struct_subfields WHERE x.a > 10 AND x.b LIKE 'abc%'", "test_pushdown_struct_subfields", ImmutableMap.of("x", toSubfields("x.a", "x.b"))); + assertQuery("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)"); + assertQuery("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)"); + assertQuery("SELECT struct FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)"); + assertQuery("SELECT struct FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)"); + + assertPushdownSubfields("SELECT struct.b FROM (SELECT x AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of("x", toSubfields("x.b"))); + assertPushdownSubfields("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of("x", toSubfields("x.b"))); + assertPushdownSubfields("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of()); + + assertPushdownSubfields("SELECT struct FROM (SELECT x AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of()); + assertPushdownSubfields("SELECT struct FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of()); + assertPushdownSubfields("SELECT struct FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields", + ImmutableMap.of()); + // Join assertPlan("SELECT l.orderkey, x.a, mod(x.d.d1, 2) FROM lineitem l, test_pushdown_struct_subfields a WHERE l.linenumber = a.id", anyTree( @@ -1442,6 +1472,222 @@ public void testPushdownSubfields() assertUpdate("DROP TABLE test_pushdown_struct_subfields"); } + @Test + public void testPushdownNegativeSubfiels() + { + assertUpdate("CREATE TABLE test_pushdown_subfields_negative_key(id bigint, arr array(bigint), mp map(integer, varchar))"); + assertPushdownSubfields("select element_at(arr, -1) from test_pushdown_subfields_negative_key", "test_pushdown_subfields_negative_key", ImmutableMap.of("arr", toSubfields())); + assertPushdownSubfields("select element_at(mp, -1) from test_pushdown_subfields_negative_key", "test_pushdown_subfields_negative_key", ImmutableMap.of("mp", toSubfields("mp[-1]"))); + assertPushdownSubfields("select element_at(arr, -1), element_at(arr, 2) from test_pushdown_subfields_negative_key", "test_pushdown_subfields_negative_key", ImmutableMap.of("arr", toSubfields())); + assertPushdownSubfields("select element_at(mp, -1), element_at(mp, 2) from test_pushdown_subfields_negative_key", "test_pushdown_subfields_negative_key", ImmutableMap.of("mp", toSubfields("mp[-1]", "mp[2]"))); + + assertUpdate("DROP TABLE test_pushdown_subfields_negative_key"); + } + + @Test + public void testPushdownSubfieldsForMapSubset() + { + Session mapSubset = Session.builder(getSession()).setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "true").build(); + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x map(integer, double))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[1, 2, 3]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[1]", "x[2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[-1, -2, 3]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[-1]", "x[-2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[1, 2, null]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[1, 2, 3, id]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[id]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x array(map(integer, double)))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_subset(mp, array[1, 2, 3])) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[*][1]", "x[*][2]", "x[*][3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_subset(mp, array[1, 2, null])) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_subset(mp, array[1, 2, id])) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id varchar, x map(varchar, double))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array['ab', 'c', 'd']) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[\"ab\"]", "x[\"c\"]", "x[\"d\"]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array['ab', 'c', null]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array['ab', 'c', 'd', id]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_subset(x, array[id]) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + } + + @Test + public void testPushdownSubfieldsForMapFilter() + { + Session mapSubset = Session.builder(getSession()).setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, "true").build(); + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x map(integer, double))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (1, 2, 3)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[1]", "x[2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> v in (1, 2, 3)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[1, 2, 3], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[1]", "x[2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[1, 2, 3], v)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k = 1) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[1]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> v = 1) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> 1 = k) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[1]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> 1 = v) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (-1, -2, 3)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[-1]", "x[-2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[-1, -2, 3], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[-1]", "x[-2]", "x[3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k=-2) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[-2]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> -2=k) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[-2]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (1, 2, null)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k is null) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[1, 2, null], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (1, 2, 3, id)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[1, 2, 3, id], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k = id) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> id = k) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (id)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[id], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id integer, x array(map(integer, double)))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k in (1, 2, 3))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[*][1]", "x[*][2]", "x[*][3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> v in (1, 2, 3))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> contains(array[1, 2, 3], k))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[*][1]", "x[*][2]", "x[*][3]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> contains(array[1, 2, 3], v))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k=2)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[*][2]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> v=2)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> 2=k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[*][2]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> 2=v)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k in (1, 2, null))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> contains(array[1, 2, null], k))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k in (1, 2, id))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> k=id)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> id=k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, transform(x, mp -> map_filter(mp, (k, v) -> contains(array[1, 2, id], k))) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + + assertUpdate("CREATE TABLE test_pushdown_map_subfields(id varchar, x map(varchar, varchar))"); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in ('ab', 'c', 'd')) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[\"ab\"]", "x[\"c\"]", "x[\"d\"]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array['ab', 'c', 'd'], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[\"ab\"]", "x[\"c\"]", "x[\"d\"]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k='d') FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[\"d\"]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> 'd'=k) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields("x[\"d\"]"))); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> v in ('ab', 'c', 'd')) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array['ab', 'c', 'd'], v)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> v='d') FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> 'd'=v) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in ('ab', 'c', null)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array['ab', 'c', null], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in ('ab', 'c', 'd', id)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k = id) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> id = k) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array['ab', 'c', 'd', id], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> k in (id)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertPushdownSubfields(mapSubset, "SELECT t.id, map_filter(x, (k, v) -> contains(array[id], k)) FROM test_pushdown_map_subfields t", "test_pushdown_map_subfields", + ImmutableMap.of("x", toSubfields())); + assertUpdate("DROP TABLE test_pushdown_map_subfields"); + } + + @Test + public void testPushdownSubfieldsForCardinality() + { + Session cardinalityPushdown = Session.builder(getSession()) + .setSystemProperty(PUSHDOWN_SUBFIELDS_FOR_CARDINALITY, "true") + .build(); + + // Test simple cardinality pushdown for MAP + assertUpdate("CREATE TABLE test_pushdown_cardinality_map(id integer, x map(integer, double))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT t.id, cardinality(x) FROM test_pushdown_cardinality_map t", "test_pushdown_cardinality_map", + ImmutableMap.of("x", toSubfields("x[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_map"); + + // Test cardinality pushdown for ARRAY + assertUpdate("CREATE TABLE test_pushdown_cardinality_array(id integer, arr array(bigint))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT t.id, cardinality(arr) FROM test_pushdown_cardinality_array t", "test_pushdown_cardinality_array", + ImmutableMap.of("arr", toSubfields("arr[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_array"); + + // Test cardinality in WHERE clause + assertUpdate("CREATE TABLE test_pushdown_cardinality_where(id integer, features map(varchar, double))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT t.id FROM test_pushdown_cardinality_where t WHERE cardinality(features) > 10", "test_pushdown_cardinality_where", + ImmutableMap.of("features", toSubfields("features[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_where"); + + // Test cardinality in aggregation + assertUpdate("CREATE TABLE test_pushdown_cardinality_agg(id integer, data map(integer, varchar))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT AVG(cardinality(data)) FROM test_pushdown_cardinality_agg", "test_pushdown_cardinality_agg", + ImmutableMap.of("data", toSubfields("data[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_agg"); + + // Test multiple cardinalities + assertUpdate("CREATE TABLE test_pushdown_cardinality_multi(id integer, map1 map(integer, double), map2 map(varchar, integer))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT cardinality(map1), cardinality(map2) FROM test_pushdown_cardinality_multi", "test_pushdown_cardinality_multi", + ImmutableMap.of("map1", toSubfields("map1[$]"), "map2", toSubfields("map2[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_multi"); + + // Test cardinality with complex expression + assertUpdate("CREATE TABLE test_pushdown_cardinality_expr(id integer, tags map(varchar, varchar))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT cardinality(tags) * 2 FROM test_pushdown_cardinality_expr", "test_pushdown_cardinality_expr", + ImmutableMap.of("tags", toSubfields("tags[$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_expr"); + + // Test cardinality on ARRAY of maps + assertUpdate("CREATE TABLE test_pushdown_cardinality_nested(id integer, arr_of_maps array(map(integer, varchar)))"); + assertPushdownSubfields(cardinalityPushdown, "SELECT transform(arr_of_maps, m -> cardinality(m)) FROM test_pushdown_cardinality_nested", "test_pushdown_cardinality_nested", + ImmutableMap.of("arr_of_maps", toSubfields("arr_of_maps[*][$]"))); + assertUpdate("DROP TABLE test_pushdown_cardinality_nested"); + } + @Test public void testPushdownSubfieldsAssorted() { @@ -1960,6 +2206,336 @@ public void testPartialAggregatePushdown() } } + @Test + public void testRowId() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, "true") + .build(); + String sql; + sql = "SELECT\n" + + " unique_id AS unique_id,\n" + + " ARBITRARY(name) AS customer_name,\n" + + " ARRAY_AGG(orderkey) AS orders_info\n" + + "FROM (\n" + + " SELECT\n" + + " customer.name,\n" + + " orders.orderkey,\n" + + " customer.\"$row_id\" AS unique_id\n" + + " FROM customer\n" + + " LEFT JOIN orders\n" + + " ON customer.custkey = orders.custkey\n" + + " AND orders.orderstatus IN ('O', 'F')\n" + + " AND orders.orderdate BETWEEN DATE '1995-01-01' AND DATE '1995-12-31'\n" + + " WHERE\n" + + " customer.nationkey IN (1, 2, 3, 4, 5)\n" + + ")\n" + + "GROUP BY\n" + + " unique_id"; + + assertPlan(session, + sql, + anyTree( + aggregation( + ImmutableMap.of(), + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))))); + + sql = "select \"$row_id\", count(*) from orders group by 1"; + assertPlan(sql, anyTree(aggregation(ImmutableMap.of(), tableScan("orders")))); + sql = "SELECT\n" + + " customer.\"$row_id\" AS unique_id,\n" + + " COUNT(orders.orderkey) AS order_count,\n" + + " ARRAY_AGG(orders.orderkey) AS order_keys\n" + + "FROM customer\n" + + "LEFT JOIN orders\n" + + " ON customer.custkey = orders.custkey\n" + + "WHERE customer.acctbal > 1000\n" + + "GROUP BY customer.\"$row_id\""; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " unique_id,\n" + + " MAX(orderdate) AS latest_order_date\n" + + "FROM (\n" + + " SELECT\n" + + " customer.\"$row_id\" AS unique_id,\n" + + " orders.orderdate\n" + + " FROM customer\n" + + " JOIN orders\n" + + " ON customer.custkey = orders.custkey\n" + + " WHERE orders.orderstatus = 'O'\n" + + ") t\n" + + "GROUP BY unique_id"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " DISTINCT customer.\"$row_id\" AS unique_id,\n" + + " customer.name\n" + + "FROM customer\n" + + "WHERE customer.nationkey = 1"; + + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + project(filter(tableScan("customer")))))); + + sql = "SELECT\n" + + " c.name,\n" + + " o.orderkey,\n" + + " c.\"$row_id\" AS customer_row_id,\n" + + " o.\"$row_id\" AS order_row_id\n" + + "FROM customer c\n" + + "JOIN orders o\n" + + " ON c.\"$row_id\" = o.\"$row_id\"\n" + + "WHERE o.totalprice > 10000"; + assertPlan(sql, anyTree( + join( + exchange(anyTree(tableScan("customer"))), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " \"$row_id\" AS unique_id,\n" + + " orderstatus,\n" + + " COUNT(*) AS cnt\n" + + "FROM orders\n" + + "GROUP BY \"$row_id\", orderstatus"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + project(tableScan("orders"))))); + + sql = "SELECT\n" + + " o1.orderkey,\n" + + " o2.totalprice\n" + + "FROM orders o1\n" + + "JOIN orders o2\n" + + " ON o1.\"$row_id\" = o2.\"$row_id\"\n" + + "WHERE o1.orderstatus = 'O'"; + assertPlan(sql, anyTree( + join( + exchange(anyTree(tableScan("orders"))), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " orderkey,\n" + + " totalprice\n" + + "FROM orders o1\n" + + "WHERE EXISTS (\n" + + " SELECT 1\n" + + " FROM orders o2\n" + + " WHERE o1.\"$row_id\" = o2.\"$row_id\"\n" + + " AND o2.orderstatus = 'F'\n" + + ")"; + assertPlan(sql, anyTree( + join( + exchange(anyTree(tableScan("orders"))), + exchange(anyTree(aggregation(ImmutableMap.of(), project(filter(tableScan("orders"))))))))); + + sql = "SELECT\n" + + " custkey,\n" + + " name\n" + + "FROM customer\n" + + "WHERE \"$row_id\" IN (\n" + + " SELECT \"$row_id\"\n" + + " FROM customer\n" + + " WHERE acctbal > 5000\n" + + ")"; + assertPlan(sql, anyTree( + semiJoin( + anyTree(tableScan("customer")), + anyTree(tableScan("customer"))))); + + sql = "SELECT \"$row_id\", orderkey FROM orders WHERE orderstatus = 'O'\n" + + "UNION\n" + + "SELECT \"$row_id\", orderkey FROM orders WHERE orderstatus = 'F'"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + project(filter(tableScan("orders")))))); + + sql = "SELECT\n" + + " c.\"$row_id\",\n" + + " COUNT(o.orderkey) AS order_count,\n" + + " SUM(o.totalprice) AS total_spent,\n" + + " MAX(o.orderdate) AS latest_order,\n" + + " MIN(o.orderdate) AS first_order\n" + + "FROM customer c\n" + + "LEFT JOIN orders o ON c.custkey = o.custkey\n" + + "GROUP BY c.\"$row_id\""; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " \"$row_id\",\n" + + " COUNT(*) AS cnt\n" + + "FROM orders\n" + + "GROUP BY \"$row_id\"\n" + + "HAVING COUNT(*) = 1"; + assertPlan(sql, anyTree( + project(filter( + aggregation(ImmutableMap.of(), + tableScan("orders")))))); + + sql = "SELECT\n" + + " c.custkey,\n" + + " c.name,\n" + + " (\n" + + " SELECT COUNT(*)\n" + + " FROM orders o\n" + + " WHERE o.custkey = c.custkey\n" + + " AND o.\"$row_id\" IS NOT NULL\n" + + " ) AS order_count\n" + + "FROM customer c"; + assertPlan(sql, anyTree( + project( + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " outer_query.customer_id,\n" + + " outer_query.order_count\n" + + "FROM (\n" + + " SELECT\n" + + " c.\"$row_id\" AS customer_id,\n" + + " COUNT(DISTINCT o.\"$row_id\") AS order_count\n" + + " FROM customer c\n" + + " LEFT JOIN orders o ON c.custkey = o.custkey\n" + + " WHERE c.nationkey IN (1, 2, 3)\n" + + " GROUP BY c.\"$row_id\"\n" + + ") outer_query\n" + + "WHERE outer_query.order_count > 2"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + project( + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders"))))))); + + sql = "SELECT\n" + + " c.custkey,\n" + + " c.name\n" + + "FROM customer c\n" + + "WHERE NOT EXISTS (\n" + + " SELECT 1\n" + + " FROM orders o\n" + + " WHERE c.\"$row_id\" = o.\"$row_id\"\n" + + ")"; + assertPlan(sql, anyTree( + join( + exchange(anyTree(tableScan("customer"))), + exchange(anyTree(aggregation(ImmutableMap.of(), tableScan("orders"))))))); + + sql = "SELECT\n" + + " c.name,\n" + + " o.orderkey,\n" + + " l.linenumber\n" + + "FROM customer c\n" + + "JOIN orders o ON c.custkey = o.custkey\n" + + "JOIN lineitem l ON o.orderkey = l.orderkey\n" + + "WHERE c.\"$row_id\" IS NOT NULL\n" + + " AND o.\"$row_id\" IS NOT NULL\n" + + " AND l.\"$row_id\" IS NOT NULL"; + assertPlan(sql, anyTree( + join( + anyTree( + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders")))), + anyTree(tableScan("lineitem"))))); + + sql = "SELECT\n" + + " c.\"$row_id\" AS customer_row_id,\n" + + " o.\"$row_id\" AS order_row_id,\n" + + " c.name,\n" + + " o.orderkey\n" + + "FROM customer c\n" + + "CROSS JOIN orders o\n" + + "WHERE c.nationkey = 1 AND o.orderstatus = 'O'"; + assertPlan(sql, anyTree( + join( + anyTree(tableScan("customer")), + anyTree(tableScan("orders"))))); + + sql = "SELECT\n" + + " orderstatus,\n" + + " COUNT(\"$row_id\") AS row_count,\n" + + " MIN(\"$row_id\") AS min_row_id\n" + + "FROM orders\n" + + "GROUP BY orderstatus"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " xxhash64(\"$row_id\") AS bucket,\n" + + " COUNT(*) AS cnt\n" + + "FROM orders\n" + + "GROUP BY xxhash64(\"$row_id\")"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT \"$row_id\", 'customer' AS source FROM customer\n" + + "UNION ALL\n" + + "SELECT \"$row_id\", 'orders' AS source FROM orders"; + assertPlan(sql, anyTree( + exchange( + anyTree(tableScan("customer")), + anyTree(tableScan("orders"))))); + + sql = "SELECT\n" + + " o.\"$row_id\" AS order_row_id,\n" + + " o.orderkey,\n" + + " l.linenumber\n" + + "FROM orders o\n" + + "JOIN lineitem l ON o.orderkey = l.orderkey\n" + + "WHERE o.orderstatus = 'O'"; + assertPlan(sql, anyTree( + join( + exchange(anyTree(tableScan("lineitem"))), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " orderkey,\n" + + " totalprice\n" + + "FROM orders o1\n" + + "WHERE \"$row_id\" > (\n" + + " SELECT MIN(\"$row_id\")\n" + + " FROM orders o2\n" + + " WHERE o2.orderstatus = 'F'\n" + + ")"; + assertPlan(sql, anyTree( + join( + tableScan("orders"), + exchange(anyTree(tableScan("orders")))))); + + sql = "SELECT\n" + + " CASE \n" + + " WHEN nationkey = 1 THEN \"$row_id\"\n" + + " ELSE cast('default' as varbinary)\n" + + " END AS conditional_row_id,\n" + + " COUNT(*) AS cnt\n" + + "FROM customer\n" + + "GROUP BY CASE \n" + + " WHEN nationkey = 1 THEN \"$row_id\"\n" + + " ELSE cast('default' as varbinary)\n" + + " END"; + assertPlan(sql, anyTree( + aggregation(ImmutableMap.of(), + exchange(anyTree(tableScan("customer")))))); + } + private static Set toSubfields(String... subfieldPaths) { return Arrays.stream(subfieldPaths) @@ -1977,6 +2553,11 @@ private void assertPushdownSubfields(Session session, String query, String table assertPlan(session, query, anyTree(tableScan(tableName, requiredSubfields))); } + private static PlanMatchPattern tableScan(String expectedTableName) + { + return PlanMatchPattern.tableScan(expectedTableName); + } + private static PlanMatchPattern tableScan(String expectedTableName, Map> expectedRequiredSubfields) { return PlanMatchPattern.tableScan(expectedTableName).with(new HiveTableScanMatcher(expectedRequiredSubfields)); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewLogicalPlanner.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewLogicalPlanner.java index 6b53172eff527..6767fc4d96b9c 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewLogicalPlanner.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewLogicalPlanner.java @@ -37,11 +37,11 @@ import java.util.Collections; import java.util.List; import java.util.Optional; -import java.util.function.Consumer; import static com.facebook.presto.SystemSessionProperties.CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; +import static com.facebook.presto.SystemSessionProperties.MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED; import static com.facebook.presto.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION; import static com.facebook.presto.SystemSessionProperties.QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; @@ -72,6 +72,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; @@ -87,6 +88,7 @@ import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertTrue; @@ -100,10 +102,237 @@ protected QueryRunner createQueryRunner() { return HiveQueryRunner.createQueryRunner( ImmutableList.of(ORDERS, LINE_ITEM, CUSTOMER, NATION, SUPPLIER), - ImmutableMap.of(), + ImmutableMap.of("experimental.allow-legacy-materialized-views-toggle", "true"), Optional.empty()); } + @Test + public void testMaterializedViewPartitionFilteringThroughLogicalView() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "orders_partitioned_lv_test"; + String materializedView = "orders_mv_lv_test"; + String logicalView = "orders_lv_test"; + + try { + // Create a table partitioned by 'ds' (date string) + queryRunner.execute(format("CREATE TABLE %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, totalprice, '2025-11-10' AS ds FROM orders WHERE orderkey < 1000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-11' AS ds FROM orders WHERE orderkey >= 1000 AND orderkey < 2000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-12' AS ds FROM orders WHERE orderkey >= 2000 AND orderkey < 3000", table)); + + // Create a materialized view partitioned by 'ds' + queryRunner.execute(format("CREATE MATERIALIZED VIEW %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT max(totalprice) as max_price, orderkey, ds FROM %s GROUP BY orderkey, ds", materializedView, table)); + + assertTrue(getQueryRunner().tableExists(getSession(), materializedView)); + + // Only refresh partition for '2025-11-10', leaving '2025-11-11' and '2025-11-12' missing + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE ds='2025-11-10'", materializedView), 255); + + // Create a logical view on top of the materialized view + queryRunner.execute(format("CREATE VIEW %s AS SELECT * FROM %s", logicalView, materializedView)); + + setReferencedMaterializedViews((DistributedQueryRunner) queryRunner, table, ImmutableList.of(materializedView)); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(1)) + .build(); + + // Query the logical view with a predicate + // The predicate should be pushed down to the materialized view + // Since only ds='2025-11-10' is refreshed and that's what we're querying, + // the materialized view should be used (not fall back to base table) + String logicalViewQuery = format("SELECT max_price, orderkey FROM %s WHERE ds='2025-11-10' ORDER BY orderkey", logicalView); + String directMvQuery = format("SELECT max_price, orderkey FROM %s WHERE ds='2025-11-10' ORDER BY orderkey", materializedView); + String baseTableQuery = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE ds='2025-11-10' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResult = computeActual(session, baseTableQuery); + MaterializedResult logicalViewResult = computeActual(session, logicalViewQuery); + MaterializedResult directMvResult = computeActual(session, directMvQuery); + + // All three queries should return the same results + assertEquals(baseQueryResult, logicalViewResult); + assertEquals(baseQueryResult, directMvResult); + + // The plan for the logical view query should use the materialized view + // (not fall back to base table) because we're only querying the refreshed partition + assertPlan(session, logicalViewQuery, anyTree( + constrainedTableScan( + materializedView, + ImmutableMap.of("ds", singleValue(createVarcharType(10), utf8Slice("2025-11-10"))), + ImmutableMap.of()))); + + // Test query for a missing partition through logical view + // This should fall back to base table because ds='2025-11-11' is not refreshed + String logicalViewQueryMissing = format("SELECT max_price, orderkey FROM %s WHERE ds='2025-11-11' ORDER BY orderkey", logicalView); + String baseTableQueryMissing = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE ds='2025-11-11' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResultMissing = computeActual(session, baseTableQueryMissing); + MaterializedResult logicalViewResultMissing = computeActual(session, logicalViewQueryMissing); + + assertEquals(baseQueryResultMissing, logicalViewResultMissing); + + // Should fall back to base table for missing partition + assertPlan(session, logicalViewQueryMissing, anyTree( + constrainedTableScan(table, ImmutableMap.of(), ImmutableMap.of()))); + } + finally { + queryRunner.execute("DROP VIEW IF EXISTS " + logicalView); + queryRunner.execute("DROP MATERIALIZED VIEW IF EXISTS " + materializedView); + queryRunner.execute("DROP TABLE IF EXISTS " + table); + } + } + + @Test + public void testMaterializedViewPartitionFilteringThroughLogicalViewWithCTE() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "orders_partitioned_cte_test"; + String materializedView = "orders_mv_cte_test"; + String logicalView = "orders_lv_cte_test"; + + try { + // Create a table partitioned by 'ds' (date string) + queryRunner.execute(format("CREATE TABLE %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, totalprice, '2025-11-10' AS ds FROM orders WHERE orderkey < 1000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-11' AS ds FROM orders WHERE orderkey >= 1000 AND orderkey < 2000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-12' AS ds FROM orders WHERE orderkey >= 2000 AND orderkey < 3000", table)); + + // Create a materialized view partitioned by 'ds' + queryRunner.execute(format("CREATE MATERIALIZED VIEW %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT max(totalprice) as max_price, orderkey, ds FROM %s GROUP BY orderkey, ds", materializedView, table)); + + assertTrue(getQueryRunner().tableExists(getSession(), materializedView)); + + // Only refresh partition for '2025-11-11', leaving '2025-11-10' and '2025-11-12' missing + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE ds='2025-11-11'", materializedView), 248); + + // Create a logical view on top of the materialized view + queryRunner.execute(format("CREATE VIEW %s AS SELECT * FROM %s", logicalView, materializedView)); + + setReferencedMaterializedViews((DistributedQueryRunner) queryRunner, table, ImmutableList.of(materializedView)); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(1)) + .build(); + + // Query the logical view through a CTE with a predicate + // The predicate should be pushed down to the materialized view + String cteQuery = format("WITH PreQuery AS (SELECT * FROM %s WHERE ds='2025-11-11') " + + "SELECT max_price, orderkey FROM PreQuery ORDER BY orderkey", logicalView); + String baseTableQuery = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE ds='2025-11-11' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResult = computeActual(session, baseTableQuery); + MaterializedResult cteQueryResult = computeActual(session, cteQuery); + + // Both queries should return the same results + assertEquals(baseQueryResult, cteQueryResult); + + // The plan for the CTE query should use the materialized view + // (not fall back to base table) because we're only querying the refreshed partition + assertPlan(session, cteQuery, anyTree( + constrainedTableScan( + materializedView, + ImmutableMap.of("ds", singleValue(createVarcharType(10), utf8Slice("2025-11-11"))), + ImmutableMap.of()))); + } + finally { + queryRunner.execute("DROP VIEW IF EXISTS " + logicalView); + queryRunner.execute("DROP MATERIALIZED VIEW IF EXISTS " + materializedView); + queryRunner.execute("DROP TABLE IF EXISTS " + table); + } + } + + @Test + public void testMaterializedViewPartitionFilteringInCTE() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "orders_partitioned_mv_cte_test"; + String materializedView = "orders_mv_direct_cte_test"; + + try { + // Create a table partitioned by 'ds' (date string) + queryRunner.execute(format("CREATE TABLE %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT orderkey, totalprice, '2025-11-10' AS ds FROM orders WHERE orderkey < 1000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-11' AS ds FROM orders WHERE orderkey >= 1000 AND orderkey < 2000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, '2025-11-12' AS ds FROM orders WHERE orderkey >= 2000 AND orderkey < 3000", table)); + + // Create a materialized view partitioned by 'ds' + queryRunner.execute(format("CREATE MATERIALIZED VIEW %s WITH (partitioned_by = ARRAY['ds']) AS " + + "SELECT max(totalprice) as max_price, orderkey, ds FROM %s GROUP BY orderkey, ds", materializedView, table)); + + assertTrue(getQueryRunner().tableExists(getSession(), materializedView)); + + // Only refresh partition for '2025-11-10', leaving '2025-11-11' and '2025-11-12' missing + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE ds='2025-11-10'", materializedView), 255); + + setReferencedMaterializedViews((DistributedQueryRunner) queryRunner, table, ImmutableList.of(materializedView)); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(1)) + .build(); + + // Query the materialized view directly through a CTE with a predicate + // The predicate should be used to determine which partitions are needed + String cteQuery = format("WITH PreQuery AS (SELECT * FROM %s WHERE ds='2025-11-10') " + + "SELECT max_price, orderkey FROM PreQuery ORDER BY orderkey", materializedView); + String baseTableQuery = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE ds='2025-11-10' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResult = computeActual(session, baseTableQuery); + MaterializedResult cteQueryResult = computeActual(session, cteQuery); + + // Both queries should return the same results + assertEquals(baseQueryResult, cteQueryResult); + + // The plan for the CTE query should use the materialized view + // (not fall back to base table) because we're only querying the refreshed partition + assertPlan(session, cteQuery, anyTree( + constrainedTableScan( + materializedView, + ImmutableMap.of("ds", singleValue(createVarcharType(10), utf8Slice("2025-11-10"))), + ImmutableMap.of()))); + + // Test query for a missing partition through CTE + // This should fall back to base table because ds='2025-11-11' is not refreshed + String cteQueryMissing = format("WITH PreQuery AS (SELECT * FROM %s WHERE ds='2025-11-11') " + + "SELECT max_price, orderkey FROM PreQuery ORDER BY orderkey", materializedView); + String baseTableQueryMissing = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE ds='2025-11-11' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResultMissing = computeActual(session, baseTableQueryMissing); + MaterializedResult cteQueryResultMissing = computeActual(session, cteQueryMissing); + + assertEquals(baseQueryResultMissing, cteQueryResultMissing); + + // Should fall back to base table for missing partition + assertPlan(session, cteQueryMissing, anyTree( + constrainedTableScan(table, ImmutableMap.of(), ImmutableMap.of()))); + } + finally { + queryRunner.execute("DROP MATERIALIZED VIEW IF EXISTS " + materializedView); + queryRunner.execute("DROP TABLE IF EXISTS " + table); + } + } + @Test public void testMaterializedViewOptimization() { @@ -2174,6 +2403,139 @@ public void testMaterializedViewAvgRewrite() } } + @Test + public void testMaterializedViewPartitionFilteringCaseInsensitive() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "orders_partitioned_case_test"; + String view = "orders_view_case_test"; + + try { + // Create a table partitioned by 'country' (lowercase) + queryRunner.execute(format("CREATE TABLE %s WITH (partitioned_by = ARRAY['country']) AS " + + "SELECT orderkey, totalprice, 'US' AS country FROM orders WHERE orderkey < 1000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, 'UK' AS country FROM orders WHERE orderkey >= 1000 AND orderkey < 2000 " + + "UNION ALL " + + "SELECT orderkey, totalprice, 'CA' AS country FROM orders WHERE orderkey >= 2000 AND orderkey < 3000", table)); + + // Create a materialized view partitioned by 'country' + queryRunner.execute(format("CREATE MATERIALIZED VIEW %s WITH (partitioned_by = ARRAY['country']) AS " + + "SELECT max(totalprice) as max_price, orderkey, country FROM %s GROUP BY orderkey, country", view, table)); + + assertTrue(getQueryRunner().tableExists(getSession(), view)); + + // Only refresh partitions for 'US' and 'UK', leaving 'CA' missing + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE country='US'", view), 255); + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE country='UK'", view), 248); + + setReferencedMaterializedViews((DistributedQueryRunner) queryRunner, table, ImmutableList.of(view)); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(1)) + .build(); + + // Query with UPPERCASE column name, filtering for countries that ARE in the MV + // This tests that case-insensitive lookup works correctly + String viewQueryWithUpperCaseFilter = format("SELECT max_price, orderkey FROM %s WHERE COUNTRY >= 'UK' AND COUNTRY <= 'US' ORDER BY orderkey", view); + String viewQueryWithLowerCaseFilter = format("SELECT max_price, orderkey FROM %s WHERE country >= 'UK' AND country <= 'US' ORDER BY orderkey", view); + String baseQueryWithFilter = format("SELECT max(totalprice) as max_price, orderkey FROM %s " + + "WHERE country >= 'UK' AND country <= 'US' " + + "GROUP BY orderkey ORDER BY orderkey", table); + + MaterializedResult baseQueryResult = computeActual(session, baseQueryWithFilter); + MaterializedResult viewQueryUpperCaseResult = computeActual(session, viewQueryWithUpperCaseFilter); + MaterializedResult viewQueryLowerCaseResult = computeActual(session, viewQueryWithLowerCaseFilter); + + // Both queries should return the same results + assertEquals(baseQueryResult, viewQueryUpperCaseResult); + assertEquals(baseQueryResult, viewQueryLowerCaseResult); + + // The plan should use the materialized view for countries UK and US (both are refreshed) + // and should NOT count the missing 'CA' partition because the query filter excludes it + assertPlan(session, viewQueryWithUpperCaseFilter, anyTree( + constrainedTableScan( + view, + ImmutableMap.of("country", multipleValues(createVarcharType(2), utf8Slices("UK", "US"))), + ImmutableMap.of()))); + assertPlan(session, viewQueryWithLowerCaseFilter, anyTree( + constrainedTableScan( + view, + ImmutableMap.of("country", multipleValues(createVarcharType(2), utf8Slices("UK", "US"))), + ImmutableMap.of()))); + } + finally { + queryRunner.execute("DROP MATERIALIZED VIEW IF EXISTS " + view); + queryRunner.execute("DROP TABLE IF EXISTS " + table); + } + } + + @Test + public void testMaterializedViewMissingPartitionsCountWithMultiplePartitionColumns() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "orders_multi_partition_count_test"; + String view = "orders_view_multi_partition_count_test"; + + try { + // Create a table partitioned by TWO columns: 'country' and 'region' + queryRunner.execute(format("CREATE TABLE %s (id BIGINT, price DOUBLE, country VARCHAR, region VARCHAR) " + + "WITH (partitioned_by = ARRAY['country', 'region'])", table)); + + // Insert data into 4 different partitions + assertUpdate(format("INSERT INTO %s VALUES (1, 100.0, 'US', 'West'), (2, 200.0, 'US', 'West')", table), 2); + assertUpdate(format("INSERT INTO %s VALUES (3, 300.0, 'US', 'East'), (4, 400.0, 'US', 'East')", table), 2); + assertUpdate(format("INSERT INTO %s VALUES (5, 500.0, 'UK', 'North'), (6, 600.0, 'UK', 'North')", table), 2); + assertUpdate(format("INSERT INTO %s VALUES (7, 700.0, 'UK', 'South'), (8, 800.0, 'UK', 'South')", table), 2); + + // Create a materialized view partitioned by both columns + queryRunner.execute(format("CREATE MATERIALIZED VIEW %s WITH (partitioned_by = ARRAY['country', 'region']) AS " + + "SELECT max(price) as max_price, id, country, region FROM %s GROUP BY id, country, region", view, table)); + + assertTrue(getQueryRunner().tableExists(getSession(), view)); + + // Only refresh 2 out of 4 partitions, leaving 2 missing (UK/North and UK/South are missing) + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE country='US' AND region='West'", view), 2); + assertUpdate(format("REFRESH MATERIALIZED VIEW %s WHERE country='US' AND region='East'", view), 2); + + setReferencedMaterializedViews((DistributedQueryRunner) queryRunner, table, ImmutableList.of(view)); + + // Set the threshold to 2 missing partitions to test that the counted missingPartitions is 2 + Session sessionWithThreshold2 = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(2)) + .build(); + + String baseQuery = format("SELECT max(price) as max_price, id FROM %s GROUP BY id ORDER BY id", table); + + // With threshold = 2 and 2 missing partitions, the materialized view should still be used + assertPlan(sessionWithThreshold2, baseQuery, anyTree(exchange( + anyTree(constrainedTableScan( + table, + ImmutableMap.of(), + ImmutableMap.of())), + anyTree(constrainedTableScan( + view, + ImmutableMap.of(), + ImmutableMap.of()))))); + + // Now set threshold to 1 - should fall back to base table since we have 2 missing partitions + Session sessionWithThreshold1 = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, "true") + .setCatalogSessionProperty(HIVE_CATALOG, MATERIALIZED_VIEW_MISSING_PARTITIONS_THRESHOLD, Integer.toString(1)) + .build(); + + // With threshold = 1 and 2 missing partitions, should use only the base table + assertPlan(sessionWithThreshold1, baseQuery, anyTree( + constrainedTableScan(table, ImmutableMap.of(), ImmutableMap.of()))); + } + finally { + queryRunner.execute("DROP MATERIALIZED VIEW IF EXISTS " + view); + queryRunner.execute("DROP TABLE IF EXISTS " + table); + } + } + @Test public void testMaterializedViewApproxDistinctRewrite() { @@ -2592,12 +2954,25 @@ public void testInsertBySelectingFromMaterializedView() public void testMaterializedViewQueryAccessControl() { QueryRunner queryRunner = getQueryRunner(); - Session invokerSession = Session.builder(getSession()) + + Session invokerStichingSession = Session.builder(getSession()) .setIdentity(new Identity("test_view_invoker", Optional.empty())) .setCatalog(getSession().getCatalog().get()) .setSchema(getSession().getSchema().get()) .setSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, "true") .build(); + + /* Non-stitching session test is needed as the query is not rewritten + with base table. In this case the analyzer should process the materialized view + definition sql to check all the base tables permissions. + */ + Session invokerNonStichingSession = Session.builder(getSession()) + .setIdentity(new Identity("test_view_invoker2", Optional.empty())) + .setCatalog(getSession().getCatalog().get()) + .setSchema(getSession().getSchema().get()) + .setSystemProperty(MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED, "false") + .build(); + Session ownerSession = getSession(); queryRunner.execute( @@ -2611,37 +2986,60 @@ public void testMaterializedViewQueryAccessControl() "AS SELECT SUM(totalprice) AS totalprice, orderstatus FROM test_orders_base GROUP BY orderstatus"); setReferencedMaterializedViews((DistributedQueryRunner) getQueryRunner(), "test_orders_base", ImmutableList.of("test_orders_view")); - Consumer testQueryWithDeniedPrivilege = query -> { - // Verify checking the base table instead of the materialized view for SELECT permission + try { + // Check for both the direct materialized view query and the base table query optimization with materialized view + // for direct materialized view query, check when stitching is enabled/disabled + String queryMaterializedView = "SELECT totalprice, orderstatus FROM test_orders_view"; + String queryBaseTable = "SELECT SUM(totalprice) AS totalprice, orderstatus FROM test_orders_base GROUP BY orderstatus"; + assertAccessDenied( - invokerSession, - query, + invokerNonStichingSession, + queryBaseTable, "Cannot select from columns \\[.*\\] in table .*test_orders_base.*", - privilege(invokerSession.getUser(), "test_orders_base", SELECT_COLUMN)); - assertAccessAllowed( - invokerSession, - query, - privilege(invokerSession.getUser(), "test_orders_view", SELECT_COLUMN)); - }; + privilege(invokerNonStichingSession.getUser(), "test_orders_base", SELECT_COLUMN)); - try { - // Check for both the direct materialized view query and the base table query optimization with materialized view - String directMaterializedViewQuery = "SELECT totalprice, orderstatus FROM test_orders_view"; - String queryWithMaterializedViewOptimization = "SELECT SUM(totalprice) AS totalprice, orderstatus FROM test_orders_base GROUP BY orderstatus"; + assertAccessDenied( + invokerStichingSession, + queryBaseTable, + "Cannot select from columns \\[.*\\] in table .*test_orders_base.*", + privilege(invokerStichingSession.getUser(), "test_orders_base", SELECT_COLUMN)); + + assertAccessAllowed( + invokerStichingSession, + queryMaterializedView, + privilege(invokerStichingSession.getUser(), "test_orders_view", SELECT_COLUMN)); - // Test when the materialized view is not materialized yet - testQueryWithDeniedPrivilege.accept(directMaterializedViewQuery); - testQueryWithDeniedPrivilege.accept(queryWithMaterializedViewOptimization); + assertAccessDenied( + invokerNonStichingSession, + queryMaterializedView, + "Cannot select from columns \\[.*\\] in table .*test_orders_base.*", + privilege(invokerNonStichingSession.getUser(), "test_orders_base", SELECT_COLUMN)); // Test when the materialized view is partially materialized queryRunner.execute(ownerSession, "REFRESH MATERIALIZED VIEW test_orders_view WHERE orderstatus = 'F'"); - testQueryWithDeniedPrivilege.accept(directMaterializedViewQuery); - testQueryWithDeniedPrivilege.accept(queryWithMaterializedViewOptimization); + assertAccessAllowed( + invokerStichingSession, + queryMaterializedView, + privilege(invokerStichingSession.getUser(), "test_orders_view", SELECT_COLUMN)); + + assertAccessDenied( + invokerNonStichingSession, + queryMaterializedView, + "Cannot select from columns \\[.*\\] in table .*test_orders_base.*", + privilege(invokerNonStichingSession.getUser(), "test_orders_base", SELECT_COLUMN)); // Test when the materialized view is fully materialized queryRunner.execute(ownerSession, "REFRESH MATERIALIZED VIEW test_orders_view WHERE orderstatus <> 'F'"); - testQueryWithDeniedPrivilege.accept(directMaterializedViewQuery); - testQueryWithDeniedPrivilege.accept(queryWithMaterializedViewOptimization); + assertAccessAllowed( + invokerStichingSession, + queryMaterializedView, + privilege(invokerStichingSession.getUser(), "test_orders_view", SELECT_COLUMN)); + + assertAccessDenied( + invokerNonStichingSession, + queryMaterializedView, + "Cannot select from columns \\[.*\\] in table .*test_orders_base.*", + privilege(invokerNonStichingSession.getUser(), "test_orders_base", SELECT_COLUMN)); } finally { queryRunner.execute(ownerSession, "DROP MATERIALIZED VIEW test_orders_view"); @@ -2725,6 +3123,307 @@ public void testRefreshMaterializedViewAccessControl() } } + @Test + public void testAutoRefreshMaterializedViewWithoutPredicates() + { + QueryRunner queryRunner = getQueryRunner(); + String table = "test_orders_auto_refresh_source"; + String view = "test_orders_auto_refresh_target_mv"; + String view2 = "test_orders_auto_refresh_target_mv2"; + + Session nonFullRefreshSession = getSession(); + + Session fullRefreshSession = Session.builder(getSession()) + .setSystemProperty("materialized_view_allow_full_refresh_enabled", "true") + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + queryRunner.execute( + fullRefreshSession, + format("CREATE TABLE %s WITH (partitioned_by = ARRAY['orderstatus']) " + + "AS SELECT orderkey, custkey, totalprice, orderstatus FROM orders WHERE orderkey < 100", table)); + + queryRunner.execute( + fullRefreshSession, + format("CREATE MATERIALIZED VIEW %s " + + "WITH (partitioned_by = ARRAY['orderstatus']) " + + "AS SELECT SUM(totalprice) AS total, COUNT(*) AS cnt, orderstatus " + + "FROM %s GROUP BY orderstatus", view, table)); + + queryRunner.execute( + nonFullRefreshSession, + format("CREATE MATERIALIZED VIEW %s " + + "WITH (partitioned_by = ARRAY['orderstatus']) " + + "AS SELECT SUM(totalprice) AS total, COUNT(*) AS cnt, orderstatus " + + "FROM %s GROUP BY orderstatus", view2, table)); + + try { + // Test that refresh without predicates succeeds when flag is enabled + queryRunner.execute(fullRefreshSession, format("REFRESH MATERIALIZED VIEW %s", view)); + + // Verify all partitions are refreshed + MaterializedResult result = queryRunner.execute(fullRefreshSession, + format("SELECT COUNT(DISTINCT orderstatus) FROM %s", view)); + assertTrue(((Long) result.getOnlyValue()) > 0, "Materialized view should contain data after auto-refresh"); + + // Test that refresh without predicates fails when flag is not enabled + assertQueryFails( + nonFullRefreshSession, + format("REFRESH MATERIALIZED VIEW %s", view2), + ".*misses too many partitions or is never refreshed and may incur high cost.*"); + } + finally { + queryRunner.execute(fullRefreshSession, format("DROP MATERIALIZED VIEW %s", view)); + queryRunner.execute(fullRefreshSession, format("DROP TABLE %s", table)); + } + } + + @Test + public void testAutoRefreshMaterializedViewWithJoinWithoutPredicates() + { + QueryRunner queryRunner = getQueryRunner(); + + String table1 = "test_customer_auto_refresh"; + String table2 = "test_orders_join_auto_refresh"; + String view = "test_auto_refresh_join_target_mv"; + + Session fullRefreshSession = Session.builder(getSession()) + .setSystemProperty("materialized_view_allow_full_refresh_enabled", "true") + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + Session ownerSession = getSession(); + + queryRunner.execute( + fullRefreshSession, + format("CREATE TABLE %s WITH (partitioned_by = ARRAY['nationkey']) " + + "AS SELECT custkey, name, nationkey FROM customer WHERE custkey < 100", table1)); + queryRunner.execute( + fullRefreshSession, + format("CREATE TABLE %s WITH (partitioned_by = ARRAY['orderstatus']) " + + "AS SELECT orderkey, custkey, totalprice, orderstatus FROM orders WHERE orderkey < 100", table2)); + queryRunner.execute( + fullRefreshSession, + format("CREATE MATERIALIZED VIEW %s " + + "WITH (partitioned_by = ARRAY['nationkey', 'orderstatus']) " + + "AS SELECT c.name, SUM(o.totalprice) AS total, c.nationkey, o.orderstatus " + + "FROM %s c JOIN %s o ON c.custkey = o.custkey " + + "GROUP BY c.name, c.nationkey, o.orderstatus", view, table1, table2)); + + try { + queryRunner.execute(fullRefreshSession, format("REFRESH MATERIALIZED VIEW %s", view)); + + MaterializedResult result = queryRunner.execute(fullRefreshSession, + format("SELECT COUNT(*) FROM %s", view)); + assertTrue(((Long) result.getOnlyValue()) > 0, + "Materialized view with join should contain data after auto-refresh"); + } + finally { + queryRunner.execute(ownerSession, format("DROP MATERIALIZED VIEW %s", view)); + queryRunner.execute(ownerSession, format("DROP TABLE %s", table1)); + queryRunner.execute(ownerSession, format("DROP TABLE %s", table2)); + } + } + + @Test + public void testAutoRefreshMaterializedViewFullyRefreshed() + { + QueryRunner queryRunner = getQueryRunner(); + + String table = "test_customer_auto_refresh"; + String view = "test_auto_refresh_join_target_mv"; + + Session fullRefreshSession = Session.builder(getSession()) + .setSystemProperty("materialized_view_allow_full_refresh_enabled", "true") + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + Session ownerSession = getSession(); + + queryRunner.execute( + fullRefreshSession, + format("CREATE TABLE %s WITH (partitioned_by = ARRAY['nationkey']) " + + "AS SELECT custkey, name, nationkey FROM customer WHERE custkey < 100", table)); + + queryRunner.execute( + fullRefreshSession, + format("CREATE MATERIALIZED VIEW %s " + + "WITH (partitioned_by = ARRAY['nationkey']) " + + "AS SELECT custkey, nationkey FROM %s", view, table)); + + try { + queryRunner.execute(fullRefreshSession, format("REFRESH MATERIALIZED VIEW %s", view)); + + MaterializedResult result = queryRunner.execute(fullRefreshSession, + format("REFRESH MATERIALIZED VIEW %s", view)); + + assertEquals(result.getWarnings().size(), 1); + assertTrue(result.getWarnings().get(0).getMessage().matches("Materialized view .* is already fully refreshed")); + } + finally { + queryRunner.execute(ownerSession, format("DROP MATERIALIZED VIEW %s", view)); + queryRunner.execute(ownerSession, format("DROP TABLE %s", table)); + } + } + + @Test + public void testAutoRefreshMaterializedViewAfterInsertion() + { + QueryRunner queryRunner = getQueryRunner(); + + String table = "test_auto_refresh"; + String view = "test_auto_refresh_mv"; + + Session fullRefreshSession = Session.builder(getSession()) + .setSystemProperty("materialized_view_allow_full_refresh_enabled", "true") + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + Session ownerSession = getSession(); + + queryRunner.execute( + fullRefreshSession, + format("CREATE TABLE %s (col1 bigint, col2 varchar, part_key varchar) " + + "WITH (partitioned_by = ARRAY['part_key'])", table)); + + queryRunner.execute( + fullRefreshSession, + format("INSERT INTO %s VALUES (1, 'aaa', 'p1'), " + + "(2, 'bbb', 'p2'), (3, 'aaa', 'p1')", table)); + + queryRunner.execute( + fullRefreshSession, + format("CREATE MATERIALIZED VIEW %s " + + "WITH (partitioned_by = ARRAY['part_key']) " + + "AS SELECT col1, part_key FROM %s", view, table)); + + try { + queryRunner.execute(fullRefreshSession, format("REFRESH MATERIALIZED VIEW %s", view)); + + MaterializedResult result = queryRunner.execute(fullRefreshSession, + format("SELECT COUNT(DISTINCT part_key) FROM %s", view)); + assertEquals((long) ((Long) result.getOnlyValue()), 2, "Materialized view should contain all data after refreshes"); + + queryRunner.execute( + fullRefreshSession, + format("INSERT INTO %s VALUES (1, 'aaa', 'p3'), " + + "(2, 'bbb', 'p4'), (3, 'aaa', 'p5')", table)); + + queryRunner.execute(fullRefreshSession, + format("REFRESH MATERIALIZED VIEW %s", view)); + + result = queryRunner.execute(fullRefreshSession, + format("SELECT COUNT(DISTINCT part_key) FROM %s", view)); + assertEquals((long) ((Long) result.getOnlyValue()), 5, "Materialized view should contain all data after refreshes"); + } + finally { + queryRunner.execute(ownerSession, format("DROP MATERIALIZED VIEW %s", view)); + queryRunner.execute(ownerSession, format("DROP TABLE %s", table)); + } + } + + @Test + public void testMVJoinQueryWithOtherTableColumnFiltering() + { + QueryRunner queryRunner = getQueryRunner(); + Session session = getSession(); + + assertUpdate("CREATE TABLE mv_base (mv_col1 int, mv_col2 varchar, mv_col3 varchar) " + + "WITH (partitioned_by=ARRAY['mv_col3'])"); + assertUpdate("CREATE TABLE join_table (table_col1 int, table_col2 varchar, table_col3 varchar) " + + " WITH (partitioned_by=ARRAY['table_col3'])"); + + assertUpdate("INSERT INTO mv_base VALUES (1, 'Alice', 'A'), (2, 'Bob', 'B'), (3, 'Charlie', 'C')", 3); + assertUpdate("INSERT INTO join_table VALUES (1, 'CityA', 'A'), (21, 'CityA', 'B'), (32, 'CityB', 'C')", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv " + + "WITH (partitioned_by=ARRAY['mv_col3']) " + + "AS SELECT mv_col1, mv_col2, mv_col3 FROM mv_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv WHERE mv_col3>='A'", 3); + + // Query MV with JOIN and WHERE clause on column from joined table (not in MV) + MaterializedResult result = queryRunner.execute(session, + "SELECT mv_col2 FROM mv " + + "JOIN join_table ON mv_col3=table_col3 " + + "WHERE table_col1>10 ORDER BY mv_col1"); + assertEquals(result.getRowCount(), 2, "Materialized view join produced unexpected row counts"); + + List expectedResults = List.of("Bob", "Charlie"); + List actualResults = result.getMaterializedRows().stream() + .map(row -> row.getField(0)) + .collect(toList()); + assertEquals(actualResults, expectedResults, "Materialized view join returned unexpected row values"); + + // WHERE clause on MV column + result = queryRunner.execute(session, "SELECT mv_col2 FROM mv JOIN join_table " + + "ON mv_col3=table_col3 WHERE mv_col2>'Alice' ORDER BY mv_col2"); + assertEquals(result.getRowCount(), 2, "Materialized view join produced unexpected row counts"); + + expectedResults = List.of("Bob", "Charlie"); + actualResults = result.getMaterializedRows().stream() + .map(row -> row.getField(0)) + .collect(toList()); + assertEquals(actualResults, expectedResults, "Materialized view join returned unexpected row values"); + + // Test with multiple conditions in WHERE clause (non-partition column) + result = queryRunner.execute(session, "SELECT mv_col1 FROM mv JOIN join_table ON mv_col3=table_col3 " + + "WHERE table_col1>10 AND table_col3='B' AND mv_col1>1"); + assertEquals(result.getRowCount(), 1, "Materialized view join produced unexpected row counts"); + + expectedResults = List.of(2); + actualResults = result.getMaterializedRows().stream() + .map(row -> row.getField(0)) + .collect(toList()); + assertEquals(actualResults, expectedResults, "Materialized view join returned unexpected row values"); + + // Test with multiple conditions in WHERE clause (partition column) + result = queryRunner.execute(session, "SELECT mv_col1 FROM mv JOIN join_table ON mv_col3=table_col3 " + + "WHERE table_col1>10 AND table_col3='B' AND mv_col3='C'"); + assertEquals(result.getRowCount(), 0, "Materialized view join produced wrong results"); + + assertUpdate("DROP MATERIALIZED VIEW mv"); + assertUpdate("DROP TABLE join_table"); + assertUpdate("DROP TABLE mv_base"); + } + + public void testMaterializedViewNotRefreshedInNonLegacyMode() + { + Session nonLegacySession = Session.builder(getSession()) + .setSystemProperty("legacy_materialized_views", "false") + .build(); + try { + assertUpdate("CREATE TABLE base_table (id BIGINT, name VARCHAR, part_key BIGINT) WITH (partitioned_by = ARRAY['part_key'])"); + assertUpdate("INSERT INTO base_table VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW simple_mv WITH (partitioned_by = ARRAY['part_key']) AS SELECT id, name, part_key FROM base_table"); + + assertPlan(nonLegacySession, "SELECT * FROM simple_mv", + anyTree(tableScan("base_table"))); + } + finally { + assertUpdate("DROP TABLE base_table"); + assertUpdate("DROP MATERIALIZED VIEW simple_mv"); + } + } + + @Test + public void testMaterializedViewRefreshedInNonLegacyMode() + { + Session nonLegacySession = Session.builder(getSession()) + .setSystemProperty("legacy_materialized_views", "false") + .build(); + try { + assertUpdate("CREATE TABLE base_table (id BIGINT, name VARCHAR, part_key BIGINT) WITH (partitioned_by = ARRAY['part_key'])"); + assertUpdate("INSERT INTO base_table VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW simple_mv WITH (partitioned_by = ARRAY['part_key']) AS SELECT id, name, part_key FROM base_table"); + assertUpdate("REFRESH MATERIALIZED VIEW simple_mv where part_key > 0", 2); + + assertPlan(nonLegacySession, "SELECT * FROM simple_mv", + anyTree(tableScan("simple_mv"))); + } + finally { + assertUpdate("DROP TABLE base_table"); + assertUpdate("DROP MATERIALIZED VIEW simple_mv"); + } + } + private void setReferencedMaterializedViews(DistributedQueryRunner queryRunner, String tableName, List referencedMaterializedViews) { appendTableParameter(replicateHiveMetastore(queryRunner), diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewUtils.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewUtils.java index 85a6c471bb373..02cb66177e137 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewUtils.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMaterializedViewUtils.java @@ -659,6 +659,6 @@ private static MaterializedViewDefinition getConnectorMaterializedViewDefinition List tables, Map> originalColumnMapping) { - return new MaterializedViewDefinition(SQL, SCHEMA_NAME, TABLE_NAME, tables, Optional.empty(), originalColumnMapping, originalColumnMapping, ImmutableList.of(), Optional.empty()); + return new MaterializedViewDefinition(SQL, SCHEMA_NAME, TABLE_NAME, tables, Optional.empty(), Optional.empty(), originalColumnMapping, originalColumnMapping, ImmutableList.of(), Optional.empty()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataFileFormatEncryptionSettings.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataFileFormatEncryptionSettings.java index ad735f3010873..eaef20e120656 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataFileFormatEncryptionSettings.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataFileFormatEncryptionSettings.java @@ -138,7 +138,6 @@ public void setup() new HivePartitionObjectBuilder(), new HiveEncryptionInformationProvider(ImmutableList.of(new TestDwrfEncryptionInformationSource())), new HivePartitionStats(), - new HiveFileRenamer(), HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, new QuickStatsProvider(metastore, HDFS_ENVIRONMENT, HiveTestUtils.DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(false)); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdateHandle.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdateHandle.java deleted file mode 100644 index 44b2cbecff22b..0000000000000 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdateHandle.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.presto.spi.SchemaTableName; -import org.testng.annotations.Test; - -import java.util.Optional; -import java.util.UUID; - -import static com.facebook.presto.hive.TestHiveMetadataUpdater.TEST_FILE_NAME; -import static org.testng.Assert.assertEquals; - -public class TestHiveMetadataUpdateHandle -{ - public static final UUID TEST_REQUEST_ID = UUID.randomUUID(); - public static final String TEST_SCHEMA_NAME = "schema"; - public static final String TEST_TABLE_NAME = "table"; - public static final String TEST_PARTITION_NAME = "partition_name"; - - public static final SchemaTableName TEST_SCHEMA_TABLE_NAME = new SchemaTableName(TEST_SCHEMA_NAME, TEST_TABLE_NAME); - public static final HiveMetadataUpdateHandle TEST_HIVE_METADATA_UPDATE_REQUEST = new HiveMetadataUpdateHandle(TEST_REQUEST_ID, TEST_SCHEMA_TABLE_NAME, Optional.of(TEST_PARTITION_NAME), Optional.empty()); - - private final JsonCodec codec = JsonCodec.jsonCodec(HiveMetadataUpdateHandle.class); - - @Test - public void testHiveMetadataUpdateRequest() - { - testRoundTrip(TEST_HIVE_METADATA_UPDATE_REQUEST); - } - - @Test - public void testHiveMetadataUpdateResult() - { - HiveMetadataUpdateHandle request = TEST_HIVE_METADATA_UPDATE_REQUEST; - HiveMetadataUpdateHandle expectedHiveMetadataUpdateResult = new HiveMetadataUpdateHandle(request.getRequestId(), request.getSchemaTableName(), request.getPartitionName(), Optional.of(TEST_FILE_NAME)); - testRoundTrip(expectedHiveMetadataUpdateResult); - } - - private void testRoundTrip(HiveMetadataUpdateHandle expected) - { - String json = codec.toJson(expected); - HiveMetadataUpdateHandle actual = codec.fromJson(json); - - assertEquals(actual.getRequestId(), expected.getRequestId()); - assertEquals(actual.getSchemaTableName(), expected.getSchemaTableName()); - assertEquals(actual.getPartitionName(), expected.getPartitionName()); - assertEquals(actual.getMetadataUpdate(), expected.getMetadataUpdate()); - } -} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdater.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdater.java deleted file mode 100644 index 31b11963a1b02..0000000000000 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveMetadataUpdater.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.hive; - -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.google.common.collect.ImmutableList; -import org.testng.annotations.Test; - -import java.util.List; -import java.util.Optional; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; - -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_PARTITION_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_SCHEMA_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_SCHEMA_TABLE_NAME; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_TABLE_NAME; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.fail; - -public class TestHiveMetadataUpdater -{ - public static final String TEST_FILE_NAME = "fileName"; - - private static final int TEST_WRITER_INDEX = 0; - private static final Executor EXECUTOR = Executors.newFixedThreadPool(5); - - @Test - public void testEmptyMetadataUpdateRequestQueue() - { - HiveMetadataUpdater hiveMetadataUpdater = new HiveMetadataUpdater(EXECUTOR); - assertEquals(hiveMetadataUpdater.getPendingMetadataUpdateRequests().size(), 0); - } - - @Test - public void testAddMetadataUpdateRequest() - { - HiveMetadataUpdater hiveMetadataUpdater = new HiveMetadataUpdater(EXECUTOR); - - // add Request - hiveMetadataUpdater.addMetadataUpdateRequest(TEST_SCHEMA_NAME, TEST_TABLE_NAME, Optional.of(TEST_PARTITION_NAME), TEST_WRITER_INDEX); - List hiveMetadataUpdateRequests = hiveMetadataUpdater.getPendingMetadataUpdateRequests(); - - // assert the pending request queue size - assertEquals(hiveMetadataUpdateRequests.size(), 1); - - // assert that request in queue is same as the request added - HiveMetadataUpdateHandle request = (HiveMetadataUpdateHandle) hiveMetadataUpdateRequests.get(0); - assertEquals(request.getSchemaTableName(), TEST_SCHEMA_TABLE_NAME); - assertEquals(request.getPartitionName(), Optional.of(TEST_PARTITION_NAME)); - } - - @Test - public void testSetMetadataUpdateResults() - { - HiveMetadataUpdater hiveMetadataUpdater = new HiveMetadataUpdater(EXECUTOR); - - // add Request - hiveMetadataUpdater.addMetadataUpdateRequest(TEST_SCHEMA_NAME, TEST_TABLE_NAME, Optional.of(TEST_PARTITION_NAME), TEST_WRITER_INDEX); - List hiveMetadataUpdateRequests = hiveMetadataUpdater.getPendingMetadataUpdateRequests(); - assertEquals(hiveMetadataUpdateRequests.size(), 1); - HiveMetadataUpdateHandle request = (HiveMetadataUpdateHandle) hiveMetadataUpdateRequests.get(0); - - // create Result - HiveMetadataUpdateHandle hiveMetadataUpdateResult = new HiveMetadataUpdateHandle(request.getRequestId(), request.getSchemaTableName(), request.getPartitionName(), Optional.of(TEST_FILE_NAME)); - - // set the result - hiveMetadataUpdater.setMetadataUpdateResults(ImmutableList.of(hiveMetadataUpdateResult)); - - try { - // get the fileName - String fileName = hiveMetadataUpdater.getMetadataResult(TEST_WRITER_INDEX).get(); - - // assert the fileName - assertEquals(fileName, TEST_FILE_NAME); - - // assert the pending request queue size is zero - assertEquals(hiveMetadataUpdater.getPendingMetadataUpdateRequests().size(), 0); - } - catch (InterruptedException | ExecutionException e) { - fail("Expected to succeed and get the fileName metadata result"); - } - } -} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSourceProvider.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSourceProvider.java index 42a5a02d919f3..e27ddea274ea0 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSourceProvider.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePageSourceProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.common.Page; import com.facebook.presto.common.RuntimeStats; @@ -41,7 +42,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.ql.io.orc.OrcSerde; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; @@ -75,9 +75,9 @@ import static com.facebook.presto.hive.HiveTestUtils.getDefaultHiveRecordCursorProvider; import static com.facebook.presto.hive.HiveType.HIVE_LONG; import static com.facebook.presto.hive.HiveUtil.CUSTOM_FILE_SPLIT_CLASS_KEY; -import static com.facebook.presto.hive.TestHiveMetadataUpdateHandle.TEST_TABLE_NAME; import static com.facebook.presto.hive.TestHivePageSink.getColumnHandles; import static com.facebook.presto.hive.metastore.thrift.MockHiveMetastoreClient.TEST_DATABASE; +import static com.facebook.presto.hive.metastore.thrift.MockHiveMetastoreClient.TEST_TABLE; import static com.facebook.presto.hive.util.HudiRealtimeSplitConverter.HUDI_BASEPATH_KEY; import static com.facebook.presto.hive.util.HudiRealtimeSplitConverter.HUDI_DELTA_FILEPATHS_KEY; import static com.facebook.presto.hive.util.HudiRealtimeSplitConverter.HUDI_MAX_COMMIT_TIME_KEY; @@ -385,7 +385,7 @@ public void testAggregatedPageSource() } @Test(expectedExceptions = PrestoException.class, - expectedExceptionsMessageRegExp = "Table testdb.table has file of format org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe that does not support partial aggregation pushdown. " + + expectedExceptionsMessageRegExp = "Table testdb.testtbl has file of format org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe that does not support partial aggregation pushdown. " + "Set session property \\[catalog\\-name\\].pushdown_partial_aggregations_into_scan=false and execute query again.") public void testFailsWhenNoAggregatedPageSourceAvailable() { @@ -402,7 +402,7 @@ public void testFailsWhenNoAggregatedPageSourceAvailable() @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = "Partial aggregation pushdown is not supported when footer stats are unreliable. " + - "Table testdb.table has file file://test with unreliable footer stats. " + + "Table testdb.testtbl has file file://test with unreliable footer stats. " + "Set session property \\[catalog\\-name\\].pushdown_partial_aggregations_into_scan=false and execute query again.") public void testFailsWhenFooterStatsUnreliable() { @@ -465,8 +465,8 @@ public void testCreatePageSource_withRowIDMissingPartitionComponent() private static ConnectorTableLayoutHandle getHiveTableLayout(boolean pushdownFilterEnabled, boolean partialAggregationsPushedDown, boolean footerStatsUnreliable) { return new HiveTableLayoutHandle( - new SchemaTableName(TEST_DATABASE, TEST_TABLE_NAME), - TEST_TABLE_NAME, + new SchemaTableName(TEST_DATABASE, TEST_TABLE), + TEST_TABLE, ImmutableList.of(), ImmutableList.of(), // TODO fill out columns ImmutableMap.of(), @@ -718,9 +718,9 @@ public Optional createPageSource( Storage storage, List columns, Map prefilledValues, + String> prefilledValues, Map coercers, + HiveCoercer> coercers, Optional bucketAdaptation, List outputColumns, TupleDomain domainPredicate, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePartitionManager.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePartitionManager.java index c63f5e9d36edb..342e0d01623a4 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePartitionManager.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePartitionManager.java @@ -16,6 +16,7 @@ import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.TestingTypeManager; @@ -23,6 +24,7 @@ import com.facebook.presto.hive.metastore.PrestoTableType; import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.Table; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.Constraint; import com.facebook.presto.testing.TestingConnectorSession; @@ -32,7 +34,9 @@ import org.testng.annotations.Test; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.function.Predicate; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; @@ -42,11 +46,13 @@ import static com.facebook.presto.hive.BucketFunctionType.HIVE_COMPATIBLE; import static com.facebook.presto.hive.HiveColumnHandle.MAX_PARTITION_KEY_COLUMN_INDEX; import static com.facebook.presto.hive.HiveColumnHandle.bucketColumnHandle; +import static com.facebook.presto.hive.HiveMetadata.convertToPredicate; import static com.facebook.presto.hive.HiveStorageFormat.ORC; import static com.facebook.presto.hive.HiveType.HIVE_INT; import static com.facebook.presto.hive.HiveType.HIVE_STRING; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; import static io.airlift.slice.Slices.utf8Slice; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -54,9 +60,12 @@ public class TestHivePartitionManager { private static final String SCHEMA_NAME = "schema"; private static final String TABLE_NAME = "table"; + private static final String TABLE_NAME_LARGE_PARTITIONS = "table_large_partitions"; private static final String USER_NAME = "user"; private static final String LOCATION = "somewhere/over/the/rainbow"; + private static final String LOCATION_LARGE_PARTITIONS = "large/partitions/over/the/rainbow"; private static final Column PARTITION_COLUMN = new Column("ds", HIVE_STRING, Optional.empty(), Optional.empty()); + private static final Column PARTITION_COLUMN_TS = new Column("ts", HIVE_STRING, Optional.empty(), Optional.empty()); private static final Column BUCKET_COLUMN = new Column("c1", HIVE_INT, Optional.empty(), Optional.empty()); private static final Table TABLE = new Table( Optional.of("catalogName"), @@ -80,8 +89,32 @@ public class TestHivePartitionManager ImmutableMap.of(), Optional.empty(), Optional.empty()); + private static final Table TABLE_LARGE_PARTITIONS = new Table( + Optional.of("catalogName"), + SCHEMA_NAME, + TABLE_NAME_LARGE_PARTITIONS, + USER_NAME, + PrestoTableType.MANAGED_TABLE, + new Storage(fromHiveStorageFormat(ORC), + LOCATION_LARGE_PARTITIONS, + Optional.of(new HiveBucketProperty( + ImmutableList.of(BUCKET_COLUMN.getName()), + 100, + ImmutableList.of(), + HIVE_COMPATIBLE, + Optional.empty())), + false, + ImmutableMap.of(), + ImmutableMap.of()), + ImmutableList.of(BUCKET_COLUMN), + ImmutableList.of(PARTITION_COLUMN, PARTITION_COLUMN_TS), + ImmutableMap.of(), + Optional.empty(), + Optional.empty()); private static final List PARTITIONS = ImmutableList.of("ds=2019-07-23", "ds=2019-08-23"); + private static final List PARTITIONS_LARGE_PARTITIONS = ImmutableList.of("ds=2019-07-23/ts=2019-07-23:01:00:00", + "ds=2019-07-23/ts=2019-07-23:10:00:00", "ds=2019-08-23/ts=2019-07-23:01:00:00", "ds=2019-08-23/ts=2019-08-23:05:00:00"); private HivePartitionManager hivePartitionManager = new HivePartitionManager(new TestingTypeManager(), new HiveClientConfig()); private final TestingSemiTransactionalHiveMetastore metastore = TestingSemiTransactionalHiveMetastore.create(); @@ -90,6 +123,7 @@ public class TestHivePartitionManager public void setUp() { metastore.addTable(SCHEMA_NAME, TABLE_NAME, TABLE, PARTITIONS); + metastore.addTable(SCHEMA_NAME, TABLE_NAME_LARGE_PARTITIONS, TABLE_LARGE_PARTITIONS, PARTITIONS_LARGE_PARTITIONS); } @Test @@ -219,4 +253,34 @@ public void testIgnoresBucketingWhenConfigured() assertFalse(result.getBucketHandle().isPresent(), "bucketHandle is present"); assertFalse(result.getBucketFilter().isPresent(), "bucketFilter is present"); } + + @Test + public void testMultiplePartitions() + { + ConnectorSession session = new TestingConnectorSession( + new HiveSessionProperties( + new HiveClientConfig().setIgnoreTableBucketing(true).setOptimizeParsingOfPartitionValues(true).setOptimizeParsingOfPartitionValuesThreshold(2), + new OrcFileWriterConfig(), + new ParquetFileWriterConfig(), + new CacheConfig()) + .getSessionProperties()); + ColumnHandle columnHandle = new HiveColumnHandle( + PARTITION_COLUMN.getName(), + PARTITION_COLUMN.getType(), + parseTypeSignature(StandardTypes.VARCHAR), + MAX_PARTITION_KEY_COLUMN_INDEX, + PARTITION_KEY, + Optional.empty(), + Optional.empty()); + TupleDomain tupleDomain = + TupleDomain.withColumnDomains( + ImmutableMap.of( + columnHandle, + Domain.singleValue(VARCHAR, utf8Slice("2019-07-23")))); + Predicate> predicate = convertToPredicate(tupleDomain); + List predicateInput = ImmutableList.of(columnHandle); + Constraint constraint = new Constraint<>(TupleDomain.all(), Optional.of(predicate), Optional.of(predicateInput)); + HivePartitionResult result = hivePartitionManager.getPartitions(metastore, new HiveTableHandle(SCHEMA_NAME, TABLE_NAME_LARGE_PARTITIONS), constraint, session); + assertEquals(result.getPartitions().size(), 2); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownDistributedQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownDistributedQueries.java index 28ec3c9ef3491..dc816175ffb12 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownDistributedQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownDistributedQueries.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; @@ -34,7 +35,7 @@ public class TestHivePushdownDistributedQueries protected QueryRunner createQueryRunner() throws Exception { - return HiveQueryRunner.createQueryRunner( + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( getTables(), ImmutableMap.of("experimental.pushdown-subfields-enabled", "true", "experimental.pushdown-dereference-enabled", "true"), @@ -44,6 +45,8 @@ protected QueryRunner createQueryRunner() "hive.partial_aggregation_pushdown_enabled", "true", "hive.partial_aggregation_pushdown_for_variable_length_datatypes_enabled", "true"), Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java index 5b0844835b8af..1e0f61f2becdc 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownFilterQueries.java @@ -1163,6 +1163,32 @@ public void testArraySubscriptPushdown() } } + @Test + public void testMapSubscriptPushdown() + { + Session session = enablePushdownFilterAndSubfield(getQueryRunner().getDefaultSession()); + getQueryRunner().execute(session, + "CREATE TABLE test_neg_map_sub_pushdown AS \n" + + "select map(ARRAY[-10,20,30,0], array['a', 'b', 'c', 'd']) numbers"); + + try { + assertQuery("select element_at(numbers,-10) as number from test_neg_map_sub_pushdown", "SELECT 'a'"); + assertQuery("select element_at(numbers,20) as number from test_neg_map_sub_pushdown", "SELECT 'b'"); + assertQuery("select element_at(numbers,30) as number from test_neg_map_sub_pushdown", "SELECT 'c'"); + assertQuery("select element_at(numbers,0) as number from test_neg_map_sub_pushdown", "SELECT 'd'"); + assertQuery("select element_at(numbers,40) as number from test_neg_map_sub_pushdown", "SELECT cast(NULL as varchar)"); + + assertQuery("select numbers[-10] as number from test_neg_map_sub_pushdown", "SELECT 'a'"); + assertQuery("select numbers[20] as number from test_neg_map_sub_pushdown", "SELECT 'b'"); + assertQuery("select numbers[30] as number from test_neg_map_sub_pushdown", "SELECT 'c'"); + assertQuery("select numbers[0] as number from test_neg_map_sub_pushdown", "SELECT 'd'"); + assertQueryFails("select numbers[40] as number from test_neg_map_sub_pushdown", "Key not present in map: 40"); + } + finally { + getQueryRunner().execute("DROP TABLE test_neg_map_sub_pushdown"); + } + } + @Test public void testArraySubscriptPushdownEmptyArray() { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownIntegrationSmokeTest.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownIntegrationSmokeTest.java index 38a943872104d..d81d48329b624 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownIntegrationSmokeTest.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHivePushdownIntegrationSmokeTest.java @@ -54,7 +54,8 @@ protected QueryRunner createQueryRunner() "hive.enable-parquet-dereference-pushdown", "true", "hive.partial_aggregation_pushdown_enabled", "true", "hive.partial_aggregation_pushdown_for_variable_length_datatypes_enabled", "true", - "hive.orc.writer.string-statistics-limit", "128B"), + "hive.orc.writer.string-statistics-limit", "128B", + "hive.restrict-procedure-call", "false"), Optional.empty()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveRecoverableExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveRecoverableExecution.java index 3fa1507441d48..9bface0dc0145 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveRecoverableExecution.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveRecoverableExecution.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.server.testing.TestingPrestoServer; @@ -26,7 +27,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java index b5d7f6b3c8f43..395ef65a24096 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplit.java @@ -16,6 +16,7 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockEncoding; @@ -23,6 +24,7 @@ import com.facebook.presto.common.block.BlockEncodingSerde; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.StorageFormat; @@ -153,8 +155,10 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); + binder.install(new ThriftCodecModule()); binder.install(new HandleJsonModule()); configBinder(binder).bindConfig(FeaturesConfig.class); + binder.bind(ConnectorManager.class).toProvider(() -> null); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitManager.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitManager.java index 9d1ca4454779f..1a56aaa24214a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitManager.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitManager.java @@ -531,7 +531,6 @@ private void assertRedundantColumnDomains(Range predicateRange, PartitionStatist new HivePartitionObjectBuilder(), new HiveEncryptionInformationProvider(ImmutableList.of()), new HivePartitionStats(), - new HiveFileRenamer(), HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, new QuickStatsProvider(metastore, HDFS_ENVIRONMENT, DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(false)); @@ -679,7 +678,6 @@ public void testEncryptionInformation() new HivePartitionObjectBuilder(), encryptionInformationProvider, new HivePartitionStats(), - new HiveFileRenamer(), HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, new QuickStatsProvider(metastore, HDFS_ENVIRONMENT, DO_NOTHING_DIRECTORY_LISTER, new HiveClientConfig(), new NamenodeStats(), ImmutableList.of()), new HiveTableWritabilityChecker(false)); @@ -814,11 +812,11 @@ public Iterator list(ExtendedFileSystem fileSystem, Table table, P { try { return ImmutableList.of( - createHiveFileInfo( - new LocatedFileStatus( - new FileStatus(0, false, 1, 0, 0, new Path(path.toString() + "/" + "test_file_name")), - new BlockLocation[] {}), - Optional.empty())) + createHiveFileInfo( + new LocatedFileStatus( + new FileStatus(0, false, 1, 0, 0, new Path(path.toString() + "/" + "test_file_name")), + new BlockLocation[] {}), + Optional.empty())) .iterator(); } catch (IOException e) { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java index 0113e717e7765..8d1340d489180 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSplitSource.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.spi.ConnectorSplit; @@ -24,7 +25,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.time.Instant; @@ -39,6 +39,9 @@ import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.testing.Assertions.assertContains; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.CacheQuotaScope.GLOBAL; import static com.facebook.presto.hive.CacheQuotaScope.PARTITION; import static com.facebook.presto.hive.CacheQuotaScope.TABLE; @@ -48,9 +51,6 @@ import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.SOFT_AFFINITY; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.toIntExact; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithKeyStore.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithKeyStore.java index 055024a2d2437..3b01b14525b18 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithKeyStore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithKeyStore.java @@ -14,20 +14,22 @@ package com.facebook.presto.hive; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; -import java.net.URISyntaxException; -import java.nio.file.Paths; +import static com.facebook.presto.tests.SslKeystoreManager.SSL_STORE_PASSWORD; +import static com.facebook.presto.tests.SslKeystoreManager.getKeystorePath; +@Test public class TestHiveSslWithKeyStore extends AbstractHiveSslTest { - TestHiveSslWithKeyStore() throws URISyntaxException + TestHiveSslWithKeyStore() { super(ImmutableMap.builder() // This is required when connecting to ssl enabled hms .put("hive.metastore.thrift.client.tls.enabled", "true") - .put("hive.metastore.thrift.client.tls.keystore-path", Paths.get((TestHiveSslWithKeyStore.class.getResource("/hive_ssl_enable/hive-metastore.jks")).toURI()).toFile().toString()) - .put("hive.metastore.thrift.client.tls.keystore-password", "123456") + .put("hive.metastore.thrift.client.tls.keystore-path", getKeystorePath()) + .put("hive.metastore.thrift.client.tls.keystore-password", SSL_STORE_PASSWORD) .build()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStore.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStore.java index 4315c7616eb83..c392283d93bfe 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStore.java @@ -14,20 +14,22 @@ package com.facebook.presto.hive; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; -import java.net.URISyntaxException; -import java.nio.file.Paths; +import static com.facebook.presto.tests.SslKeystoreManager.SSL_STORE_PASSWORD; +import static com.facebook.presto.tests.SslKeystoreManager.getTruststorePath; +@Test public class TestHiveSslWithTrustStore extends AbstractHiveSslTest { - TestHiveSslWithTrustStore() throws URISyntaxException + TestHiveSslWithTrustStore() { super(ImmutableMap.builder() // This is required when connecting to ssl enabled hms .put("hive.metastore.thrift.client.tls.enabled", "true") - .put("hive.metastore.thrift.client.tls.truststore-path", Paths.get((TestHiveSslWithTrustStore.class.getResource("/hive_ssl_enable/hive-metastore-truststore.jks")).toURI()).toFile().toString()) - .put("hive.metastore.thrift.client.tls.truststore-password", "123456") + .put("hive.metastore.thrift.client.tls.truststore-path", getTruststorePath()) + .put("hive.metastore.thrift.client.tls.truststore-password", SSL_STORE_PASSWORD) .build()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStoreKeyStore.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStoreKeyStore.java index 62787a1f6bc2e..3bc978a711872 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStoreKeyStore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveSslWithTrustStoreKeyStore.java @@ -14,22 +14,25 @@ package com.facebook.presto.hive; import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; -import java.net.URISyntaxException; -import java.nio.file.Paths; +import static com.facebook.presto.tests.SslKeystoreManager.SSL_STORE_PASSWORD; +import static com.facebook.presto.tests.SslKeystoreManager.getKeystorePath; +import static com.facebook.presto.tests.SslKeystoreManager.getTruststorePath; +@Test public class TestHiveSslWithTrustStoreKeyStore extends AbstractHiveSslTest { - TestHiveSslWithTrustStoreKeyStore() throws URISyntaxException + TestHiveSslWithTrustStoreKeyStore() { super(ImmutableMap.builder() // This is required when connecting to ssl enabled hms .put("hive.metastore.thrift.client.tls.enabled", "true") - .put("hive.metastore.thrift.client.tls.keystore-path", Paths.get((TestHiveSslWithTrustStoreKeyStore.class.getResource("/hive_ssl_enable/hive-metastore.jks")).toURI()).toFile().toString()) - .put("hive.metastore.thrift.client.tls.keystore-password", "123456") - .put("hive.metastore.thrift.client.tls.truststore-path", Paths.get((TestHiveSslWithTrustStoreKeyStore.class.getResource("/hive_ssl_enable/hive-metastore-truststore.jks")).toURI()).toFile().toString()) - .put("hive.metastore.thrift.client.tls.truststore-password", "123456") + .put("hive.metastore.thrift.client.tls.keystore-path", getKeystorePath()) + .put("hive.metastore.thrift.client.tls.keystore-password", SSL_STORE_PASSWORD) + .put("hive.metastore.thrift.client.tls.truststore-path", getTruststorePath()) + .put("hive.metastore.thrift.client.tls.truststore-password", SSL_STORE_PASSWORD) .build()); } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveUtil.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveUtil.java index 3e7c3cc163c10..e7e7b2d394c54 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveUtil.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestHiveUtil.java @@ -17,6 +17,7 @@ import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.hive.metastore.file.FileHiveMetastore; +import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -24,8 +25,10 @@ import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hive.metastore.Warehouse; import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat; import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer; import org.apache.hadoop.hive.serde2.thrift.test.IntString; +import org.apache.hadoop.mapred.InputFormat; import org.apache.hudi.hadoop.realtime.HoodieRealtimeFileSplit; import org.apache.thrift.protocol.TBinaryProtocol; import org.joda.time.DateTime; @@ -49,6 +52,9 @@ import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.hive.HiveStorageFormat.ORC; +import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; +import static com.facebook.presto.hive.HiveStorageFormat.TEXTFILE; import static com.facebook.presto.hive.HiveTestUtils.SESSION; import static com.facebook.presto.hive.HiveUtil.CLIENT_TAGS_DELIMITER; import static com.facebook.presto.hive.HiveUtil.CUSTOM_FILE_SPLIT_CLASS_KEY; @@ -75,8 +81,10 @@ import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_CLASS; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_FORMAT; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; public class TestHiveUtil @@ -215,6 +223,84 @@ public void testParsePartitionValue() assertEquals(prestoValue, Slices.utf8Slice("USA")); } + @Test + public void testGetInputFormatValidInput() + { + Configuration configuration = new Configuration(); + String inputFormatName = ORC.getInputFormat(); + String serDe = ORC.getSerDe(); + boolean symlinkTarget = false; + + InputFormat inputFormat = HiveUtil.getInputFormat(configuration, inputFormatName, serDe, symlinkTarget); + assertNotNull(inputFormat, "InputFormat should not be null for valid input"); + assertEquals(inputFormat.getClass().getName(), ORC.getInputFormat()); + } + + @Test + public void testGetInputFormatInvalidInputFormatName() + { + Configuration configuration = new Configuration(); + String inputFormatName = "invalid.InputFormatName"; + String serDe = ORC.getSerDe(); + boolean symlinkTarget = false; + + assertThatThrownBy(() -> HiveUtil.getInputFormat(configuration, inputFormatName, serDe, symlinkTarget)) + .isInstanceOf(PrestoException.class) + .hasStackTraceContaining("Unable to create input format invalid.InputFormatName"); + } + + @Test + public void testGetInputFormatMissingSerDeForSymlinkTextInputFormat() + { + Configuration configuration = new Configuration(); + String inputFormatName = SymlinkTextInputFormat.class.getName(); + String serDe = null; + boolean symlinkTarget = true; + + assertThatThrownBy(() -> HiveUtil.getInputFormat(configuration, inputFormatName, serDe, symlinkTarget)) + .isInstanceOf(PrestoException.class) + .hasStackTraceContaining("Missing SerDe for SymlinkTextInputFormat"); + } + + @Test + public void testGetInputFormatUnsupportedSerDeForSymlinkTextInputFormat() + { + Configuration configuration = new Configuration(); + String inputFormatName = SymlinkTextInputFormat.class.getName(); + String serDe = "unsupported.SerDe"; + boolean symlinkTarget = true; + + assertThatThrownBy(() -> HiveUtil.getInputFormat(configuration, inputFormatName, serDe, symlinkTarget)) + .isInstanceOf(PrestoException.class) + .hasStackTraceContaining("Unsupported SerDe for SymlinkTextInputFormat: unsupported.SerDe"); + } + + @Test + public void testGetInputFormatForAllSupportedSerDesForSymlinkTextInputFormat() + { + Configuration configuration = new Configuration(); + boolean symlinkTarget = true; + + /* + * https://github.com/apache/hive/blob/b240eb3266d4736424678d6c71c3c6f6a6fdbf38/ql/src/java/org/apache/hadoop/hive/ql/io/SymlinkTextInputFormat.java#L47-L52 + * According to Hive implementation of SymlinkInputFormat, The target input data should be in TextInputFormat. + * + * But another common use-case of Symlink Tables is to read Delta Lake Symlink Tables with target input data as MapredParquetInputFormat + * https://docs.delta.io/latest/presto-integration.html + */ + List supportedFormats = ImmutableList.of(PARQUET, TEXTFILE); + + for (HiveStorageFormat hiveStorageFormat : supportedFormats) { + String inputFormatName = SymlinkTextInputFormat.class.getName(); + String serDe = hiveStorageFormat.getSerDe(); + + InputFormat inputFormat = HiveUtil.getInputFormat(configuration, inputFormatName, serDe, symlinkTarget); + + assertNotNull(inputFormat, "InputFormat should not be null for valid SerDe: " + serDe); + assertEquals(inputFormat.getClass().getName(), hiveStorageFormat.getInputFormat()); + } + } + private static void assertToPartitionValues(String partitionName) throws MetaException { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java index a0559b03b0302..b99eff80743ab 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestLambdaSubfieldPruning.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.presto.Session; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.facebook.presto.tests.DistributedQueryRunner; @@ -126,6 +127,7 @@ private static DistributedQueryRunner createLineItemExTable(DistributedQueryRunn "FROM lineitem \n"); } + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); return queryRunner; } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlan.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlan.java index dc1a093fe1cdd..4bd71b46d8177 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlan.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlan.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; @@ -56,6 +57,12 @@ protected QueryRunner createQueryRunner() Optional.empty()); } + @Override + protected FeaturesConfig createFeaturesConfig() + { + return new FeaturesConfig().setNativeExecutionEnabled(true); + } + @Test public void testJoinType() { @@ -83,19 +90,19 @@ public void testJoinType() assertPlan( mergeJoinEnabled(), "select * from test_join_customer_join_type left join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", - joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), LEFT, false)); + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), LEFT, true)); // Right join assertPlan( mergeJoinEnabled(), "select * from test_join_customer_join_type right join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", - joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), RIGHT, false)); + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), RIGHT, true)); // Outer join assertPlan( mergeJoinEnabled(), "select * from test_join_customer_join_type full join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", - joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), FULL, false)); + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), FULL, true)); } finally { queryRunner.execute("DROP TABLE IF EXISTS test_join_customer_join_type"); diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlanPrestoOnSpark.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlanPrestoOnSpark.java new file mode 100644 index 0000000000000..a1b8acc60c933 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestMergeJoinPlanPrestoOnSpark.java @@ -0,0 +1,429 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.GROUPED_EXECUTION; +import static com.facebook.presto.SystemSessionProperties.PREFER_MERGE_JOIN_FOR_SORTED_INPUTS; +import static com.facebook.presto.hive.HiveQueryRunner.HIVE_CATALOG; +import static com.facebook.presto.hive.HiveSessionProperties.ORDER_BASED_EXECUTION_ENABLED; +import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.mergeJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; + +public class TestMergeJoinPlanPrestoOnSpark + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.createQueryRunner( + ImmutableList.of(ORDERS, LINE_ITEM, CUSTOMER, NATION), + ImmutableMap.of(), + Optional.empty()); + } + + @Override + protected FeaturesConfig createFeaturesConfig() + { + return new FeaturesConfig().setNativeExecutionEnabled(true).setPrestoSparkExecutionEnvironment(true); + } + + @Test + public void testJoinType() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer_join_type WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order_join_type WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // When merge join session property is turned on and data properties requirements for merge join are met + // Inner join + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_join_type join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, true)); + + // Left join + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_join_type left join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), LEFT, true)); + + // Right join + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_join_type right join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), RIGHT, true)); + + // Outer join + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_join_type full join test_join_order_join_type on test_join_customer_join_type.custkey = test_join_order_join_type.custkey", + joinPlan("test_join_customer_join_type", "test_join_order_join_type", ImmutableList.of("custkey"), ImmutableList.of("custkey"), FULL, true)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer_join_type"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order_join_type"); + } + } + + @Test + public void testSessionProperty() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // By default, we can't enable merge join + assertPlan( + "select * from test_join_customer join test_join_order on test_join_customer.custkey = test_join_order.custkey", + joinPlan("test_join_customer", "test_join_order", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, false)); + + // when we miss session property, we can't enable merge join + assertPlan( + "select * from test_join_customer join test_join_order on test_join_customer.custkey = test_join_order.custkey", + joinPlan("test_join_customer", "test_join_order", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, false)); + + // When merge join session property is turned on and data properties requirements for merge join are met + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer join test_join_order on test_join_customer.custkey = test_join_order.custkey", + joinPlan("test_join_customer", "test_join_order", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, true)); + + // Presto on spark does not need the grouped execution to be enabled + assertPlan( + groupedExecutionDisabled(), + "select * from test_join_customer join test_join_order on test_join_customer.custkey = test_join_order.custkey", + joinPlan("test_join_customer", "test_join_order", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, true)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order"); + } + } + + @Test + public void testDifferentBucketedByKey() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer2 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['name'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order2 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // merge join can't be enabled + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer2 join test_join_order2 on test_join_customer2.custkey = test_join_order2.custkey", + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("custkey_l", "custkey_r")), + Optional.empty(), + exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, ImmutableList.of(sort("custkey_l", ASCENDING, FIRST)), + sort(anyTree(tableScan("test_join_customer2", ImmutableMap.of("custkey_l", "custkey"))))), + tableScan("test_join_order2", ImmutableMap.of("custkey_r", "custkey"))))); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer2"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order2"); + } + } + + @Test + public void testDifferentSortByKey() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer3 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['name'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order3 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // merge join can be enabled when only one side is sorted (Presto on Spark behavior) + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer3 join test_join_order3 on test_join_customer3.custkey = test_join_order3.custkey", + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("custkey_l", "custkey_r")), + Optional.empty(), + exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.GATHER, ImmutableList.of(sort("custkey_l", ASCENDING, FIRST)), + sort(tableScan("test_join_customer3", ImmutableMap.of("custkey_l", "custkey")))), + tableScan("test_join_order3", ImmutableMap.of("custkey_r", "custkey"))))); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer3"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order3"); + } + } + + @Test + public void testMultipleSortByKeys() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer4 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey', 'name'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order4 WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // merge join can be enabled + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer4 join test_join_order4 on test_join_customer4.custkey = test_join_order4.custkey", + joinPlan("test_join_customer4", "test_join_order4", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, true)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer4"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order4"); + } + } + + @Test + public void testMultipleJoinKeys() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer5(" + + " \"custkey\" bigint, \"name\" varchar(25), \"address\" varchar(40), \"orderkey\" bigint, \"phone\" varchar(15), \n" + + " \"acctbal\" double, \"mktsegment\" varchar(10), \"comment\" varchar(117), \"ds\" varchar(10)) WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey', 'orderkey'], \n" + + " sorted_by = ARRAY['custkey', 'orderkey'], partitioned_by=array['ds'], \n" + + " format = 'DWRF' )"); + queryRunner.execute("INSERT INTO test_join_customer5 \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order5(" + + " \"orderkey\" bigint, \"custkey\" bigint, \"orderstatus\" varchar(1), \"totalprice\" double, \"orderdate\" date," + + " \"orderpriority\" varchar(15), \"clerk\" varchar(15), \"shippriority\" integer, \"comment\" varchar(79), \"ds\" varchar(10)) WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey', 'orderkey'], \n" + + " sorted_by = ARRAY['custkey', 'orderkey'], partitioned_by=array['ds'])"); + queryRunner.execute("INSERT INTO test_join_order5 \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.orders LIMIT 1000"); + + // merge join can be enabled + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer5 join test_join_order5 on test_join_customer5.custkey = test_join_order5.custkey and test_join_customer5.orderkey = test_join_order5.orderkey", + joinPlan("test_join_customer5", "test_join_order5", ImmutableList.of("custkey", "orderkey"), ImmutableList.of("custkey", "orderkey"), INNER, true)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer5"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order5"); + } + } + + @Test + public void testMultiplePartitions() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer_multi_partitions WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + queryRunner.execute("INSERT INTO test_join_customer_multi_partitions \n" + + "SELECT *, '2021-07-12' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order_multi_partitions WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + queryRunner.execute("INSERT INTO test_join_order_multi_partitions \n" + + "SELECT *, '2021-07-12' as ds FROM tpch.sf1.orders LIMIT 1000"); + + // When partition key does not appear in join keys and we query multiple partitions, we can't enable merge join + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_multi_partitions join test_join_order_multi_partitions on test_join_customer_multi_partitions.custkey = test_join_order_multi_partitions.custkey", + joinPlan("test_join_customer_multi_partitions", "test_join_order_multi_partitions", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, false)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer_multi_partitions"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order_multi_partitions"); + } + } + + @Test + public void testBothSidesNotBucketed() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer_not_bucketed WITH ( \n" + + " partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order_not_bucketed WITH ( \n" + + " partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // merge join can't be enabled when both sides are not bucketed + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_not_bucketed join test_join_order_not_bucketed on test_join_customer_not_bucketed.custkey = test_join_order_not_bucketed.custkey", + joinPlan("test_join_customer_not_bucketed", "test_join_order_not_bucketed", ImmutableList.of("custkey"), ImmutableList.of("custkey"), INNER, false)); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer_not_bucketed"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order_not_bucketed"); + } + } + + @Test + public void testOnlyOneSideBucketedAndSorted() + { + QueryRunner queryRunner = getQueryRunner(); + + try { + queryRunner.execute("CREATE TABLE test_join_customer_not_bucketed_sorted WITH ( \n" + + " partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.customer LIMIT 1000"); + + queryRunner.execute("CREATE TABLE test_join_order_bucketed_sorted WITH ( \n" + + " bucket_count = 4, bucketed_by = ARRAY['custkey'], \n" + + " sorted_by = ARRAY['custkey'], partitioned_by=array['ds']) AS \n" + + "SELECT *, '2021-07-11' as ds FROM tpch.sf1.\"orders\" LIMIT 1000"); + + // merge join can be enabled when only one side is bucketed and sorted + assertPlan( + mergeJoinEnabled(), + "select * from test_join_customer_not_bucketed_sorted join test_join_order_bucketed_sorted on test_join_customer_not_bucketed_sorted.custkey = test_join_order_bucketed_sorted.custkey", + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("custkey_l", "custkey_r")), + Optional.empty(), + anyTree(sort(anyTree(tableScan("test_join_customer_not_bucketed_sorted", ImmutableMap.of("custkey_l", "custkey"))))), + tableScan("test_join_order_bucketed_sorted", ImmutableMap.of("custkey_r", "custkey"))))); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS test_join_customer_not_bucketed_sorted"); + queryRunner.execute("DROP TABLE IF EXISTS test_join_order_bucketed_sorted"); + } + } + + private Session groupedExecutionDisabled() + { + return Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(PREFER_MERGE_JOIN_FOR_SORTED_INPUTS, "true") + .setSystemProperty(GROUPED_EXECUTION, "false") + .setCatalogSessionProperty(HIVE_CATALOG, ORDER_BASED_EXECUTION_ENABLED, "true") + .build(); + } + + private Session mergeJoinEnabled() + { + return Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(PREFER_MERGE_JOIN_FOR_SORTED_INPUTS, "true") + .setSystemProperty(GROUPED_EXECUTION, "true") + .setCatalogSessionProperty(HIVE_CATALOG, ORDER_BASED_EXECUTION_ENABLED, "true") + .build(); + } + + private PlanMatchPattern joinPlan(String leftTableName, String rightTableName, List leftJoinKeys, List rightJoinKeys, JoinType joinType, boolean mergeJoinEnabled) + { + int suffix1 = 0; + int suffix2 = 1; + ImmutableMap.Builder leftColumnReferencesBuilder = ImmutableMap.builder(); + ImmutableMap.Builder rightColumnReferencesBuilder = ImmutableMap.builder(); + ImmutableList.Builder joinClauses = ImmutableList.builder(); + for (int i = 0; i < leftJoinKeys.size(); i++) { + leftColumnReferencesBuilder.put(leftJoinKeys.get(i) + suffix1, leftJoinKeys.get(i)); + rightColumnReferencesBuilder.put(rightJoinKeys.get(i) + suffix2, rightJoinKeys.get(i)); + joinClauses.add(equiJoinClause(leftJoinKeys.get(i) + suffix1, rightJoinKeys.get(i) + suffix2)); + suffix1 = suffix1 + 2; + suffix2 = suffix2 + 2; + } + + return mergeJoinEnabled ? + anyTree(mergeJoin( + joinType, + joinClauses.build(), + Optional.empty(), + PlanMatchPattern.tableScan(leftTableName, leftColumnReferencesBuilder.build()), + PlanMatchPattern.tableScan(rightTableName, rightColumnReferencesBuilder.build()))) : + anyTree(join( + joinType, + joinClauses.build(), + Optional.empty(), + Optional.of(PARTITIONED), + anyTree(PlanMatchPattern.tableScan(leftTableName, leftColumnReferencesBuilder.build())), + anyTree(PlanMatchPattern.tableScan(rightTableName, rightColumnReferencesBuilder.build())))); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcBatchPageSourceMemoryTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcBatchPageSourceMemoryTracking.java index 30f02bb1a0477..6144871f10040 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcBatchPageSourceMemoryTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcBatchPageSourceMemoryTracking.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.airlift.stats.Distribution; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.predicate.TupleDomain; @@ -55,7 +56,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -102,6 +102,7 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.PARTITION_KEY; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; @@ -119,7 +120,6 @@ import static com.google.common.base.Predicates.not; import static com.google.common.collect.Iterables.filter; import static com.google.common.collect.Iterables.transform; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; @@ -456,7 +456,6 @@ public ConnectorPageSource newPageSource(FileFormatDataSourceStats stats, Connec OrcBatchPageSourceFactory orcPageSourceFactory = new OrcBatchPageSourceFactory( FUNCTION_AND_TYPE_MANAGER, - false, HDFS_ENVIRONMENT, stats, 100, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriter.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriter.java index 15029cb135563..8a0a058718f55 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriter.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.common.io.DataOutput; import com.facebook.presto.common.io.DataSink; @@ -21,13 +22,13 @@ import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.io.IOException; import java.util.List; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.ErrorType.EXTERNAL; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.hive.HiveErrorCode.HIVE_WRITER_CLOSE_ERROR; @@ -36,7 +37,6 @@ import static com.facebook.presto.orc.NoOpOrcWriterStats.NOOP_WRITER_STATS; import static com.facebook.presto.orc.OrcEncoding.ORC; import static com.facebook.presto.orc.metadata.CompressionKind.NONE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; public class TestOrcFileWriter diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriterConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriterConfig.java index 83d4fb359ec23..ecb01c6fc082b 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriterConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestOrcFileWriterConfig.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.OrcFileWriterConfig.StreamLayoutType; import com.facebook.presto.orc.OrcWriterOptions; import com.facebook.presto.orc.metadata.DwrfStripeCacheMode; import com.facebook.presto.orc.writer.StreamLayoutFactory.ColumnSizeLayoutFactory; import com.facebook.presto.orc.writer.StreamLayoutFactory.StreamSizeLayoutFactory; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; @@ -29,14 +29,14 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.hive.OrcFileWriterConfig.StreamLayoutType.BY_COLUMN_SIZE; import static com.facebook.presto.hive.OrcFileWriterConfig.StreamLayoutType.BY_STREAM_SIZE; import static com.facebook.presto.orc.metadata.DwrfStripeCacheMode.FOOTER; import static com.facebook.presto.orc.metadata.DwrfStripeCacheMode.INDEX; import static com.facebook.presto.orc.metadata.DwrfStripeCacheMode.INDEX_AND_FOOTER; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.toIntExact; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java index d8bb271cebcd0..0aef10a5b4bf0 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetDistributedQueries.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive; import com.facebook.presto.Session; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestDistributedQueries; @@ -46,7 +47,7 @@ protected QueryRunner createQueryRunner() .put("hive.partial_aggregation_pushdown_enabled", "true") .put("hive.partial_aggregation_pushdown_for_variable_length_datatypes_enabled", "true") .build(); - return HiveQueryRunner.createQueryRunner( + QueryRunner queryRunner = HiveQueryRunner.createQueryRunner( getTables(), ImmutableMap.of( "experimental.pushdown-subfields-enabled", "true", @@ -54,6 +55,8 @@ protected QueryRunner createQueryRunner() "sql-standard", parquetProperties, Optional.empty()); + queryRunner.installPlugin(new SqlInvokedFunctionsPlugin()); + return queryRunner; } @Test diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetFileWriterConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetFileWriterConfig.java index c1932253b6650..a520d5f9c6114 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetFileWriterConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestParquetFileWriterConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.hadoop.ParquetWriter; import org.testng.annotations.Test; @@ -24,8 +24,8 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class TestParquetFileWriterConfig { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestSingleNodeCteExecution.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestSingleNodeCteExecution.java new file mode 100644 index 0000000000000..54e902a147156 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestSingleNodeCteExecution.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.presto.testing.QueryRunner; +import org.testng.annotations.Test; + +@Test(singleThreaded = true) +public class TestSingleNodeCteExecution + extends AbstractTestCteExecution +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createQueryRunner(true); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestSortingFileWriterConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestSortingFileWriterConfig.java index b1da10e5467f8..f3899312571f8 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/TestSortingFileWriterConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestSortingFileWriterConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.hive; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; @@ -22,8 +22,8 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class TestSortingFileWriterConfig { diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/TestStoragePartitionLoader.java b/presto-hive/src/test/java/com/facebook/presto/hive/TestStoragePartitionLoader.java new file mode 100644 index 0000000000000..b24c7746ea803 --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/TestStoragePartitionLoader.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive; + +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.hive.TestBackgroundHiveSplitLoader.TestingHdfsEnvironment; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.testing.TestingConnectorSession; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.BlockLocation; +import org.apache.hadoop.fs.LocatedFileStatus; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat; +import org.apache.hadoop.mapred.InputFormat; +import org.testng.annotations.Test; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.presto.hive.HiveSessionProperties.isSkipEmptyFilesEnabled; +import static com.facebook.presto.hive.HiveSessionProperties.isUseListDirectoryCache; +import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; +import static com.facebook.presto.hive.HiveTestUtils.getAllSessionProperties; +import static com.facebook.presto.hive.HiveUtil.buildDirectoryContextProperties; +import static com.facebook.presto.hive.NestedDirectoryPolicy.IGNORED; +import static com.facebook.presto.hive.StoragePartitionLoader.BucketSplitInfo.createBucketSplitInfo; +import static com.facebook.presto.hive.TestBackgroundHiveSplitLoader.SIMPLE_TABLE; +import static com.facebook.presto.hive.TestBackgroundHiveSplitLoader.samplePartitionMetadatas; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static org.testng.Assert.assertEquals; + +public class TestStoragePartitionLoader +{ + @Test + public void testGetSymlinkIterator() + throws Exception + { + CachingDirectoryLister directoryLister = new CachingDirectoryLister( + new HadoopDirectoryLister(), + new Duration(5, TimeUnit.MINUTES), + new DataSize(100, KILOBYTE), + ImmutableList.of()); + + Configuration configuration = new Configuration(false); + + InputFormat inputFormat = HiveUtil.getInputFormat( + configuration, + SymlinkTextInputFormat.class.getName(), + PARQUET.getSerDe(), + true); + + Path firstFilePath = new Path("hdfs://hadoop:9000/db_name/table_name/file1"); + Path secondFilePath = new Path("hdfs://hadoop:9000/db_name/table_name/file2"); + List paths = ImmutableList.of(firstFilePath, secondFilePath); + List files = paths.stream() + .map(path -> locatedFileStatus(path, 0L)) + .collect(toImmutableList()); + + ConnectorSession connectorSession = new TestingConnectorSession(getAllSessionProperties( + new HiveClientConfig().setMaxSplitSize(new DataSize(1.0, GIGABYTE)) + .setFileStatusCacheTables(""), + new HiveCommonClientConfig())); + + StoragePartitionLoader storagePartitionLoader = storagePartitionLoader(files, directoryLister, connectorSession); + + HdfsContext hdfsContext = new HdfsContext( + connectorSession, + SIMPLE_TABLE.getDatabaseName(), + SIMPLE_TABLE.getTableName(), + SIMPLE_TABLE.getStorage().getLocation(), + false); + + HiveDirectoryContext hiveDirectoryContext = new HiveDirectoryContext( + IGNORED, + isUseListDirectoryCache(connectorSession), + isSkipEmptyFilesEnabled(connectorSession), + hdfsContext.getIdentity(), + buildDirectoryContextProperties(connectorSession), + connectorSession.getRuntimeStats()); + + Iterator symlinkIterator = storagePartitionLoader.getSymlinkIterator( + new Path("hdfs://hadoop:9000/db_name/table_name/symlink_manifest"), + false, + SIMPLE_TABLE.getStorage(), + ImmutableList.of(), + "UNPARTITIONED", + SIMPLE_TABLE.getDataColumns().size(), + getOnlyElement(samplePartitionMetadatas()), + true, + new Path("hdfs://hadoop:9000/db_name/table_name/"), + paths, + inputFormat, + hiveDirectoryContext); + + List splits = ImmutableList.copyOf(symlinkIterator); + assertEquals(splits.size(), 2); + assertEquals(splits.get(0).getPath(), firstFilePath.toString()); + assertEquals(splits.get(1).getPath(), secondFilePath.toString()); + } + + private static LocatedFileStatus locatedFileStatus(Path path, long fileSize) + { + return new LocatedFileStatus( + fileSize, + false, + 0, + 0L, + 0L, + 0L, + null, + null, + null, + null, + path, + new org.apache.hadoop.fs.BlockLocation[]{new BlockLocation(new String[1], new String[]{"localhost"}, 0, fileSize)}); + } + + private static StoragePartitionLoader storagePartitionLoader( + List files, + DirectoryLister directoryLister, + ConnectorSession connectorSession) + { + return new StoragePartitionLoader( + SIMPLE_TABLE, + ImmutableMap.of(), + createBucketSplitInfo(Optional.empty(), Optional.empty()), + connectorSession, + new TestingHdfsEnvironment(files), + new NamenodeStats(), + directoryLister, + new ConcurrentLinkedDeque<>(), + false, + false, + false); + } +} diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java index e0c1a052830a0..41a5844cdf0ae 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/FileFormat.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.io.OutputStreamDataSink; @@ -66,7 +67,6 @@ import com.facebook.presto.spi.page.PagesSerde; import com.google.common.collect.ImmutableMap; import io.airlift.slice.OutputStreamSliceOutput; -import io.airlift.units.DataSize; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; @@ -83,7 +83,6 @@ import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; import static com.facebook.presto.hive.CacheQuota.NO_CACHE_CONSTRAINTS; -import static com.facebook.presto.hive.HiveCompressionCodec.NONE; import static com.facebook.presto.hive.HiveStorageFormat.PAGEFILE; import static com.facebook.presto.hive.HiveTestUtils.FUNCTION_AND_TYPE_MANAGER; import static com.facebook.presto.hive.HiveTestUtils.FUNCTION_RESOLUTION; @@ -161,7 +160,6 @@ public ConnectorPageSource createFileFormatReader(ConnectorSession session, Hdfs { HiveBatchPageSourceFactory pageSourceFactory = new OrcBatchPageSourceFactory( FUNCTION_AND_TYPE_MANAGER, - false, hdfsEnvironment, new FileFormatDataSourceStats(), 100, @@ -247,9 +245,6 @@ public FormatWriter createFileFormatWriter( HiveCompressionCodec compressionCodec) throws IOException { - if (!compressionCodec.isSupportedStorageFormat(PAGEFILE)) { - compressionCodec = NONE; - } return new PrestoPageFormatWriter(targetFile, compressionCodec); } }, @@ -697,7 +692,7 @@ public PrestoParquetFormatWriter(File targetFile, List columnNames, List columnNames, types, ParquetWriterOptions.builder().build(), - compressionCodec.getParquetCompressionCodec().get().getHadoopCompressionCodecClassName()); + compressionCodec.getParquetCompressionCodec().getHadoopCompressionCodecClassName()); } @Override diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java index 7ded16aeed836..e17c895057f89 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/HiveFileFormatBenchmark.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.benchmark; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.BlockBuilder; @@ -32,7 +33,6 @@ import io.airlift.tpch.TpchColumn; import io.airlift.tpch.TpchEntity; import io.airlift.tpch.TpchTable; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.ints.IntArrays; import org.openjdk.jmh.annotations.AuxCounters; import org.openjdk.jmh.annotations.Benchmark; @@ -61,6 +61,7 @@ import java.util.UUID; import java.util.concurrent.TimeUnit; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.DoubleType.DOUBLE; @@ -74,7 +75,6 @@ import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.airlift.tpch.TpchTable.LINE_ITEM; import static io.airlift.tpch.TpchTable.ORDERS; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.nio.file.Files.createTempDirectory; import static java.util.stream.Collectors.toList; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveHadoopContainer.java b/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveHadoopContainer.java index 3cd03d55fe6dc..c6b2eab091617 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveHadoopContainer.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveHadoopContainer.java @@ -30,9 +30,9 @@ public class HiveHadoopContainer { private static final Logger log = Logger.get(HiveHadoopContainer.class); - private static final String IMAGE_VERSION = "10"; - public static final String DEFAULT_IMAGE = "prestodb/hdp2.6-hive:" + IMAGE_VERSION; - public static final String HIVE3_IMAGE = "prestodb/hive3.1-hive:" + IMAGE_VERSION; + private static final String IMAGE_VERSION = "11"; + public static final String DEFAULT_IMAGE = "prestodb/hdp3.1-hive:" + IMAGE_VERSION; + public static final String HIVE3_IMAGE = "prestodb/hive3.1-hive:10"; public static final String HOST_NAME = "hadoop-master"; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveMinIODataLake.java b/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveMinIODataLake.java index a3fda22a07c6c..0a0e5bf7a67f9 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveMinIODataLake.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/containers/HiveMinIODataLake.java @@ -25,10 +25,18 @@ import java.io.Closeable; import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import static com.facebook.presto.hive.containers.HiveHadoopContainer.HIVE3_IMAGE; +import static com.facebook.presto.tests.SslKeystoreManager.getKeystorePath; +import static com.facebook.presto.tests.SslKeystoreManager.getTruststorePath; +import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; import static java.util.Objects.requireNonNull; import static org.testcontainers.containers.Network.newNetwork; @@ -37,6 +45,7 @@ public class HiveMinIODataLake { public static final String ACCESS_KEY = "accesskey"; public static final String SECRET_KEY = "secretkey"; + private static final Object SSL_LOCK = new Object(); private final String bucketName; private final MinIOContainer minIOContainer; @@ -67,15 +76,38 @@ public HiveMinIODataLake(String bucketName, Map hiveHadoopFilesT .putAll(hiveHadoopFilesToMount); String hadoopCoreSitePath = "/etc/hadoop/conf/core-site.xml"; - if (hiveHadoopImage == HIVE3_IMAGE) { + + if (Objects.equals(hiveHadoopImage, HIVE3_IMAGE)) { hadoopCoreSitePath = "/opt/hadoop/etc/hadoop/core-site.xml"; filesToMount.put("hive_s3_insert_overwrite/hive-site.xml", "/opt/hive/conf/hive-site.xml"); } filesToMount.put("hive_s3_insert_overwrite/hadoop-core-site.xml", hadoopCoreSitePath); if (isSslEnabledTest) { - filesToMount.put("hive_ssl_enable/hive-site.xml", "/opt/hive/conf/hive-site.xml"); - filesToMount.put("hive_ssl_enable/hive-metastore.jks", "/opt/hive/conf/hive-metastore.jks"); - filesToMount.put("hive_ssl_enable/hive-metastore-truststore.jks", "/opt/hive/conf/hive-metastore-truststore.jks"); + try { + // Copy dynamically generated keystore files into target/test-classes so that + // Testcontainers can resolve them. + // Without this step, the files would only exist on the filesystem and not + // on the test runtime classpath, causing classpath lookups to fail. + Path targetDir = Paths.get("target", "test-classes", "ssl_enable"); + Files.createDirectories(targetDir); + + Path keyStoreTarget = targetDir.resolve("keystore.jks"); + Path trustStoreTarget = targetDir.resolve("truststore.jks"); + + synchronized (SSL_LOCK) { + // Copy freshly generated keystores, replacing if they exist + Files.copy(Paths.get(getKeystorePath()), keyStoreTarget, REPLACE_EXISTING); + Files.copy(Paths.get(getTruststorePath()), trustStoreTarget, REPLACE_EXISTING); + + filesToMount.put("ssl_enable/keystore.jks", "/opt/hive/conf/hive-metastore.jks"); + filesToMount.put("ssl_enable/truststore.jks", "/opt/hive/conf/hive-metastore-truststore.jks"); + } + + filesToMount.put("hive_ssl_enable/hive-site.xml", "/opt/hive/conf/hive-site.xml"); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to prepare keystore files for Testcontainers", e); + } } this.hiveHadoopContainer = closer.register( HiveHadoopContainer.builder() diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/hudi/TestHudiIntegration.java b/presto-hive/src/test/java/com/facebook/presto/hive/hudi/TestHudiIntegration.java index 581156518e984..c492e54b0d757 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/hudi/TestHudiIntegration.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/hudi/TestHudiIntegration.java @@ -34,6 +34,7 @@ import static com.facebook.presto.hive.hudi.HudiTestingDataGenerator.HUDI_META_COLUMNS; import static com.facebook.presto.hive.hudi.HudiTestingDataGenerator.PARTITION_COLUMNS; import static java.lang.String.format; +import static java.util.Locale.ENGLISH; public class TestHudiIntegration extends AbstractTestQueryFramework @@ -151,10 +152,43 @@ public void testQueryOnUnavailablePartition() private static String generateDescribeIdenticalQuery(TypeManager typeManager, List metaColumns, List dataColumns, List partitionColumns) { Stream regularRows = Streams.concat(metaColumns.stream(), dataColumns.stream()) - .map(column -> format("('%s', '%s', '', '')", column.getName(), column.getType().getType(typeManager).getDisplayName())); + .map(column -> { + return format("('%s', '%s', '', '', %s, NULL, %s)", column.getName(), column.getType().getType(typeManager).getDisplayName(), + getColumnSize(column.getType().getType(typeManager).getDisplayName()), + (column.getType().getType(typeManager).getDisplayName()).toLowerCase(ENGLISH).equals("varchar") ? "2147483647" : "NULL"); + }); + Stream partitionRows = partitionColumns.stream() - .map(column -> format("('%s', '%s', 'partition key', '')", column.getName(), column.getType().getType(typeManager).getDisplayName())); - String rows = Streams.concat(regularRows, partitionRows).collect(Collectors.joining(",")); + .map(column -> { + return format("('%s', '%s', 'partition key', '', %s, NULL, %s)", column.getName(), + column.getType().getType(typeManager).getDisplayName(), + getColumnSize(column.getType().getType(typeManager).getDisplayName()), + (column.getType().getType(typeManager).getDisplayName()).toLowerCase(ENGLISH).equals("varchar") ? "2147483647" : "NULL"); + }); + + String rows = Streams.concat(regularRows, partitionRows) + .collect(Collectors.joining(", ")); + return "SELECT * FROM VALUES " + rows; } + + private static String getColumnSize(String type) + { + switch (type.toLowerCase(ENGLISH)) { + case "bigint": + return "19"; + case "integer": + return "10"; + case "smallint": + return "5"; + case "tinyint": + return "3"; + case "double": + return "53"; + case "real": + return "24"; + default: + return "NULL"; + } + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestInMemoryCachingHiveMetastore.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestInMemoryCachingHiveMetastore.java index 8b131d8c2618a..f3fe7dee3b7c4 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestInMemoryCachingHiveMetastore.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/TestInMemoryCachingHiveMetastore.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.hive.metastore; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.MockHiveMetastore; import com.facebook.presto.hive.PartitionMutator; import com.facebook.presto.hive.PartitionNameWithVersion; -import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; import com.facebook.presto.hive.metastore.thrift.BridgingHiveMetastore; import com.facebook.presto.hive.metastore.thrift.HiveCluster; import com.facebook.presto.hive.metastore.thrift.HiveMetastoreClient; @@ -32,7 +32,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -45,6 +44,10 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.presto.hive.HiveTestUtils.HDFS_ENVIRONMENT; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ALL; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.PARTITION_STATISTICS; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE; import static com.facebook.presto.hive.metastore.NoopMetastoreCacheStats.NOOP_METASTORE_CACHE_STATS; import static com.facebook.presto.hive.metastore.Partition.Builder; import static com.facebook.presto.hive.metastore.thrift.MockHiveMetastoreClient.BAD_DATABASE; @@ -77,7 +80,8 @@ public class TestInMemoryCachingHiveMetastore private static final ImmutableList EXPECTED_PARTITIONS = ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2); private MockHiveMetastoreClient mockClient; - private InMemoryCachingHiveMetastore metastore; + private InMemoryCachingHiveMetastore metastoreWithAllCachesEnabled; + private InMemoryCachingHiveMetastore metastoreWithSelectiveCachesEnabled; private ThriftHiveMetastoreStats stats; @BeforeMethod @@ -87,20 +91,44 @@ public void setUp() MockHiveCluster mockHiveCluster = new MockHiveCluster(mockClient); ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("test-%s"))); MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfig.setMetastoreCacheMaximumSize(1000); + metastoreClientConfig.setEnabledCaches(ALL.name()); + ThriftHiveMetastore thriftHiveMetastore = new ThriftHiveMetastore(mockHiveCluster, metastoreClientConfig, HDFS_ENVIRONMENT); PartitionMutator hivePartitionMutator = new HivePartitionMutator(); - metastore = new InMemoryCachingHiveMetastore( + metastoreWithAllCachesEnabled = new InMemoryCachingHiveMetastore( new BridgingHiveMetastore(thriftHiveMetastore, hivePartitionMutator), executor, false, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES), 1000, false, - MetastoreCacheScope.ALL, 0.0, metastoreClientConfig.getPartitionCacheColumnCountLimit(), - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); + + MetastoreClientConfig metastoreClientConfigWithSelectiveCaching = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfigWithSelectiveCaching.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfigWithSelectiveCaching.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfigWithSelectiveCaching.setMetastoreCacheMaximumSize(1000); + metastoreClientConfigWithSelectiveCaching.setDisabledCaches(TABLE.name()); + + ThriftHiveMetastore thriftHiveMetastoreWithSelectiveCaching = new ThriftHiveMetastore(mockHiveCluster, metastoreClientConfigWithSelectiveCaching, HDFS_ENVIRONMENT); + metastoreWithSelectiveCachesEnabled = new InMemoryCachingHiveMetastore( + new BridgingHiveMetastore(thriftHiveMetastoreWithSelectiveCaching, hivePartitionMutator), + executor, + false, + 1000, + false, + 0.0, + metastoreClientConfigWithSelectiveCaching.getPartitionCacheColumnCountLimit(), + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfigWithSelectiveCaching)); + stats = thriftHiveMetastore.getStats(); } @@ -108,19 +136,19 @@ public void setUp() public void testGetAllDatabases() { assertEquals(mockClient.getAccessCount(), 0); - assertEquals(metastore.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); + assertEquals(metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); assertEquals(mockClient.getAccessCount(), 1); - assertEquals(metastore.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); + assertEquals(metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertEquals(metastore.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); + assertEquals(metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); assertEquals(mockClient.getAccessCount(), 2); // Test invalidate a specific database - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); - assertEquals(metastore.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); + assertEquals(metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT), ImmutableList.of(TEST_DATABASE)); assertEquals(mockClient.getAccessCount(), 3); } @@ -128,69 +156,94 @@ public void testGetAllDatabases() public void testGetAllTable() { assertEquals(mockClient.getAccessCount(), 0); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 1); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 2); // Test invalidate a specific database which will also invalidate all table caches mapped to that database - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 3); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 3); // Test invalidate a specific database.table which also invalidates the tablesNamesCache for that database - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); - assertEquals(metastore.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); + assertEquals(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 4); } + @Test + public void testGetAllTableWithSelectiveCaching() + { + assertEquals(mockClient.getAccessCount(), 0); + assertEquals(metastoreWithSelectiveCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(mockClient.getAccessCount(), 1); + assertEquals(metastoreWithSelectiveCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(mockClient.getAccessCount(), 1); + + metastoreWithSelectiveCachesEnabled.invalidateAll(); + + assertEquals(metastoreWithSelectiveCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, TEST_DATABASE).get(), ImmutableList.of(TEST_TABLE, TEST_TABLE_WITH_CONSTRAINTS)); + assertEquals(mockClient.getAccessCount(), 2); + } + public void testInvalidDbGetAllTAbles() { - assertFalse(metastore.getAllTables(TEST_METASTORE_CONTEXT, BAD_DATABASE).isPresent()); + assertFalse(metastoreWithAllCachesEnabled.getAllTables(TEST_METASTORE_CONTEXT, BAD_DATABASE).isPresent()); } @Test public void testGetTable() { assertEquals(mockClient.getAccessCount(), 0); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 1); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 2); // Test invalidate a specific database which will also invalidate all table caches mapped to that database - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 3); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 3); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 4); // Test invalidate a specific table - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS)); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS)); assertEquals(mockClient.getAccessCount(), 4); - assertNotNull(metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertNotNull(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); assertEquals(mockClient.getAccessCount(), 5); } + @Test + public void testGetTableWithSelectiveCaching() + { + assertEquals(mockClient.getAccessCount(), 0); + assertNotNull(metastoreWithSelectiveCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertEquals(mockClient.getAccessCount(), 1); + assertNotNull(metastoreWithSelectiveCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE)); + assertEquals(mockClient.getAccessCount(), 2); + } + public void testInvalidDbGetTable() { - assertFalse(metastore.getTable(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE).isPresent()); + assertFalse(metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE).isPresent()); assertEquals(stats.getGetTable().getThriftExceptions().getTotalCount(), 0); assertEquals(stats.getGetTable().getTotalFailures().getTotalCount(), 0); @@ -202,33 +255,33 @@ public void testGetPartitionNames() { ImmutableList expectedPartitions = ImmutableList.of(TEST_PARTITION_NAME_WITHOUT_VERSION1, TEST_PARTITION_NAME_WITHOUT_VERSION2); assertEquals(mockClient.getAccessCount(), 0); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 1); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 2); // Test invalidate the database which will also invalidate all linked table and partition caches - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 3); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 3); // Test invalidate a specific table which will also invalidate all linked partition caches - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 4); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 4); // Test invalidate a specific partition - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of("testpartition1")); - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of("testpartition1")); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE).get(), expectedPartitions); assertEquals(mockClient.getAccessCount(), 5); } @@ -236,27 +289,27 @@ public void testGetPartitionNames() public void testInvalidInvalidateCache() { // Test invalidate cache with null/empty database name - assertThatThrownBy(() -> metastore.invalidateCache(TEST_METASTORE_CONTEXT, null)) + assertThatThrownBy(() -> metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("databaseName cannot be null or empty"); // Test invalidate cache with null/empty table name - assertThatThrownBy(() -> metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, null)) + assertThatThrownBy(() -> metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("tableName cannot be null or empty"); // Test invalidate cache with invalid/empty partition columns list - assertThatThrownBy(() -> metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(), ImmutableList.of())) + assertThatThrownBy(() -> metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(), ImmutableList.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("partitionColumnNames cannot be null or empty"); // Test invalidate cache with invalid/empty partition values list - assertThatThrownBy(() -> metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of())) + assertThatThrownBy(() -> metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of())) .isInstanceOf(IllegalArgumentException.class) .hasMessage("partitionValues cannot be null or empty"); // Test invalidate cache with mismatched partition columns and values list - assertThatThrownBy(() -> metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key1", "key2"), ImmutableList.of("testpartition1"))) + assertThatThrownBy(() -> metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key1", "key2"), ImmutableList.of("testpartition1"))) .isInstanceOf(IllegalArgumentException.class) .hasMessage("partitionColumnNames and partitionValues should be of same length"); } @@ -264,7 +317,7 @@ public void testInvalidInvalidateCache() @Test public void testInvalidGetPartitionNames() { - assertEquals(metastore.getPartitionNames(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE).get(), ImmutableList.of()); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNames(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE).get(), ImmutableList.of()); } @Test @@ -273,14 +326,14 @@ public void testGetPartitionNamesByParts() ImmutableList expectedPartitions = ImmutableList.of(TEST_PARTITION_NAME_WITHOUT_VERSION1, TEST_PARTITION_NAME_WITHOUT_VERSION2); assertEquals(mockClient.getAccessCount(), 0); - assertEquals(metastore.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); assertEquals(mockClient.getAccessCount(), 1); - assertEquals(metastore.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertEquals(metastore.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); + assertEquals(metastoreWithAllCachesEnabled.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), expectedPartitions); assertEquals(mockClient.getAccessCount(), 2); } @@ -311,18 +364,23 @@ public void testCachingWithPartitionVersioning() ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("partition-versioning-test-%s"))); MockHiveMetastore mockHiveMetastore = new MockHiveMetastore(mockHiveCluster); PartitionMutator mockPartitionMutator = new MockPartitionMutator(identity()); + MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfig.setMetastoreCacheMaximumSize(1000); + metastoreClientConfig.setEnabledCaches(String.join(",", PARTITION.name(), PARTITION_STATISTICS.name())); + InMemoryCachingHiveMetastore partitionCachingEnabledmetastore = new InMemoryCachingHiveMetastore( new BridgingHiveMetastore(mockHiveMetastore, mockPartitionMutator), executor, false, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES), 1000, true, - MetastoreCacheScope.PARTITION, 0.0, 10_000, - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); assertEquals(mockClient.getAccessCount(), 0); assertEquals(partitionCachingEnabledmetastore.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableMap.of()), EXPECTED_PARTITIONS); @@ -361,18 +419,23 @@ private void assertInvalidateCache(MockPartitionMutator partitionMutator, Functi MockHiveCluster mockHiveCluster = new MockHiveCluster(mockClient); ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("partition-versioning-test-%s"))); MockHiveMetastore mockHiveMetastore = new MockHiveMetastore(mockHiveCluster); + MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfig.setMetastoreCacheMaximumSize(1000); + metastoreClientConfig.setEnabledCaches(String.join(",", PARTITION.name(), PARTITION_STATISTICS.name())); + InMemoryCachingHiveMetastore partitionCachingEnabledmetastore = new InMemoryCachingHiveMetastore( new BridgingHiveMetastore(mockHiveMetastore, partitionMutator), executor, false, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES), 1000, true, - MetastoreCacheScope.PARTITION, 0.0, 10_000, - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); int clientAccessCount = 0; for (int i = 0; i < 100; i++) { @@ -388,7 +451,7 @@ private void assertInvalidateCache(MockPartitionMutator partitionMutator, Functi public void testInvalidGetPartitionNamesByParts() { - assertTrue(metastore.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE, ImmutableMap.of()).isEmpty()); + assertTrue(metastoreWithAllCachesEnabled.getPartitionNamesByFilter(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE, ImmutableMap.of()).isEmpty()); } @Test @@ -399,18 +462,23 @@ public void testPartitionCacheValidation() ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("partition-versioning-test-%s"))); MockHiveMetastore mockHiveMetastore = new MockHiveMetastore(mockHiveCluster); PartitionMutator mockPartitionMutator = new MockPartitionMutator(identity()); + MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfig.setMetastoreCacheMaximumSize(1000); + metastoreClientConfig.setEnabledCaches(String.join(",", PARTITION.name(), PARTITION_STATISTICS.name())); + InMemoryCachingHiveMetastore partitionCacheVerificationEnabledMetastore = new InMemoryCachingHiveMetastore( new BridgingHiveMetastore(mockHiveMetastore, mockPartitionMutator), executor, false, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES), 1000, true, - MetastoreCacheScope.PARTITION, 100.0, 10_000, - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); // Warmup the cache partitionCacheVerificationEnabledMetastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)); @@ -430,19 +498,24 @@ public void testPartitionCacheColumnCountLimit() ListeningExecutorService executor = listeningDecorator(newCachedThreadPool(daemonThreadsNamed("partition-versioning-test-%s"))); MockHiveMetastore mockHiveMetastore = new MockHiveMetastore(mockHiveCluster); PartitionMutator mockPartitionMutator = new MockPartitionMutator(identity()); + MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); + // Configure Metastore Cache + metastoreClientConfig.setDefaultMetastoreCacheTtl(new Duration(5, TimeUnit.MINUTES)); + metastoreClientConfig.setDefaultMetastoreCacheRefreshInterval(new Duration(1, TimeUnit.MINUTES)); + metastoreClientConfig.setMetastoreCacheMaximumSize(1000); + metastoreClientConfig.setEnabledCaches(String.join(",", PARTITION.name(), PARTITION_STATISTICS.name())); + InMemoryCachingHiveMetastore partitionCachingEnabledMetastore = new InMemoryCachingHiveMetastore( new BridgingHiveMetastore(mockHiveMetastore, mockPartitionMutator), executor, false, - new Duration(5, TimeUnit.MINUTES), - new Duration(1, TimeUnit.MINUTES), 1000, true, - MetastoreCacheScope.PARTITION, 0.0, // set the cached partition column count limit as 1 for testing purpose 1, - NOOP_METASTORE_CACHE_STATS); + NOOP_METASTORE_CACHE_STATS, + new MetastoreCacheSpecProvider(metastoreClientConfig)); // Select all of the available partitions. Normally they would have been loaded into the cache. But because of column count limit, they will not be cached assertEquals(partitionCachingEnabledMetastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); @@ -461,43 +534,43 @@ public void testPartitionCacheColumnCountLimit() public void testGetPartitionsByNames() { assertEquals(mockClient.getAccessCount(), 0); - metastore.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); + metastoreWithAllCachesEnabled.getTable(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); assertEquals(mockClient.getAccessCount(), 1); // Select half of the available partitions and load them into the cache - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); assertEquals(mockClient.getAccessCount(), 2); // Now select all of the partitions - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); // There should be one more access to fetch the remaining partition assertEquals(mockClient.getAccessCount(), 3); // Now if we fetch any or both of them, they should not hit the client - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION2)).size(), 1); - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION2)).size(), 1); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); assertEquals(mockClient.getAccessCount(), 3); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); // Fetching both should only result in one batched access - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); assertEquals(mockClient.getAccessCount(), 4); // Test invalidate a specific partition - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of("testpartition1")); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of("key"), ImmutableList.of("testpartition1")); // This should still be a cache hit - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION2)).size(), 1); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION2)).size(), 1); assertEquals(mockClient.getAccessCount(), 4); // This should be a cache miss - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)).size(), 1); assertEquals(mockClient.getAccessCount(), 5); // This should be a cache hit - assertEquals(metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); + assertEquals(metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1, TEST_PARTITION_NAME_WITH_VERSION2)).size(), 2); assertEquals(mockClient.getAccessCount(), 5); } @@ -507,31 +580,31 @@ public void testListRoles() { assertEquals(mockClient.getAccessCount(), 0); - assertEquals(metastore.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); + assertEquals(metastoreWithAllCachesEnabled.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); assertEquals(mockClient.getAccessCount(), 1); - assertEquals(metastore.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); + assertEquals(metastoreWithAllCachesEnabled.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); assertEquals(mockClient.getAccessCount(), 1); - metastore.invalidateAll(); + metastoreWithAllCachesEnabled.invalidateAll(); - assertEquals(metastore.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); + assertEquals(metastoreWithAllCachesEnabled.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); assertEquals(mockClient.getAccessCount(), 2); - metastore.createRole(TEST_METASTORE_CONTEXT, "role", "grantor"); + metastoreWithAllCachesEnabled.createRole(TEST_METASTORE_CONTEXT, "role", "grantor"); - assertEquals(metastore.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); + assertEquals(metastoreWithAllCachesEnabled.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); assertEquals(mockClient.getAccessCount(), 3); - metastore.dropRole(TEST_METASTORE_CONTEXT, "testrole"); + metastoreWithAllCachesEnabled.dropRole(TEST_METASTORE_CONTEXT, "testrole"); - assertEquals(metastore.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); + assertEquals(metastoreWithAllCachesEnabled.listRoles(TEST_METASTORE_CONTEXT), TEST_ROLES); assertEquals(mockClient.getAccessCount(), 4); } public void testInvalidGetPartitionsByNames() { - Map> partitionsByNames = metastore.getPartitionsByNames(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)); + Map> partitionsByNames = metastoreWithAllCachesEnabled.getPartitionsByNames(TEST_METASTORE_CONTEXT, BAD_DATABASE, TEST_TABLE, ImmutableList.of(TEST_PARTITION_NAME_WITH_VERSION1)); assertEquals(partitionsByNames.size(), 1); Optional onlyElement = Iterables.getOnlyElement(partitionsByNames.values()); assertFalse(onlyElement.isPresent()); @@ -543,7 +616,7 @@ public void testNoCacheExceptions() // Throw exceptions on usage mockClient.setThrowException(true); try { - metastore.getAllDatabases(TEST_METASTORE_CONTEXT); + metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT); } catch (RuntimeException ignored) { } @@ -551,7 +624,7 @@ public void testNoCacheExceptions() // Second try should hit the client again try { - metastore.getAllDatabases(TEST_METASTORE_CONTEXT); + metastoreWithAllCachesEnabled.getAllDatabases(TEST_METASTORE_CONTEXT); } catch (RuntimeException ignored) { } @@ -562,25 +635,25 @@ public void testNoCacheExceptions() public void testTableConstraints() { assertEquals(mockClient.getAccessCount(), 0); - List> tableConstraints = metastore.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + List> tableConstraints = metastoreWithAllCachesEnabled.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); assertEquals(tableConstraints.get(0), new PrimaryKeyConstraint<>(Optional.of("pk"), new LinkedHashSet<>(ImmutableList.of("c1")), true, true, false)); assertEquals(tableConstraints.get(1), new UniqueConstraint<>(Optional.of("uk"), new LinkedHashSet<>(ImmutableList.of("c2")), true, true, false)); assertEquals(tableConstraints.get(2), new NotNullConstraint<>("c3")); assertEquals(mockClient.getAccessCount(), 3); - metastore.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + metastoreWithAllCachesEnabled.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); assertEquals(mockClient.getAccessCount(), 3); - metastore.invalidateAll(); - metastore.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + metastoreWithAllCachesEnabled.invalidateAll(); + metastoreWithAllCachesEnabled.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); assertEquals(mockClient.getAccessCount(), 6); // Test invalidate TEST_TABLE, which should not affect any entries linked to TEST_TABLE_WITH_CONSTRAINTS - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); - metastore.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE); + metastoreWithAllCachesEnabled.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); assertEquals(mockClient.getAccessCount(), 6); // Test invalidate TEST_TABLE_WITH_CONSTRAINTS - metastore.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); - metastore.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + metastoreWithAllCachesEnabled.invalidateCache(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); + metastoreWithAllCachesEnabled.getTableConstraints(TEST_METASTORE_CONTEXT, TEST_DATABASE, TEST_TABLE_WITH_CONSTRAINTS); assertEquals(mockClient.getAccessCount(), 9); } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueInputConverter.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueInputConverter.java index ef5d5213bd9f2..967908b8c83c8 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueInputConverter.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestGlueInputConverter.java @@ -13,10 +13,6 @@ */ package com.facebook.presto.hive.metastore.glue; -import com.amazonaws.services.glue.model.DatabaseInput; -import com.amazonaws.services.glue.model.PartitionInput; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.TableInput; import com.facebook.presto.hive.HiveBucketProperty; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Database; @@ -26,6 +22,10 @@ import com.facebook.presto.hive.metastore.glue.converter.GlueInputConverter; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; +import software.amazon.awssdk.services.glue.model.DatabaseInput; +import software.amazon.awssdk.services.glue.model.PartitionInput; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.TableInput; import java.util.List; @@ -46,10 +46,10 @@ public void testConvertDatabase() { DatabaseInput dbInput = GlueInputConverter.convertDatabase(testDb); - assertEquals(dbInput.getName(), testDb.getDatabaseName()); - assertEquals(dbInput.getDescription(), testDb.getComment().get()); - assertEquals(dbInput.getLocationUri(), testDb.getLocation().get()); - assertEquals(dbInput.getParameters(), testDb.getParameters()); + assertEquals(dbInput.name(), testDb.getDatabaseName()); + assertEquals(dbInput.description(), testDb.getComment().get()); + assertEquals(dbInput.locationUri(), testDb.getLocation().get()); + assertEquals(dbInput.parameters(), testDb.getParameters()); } @Test @@ -57,15 +57,15 @@ public void testConvertTable() { TableInput tblInput = GlueInputConverter.convertTable(testTbl); - assertEquals(tblInput.getName(), testTbl.getTableName()); - assertEquals(tblInput.getOwner(), testTbl.getOwner()); - assertEquals(tblInput.getTableType(), testTbl.getTableType().toString()); - assertEquals(tblInput.getParameters(), testTbl.getParameters()); - assertColumnList(tblInput.getStorageDescriptor().getColumns(), testTbl.getDataColumns()); - assertColumnList(tblInput.getPartitionKeys(), testTbl.getPartitionColumns()); - assertStorage(tblInput.getStorageDescriptor(), testTbl.getStorage()); - assertEquals(tblInput.getViewExpandedText(), testTbl.getViewExpandedText().get()); - assertEquals(tblInput.getViewOriginalText(), testTbl.getViewOriginalText().get()); + assertEquals(tblInput.name(), testTbl.getTableName()); + assertEquals(tblInput.owner(), testTbl.getOwner()); + assertEquals(tblInput.tableType(), testTbl.getTableType().toString()); + assertEquals(tblInput.parameters(), testTbl.getParameters()); + assertColumnList(tblInput.storageDescriptor().columns(), testTbl.getDataColumns()); + assertColumnList(tblInput.partitionKeys(), testTbl.getPartitionColumns()); + assertStorage(tblInput.storageDescriptor(), testTbl.getStorage()); + assertEquals(tblInput.viewExpandedText(), testTbl.getViewExpandedText().get()); + assertEquals(tblInput.viewOriginalText(), testTbl.getViewOriginalText().get()); } @Test @@ -73,12 +73,12 @@ public void testConvertPartition() { PartitionInput partitionInput = GlueInputConverter.convertPartition(testPartition); - assertEquals(partitionInput.getParameters(), testPartition.getParameters()); - assertStorage(partitionInput.getStorageDescriptor(), testPartition.getStorage()); - assertEquals(partitionInput.getValues(), testPartition.getValues()); + assertEquals(partitionInput.parameters(), testPartition.getParameters()); + assertStorage(partitionInput.storageDescriptor(), testPartition.getStorage()); + assertEquals(partitionInput.values(), testPartition.getValues()); } - private static void assertColumnList(List actual, List expected) + private static void assertColumnList(List actual, List expected) { if (expected == null) { assertNull(actual); @@ -90,24 +90,24 @@ private static void assertColumnList(List) null).build(); + com.facebook.presto.hive.metastore.Table prestoTbl = GlueToPrestoConverter.convertTable(testTbl, testDb.name()); assertTrue(prestoTbl.getPartitionColumns().isEmpty()); } @Test public void testConvertTableUppercaseColumnType() { - com.amazonaws.services.glue.model.Column uppercaseCol = getGlueTestColumn().withType("String"); - testTbl.getStorageDescriptor().setColumns(ImmutableList.of(uppercaseCol)); - GlueToPrestoConverter.convertTable(testTbl, testDb.getName()); + software.amazon.awssdk.services.glue.model.Column uppercaseCol = getGlueTestColumn().toBuilder().type("String").build(); + + StorageDescriptor sd = testTbl.storageDescriptor(); + testTbl = testTbl.toBuilder().storageDescriptor(sd.toBuilder().columns(ImmutableList.of(uppercaseCol)).build()).build(); + GlueToPrestoConverter.convertTable(testTbl, testDb.name()); } @Test public void testConvertPartition() { - GluePartitionConverter converter = new GluePartitionConverter(testPartition.getDatabaseName(), testPartition.getTableName()); + GluePartitionConverter converter = new GluePartitionConverter(testPartition.databaseName(), testPartition.tableName()); com.facebook.presto.hive.metastore.Partition prestoPartition = converter.apply(testPartition); - assertEquals(prestoPartition.getDatabaseName(), testPartition.getDatabaseName()); - assertEquals(prestoPartition.getTableName(), testPartition.getTableName()); - assertColumnList(prestoPartition.getColumns(), testPartition.getStorageDescriptor().getColumns()); - assertEquals(prestoPartition.getValues(), testPartition.getValues()); - assertStorage(prestoPartition.getStorage(), testPartition.getStorageDescriptor()); - assertEquals(prestoPartition.getParameters(), testPartition.getParameters()); + assertEquals(prestoPartition.getDatabaseName(), testPartition.databaseName()); + assertEquals(prestoPartition.getTableName(), testPartition.tableName()); + assertColumnList(prestoPartition.getColumns(), testPartition.storageDescriptor().columns()); + assertEquals(prestoPartition.getValues(), testPartition.values()); + assertStorage(prestoPartition.getStorage(), testPartition.storageDescriptor()); + assertEquals(prestoPartition.getParameters(), testPartition.parameters()); } @Test public void testPartitionConversionMemoization() { String fakeS3Location = "s3://some-fake-location"; - testPartition.getStorageDescriptor().setLocation(fakeS3Location); + + StorageDescriptor sdPartition = testPartition.storageDescriptor(); + testPartition = testPartition.toBuilder().storageDescriptor(sdPartition.toBuilder().location(fakeS3Location).build()).build(); + // Second partition to convert with equal (but not aliased) values - Partition partitionTwo = getGlueTestPartition(testPartition.getDatabaseName(), testPartition.getTableName(), new ArrayList<>(testPartition.getValues())); + Partition partitionTwo = getGlueTestPartition(testPartition.databaseName(), testPartition.tableName(), new ArrayList<>(testPartition.values())); // Ensure storage fields are equal but not aliased as well - partitionTwo.getStorageDescriptor().setColumns(new ArrayList<>(testPartition.getStorageDescriptor().getColumns())); - partitionTwo.getStorageDescriptor().setBucketColumns(new ArrayList<>(testPartition.getStorageDescriptor().getBucketColumns())); - partitionTwo.getStorageDescriptor().setLocation("" + fakeS3Location); - partitionTwo.getStorageDescriptor().setInputFormat("" + testPartition.getStorageDescriptor().getInputFormat()); - partitionTwo.getStorageDescriptor().setOutputFormat("" + testPartition.getStorageDescriptor().getOutputFormat()); - partitionTwo.getStorageDescriptor().setParameters(new HashMap<>(testPartition.getStorageDescriptor().getParameters())); - - GluePartitionConverter converter = new GluePartitionConverter(testDb.getName(), testTbl.getName()); + StorageDescriptor sdPartitionTwo = partitionTwo.storageDescriptor(); + partitionTwo = partitionTwo.toBuilder().storageDescriptor( + sdPartitionTwo.toBuilder() + .columns(new ArrayList<>(testPartition.storageDescriptor().columns())) + .bucketColumns(new ArrayList<>(testPartition.storageDescriptor().bucketColumns())) + .location("" + fakeS3Location) + .inputFormat("" + testPartition.storageDescriptor().inputFormat()) + .outputFormat("" + testPartition.storageDescriptor().outputFormat()) + .parameters(new HashMap<>(testPartition.storageDescriptor().parameters())) + .build()).build(); + + GluePartitionConverter converter = new GluePartitionConverter(testDb.name(), testTbl.name()); com.facebook.presto.hive.metastore.Partition prestoPartition = converter.apply(testPartition); com.facebook.presto.hive.metastore.Partition prestoPartition2 = converter.apply(partitionTwo); @@ -161,16 +172,20 @@ public void testPartitionConversionMemoization() @Test public void testDatabaseNullParameters() { - testDb.setParameters(null); + testDb = testDb.toBuilder().parameters(null).build(); assertNotNull(GlueToPrestoConverter.convertDatabase(testDb).getParameters()); } @Test public void testTableNullParameters() { - testTbl.setParameters(null); - testTbl.getStorageDescriptor().getSerdeInfo().setParameters(null); - com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.getName()); + StorageDescriptor sd = testTbl.storageDescriptor(); + SerDeInfo serDeInfo = sd.serdeInfo(); + testTbl = testTbl.toBuilder() + .parameters(null) + .storageDescriptor(sd.toBuilder().serdeInfo(serDeInfo.toBuilder().parameters(null).build()).build()) + .build(); + com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.name()); assertNotNull(prestoTable.getParameters()); assertNotNull(prestoTable.getStorage().getSerdeParameters()); } @@ -178,38 +193,37 @@ public void testTableNullParameters() @Test public void testPartitionNullParameters() { - testPartition.setParameters(null); - assertNotNull(new GluePartitionConverter(testDb.getName(), testTbl.getName()).apply(testPartition).getParameters()); + testPartition = testPartition.toBuilder().parameters(null).build(); + assertNotNull(new GluePartitionConverter(testDb.name(), testTbl.name()).apply(testPartition).getParameters()); } @Test public void testConvertTableWithoutTableType() { - Table table = getGlueTestTable(testDb.getName()); - table.setTableType(null); - com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(table, testDb.getName()); + Table table = getGlueTestTable(testDb.name()).toBuilder().tableType(null).build(); + com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(table, testDb.name()); assertEquals(prestoTable.getTableType(), EXTERNAL_TABLE); } @Test public void testIcebergTableNonNullStorageDescriptor() { - testTbl.setParameters(ImmutableMap.of(ICEBERG_TABLE_TYPE_NAME, ICEBERG_TABLE_TYPE_VALUE)); - assertNotNull(testTbl.getStorageDescriptor()); - com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.getName()); + testTbl = testTbl.toBuilder().parameters(ImmutableMap.of(ICEBERG_TABLE_TYPE_NAME, ICEBERG_TABLE_TYPE_VALUE)).build(); + assertNotNull(testTbl.storageDescriptor()); + com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.name()); assertEquals(prestoTable.getDataColumns().size(), 1); } @Test public void testDeltaTableNonNullStorageDescriptor() { - testTbl.setParameters(ImmutableMap.of(SPARK_TABLE_PROVIDER_KEY, DELTA_LAKE_PROVIDER)); - assertNotNull(testTbl.getStorageDescriptor()); - com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.getName()); + testTbl = testTbl.toBuilder().parameters(ImmutableMap.of(SPARK_TABLE_PROVIDER_KEY, DELTA_LAKE_PROVIDER)).build(); + assertNotNull(testTbl.storageDescriptor()); + com.facebook.presto.hive.metastore.Table prestoTable = GlueToPrestoConverter.convertTable(testTbl, testDb.name()); assertEquals(prestoTable.getDataColumns().size(), 1); } - private static void assertColumnList(List actual, List expected) + private static void assertColumnList(List actual, List expected) { if (expected == null) { assertNull(actual); @@ -221,23 +235,23 @@ private static void assertColumnList(List actual, List getMetastoreClient().getTable(metastoreContext, table.getSchemaName(), table.getTableName())) .hasMessageStartingWith("Table StorageDescriptor is null for table"); - glueClient.deleteTable(deleteTableRequest); + awsSyncRequest(glueClient::deleteTable, deleteTableRequest, null); // Iceberg table - tableInput = tableInput.withParameters(ImmutableMap.of(ICEBERG_TABLE_TYPE_NAME, ICEBERG_TABLE_TYPE_VALUE)); - glueClient.createTable(new CreateTableRequest() - .withDatabaseName(database) - .withTableInput(tableInput)); + tableInput = tableInput.toBuilder().parameters(ImmutableMap.of(ICEBERG_TABLE_TYPE_NAME, ICEBERG_TABLE_TYPE_VALUE)).build(); + awsSyncRequest( + glueClient::createTable, + CreateTableRequest.builder() + .databaseName(database) + .tableInput(tableInput) + .build(), + null); assertTrue(isIcebergTable(getMetastoreClient().getTable(metastoreContext, table.getSchemaName(), table.getTableName()).orElseThrow(() -> new NoSuchElementException()))); - glueClient.deleteTable(deleteTableRequest); + awsSyncRequest(glueClient::deleteTable, deleteTableRequest, null); // Delta Lake table - tableInput = tableInput.withParameters(ImmutableMap.of(SPARK_TABLE_PROVIDER_KEY, DELTA_LAKE_PROVIDER)); - glueClient.createTable(new CreateTableRequest() - .withDatabaseName(database) - .withTableInput(tableInput)); + tableInput = tableInput.toBuilder().parameters(ImmutableMap.of(SPARK_TABLE_PROVIDER_KEY, DELTA_LAKE_PROVIDER)).build(); + awsSyncRequest( + glueClient::createTable, + CreateTableRequest.builder() + .databaseName(database) + .tableInput(tableInput) + .build(), + null); assertTrue(isDeltaLakeTable(getMetastoreClient().getTable(metastoreContext, table.getSchemaName(), table.getTableName()).orElseThrow(() -> new NoSuchElementException()))); } finally { // Table cannot be dropped through HiveMetastore since a TableHandle cannot be created - glueClient.deleteTable(new DeleteTableRequest() - .withDatabaseName(table.getSchemaName()) - .withName(table.getTableName())); + awsSyncRequest(glueClient::deleteTable, deleteTableRequest, null); } } @@ -351,12 +366,16 @@ public void testGetPartitionsWithFilterUsingReservedKeywordsAsColumnName() .addBigintValues(regularColumnPartitionName, 2L) .build(); - List partitionNames = metastoreClient.getPartitionNamesByFilter( + List partitionNamesWithVersion = metastoreClient.getPartitionNamesByFilter( METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName(), predicates); + List partitionNames = partitionNamesWithVersion.stream() + .map(PartitionNameWithVersion::getPartitionName) + .collect(toImmutableList()); + assertFalse(partitionNames.isEmpty()); assertEquals(partitionNames, ImmutableList.of("key=value2/int_partition=2")); @@ -366,11 +385,16 @@ public void testGetPartitionsWithFilterUsingReservedKeywordsAsColumnName() .addStringValues(reservedKeywordPartitionColumnName, "value1") .build(); - partitionNames = metastoreClient.getPartitionNamesByFilter( + partitionNamesWithVersion = metastoreClient.getPartitionNamesByFilter( METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName(), predicates); + + partitionNames = partitionNamesWithVersion.stream() + .map(PartitionNameWithVersion::getPartitionName) + .collect(toImmutableList()); + assertFalse(partitionNames.isEmpty()); assertEquals(partitionNames, ImmutableList.of("key=value1/int_partition=1", "key=value2/int_partition=2")); } @@ -898,11 +922,16 @@ private void doGetPartitionsFilterTest( .map(expectedPartitionValues -> makePartName(partitionColumnNames, expectedPartitionValues.getValues())) .collect(toImmutableList()); - List partitionNames = metastoreClient.getPartitionNamesByFilter( + List partitionNamesWithVersion = metastoreClient.getPartitionNamesByFilter( METASTORE_CONTEXT, tableName.getSchemaName(), tableName.getTableName(), filter); + + List partitionNames = partitionNamesWithVersion.stream() + .map(PartitionNameWithVersion::getPartitionName) + .collect(toImmutableList()); + assertEquals( partitionNames, expectedResults, diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestingMetastoreObjects.java b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestingMetastoreObjects.java index be540ec76279a..c18057870b95f 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestingMetastoreObjects.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/metastore/glue/TestingMetastoreObjects.java @@ -13,18 +13,18 @@ */ package com.facebook.presto.hive.metastore.glue; -import com.amazonaws.services.glue.model.Column; -import com.amazonaws.services.glue.model.Database; -import com.amazonaws.services.glue.model.Partition; -import com.amazonaws.services.glue.model.SerDeInfo; -import com.amazonaws.services.glue.model.StorageDescriptor; -import com.amazonaws.services.glue.model.Table; import com.facebook.presto.hive.HiveType; import com.facebook.presto.hive.metastore.Storage; import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.spi.security.PrincipalType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import software.amazon.awssdk.services.glue.model.Column; +import software.amazon.awssdk.services.glue.model.Database; +import software.amazon.awssdk.services.glue.model.Partition; +import software.amazon.awssdk.services.glue.model.SerDeInfo; +import software.amazon.awssdk.services.glue.model.StorageDescriptor; +import software.amazon.awssdk.services.glue.model.Table; import java.util.List; import java.util.Optional; @@ -41,58 +41,64 @@ private TestingMetastoreObjects() {} public static Database getGlueTestDatabase() { - return new Database() - .withName("test-db" + generateRandom()) - .withDescription("database desc") - .withLocationUri("/db") - .withParameters(ImmutableMap.of()); + return Database.builder() + .name("test-db" + generateRandom()) + .description("database desc") + .locationUri("/db") + .parameters(ImmutableMap.of()) + .build(); } public static Table getGlueTestTable(String dbName) { - return new Table() - .withDatabaseName(dbName) - .withName("test-tbl" + generateRandom()) - .withOwner("owner") - .withParameters(ImmutableMap.of()) - .withPartitionKeys(ImmutableList.of(getGlueTestColumn())) - .withStorageDescriptor(getGlueTestStorageDescriptor()) - .withTableType(EXTERNAL_TABLE.name()) - .withViewOriginalText("originalText") - .withViewExpandedText("expandedText"); + return Table.builder() + .databaseName(dbName) + .name("test-tbl" + generateRandom()) + .owner("owner") + .parameters(ImmutableMap.of()) + .partitionKeys(ImmutableList.of(getGlueTestColumn())) + .storageDescriptor(getGlueTestStorageDescriptor()) + .tableType(EXTERNAL_TABLE.name()) + .viewOriginalText("originalText") + .viewExpandedText("expandedText") + .build(); } public static Column getGlueTestColumn() { - return new Column() - .withName("test-col" + generateRandom()) - .withType("string") - .withComment("column comment"); + return Column.builder() + .name("test-col" + generateRandom()) + .type("string") + .comment("column comment") + .build(); } public static StorageDescriptor getGlueTestStorageDescriptor() { - return new StorageDescriptor() - .withBucketColumns(ImmutableList.of("test-bucket-col")) - .withColumns(ImmutableList.of(getGlueTestColumn())) - .withParameters(ImmutableMap.of()) - .withSerdeInfo(new SerDeInfo() - .withSerializationLibrary("SerdeLib") - .withParameters(ImmutableMap.of())) - .withInputFormat("InputFormat") - .withOutputFormat("OutputFormat") - .withLocation("/test-tbl") - .withNumberOfBuckets(1); + return StorageDescriptor.builder() + .bucketColumns(ImmutableList.of("test-bucket-col")) + .columns(ImmutableList.of(getGlueTestColumn())) + .parameters(ImmutableMap.of()) + .serdeInfo(SerDeInfo.builder() + .serializationLibrary("SerdeLib") + .parameters(ImmutableMap.of()) + .build()) + .inputFormat("InputFormat") + .outputFormat("OutputFormat") + .location("/test-tbl") + .numberOfBuckets(1) + .build(); } public static Partition getGlueTestPartition(String dbName, String tblName, List values) { - return new Partition() - .withDatabaseName(dbName) - .withTableName(tblName) - .withValues(values) - .withParameters(ImmutableMap.of()) - .withStorageDescriptor(getGlueTestStorageDescriptor()); + return Partition.builder() + .databaseName(dbName) + .tableName(tblName) + .values(values) + .parameters(ImmutableMap.of()) + .storageDescriptor(getGlueTestStorageDescriptor()) + .build(); } // --------------- Presto Objects --------------- diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java index e8c88af69a7df..7b06cb102e9ae 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/AbstractTestParquetReader.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.hive.parquet; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.SqlDate; @@ -35,8 +37,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Range; import com.google.common.primitives.Shorts; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector; @@ -72,6 +72,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; @@ -102,7 +103,6 @@ import static com.google.common.collect.Iterables.cycle; import static com.google.common.collect.Iterables.limit; import static com.google.common.collect.Iterables.transform; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Arrays.asList; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/BenchmarkParquetPageSource.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/BenchmarkParquetPageSource.java index 79fe3806192da..e1a2b0b852417 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/BenchmarkParquetPageSource.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/BenchmarkParquetPageSource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.parquet; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.RuntimeStats; @@ -33,7 +34,7 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.testing.TestingSession; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; +import com.google.errorprone.annotations.Immutable; import org.apache.parquet.column.ParquetProperties; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.hadoop.metadata.ParquetMetadata; @@ -58,8 +59,6 @@ import org.openjdk.jmh.runner.options.VerboseMode; import org.testng.annotations.Test; -import javax.annotation.concurrent.Immutable; - import java.io.File; import java.io.IOException; import java.util.ArrayList; @@ -70,6 +69,7 @@ import java.util.Random; import java.util.stream.IntStream; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL; import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL; @@ -95,7 +95,6 @@ import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java index 53236fab53134..acd6d69e2aa12 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/ParquetTester.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.parquet; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -58,7 +59,6 @@ import com.google.common.collect.Lists; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.Path; import org.apache.hadoop.hive.ql.exec.FileSinkOperator.RecordWriter; import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; @@ -97,6 +97,7 @@ import java.util.Properties; import java.util.Set; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.Chars.truncateToLengthAndTrimSpaces; @@ -128,7 +129,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.Iterables.transform; import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Math.toIntExact; import static java.util.Arrays.stream; import static java.util.Collections.singletonList; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/TestParquetReaderMemoryTracking.java b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/TestParquetReaderMemoryTracking.java index e8932c3db63c2..b55c835a693f4 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/parquet/TestParquetReaderMemoryTracking.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/parquet/TestParquetReaderMemoryTracking.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.hive.parquet; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.Type; import com.facebook.presto.parquet.Field; @@ -21,7 +22,6 @@ import com.facebook.presto.parquet.reader.ParquetReader; import com.facebook.presto.parquet.writer.ParquetWriterOptions; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.parquet.hadoop.metadata.CompressionCodecName; import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.io.ColumnIOConverter; @@ -38,6 +38,7 @@ import java.util.concurrent.ThreadLocalRandom; import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.hive.parquet.ParquetTester.writeParquetFileFromPresto; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -45,7 +46,6 @@ import static com.google.common.io.Files.createTempDir; import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.UUID.randomUUID; import static org.apache.parquet.hadoop.metadata.CompressionCodecName.UNCOMPRESSED; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/s3/TestHiveS3Config.java b/presto-hive/src/test/java/com/facebook/presto/hive/s3/TestHiveS3Config.java index 0dc652710bc42..e042a36cd5af8 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/s3/TestHiveS3Config.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/s3/TestHiveS3Config.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.hive.s3; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; +import com.facebook.airlift.units.Duration; import com.google.common.base.StandardSystemProperty; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.io.File; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControl.java b/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControl.java index 8558f69cf9057..7f2604debfa4d 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControl.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControl.java @@ -19,6 +19,7 @@ import com.facebook.airlift.http.client.testing.TestingResponse; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; +import com.facebook.presto.hive.security.SecurityConfig; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.WarningCollector; @@ -75,11 +76,33 @@ public void testDefaultAccessAllowedNotChecked() accessControl.checkCanShowSchemas(TRANSACTION_HANDLE, user("anyuser"), CONTEXT); } + @Test + public void testDefaultProcedureCallNotAllowed() + { + // `restrictProcedureCall` default to `true` + ConnectorAccessControl accessControl = createRangerAccessControl("default-allow-all.json", "user_groups.json"); + assertDenied(() -> accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("admin"), CONTEXT, new SchemaTableName("system", "procedure1"))); + assertDenied(() -> accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("system", "procedure1"))); + assertDenied(() -> accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("system", "procedure1"))); + assertDenied(() -> accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("anyuser"), CONTEXT, new SchemaTableName("system", "procedure1"))); + } + + @Test + public void testProcedureCallAllowedWithSpecificConfiguration() + { + ConnectorAccessControl accessControl = createRangerAccessControl("default-allow-all.json", "user_groups.json", false); + accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("admin"), CONTEXT, new SchemaTableName("system", "procedure1")); + accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("system", "procedure1")); + accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("system", "procedure1")); + accessControl.checkCanCallProcedure(TRANSACTION_HANDLE, user("anyuser"), CONTEXT, new SchemaTableName("system", "procedure1")); + } + @Test public void testDefaultTableAccessIfNotDefined() { ConnectorAccessControl accessControl = createRangerAccessControl("default-allow-all.json", "user_groups.json"); accessControl.checkCanCreateTable(TRANSACTION_HANDLE, user("admin"), CONTEXT, new SchemaTableName("test", "test")); + accessControl.checkCanShowCreateTable(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("bobschema", "bobtable")); accessControl.checkCanSelectFromColumns(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("test", "test"), ImmutableSet.of()); accessControl.checkCanSelectFromColumns(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("bobschema", "bobtable"), ImmutableSet.of()); accessControl.checkCanRenameTable(TRANSACTION_HANDLE, user("admin"), CONTEXT, new SchemaTableName("test", "test"), new SchemaTableName("test1", "test1")); @@ -93,6 +116,7 @@ public void testTableOperations() { ConnectorAccessControl accessControl = createRangerAccessControl("default-schema-level-access.json", "user_groups.json"); // 'etladmin' group have all access {group - etladmin, user - alice} + accessControl.checkCanShowCreateTable(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("foodmart", "test")); accessControl.checkCanCreateTable(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("foodmart", "test")); accessControl.checkCanRenameTable(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("foodmart", "test"), new SchemaTableName("foodmart", "test1")); accessControl.checkCanDropTable(TRANSACTION_HANDLE, user("alice"), CONTEXT, new SchemaTableName("foodmart", "test")); @@ -111,6 +135,7 @@ public void testTableOperations() assertDenied(() -> accessControl.checkCanRenameColumn(TRANSACTION_HANDLE, user("joe"), CONTEXT, new SchemaTableName("foodmart", "test"))); // Access denied to others {group - readall, user - bob} + assertDenied(() -> accessControl.checkCanShowCreateTable(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("foodmart", "test"))); assertDenied(() -> accessControl.checkCanSelectFromColumns(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("foodmart", "test"), ImmutableSet.of(new Subfield("column1")))); assertDenied(() -> accessControl.checkCanCreateTable(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("foodmart", "test"))); assertDenied(() -> accessControl.checkCanRenameTable(TRANSACTION_HANDLE, user("bob"), CONTEXT, new SchemaTableName("foodmart", "test"), new SchemaTableName("foodmart", "test1"))); @@ -180,6 +205,11 @@ private static ConnectorIdentity user(String name) } private ConnectorAccessControl createRangerAccessControl(String policyFile, String usersFile) + { + return createRangerAccessControl(policyFile, usersFile, new SecurityConfig().isRestrictProcedureCall()); + } + + private ConnectorAccessControl createRangerAccessControl(String policyFile, String usersFile, boolean restrictProcedureCall) { String policyFilePath = "com.facebook.presto.hive.security.ranger/" + policyFile; String usersFilePath = "com.facebook.presto.hive.security.ranger/" + usersFile; @@ -212,7 +242,9 @@ else if (uriPath.contains("sam")) { RangerBasedAccessControlConfig config = new RangerBasedAccessControlConfig() .setRangerHttpEndPoint("http://test") .setRangerHiveServiceName("dummy"); - RangerBasedAccessControl rangerBasedAccessControl = new RangerBasedAccessControl(config, httpClient); + RangerBasedAccessControl rangerBasedAccessControl = new RangerBasedAccessControl(config, + new SecurityConfig().setRestrictProcedureCall(restrictProcedureCall), + httpClient); return rangerBasedAccessControl; } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControlConfig.java b/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControlConfig.java index c1903b146a8a0..9ef96f279cadc 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControlConfig.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/security/ranger/TestRangerBasedAccessControlConfig.java @@ -15,9 +15,9 @@ import com.facebook.airlift.configuration.ConfigurationFactory; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; import com.google.inject.ConfigurationException; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -87,7 +87,7 @@ public void testValidation() RANGER_REST_USER_GROUPS_AUTH_USERNAME, "admin", RANGER_REST_USER_GROUPS_AUTH_PASSWORD, "admin"))) .isInstanceOf(ConfigurationException.class) - .hasMessageContaining("Invalid configuration property hive.ranger.rest-endpoint: may not be null"); + .hasMessageContaining("Invalid configuration property hive.ranger.rest-endpoint: must not be null"); assertThatThrownBy(() -> newInstance(ImmutableMap.of( RANGER_POLICY_REFRESH_PERIOD, "120s", @@ -95,7 +95,7 @@ public void testValidation() RANGER_REST_USER_GROUPS_AUTH_USERNAME, "admin", RANGER_REST_USER_GROUPS_AUTH_PASSWORD, "admin"))) .isInstanceOf(ConfigurationException.class) - .hasMessageContaining("Invalid configuration property hive.ranger.policy.hive-servicename: may not be null"); + .hasMessageContaining("Invalid configuration property hive.ranger.policy.hive-servicename: must not be null"); } private static RangerBasedAccessControlConfig newInstance(Map properties) diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestParquetQuickStatsBuilder.java b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestParquetQuickStatsBuilder.java index c8cb678cfbcd1..5976c61a04f8a 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestParquetQuickStatsBuilder.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestParquetQuickStatsBuilder.java @@ -14,6 +14,7 @@ package com.facebook.presto.hive.statistics; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.FileFormatDataSourceStats; import com.facebook.presto.hive.HdfsConfigurationInitializer; import com.facebook.presto.hive.HdfsContext; @@ -40,7 +41,6 @@ import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.LocatedFileStatus; diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestQuickStatsProvider.java b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestQuickStatsProvider.java index b22d980cb042f..15feedd158ea1 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestQuickStatsProvider.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/statistics/TestQuickStatsProvider.java @@ -13,12 +13,18 @@ */ package com.facebook.presto.hive.statistics; +import com.facebook.airlift.units.Duration; import com.facebook.presto.hive.DirectoryLister; +import com.facebook.presto.hive.HadoopDirectoryLister; +import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveClientConfig; +import com.facebook.presto.hive.HiveDirectoryContext; +import com.facebook.presto.hive.HiveFileInfo; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.NamenodeStats; import com.facebook.presto.hive.TestingExtendedHiveMetastore; +import com.facebook.presto.hive.filesystem.ExtendedFileSystem; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.metastore.Partition; @@ -26,6 +32,7 @@ import com.facebook.presto.hive.metastore.PartitionWithStatistics; import com.facebook.presto.hive.metastore.PrincipalPrivileges; import com.facebook.presto.hive.metastore.Storage; +import com.facebook.presto.hive.metastore.StorageFormat; import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; @@ -35,12 +42,21 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat; +import org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat; +import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe; +import org.apache.hadoop.mapred.InputFormat; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.net.URL; import java.time.Instant; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -50,19 +66,37 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.hive.HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER; +import static com.facebook.presto.hive.HivePartition.UNPARTITIONED_ID; import static com.facebook.presto.hive.HiveSessionProperties.QUICK_STATS_BACKGROUND_BUILD_TIMEOUT; import static com.facebook.presto.hive.HiveSessionProperties.QUICK_STATS_ENABLED; import static com.facebook.presto.hive.HiveSessionProperties.QUICK_STATS_INLINE_BUILD_TIMEOUT; import static com.facebook.presto.hive.HiveSessionProperties.SKIP_EMPTY_FILES; import static com.facebook.presto.hive.HiveSessionProperties.USE_LIST_DIRECTORY_CACHE; +import static com.facebook.presto.hive.HiveSessionProperties.isSkipEmptyFilesEnabled; +import static com.facebook.presto.hive.HiveSessionProperties.isUseListDirectoryCache; import static com.facebook.presto.hive.HiveStorageFormat.PARQUET; import static com.facebook.presto.hive.HiveTestUtils.createTestHdfsEnvironment; +import static com.facebook.presto.hive.HiveUtil.buildDirectoryContextProperties; +import static com.facebook.presto.hive.HiveUtil.getInputFormat; +import static com.facebook.presto.hive.NestedDirectoryPolicy.RECURSE; import static com.facebook.presto.hive.RetryDriver.retry; import static com.facebook.presto.hive.metastore.PartitionStatistics.empty; +import static com.facebook.presto.hive.metastore.PrestoTableType.EXTERNAL_TABLE; import static com.facebook.presto.hive.metastore.PrestoTableType.MANAGED_TABLE; import static com.facebook.presto.hive.metastore.StorageFormat.fromHiveStorageFormat; import static com.facebook.presto.hive.statistics.PartitionQuickStats.convertToPartitionStatistics; import static com.facebook.presto.spi.session.PropertyMetadata.booleanProperty; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static com.google.common.io.Resources.getResource; +import static java.nio.file.Files.copy; +import static java.nio.file.Files.createDirectory; +import static java.nio.file.Files.createFile; +import static java.nio.file.Files.createTempDirectory; +import static java.nio.file.Files.newBufferedWriter; +import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; +import static java.nio.file.StandardOpenOption.CREATE; +import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import static java.util.Collections.emptyIterator; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; @@ -71,6 +105,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.testcontainers.shaded.com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @@ -119,6 +154,7 @@ public class TestQuickStatsProvider private MetastoreContext metastoreContext; private PartitionQuickStats mockPartitionQuickStats; private PartitionStatistics expectedPartitionStats; + private ColumnQuickStats mockIntegerColumnStats; private static ConnectorSession getSession(String inlineBuildTimeout, String backgroundBuildTimeout) { @@ -186,7 +222,7 @@ public void setUp() MetastoreClientConfig metastoreClientConfig = new MetastoreClientConfig(); hdfsEnvironment = createTestHdfsEnvironment(hiveClientConfig, metastoreClientConfig); - ColumnQuickStats mockIntegerColumnStats = new ColumnQuickStats<>("column", Integer.class); + mockIntegerColumnStats = new ColumnQuickStats<>("column", Integer.class); mockIntegerColumnStats.setMinValue(Integer.MIN_VALUE); mockIntegerColumnStats.setMaxValue(Integer.MAX_VALUE); mockIntegerColumnStats.addToRowCount(4242L); @@ -396,4 +432,110 @@ public void quickStatsBuildTimeIsBounded() }); } } + + @Test + public void testFollowSymlinkFile() + throws IOException + { + java.nio.file.Path testTempDir = createTempDirectory("test"); + java.nio.file.Path symlinkFileDir = testTempDir.resolve("symlink"); + java.nio.file.Path tableDir1 = testTempDir.resolve("table_1"); + java.nio.file.Path tableDir2 = testTempDir.resolve("table_2"); + createDirectory(symlinkFileDir); + createDirectory(tableDir1); + createDirectory(tableDir2); + + // Copy a parquet file from resources to the test/table temp dir + String fileName1 = "data_1.parquet"; + String fileName2 = "data_2.parquet"; + URL resourceUrl1 = getResource("quick_stats/tpcds_store_sales_sf_point_01/20230706_110621_00007_4uxkh_10e94cd0-1f67-4440-afd0-75cd328ea570"); + URL resourceUrl2 = getResource("quick_stats/tpcds_store_sales_sf_point_01/20230706_110621_00007_4uxkh_12b3ec73-4952-4df7-9987-2beb20cd5953"); + assertNotNull(resourceUrl1); + assertNotNull(resourceUrl2); + + java.nio.file.Path targetFilePath1 = tableDir1.resolve(fileName1); + java.nio.file.Path targetFilePath2 = tableDir2.resolve(fileName2); + + try (InputStream in = resourceUrl1.openStream()) { + copy(in, targetFilePath1, REPLACE_EXISTING); + } + try (InputStream in = resourceUrl2.openStream()) { + copy(in, targetFilePath2, REPLACE_EXISTING); + } + + // Create the symlink manifest pointing to data.parquet + java.nio.file.Path manifestFilePath = createFile(symlinkFileDir.resolve("manifest")); + try (BufferedWriter writer = newBufferedWriter(manifestFilePath, CREATE, TRUNCATE_EXISTING)) { + writer.write("file:" + tableDir1 + "/" + fileName1); + writer.newLine(); + writer.write("file:" + tableDir2 + "/" + fileName2); + } + + String symlinkTableName = "symlink_table"; + Table symlinkTable = new Table( + Optional.of("catalogName"), + TEST_SCHEMA, + symlinkTableName, + "owner", + EXTERNAL_TABLE, + Storage.builder() + .setStorageFormat(StorageFormat.create(ParquetHiveSerDe.class.getName(), SymlinkTextInputFormat.class.getName(), HiveIgnoreKeyTextOutputFormat.class.getName())) + .setLocation(symlinkFileDir.toString()) + .build(), + ImmutableList.of(), + ImmutableList.of(), + ImmutableMap.of(), + Optional.empty(), + Optional.empty()); + + metastoreMock.createTable(metastoreContext, symlinkTable, new PrincipalPrivileges(ImmutableMultimap.of(), ImmutableMultimap.of()), ImmutableList.of()); + + DirectoryLister directoryLister = new HadoopDirectoryLister(); + + QuickStatsBuilder quickStatsBuilder = (session1, metastore, table, metastoreContext, partitionId, files) -> { + List fileInfoList = ImmutableList.copyOf(files); + assertEquals(fileInfoList.size(), 2); + for (HiveFileInfo hiveFileInfo : fileInfoList) { + assertTrue(hiveFileInfo.getPath().equals("file:" + targetFilePath1) || hiveFileInfo.getPath().equals("file:" + targetFilePath2)); + } + return new PartitionQuickStats(UNPARTITIONED_ID.getPartitionName(), ImmutableList.of(mockIntegerColumnStats), fileInfoList.size()); + }; + + QuickStatsProvider quickStatsProvider = new QuickStatsProvider(metastoreMock, hdfsEnvironment, directoryLister, hiveClientConfig, new NamenodeStats(), + ImmutableList.of(quickStatsBuilder)); + + SchemaTableName table = new SchemaTableName(TEST_SCHEMA, symlinkTableName); + + Table resolvedTable = metastoreMock.getTable(metastoreContext, table.getSchemaName(), table.getTableName()).get(); + Path symlinkTablePath = new Path(resolvedTable.getStorage().getLocation()); + HdfsContext hdfsContext = new HdfsContext(SESSION, table.getSchemaName(), table.getTableName(), UNPARTITIONED_ID.getPartitionName(), false); + ExtendedFileSystem fs = hdfsEnvironment.getFileSystem(hdfsContext, symlinkTablePath); + HiveDirectoryContext hiveDirectoryContext = new HiveDirectoryContext(RECURSE, isUseListDirectoryCache(SESSION), + isSkipEmptyFilesEnabled(SESSION), hdfsContext.getIdentity(), buildDirectoryContextProperties(SESSION), SESSION.getRuntimeStats()); + + // Test directoryLister finds the manifest file in the table dir + Iterator fileInfoIterator = directoryLister.list(fs, resolvedTable, symlinkTablePath, Optional.empty(), new NamenodeStats(), hiveDirectoryContext); + ImmutableList fileInfoList = ImmutableList.copyOf(fileInfoIterator); + + assertEquals(fileInfoList.size(), 1); + assertEquals(fileInfoList.get(0).getPath(), "file:" + manifestFilePath); + assertEquals(fileInfoList.get(0).getParent(), "file:" + symlinkFileDir); + + // Test that the input format is correct + InputFormat inputFormat = getInputFormat( + hdfsEnvironment.getConfiguration(hdfsContext, symlinkTablePath), + resolvedTable.getStorage().getStorageFormat().getInputFormat(), + resolvedTable.getStorage().getStorageFormat().getSerDe(), + false); + + assertTrue(inputFormat instanceof SymlinkTextInputFormat); + + // Test entire getQuickStats and ensure file count matches fileList size in buildQuickStats + PartitionStatistics quickStats = quickStatsProvider.getQuickStats(SESSION, table, metastoreContext, UNPARTITIONED_ID.getPartitionName()); + + assertTrue(quickStats.getBasicStatistics().getFileCount().isPresent()); + assertEquals(quickStats.getBasicStatistics().getFileCount().getAsLong(), 2L); + + deleteRecursively(testTempDir, ALLOW_INSECURE); + } } diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSizeBasedSplitWeightProvider.java b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSizeBasedSplitWeightProvider.java index 955c4ad204595..5bbe25c1f68b5 100644 --- a/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSizeBasedSplitWeightProvider.java +++ b/presto-hive/src/test/java/com/facebook/presto/hive/util/TestSizeBasedSplitWeightProvider.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.hive.util; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.SplitWeight; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; diff --git a/presto-hive/src/test/resources/com.facebook.presto.hive.security.ranger/default-allow-all.json b/presto-hive/src/test/resources/com.facebook.presto.hive.security.ranger/default-allow-all.json index 53e876655d20e..b68e97b531ddd 100644 --- a/presto-hive/src/test/resources/com.facebook.presto.hive.security.ranger/default-allow-all.json +++ b/presto-hive/src/test/resources/com.facebook.presto.hive.security.ranger/default-allow-all.json @@ -94,7 +94,8 @@ "impliedGrants": [], "itemId": 1, "label": "select", - "name": "select" + "name": "select", + "category": "data" }, { "impliedGrants": [], diff --git a/presto-hive/src/test/resources/com/facebook/presto/hive/security.json b/presto-hive/src/test/resources/com/facebook/presto/hive/security.json index 58bede69077a8..787024cbf4026 100644 --- a/presto-hive/src/test/resources/com/facebook/presto/hive/security.json +++ b/presto-hive/src/test/resources/com/facebook/presto/hive/security.json @@ -12,6 +12,31 @@ "user": "hive", "owner": true } + ], + "sessionProperties": [ + { + "user": ".*", + "property": ".*", + "allow": true + } + ], + "procedures": [ + { + "user": "hive", + "schema": ".*", + "privileges": ["EXECUTE"] + }, + { + "user": "alice", + "schema": "system", + "privileges": ["EXECUTE"] + }, + { + "user": "bob", + "schema": "system", + "procedure": "invalidate_directory_list_cache", + "privileges": ["EXECUTE"] + } ] } diff --git a/presto-hive/src/test/sql/create-test.sql b/presto-hive/src/test/sql/create-test.sql index 96bf9d635d3b6..dc4f04bd45fc1 100644 --- a/presto-hive/src/test/sql/create-test.sql +++ b/presto-hive/src/test/sql/create-test.sql @@ -67,7 +67,7 @@ COMMENT 'Presto test bucketed table' PARTITIONED BY (ds STRING) CLUSTERED BY (t_string, t_int) INTO 32 BUCKETS STORED AS RCFILE -TBLPROPERTIES ('RETENTION'='-1') +TBLPROPERTIES ('bucketing_version'='1') ; CREATE TABLE presto_test_bucketed_by_bigint_boolean ( @@ -84,7 +84,7 @@ COMMENT 'Presto test bucketed table' PARTITIONED BY (ds STRING) CLUSTERED BY (t_bigint, t_boolean) INTO 32 BUCKETS STORED AS RCFILE -TBLPROPERTIES ('RETENTION'='-1') +TBLPROPERTIES ('bucketing_version'='1') ; CREATE TABLE presto_test_bucketed_by_double_float ( diff --git a/presto-hudi/pom.xml b/presto-hudi/pom.xml index f7dd5d40df8de..180511666f58d 100644 --- a/presto-hudi/pom.xml +++ b/presto-hudi/pom.xml @@ -4,16 +4,29 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-hudi + presto-hudi Presto - Hudi Connector presto-plugin ${project.parent.basedir} + 17 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + com.google.guava @@ -38,8 +51,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -119,7 +132,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -128,13 +141,15 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api org.apache.hudi hudi-presto-bundle + + 1.0.2 @@ -215,7 +230,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -233,7 +248,7 @@ - io.airlift + com.facebook.airlift units provided diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiColumnHandle.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiColumnHandle.java index 4d29649c7be28..4e5a1b6154392 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiColumnHandle.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiColumnHandle.java @@ -123,6 +123,15 @@ public ColumnMetadata toColumnMetadata(TypeManager typeManager) .build(); } + public ColumnMetadata toColumnMetadata(TypeManager typeManager, String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(hiveType.getType(typeManager)) + .setExtraInfo(getExtraInfo().orElse(null)) + .build(); + } + @Override public String toString() { diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiConfig.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiConfig.java index c6a35943a3a5e..98e5154a108c3 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiConfig.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiConfig.java @@ -16,14 +16,13 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; - -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class HudiConfig { diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadata.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadata.java index dbbedd4feff9c..24a04a82c44e1 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadata.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadata.java @@ -115,7 +115,11 @@ public Optional getSystemTable(ConnectorSession session, SchemaTabl } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle tableHandle, + Constraint constraint, + Optional> desiredColumns) { HudiTableHandle handle = (HudiTableHandle) tableHandle; Table table = getTable(session, tableHandle); @@ -127,7 +131,7 @@ public List getTableLayouts(ConnectorSession session partitionColumns, table.getParameters(), constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -170,7 +174,7 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable @Override public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) { - List tables = prefix.getTableName() != null ? singletonList(prefix.toSchemaTableName()) : listTables(session, Optional.of(prefix.getSchemaName())); + List tables = prefix.getTableName() != null ? singletonList(prefix.toSchemaTableName()) : listTables(session, Optional.ofNullable(prefix.getSchemaName())); ImmutableMap.Builder> columns = ImmutableMap.builder(); for (SchemaTableName table : tables) { @@ -206,7 +210,7 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema tableName.getTableName()).orElseThrow(() -> new TableNotFoundException(tableName)); List columnMetadatas = allColumnHandles(table) - .map(columnHandle -> columnHandle.toColumnMetadata(typeManager)) + .map(columnHandle -> columnHandle.toColumnMetadata(typeManager, normalizeIdentifier(session, columnHandle.getName()))) .collect(toList()); return new ConnectorTableMetadata(tableName, columnMetadatas); } diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadataFactory.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadataFactory.java index 715337abe4964..e7536568a26fb 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadataFactory.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiMetadataFactory.java @@ -20,8 +20,7 @@ import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; import com.facebook.presto.spi.connector.ConnectorMetadata; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiModule.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiModule.java index 84ed1d1b589e7..06271332a5974 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiModule.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiModule.java @@ -61,9 +61,9 @@ import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.multibindings.Multibinder; +import jakarta.inject.Singleton; import org.weakref.jmx.testing.TestingMBeanServer; -import javax.inject.Singleton; import javax.management.MBeanServer; import java.util.concurrent.ExecutorService; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPageSourceProvider.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPageSourceProvider.java index a67174d8419b3..d1212ab20194f 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPageSourceProvider.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPageSourceProvider.java @@ -33,11 +33,10 @@ import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; - import java.time.ZoneId; import java.util.List; import java.util.Optional; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPartitionManager.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPartitionManager.java index b7f5b6ff3127a..ba80fd2f516c7 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPartitionManager.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiPartitionManager.java @@ -30,8 +30,7 @@ import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.time.ZoneId; import java.util.HashMap; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiRecordCursors.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiRecordCursors.java index dd73e684fc370..fbab341ade696 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiRecordCursors.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiRecordCursors.java @@ -55,6 +55,7 @@ import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_ALL_COLUMNS; import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR; import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_NAMES_CONF_STR; +import static org.apache.hudi.common.config.HoodieReaderConfig.FILE_GROUP_READER_ENABLED; class HudiRecordCursors { @@ -105,6 +106,7 @@ public static RecordCursor createRealtimeRecordCursor( jobConf.setBoolean(READ_ALL_COLUMNS, false); jobConf.set(READ_COLUMN_IDS_CONF_STR, join(dataColumns, HudiColumnHandle::getId)); jobConf.set(READ_COLUMN_NAMES_CONF_STR, join(dataColumns, HudiColumnHandle::getName)); + jobConf.setBoolean(FILE_GROUP_READER_ENABLED.key(), false); schema.stringPropertyNames() .forEach(name -> jobConf.set(name, schema.getProperty(name))); refineCompressionCodecs(jobConf); diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSessionProperties.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSessionProperties.java index 74ee5db1a6ec3..b559e55c1cee9 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSessionProperties.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSessionProperties.java @@ -14,13 +14,12 @@ package com.facebook.presto.hudi; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSplitManager.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSplitManager.java index 06d44c75784f2..c6e3cae63915f 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSplitManager.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/HudiSplitManager.java @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Streams; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hudi.common.config.HoodieMetadataConfig; @@ -44,8 +45,7 @@ import org.apache.hudi.common.table.timeline.HoodieTimeline; import org.apache.hudi.common.table.view.HoodieTableFileSystemView; import org.apache.hudi.common.util.HoodieTimer; - -import javax.inject.Inject; +import org.apache.hudi.storage.StorageConfiguration; import java.io.IOException; import java.util.List; @@ -63,6 +63,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.apache.hudi.common.table.view.FileSystemViewManager.createInMemoryFileSystemViewWithTimeline; +import static org.apache.hudi.hadoop.fs.HadoopFSUtils.getStorageConfWithCopy; public class HudiSplitManager implements ConnectorSplitManager @@ -105,7 +106,7 @@ public ConnectorSplitSource getSplits( HudiTableHandle table = layout.getTable(); // Retrieve and prune partitions - HoodieTimer timer = new HoodieTimer().startTimer(); + HoodieTimer timer = HoodieTimer.start(); List partitions = hudiPartitionManager.getEffectivePartitions(session, metastore, table.getSchemaTableName(), layout.getTupleDomain()); log.debug("Took %d ms to get %d partitions", timer.endTimer(), partitions.size()); if (partitions.isEmpty()) { @@ -115,10 +116,10 @@ public ConnectorSplitSource getSplits( // Load Hudi metadata ExtendedFileSystem fs = getFileSystem(session, table); HoodieMetadataConfig metadataConfig = HoodieMetadataConfig.newBuilder().enable(isHudiMetadataTableEnabled(session)).build(); - Configuration conf = fs.getConf(); + StorageConfiguration conf = getStorageConfWithCopy(fs.getConf()); HoodieTableMetaClient metaClient = HoodieTableMetaClient.builder().setConf(conf).setBasePath(table.getPath()).build(); HoodieTimeline timeline = metaClient.getActiveTimeline().getCommitsTimeline().filterCompletedInstants(); - String timestamp = timeline.lastInstant().map(HoodieInstant::getTimestamp).orElse(null); + String timestamp = timeline.lastInstant().map(HoodieInstant::requestedTime).orElse(null); if (timestamp == null) { // no completed instant for current table return new FixedSplitSource(ImmutableList.of()); diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiBackgroundSplitLoader.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiBackgroundSplitLoader.java index 2b165462fdafc..a184f49128fde 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiBackgroundSplitLoader.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiBackgroundSplitLoader.java @@ -14,7 +14,7 @@ package com.facebook.presto.hudi.split; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitAsyncQueue.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitAsyncQueue.java index 52d0da1af9b2f..52efd9d7a7597 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitAsyncQueue.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitAsyncQueue.java @@ -14,7 +14,7 @@ package com.facebook.presto.hudi.split; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitSource.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitSource.java index 6b7333e3d74ee..f5b2eacc99c4f 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitSource.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/ForHudiSplitSource.java @@ -14,7 +14,7 @@ package com.facebook.presto.hudi.split; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/HudiPartitionSplitGenerator.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/HudiPartitionSplitGenerator.java index 93fb0896fada1..21ca4f7c36f16 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/HudiPartitionSplitGenerator.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/HudiPartitionSplitGenerator.java @@ -15,6 +15,7 @@ package com.facebook.presto.hudi.split; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.hive.util.AsyncQueue; @@ -28,9 +29,7 @@ import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.Path; -import org.apache.hudi.common.fs.FSUtils; import org.apache.hudi.common.model.FileSlice; import org.apache.hudi.common.table.view.HoodieTableFileSystemView; import org.apache.hudi.common.util.HoodieTimer; @@ -47,6 +46,8 @@ import static com.facebook.presto.hudi.HudiSplitManager.getHudiPartition; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; +import static org.apache.hudi.common.fs.FSUtils.getRelativePartitionPath; +import static org.apache.hudi.hadoop.fs.HadoopFSUtils.convertToStoragePath; /** * A runnable to take partition names from a queue of partitions to process, @@ -93,7 +94,7 @@ public HudiPartitionSplitGenerator( @Override public void run() { - HoodieTimer timer = new HoodieTimer().startTimer(); + HoodieTimer timer = HoodieTimer.start(); while (!concurrentPartitionQueue.isEmpty()) { String partitionName = concurrentPartitionQueue.poll(); if (partitionName != null) { @@ -107,7 +108,7 @@ private void generateSplitsFromPartition(String partitionName) { HudiPartition hudiPartition = getHudiPartition(metastore, metastoreContext, layout, partitionName); Path partitionPath = new Path(hudiPartition.getStorage().getLocation()); - String relativePartitionPath = FSUtils.getRelativePartitionPath(tablePath, partitionPath); + String relativePartitionPath = getRelativePartitionPath(convertToStoragePath(tablePath), convertToStoragePath(partitionPath)); Stream fileSlices = HudiTableType.MOR.equals(table.getTableType()) ? fsView.getLatestMergedFileSlicesBeforeOrOn(relativePartitionPath, latestInstant) : fsView.getLatestFileSlicesBeforeOrOn(relativePartitionPath, latestInstant, false); diff --git a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/SizeBasedSplitWeightProvider.java b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/SizeBasedSplitWeightProvider.java index 9c41dc68114ae..55cac196f9181 100644 --- a/presto-hudi/src/main/java/com/facebook/presto/hudi/split/SizeBasedSplitWeightProvider.java +++ b/presto-hudi/src/main/java/com/facebook/presto/hudi/split/SizeBasedSplitWeightProvider.java @@ -14,8 +14,8 @@ package com.facebook.presto.hudi.split; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.SplitWeight; -import io.airlift.units.DataSize; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.primitives.Doubles.constrainToRange; diff --git a/presto-hudi/src/test/java/com/facebook/presto/hudi/TestHudiConfig.java b/presto-hudi/src/test/java/com/facebook/presto/hudi/TestHudiConfig.java index d2051474c0353..163490005e60a 100644 --- a/presto-hudi/src/test/java/com/facebook/presto/hudi/TestHudiConfig.java +++ b/presto-hudi/src/test/java/com/facebook/presto/hudi/TestHudiConfig.java @@ -14,8 +14,8 @@ package com.facebook.presto.hudi; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; @@ -23,7 +23,7 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class TestHudiConfig { diff --git a/presto-hudi/src/test/java/com/facebook/presto/hudi/TestingTypeManager.java b/presto-hudi/src/test/java/com/facebook/presto/hudi/TestingTypeManager.java index 8b9f9fb7a747a..c4a0658622fc6 100644 --- a/presto-hudi/src/test/java/com/facebook/presto/hudi/TestingTypeManager.java +++ b/presto-hudi/src/test/java/com/facebook/presto/hudi/TestingTypeManager.java @@ -63,4 +63,10 @@ public List getTypes() { return ImmutableList.of(BOOLEAN, INTEGER, BIGINT, DOUBLE, VARCHAR, VARBINARY, TIMESTAMP, DATE, HYPER_LOG_LOG); } + + @Override + public boolean hasType(TypeSignature signature) + { + return getType(signature) != null; + } } diff --git a/presto-hudi/src/test/java/com/facebook/presto/hudi/split/TestSizeBasedSplitWeightProvider.java b/presto-hudi/src/test/java/com/facebook/presto/hudi/split/TestSizeBasedSplitWeightProvider.java index a2ffa48b1abb2..f81bc2132b5c2 100644 --- a/presto-hudi/src/test/java/com/facebook/presto/hudi/split/TestSizeBasedSplitWeightProvider.java +++ b/presto-hudi/src/test/java/com/facebook/presto/hudi/split/TestSizeBasedSplitWeightProvider.java @@ -14,11 +14,11 @@ package com.facebook.presto.hudi.split; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.SplitWeight; -import io.airlift.units.DataSize; import org.testng.annotations.Test; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; public class TestSizeBasedSplitWeightProvider diff --git a/presto-i18n-functions/pom.xml b/presto-i18n-functions/pom.xml index 405a83ada0775..1b9a7c429f32b 100644 --- a/presto-i18n-functions/pom.xml +++ b/presto-i18n-functions/pom.xml @@ -4,15 +4,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-i18n-functions + presto-i18n-functions Internationalization functions for Presto presto-plugin ${project.parent.basedir} + true @@ -40,7 +42,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -52,7 +54,7 @@ - io.airlift + com.facebook.airlift units provided @@ -88,6 +90,12 @@ test-jar test + + + com.facebook.presto + presto-main-tests + test + diff --git a/presto-i18n-functions/src/test/java/com/facebook/presto/i18n/functions/TestMyanmarFunctions.java b/presto-i18n-functions/src/test/java/com/facebook/presto/i18n/functions/TestMyanmarFunctions.java index 1f5bc00908408..f0259c46f8c5c 100644 --- a/presto-i18n-functions/src/test/java/com/facebook/presto/i18n/functions/TestMyanmarFunctions.java +++ b/presto-i18n-functions/src/test/java/com/facebook/presto/i18n/functions/TestMyanmarFunctions.java @@ -14,15 +14,23 @@ package com.facebook.presto.i18n.functions; import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; public class TestMyanmarFunctions extends AbstractTestFunctions { + public TestMyanmarFunctions() + { + super(TEST_SESSION, new FeaturesConfig(), new FunctionsConfig(), false); + } + @BeforeClass public void setUp() { diff --git a/presto-iceberg/pom.xml b/presto-iceberg/pom.xml index 7ada5aed5695a..6549854722328 100644 --- a/presto-iceberg/pom.xml +++ b/presto-iceberg/pom.xml @@ -4,19 +4,33 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-iceberg + presto-iceberg Presto - Iceberg Connector presto-plugin ${project.parent.basedir} - 1.5.0 - 0.77.1 + 17 + 1.10.0 + 0.103.0 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + + com.facebook.airlift concurrent @@ -74,6 +88,22 @@ + + org.apache.parquet + parquet-hadoop + ${dep.parquet.version} + + + org.apache.yetus + audience-annotations + + + org.apache.hadoop + hadoop-client + + + + com.facebook.presto presto-expressions @@ -132,7 +162,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -178,7 +208,7 @@ - io.airlift + com.facebook.airlift units provided @@ -189,7 +219,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -207,11 +237,16 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true + + jakarta.annotation + jakarta.annotation-api + + com.google.guava guava @@ -235,8 +270,13 @@ - javax.validation - validation-api + jakarta.inject + jakarta.inject-api + + + + jakarta.validation + jakarta.validation-api @@ -263,16 +303,6 @@ - - com.amazonaws - aws-java-sdk-core - - - - com.amazonaws - aws-java-sdk-s3 - - org.apache.iceberg iceberg-core @@ -333,6 +363,55 @@ + + org.apache.iceberg + iceberg-aws + ${dep.iceberg.version} + runtime + + + + software.amazon.awssdk + s3 + runtime + + + + software.amazon.awssdk + regions + runtime + + + + software.amazon.awssdk + sdk-core + runtime + + + + software.amazon.awssdk + auth + runtime + + + + software.amazon.awssdk + aws-core + runtime + + + + software.amazon.awssdk + sts + runtime + + + + software.amazon.awssdk + kms + runtime + + org.apache.iceberg iceberg-parquet @@ -367,8 +446,8 @@ jackson-databind - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api @@ -395,8 +474,8 @@ jackson-databind - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api @@ -405,6 +484,7 @@ org.projectnessie.nessie nessie-client ${dep.nessie.version} + runtime org.slf4j @@ -423,8 +503,8 @@ jackson-databind - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api @@ -501,20 +581,37 @@ jjwt-jackson runtime + + com.facebook.airlift + node + + org.openjdk.jmh jmh-core test + + com.facebook.airlift + testing + test + + org.openjdk.jmh jmh-generator-annprocess test + + jakarta.servlet + jakarta.servlet-api + test + + com.facebook.presto presto-cache @@ -525,6 +622,11 @@ presto-main-base test + + com.facebook.presto + presto-main-tests + test + com.facebook.presto presto-main @@ -548,6 +650,25 @@ test + + com.facebook.presto + presto-common + test-jar + test + + + + com.facebook.airlift + http-server + test + + + + org.apache.httpcomponents.core5 + httpcore5 + 5.3.1 + + com.facebook.presto presto-hive @@ -596,6 +717,7 @@ testng test + com.facebook.presto presto-testing-docker @@ -614,15 +736,15 @@ - org.testcontainers - testcontainers - test - - - org.slf4j - slf4j-api - - + org.testcontainers + testcontainers + test + + + org.slf4j + slf4j-api + + @@ -661,28 +783,38 @@ org.apache.commons commons-math3 + test + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.facebook.airlift:node + org.projectnessie.nessie:nessie-model + org.apache.httpcomponents.core5:httpcore5 + com.amazonaws:aws-java-sdk-s3 + com.amazonaws:aws-java-sdk-core + + + + + + + + + + + + + + - - org.apache.maven.plugins - maven-dependency-plugin - - - - org.glassfish.jersey.core:jersey-common:jar - org.eclipse.jetty:jetty-server:jar - - com.facebook.airlift:http-server:jar - com.facebook.airlift:node:jar - javax.servlet:javax.servlet-api:jar - org.apache.httpcomponents.core5:httpcore5:jar - - - org.basepom.maven duplicate-finder-maven-plugin diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/FilesTable.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/FilesTable.java index b2d5a1293c819..0510602fa9600 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/FilesTable.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/FilesTable.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg; import com.facebook.presto.common.Page; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.StandardTypes; @@ -117,13 +118,14 @@ public ConnectorTableMetadata getTableMetadata() @Override public ConnectorPageSource pageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, TupleDomain constraint) { - return new FixedPageSource(buildPages(tableMetadata, icebergTable, snapshotId)); + return new FixedPageSource(buildPages(tableMetadata, icebergTable, snapshotId, session)); } - private static List buildPages(ConnectorTableMetadata tableMetadata, Table icebergTable, Optional snapshotId) + private static List buildPages(ConnectorTableMetadata tableMetadata, Table icebergTable, Optional snapshotId, ConnectorSession session) { PageListBuilder pagesBuilder = forTable(tableMetadata); - TableScan tableScan = getTableScan(TupleDomain.all(), snapshotId, icebergTable).includeColumnStats(); + RuntimeStats runtimeStats = session.getRuntimeStats(); + TableScan tableScan = getTableScan(TupleDomain.all(), snapshotId, icebergTable, runtimeStats).includeColumnStats(); Map idToTypeMap = getIdToTypeMap(icebergTable.schema()); try (CloseableIterable fileScanTasks = tableScan.planFiles()) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java index 688a57cff011e..f9023ee9da7b2 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java @@ -13,9 +13,11 @@ */ package com.facebook.presto.iceberg; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.airlift.log.Logger; import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HivePrivilegeInfo; import com.facebook.presto.hive.metastore.MetastoreContext; @@ -28,11 +30,15 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableNotFoundException; import com.facebook.presto.spi.security.PrestoPrincipal; +import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Sets; +import jakarta.annotation.Nullable; +import org.apache.hadoop.hive.metastore.api.InvalidObjectException; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.mapred.FileInputFormat; import org.apache.hadoop.mapred.FileOutputFormat; @@ -42,15 +48,16 @@ import org.apache.iceberg.TableMetadataParser; import org.apache.iceberg.TableOperations; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.AlreadyExistsException; import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.LocationProvider; import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.Tasks; -import javax.annotation.Nullable; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.FileNotFoundException; import java.util.HashMap; import java.util.Map; @@ -60,13 +67,18 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; +import static com.facebook.presto.hive.HiveErrorCode.HIVE_METASTORE_ERROR; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.DELETE; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.INSERT; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.SELECT; import static com.facebook.presto.hive.metastore.HivePrivilegeInfo.HivePrivilege.UPDATE; import static com.facebook.presto.hive.metastore.MetastoreUtil.TABLE_COMMENT; import static com.facebook.presto.hive.metastore.MetastoreUtil.isPrestoView; +import static com.facebook.presto.iceberg.HiveTableOperations.CommitStatus.FAILED; +import static com.facebook.presto.iceberg.HiveTableOperations.CommitStatus.SUCCESS; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_COMMIT_ERROR; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA; import static com.facebook.presto.iceberg.IcebergUtil.isIcebergTable; import static com.facebook.presto.iceberg.IcebergUtil.toHiveColumns; @@ -78,8 +90,17 @@ import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; import static org.apache.iceberg.BaseMetastoreTableOperations.ICEBERG_TABLE_TYPE_VALUE; +import static org.apache.iceberg.BaseMetastoreTableOperations.METADATA_LOCATION_PROP; import static org.apache.iceberg.BaseMetastoreTableOperations.TABLE_TYPE_PROP; import static org.apache.iceberg.TableMetadataParser.getFileExtension; +import static org.apache.iceberg.TableProperties.COMMIT_NUM_STATUS_CHECKS; +import static org.apache.iceberg.TableProperties.COMMIT_NUM_STATUS_CHECKS_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_MAX_WAIT_MS; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_MAX_WAIT_MS_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_MIN_WAIT_MS; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_MIN_WAIT_MS_DEFAULT; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_TOTAL_WAIT_MS; +import static org.apache.iceberg.TableProperties.COMMIT_STATUS_CHECKS_TOTAL_WAIT_MS_DEFAULT; import static org.apache.iceberg.TableProperties.METADATA_COMPRESSION; import static org.apache.iceberg.TableProperties.METADATA_COMPRESSION_DEFAULT; import static org.apache.iceberg.TableProperties.WRITE_METADATA_LOCATION; @@ -106,6 +127,7 @@ public class HiveTableOperations private final Optional owner; private final Optional location; private final HdfsFileIO fileIO; + private final IcebergHiveTableOperationsConfig config; private TableMetadata currentMetadata; @@ -216,7 +238,7 @@ public TableMetadata refresh() Table table = getTable(); if (!isIcebergTable(table)) { - throw new UnknownTableTypeException(getSchemaTableName()); + throw new UnknownTableTypeException("Not an Iceberg table: " + getSchemaTableName()); } if (isPrestoView(table)) { @@ -251,14 +273,19 @@ public void commit(@Nullable TableMetadata base, TableMetadata metadata) String newMetadataLocation = writeNewMetadata(metadata, version + 1); Table table; - // getting a process-level lock per table to avoid concurrent commit attempts to the same table from the same - // JVM process, which would result in unnecessary and costly HMS lock acquisition requests Optional lockId = Optional.empty(); + boolean useHMSLock = Optional.ofNullable(metadata.property(TableProperties.HIVE_LOCK_ENABLED, null)) + .map(Boolean::parseBoolean) + .orElse(config.getLockingEnabled()); ReentrantLock tableLevelMutex = commitLockCache.getUnchecked(database + "." + tableName); + // getting a process-level lock per table to avoid concurrent commit attempts to the same table from the same + // JVM process, which would result in unnecessary and costly HMS lock acquisition requests tableLevelMutex.lock(); try { try { - lockId = metastore.lock(metastoreContext, database, tableName); + if (useHMSLock) { + lockId = metastore.lock(metastoreContext, database, tableName); + } if (base == null) { String tableComment = metadata.properties().get(TABLE_COMMENT); Map parameters = new HashMap<>(); @@ -313,15 +340,62 @@ public void commit(@Nullable TableMetadata base, TableMetadata metadata) .put(table.getOwner(), new HivePrivilegeInfo(DELETE, true, owner, owner)) .build(), ImmutableMultimap.of()); - if (base == null) { - metastore.createTable(metastoreContext, table, privileges, emptyList()); + try { + if (base == null) { + metastore.createTable(metastoreContext, table, privileges, emptyList()); + } + else { + PartitionStatistics tableStats = metastore.getTableStatistics(metastoreContext, database, tableName); + metastore.persistTable(metastoreContext, database, tableName, table, privileges, () -> tableStats, useHMSLock ? ImmutableMap.of() : hmsEnvContext(base.metadataFileLocation())); + } } - else { - PartitionStatistics tableStats = metastore.getTableStatistics(metastoreContext, database, tableName); - metastore.replaceTable(metastoreContext, database, tableName, table, privileges); - - // attempt to put back previous table statistics - metastore.updateTableStatistics(metastoreContext, database, tableName, oldStats -> tableStats); + catch (AlreadyExistsException e) { + throw new PrestoException(HIVE_METASTORE_ERROR, format("Table already exists: %s.%s", database, tableName), e); + } + catch (CommitFailedException | CommitStateUnknownException e) { + throw e; + } + catch (Throwable e) { + if (e instanceof PrestoException && e.getCause() instanceof InvalidObjectException) { + throw new ValidationException(e, "Invalid Hive object for %s.%s", database, tableName); + } + if (e.getMessage() != null + && e.getMessage().contains("Table/View 'HIVE_LOCKS' does not exist")) { + throw new PrestoException(ICEBERG_COMMIT_ERROR, + "Failed to acquire locks from metastore because the underlying metastore " + + "table 'HIVE_LOCKS' does not exist. This can occur when using an embedded metastore which does not " + + "support transactions. To fix this use an alternative metastore.", + e); + } + CommitStatus commitStatus; + if (e.getMessage() != null + && e.getMessage() + .contains( + "The table has been modified. The parameter value for key '" + + METADATA_LOCATION_PROP + + "' is")) { + // It's possible the HMS client incorrectly retries a successful operation, due to network + // issue for example, and triggers this exception. So we need double-check to make sure + // this is really a concurrent modification. Hitting this exception means no pending + // requests, if any, can succeed later, so it's safe to check status in strict mode + commitStatus = checkCommitStatusStrict(newMetadataLocation, metadata); + if (commitStatus == FAILED) { + throw new CommitFailedException( + e, "The table %s.%s has been modified concurrently", database, tableName); + } + } + else { + // Cannot tell if commit to succeeded, attempting to reconnect and check. + commitStatus = checkCommitStatus(newMetadataLocation, metadata); + } + switch (commitStatus) { + case SUCCESS: + break; + case FAILED: + throw e; + case UNKNOWN: + throw new CommitStateUnknownException(e); + } } deleteRemovedMetadataFiles(base, metadata); } @@ -500,4 +574,89 @@ private void deleteRemovedMetadataFiles(TableMetadata base, TableMetadata metada .run(previousMetadataFile -> io().deleteFile(previousMetadataFile.file())); } } + + private Map hmsEnvContext(String metadataLocation) + { + return ImmutableMap.of( + org.apache.iceberg.hive.HiveTableOperations.NO_LOCK_EXPECTED_KEY, + METADATA_LOCATION_PROP, + org.apache.iceberg.hive.HiveTableOperations.NO_LOCK_EXPECTED_VALUE, + metadataLocation); + } + + @VisibleForTesting + public IcebergHiveTableOperationsConfig getConfig() + { + return config; + } + + /** + * Validate if the new metadata location is the current metadata location or present within + * previous metadata files. + * + * @param newMetadataLocation newly written metadata location + * @return true if the new metadata location is the current metadata location or present within + * previous metadata files. + */ + private boolean checkCurrentMetadataLocation(String newMetadataLocation) + { + TableMetadata metadata = refresh(); + String currentMetadataFileLocation = metadata.metadataFileLocation(); + return currentMetadataFileLocation.equals(newMetadataLocation) + || metadata.previousFiles().stream() + .anyMatch(log -> log.file().equals(newMetadataLocation)); + } + + protected CommitStatus checkCommitStatus(String newMetadataLocation, TableMetadata config) + { + CommitStatus strictStatus = + checkCommitStatusStrict(newMetadataLocation, config); + if (strictStatus == FAILED) { + return CommitStatus.UNKNOWN; + } + return strictStatus; + } + + protected CommitStatus checkCommitStatusStrict(String newMetadataLocation, TableMetadata config) + { + Supplier commitStatusSupplier = () -> checkCurrentMetadataLocation(newMetadataLocation); + + int maxAttempts = + PropertyUtil.propertyAsInt( + config.properties(), COMMIT_NUM_STATUS_CHECKS, COMMIT_NUM_STATUS_CHECKS_DEFAULT); + long minWaitMs = + PropertyUtil.propertyAsLong( + config.properties(), COMMIT_STATUS_CHECKS_MIN_WAIT_MS, COMMIT_STATUS_CHECKS_MIN_WAIT_MS_DEFAULT); + long maxWaitMs = + PropertyUtil.propertyAsLong( + config.properties(), COMMIT_STATUS_CHECKS_MAX_WAIT_MS, COMMIT_STATUS_CHECKS_MAX_WAIT_MS_DEFAULT); + long totalRetryMs = + PropertyUtil.propertyAsLong( + config.properties(), + COMMIT_STATUS_CHECKS_TOTAL_WAIT_MS, + COMMIT_STATUS_CHECKS_TOTAL_WAIT_MS_DEFAULT); + + AtomicReference status = new AtomicReference(CommitStatus.UNKNOWN); + + Tasks.foreach(newMetadataLocation) + .retry(maxAttempts) + .suppressFailureWhenFinished() + .exponentialBackoff(minWaitMs, maxWaitMs, totalRetryMs, 2.0) + .run( + location -> { + boolean commitSuccess = commitStatusSupplier.get(); + + if (commitSuccess) { + status.set(SUCCESS); + } + else { + status.set(FAILED); + } + }); + return status.get(); + } + + public enum CommitStatus { + SUCCESS, FAILED, UNKNOWN + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index b2e6b7431a4f6..4d071a205089b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -15,6 +15,9 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.BigintType; @@ -23,16 +26,22 @@ import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.hive.HiveOutputInfo; +import com.facebook.presto.hive.HiveOutputMetadata; import com.facebook.presto.hive.HivePartition; -import com.facebook.presto.hive.HiveWrittenPartitions; import com.facebook.presto.hive.NodeVersion; +import com.facebook.presto.hive.PartitionSet; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.iceberg.changelog.ChangelogOperation; import com.facebook.presto.iceberg.changelog.ChangelogUtil; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; @@ -43,6 +52,12 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.DiscretePredicates; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewDefinition.ColumnMapping; +import com.facebook.presto.spi.MaterializedViewRefreshType; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; +import com.facebook.presto.spi.MaterializedViewStalenessConfig; +import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; @@ -54,10 +69,15 @@ import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionOperator; import com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; +import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.RowExpressionService; +import com.facebook.presto.spi.security.ViewSecurity; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ComputedStatistics; import com.facebook.presto.spi.statistics.TableStatisticType; @@ -80,18 +100,20 @@ import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileMetadata; import org.apache.iceberg.IsolationLevel; +import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.MetricsModes.None; -import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.RowDelta; import org.apache.iceberg.RowLevelOperationMode; import org.apache.iceberg.Schema; import org.apache.iceberg.SchemaParser; +import org.apache.iceberg.Snapshot; import org.apache.iceberg.SortOrder; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; import org.apache.iceberg.Transaction; +import org.apache.iceberg.UpdatePartitionSpec; import org.apache.iceberg.UpdateProperties; import org.apache.iceberg.exceptions.NoSuchTableException; import org.apache.iceberg.exceptions.NoSuchViewException; @@ -100,12 +122,14 @@ import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.StringType; import org.apache.iceberg.util.CharSequenceSet; import org.apache.iceberg.view.View; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -114,9 +138,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; @@ -130,15 +156,31 @@ import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression; import static com.facebook.presto.iceberg.IcebergColumnHandle.DATA_SEQUENCE_NUMBER_COLUMN_HANDLE; import static com.facebook.presto.iceberg.IcebergColumnHandle.DATA_SEQUENCE_NUMBER_COLUMN_METADATA; +import static com.facebook.presto.iceberg.IcebergColumnHandle.DELETE_FILE_PATH_COLUMN_HANDLE; +import static com.facebook.presto.iceberg.IcebergColumnHandle.DELETE_FILE_PATH_COLUMN_METADATA; +import static com.facebook.presto.iceberg.IcebergColumnHandle.IS_DELETED_COLUMN_HANDLE; +import static com.facebook.presto.iceberg.IcebergColumnHandle.IS_DELETED_COLUMN_METADATA; import static com.facebook.presto.iceberg.IcebergColumnHandle.PATH_COLUMN_HANDLE; import static com.facebook.presto.iceberg.IcebergColumnHandle.PATH_COLUMN_METADATA; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_COMMIT_ERROR; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_FORMAT_VERSION; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_MATERIALIZED_VIEW; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_SNAPSHOT_ID; +import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getRefreshType; +import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getStaleReadBehavior; +import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getStalenessWindow; +import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getStorageSchema; +import static com.facebook.presto.iceberg.IcebergMaterializedViewProperties.getStorageTable; import static com.facebook.presto.iceberg.IcebergMetadataColumn.DATA_SEQUENCE_NUMBER; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.DELETE_FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.FILE_PATH; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.IS_DELETED; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_PARTITION_DATA; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_TARGET_ROW_ID_DATA; import static com.facebook.presto.iceberg.IcebergMetadataColumn.UPDATE_ROW_DATA; import static com.facebook.presto.iceberg.IcebergPartitionType.ALL; import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getMaterializedViewStoragePrefix; import static com.facebook.presto.iceberg.IcebergSessionProperties.isPushdownFilterEnabled; import static com.facebook.presto.iceberg.IcebergTableProperties.LOCATION_PROPERTY; import static com.facebook.presto.iceberg.IcebergTableProperties.PARTITIONING_PROPERTY; @@ -148,6 +190,7 @@ import static com.facebook.presto.iceberg.IcebergTableType.EQUALITY_DELETES; import static com.facebook.presto.iceberg.IcebergUtil.MIN_FORMAT_VERSION_FOR_DELETE; import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getColumnsForWrite; import static com.facebook.presto.iceberg.IcebergUtil.getDeleteMode; import static com.facebook.presto.iceberg.IcebergUtil.getFileFormat; import static com.facebook.presto.iceberg.IcebergUtil.getPartitionFields; @@ -183,17 +226,27 @@ import static com.facebook.presto.iceberg.optimizer.IcebergPlanOptimizer.getEnforcedColumns; import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateBaseTableStatistics; import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateStatisticsConsideringLayout; +import static com.facebook.presto.spi.MaterializedViewStatus.MaterializedViewState.FULLY_MATERIALIZED; +import static com.facebook.presto.spi.MaterializedViewStatus.MaterializedViewState.NOT_MATERIALIZED; +import static com.facebook.presto.spi.MaterializedViewStatus.MaterializedViewState.PARTIALLY_MATERIALIZED; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_VIEW; +import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT; import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Maps.transformValues; +import static java.lang.Long.parseLong; import static java.lang.String.format; import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; import static org.apache.iceberg.MetadataColumns.ROW_POSITION; +import static org.apache.iceberg.MetadataColumns.SPEC_ID; import static org.apache.iceberg.RowLevelOperationMode.MERGE_ON_READ; import static org.apache.iceberg.SnapshotSummary.DELETED_RECORDS_PROP; import static org.apache.iceberg.SnapshotSummary.REMOVED_EQ_DELETES_PROP; @@ -201,6 +254,7 @@ import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL; import static org.apache.iceberg.TableProperties.DELETE_ISOLATION_LEVEL_DEFAULT; import static org.apache.iceberg.TableProperties.WRITE_DATA_LOCATION; +import static org.apache.iceberg.expressions.Expressions.alwaysTrue; public abstract class IcebergAbstractMetadata implements ConnectorMetadata @@ -208,11 +262,32 @@ public abstract class IcebergAbstractMetadata private static final Logger log = Logger.get(IcebergAbstractMetadata.class); protected static final String INFORMATION_SCHEMA = "information_schema"; + // Materialized view metadata property keys + protected static final String PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION = "presto.materialized_view.format_version"; + protected static final String PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL = "presto.materialized_view.original_sql"; + protected static final String PRESTO_MATERIALIZED_VIEW_BASE_TABLES = "presto.materialized_view.base_tables"; + protected static final String PRESTO_MATERIALIZED_VIEW_BASE_SNAPSHOT_PREFIX = "presto.materialized_view.base_snapshot."; + protected static final String PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID = "presto.materialized_view.last_refresh_snapshot_id"; + protected static final String PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA = "presto.materialized_view.storage_schema"; + protected static final String PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME = "presto.materialized_view.storage_table_name"; + protected static final String PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS = "presto.materialized_view.column_mappings"; + protected static final String PRESTO_MATERIALIZED_VIEW_OWNER = "presto.materialized_view.owner"; + protected static final String PRESTO_MATERIALIZED_VIEW_SECURITY_MODE = "presto.materialized_view.security_mode"; + protected static final String PRESTO_MATERIALIZED_VIEW_STALE_READ_BEHAVIOR = "presto.materialized_view.stale_read_behavior"; + protected static final String PRESTO_MATERIALIZED_VIEW_STALENESS_WINDOW = "presto.materialized_view.staleness_window"; + protected static final String PRESTO_MATERIALIZED_VIEW_REFRESH_TYPE = "presto.materialized_view.refresh_type"; + + protected static final int CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION = 1; + protected final TypeManager typeManager; + protected final ProcedureRegistry procedureRegistry; protected final JsonCodec commitTaskCodec; + protected final JsonCodec> columnMappingsCodec; + protected final JsonCodec> schemaTableNamesCodec; protected final NodeVersion nodeVersion; protected final RowExpressionService rowExpressionService; protected final FilterStatsCalculatorService filterStatsCalculatorService; + protected Optional procedureContext = Optional.empty(); protected Transaction transaction; protected final StatisticsFileCache statisticsFileCache; protected final IcebergTableProperties tableProperties; @@ -222,16 +297,22 @@ public abstract class IcebergAbstractMetadata public IcebergAbstractMetadata( TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, + JsonCodec> columnMappingsCodec, + JsonCodec> schemaTableNamesCodec, NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, IcebergTableProperties tableProperties) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + this.columnMappingsCodec = requireNonNull(columnMappingsCodec, "columnMappingsCodec is null"); + this.schemaTableNamesCodec = requireNonNull(schemaTableNamesCodec, "schemaTableNamesCodec is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); @@ -251,12 +332,31 @@ protected final Table getIcebergTable(ConnectorSession session, SchemaTableName protected abstract View getIcebergView(ConnectorSession session, SchemaTableName schemaTableName); + protected abstract void createIcebergView( + ConnectorSession session, + SchemaTableName viewName, + List columns, + String viewSql, + Map properties); + + protected abstract void dropIcebergView(ConnectorSession session, SchemaTableName viewName); + + protected abstract void updateIcebergViewProperties( + ConnectorSession session, + SchemaTableName viewName, + Map properties); + protected abstract boolean tableExists(ConnectorSession session, SchemaTableName schemaTableName); public abstract void registerTable(ConnectorSession clientSession, SchemaTableName schemaTableName, Path metadataLocation); public abstract void unregisterTable(ConnectorSession clientSession, SchemaTableName schemaTableName); + public Optional getProcedureContext() + { + return this.procedureContext; + } + /** * This class implements the default implementation for getTableLayoutForConstraint which will be used in the case of a Java Worker */ @@ -267,29 +367,33 @@ public ConnectorTableLayoutResult getTableLayoutForConstraint( Constraint constraint, Optional> desiredColumns) { - Map predicateColumns = constraint.getSummary().getDomains().get().keySet().stream() - .map(IcebergColumnHandle.class::cast) - .collect(toImmutableMap(IcebergColumnHandle::getName, Functions.identity())); + Map predicateColumns = constraint.getSummary().getDomains() + .map(domains -> domains.keySet().stream() + .map(IcebergColumnHandle.class::cast) + .collect(toImmutableMap(IcebergColumnHandle::getName, Functions.identity()))) + .orElse(ImmutableMap.of()); IcebergTableHandle handle = (IcebergTableHandle) table; Table icebergTable = getIcebergTable(session, handle.getSchemaTableName()); List partitionColumns = getPartitionKeyColumnHandles(handle, icebergTable, typeManager); - TupleDomain partitionColumnPredicate = TupleDomain.withColumnDomains(Maps.filterKeys(constraint.getSummary().getDomains().get(), Predicates.in(partitionColumns))); + TupleDomain partitionColumnPredicate = TupleDomain.withColumnDomains(Maps.filterKeys(constraint.getSummary().getDomains().orElse(ImmutableMap.of()), Predicates.in(partitionColumns))); Optional> requestedColumns = desiredColumns.map(columns -> columns.stream().map(column -> (IcebergColumnHandle) column).collect(toImmutableSet())); - List partitions; + PartitionSet partitions; if (handle.getIcebergTableName().getTableType() == CHANGELOG || handle.getIcebergTableName().getTableType() == EQUALITY_DELETES) { - partitions = ImmutableList.of(new HivePartition(handle.getSchemaTableName())); + partitions = new PartitionSet(ImmutableList.of(new HivePartition(handle.getSchemaTableName()))); } else { + RuntimeStats runtimeStats = session.getRuntimeStats(); partitions = getPartitions( typeManager, handle, icebergTable, constraint, - partitionColumns); + partitionColumns, + runtimeStats); } ConnectorTableLayout layout = getTableLayout( @@ -303,7 +407,7 @@ public ConnectorTableLayoutResult getTableLayoutForConstraint( .setRequestedColumns(requestedColumns) .setPushdownFilterEnabled(isPushdownFilterEnabled(session)) .setPartitionColumnPredicate(partitionColumnPredicate.simplify()) - .setPartitions(Optional.ofNullable(partitions.size() == 0 ? null : partitions)) + .setPartitions(Optional.ofNullable(partitions.isEmpty() ? null : partitions)) .setTable(handle) .build()); return new ConnectorTableLayoutResult(layout, constraint.getSummary()); @@ -327,7 +431,7 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa Table icebergTable = getIcebergTable(session, tableHandle.getSchemaTableName()); validateTableMode(session, icebergTable); List partitionColumns = ImmutableList.copyOf(icebergTableLayoutHandle.getPartitionColumns()); - Optional> partitions = icebergTableLayoutHandle.getPartitions(); + Optional> partitions = icebergTableLayoutHandle.getPartitions(); Optional discretePredicates = partitions.flatMap(parts -> getDiscretePredicates(partitionColumns, parts)); if (!isPushdownFilterEnabled(session)) { return new ConnectorTableLayout( @@ -343,7 +447,11 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa Map predicateColumns = icebergTableLayoutHandle.getPredicateColumns().entrySet() .stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - Optional> predicate = partitions.map(parts -> getPredicate(icebergTableLayoutHandle, partitionColumns, parts, predicateColumns)); + Optional> predicate = partitions + .map(parts -> + getPredicate(icebergTableLayoutHandle, partitionColumns, + StreamSupport.stream(parts.spliterator(), false).toList(), + predicateColumns)); // capture subfields from domainPredicate to add to remainingPredicate // so those filters don't get lost Map columnTypes = getColumns(icebergTable.schema(), icebergTable.spec(), typeManager).stream() @@ -404,9 +512,11 @@ protected Optional getIcebergSystemTable(SchemaTableName tableName, case FILES: return Optional.of(new FilesTable(systemTableName, table, snapshotId, typeManager)); case PROPERTIES: - return Optional.of(new PropertiesTable(systemTableName, table)); + return Optional.of(new PropertiesTable(systemTableName, table, tableProperties)); case REFS: return Optional.of(new RefsTable(systemTableName, table)); + case METADATA_LOG_ENTRIES: + return Optional.of(new MetadataLogTable(systemTableName, table)); } return Optional.empty(); } @@ -424,13 +534,15 @@ protected ConnectorTableMetadata getTableOrViewMetadata(ConnectorSession session try { Table icebergTable = getIcebergTable(session, schemaTableName); ImmutableList.Builder columns = ImmutableList.builder(); - columns.addAll(getColumnMetadata(icebergTable)); + columns.addAll(getColumnMetadata(session, icebergTable)); if (icebergTableName.getTableType() == CHANGELOG) { return ChangelogUtil.getChangelogTableMeta(table, typeManager, columns.build()); } else { columns.add(PATH_COLUMN_METADATA); columns.add(DATA_SEQUENCE_NUMBER_COLUMN_METADATA); + columns.add(IS_DELETED_COLUMN_METADATA); + columns.add(DELETE_FILE_PATH_COLUMN_METADATA); } return new ConnectorTableMetadata(table, columns.build(), createMetadataProperties(icebergTable, session), getTableComment(icebergTable)); } @@ -440,7 +552,7 @@ protected ConnectorTableMetadata getTableOrViewMetadata(ConnectorSession session // try to load it as a view when getting an `NoSuchTableException`. This will be more efficient. try { View icebergView = getIcebergView(session, schemaTableName); - return new ConnectorTableMetadata(table, getColumnMetadata(icebergView), createViewMetadataProperties(icebergView), getViewComment(icebergView)); + return new ConnectorTableMetadata(table, getColumnMetadata(session, icebergView), createViewMetadataProperties(icebergView), getViewComment(icebergView)); } catch (NoSuchViewException noSuchViewException) { throw new TableNotFoundException(schemaTableName); @@ -505,12 +617,13 @@ protected ConnectorInsertTableHandle beginIcebergTableInsert(ConnectorSession se table.getIcebergTableName(), toPrestoSchema(icebergTable.schema(), typeManager), toPrestoPartitionSpec(icebergTable.spec(), typeManager), - getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + getColumnsForWrite(icebergTable.schema(), icebergTable.spec(), typeManager), icebergTable.location(), getFileFormat(icebergTable), getCompressionCodec(session), icebergTable.properties(), - getSupportedSortFields(icebergTable.schema(), icebergTable.sortOrder())); + getSupportedSortFields(icebergTable.schema(), icebergTable.sortOrder()), + Optional.empty()); } public static List getSupportedSortFields(Schema schema, SortOrder sortOrder) @@ -570,9 +683,9 @@ private Optional finishInsert(ConnectorSession session, throw new PrestoException(ICEBERG_COMMIT_ERROR, "Failed to commit Iceberg update to table: " + writableTableHandle.getTableName(), e); } - return Optional.of(new HiveWrittenPartitions(commitTasks.stream() + return Optional.of(new HiveOutputMetadata(new HiveOutputInfo(commitTasks.stream() .map(CommitTaskData::getPath) - .collect(toImmutableList()))); + .collect(toImmutableList()), icebergTable.location()))); } private Optional finishWrite(ConnectorSession session, IcebergWritableTableHandle writableTableHandle, Collection fragments, ChangelogOperation operationType) @@ -617,9 +730,9 @@ private Optional finishWrite(ConnectorSession session, throw new PrestoException(ICEBERG_COMMIT_ERROR, "Failed to commit Iceberg update to table: " + writableTableHandle.getTableName(), e); } - return Optional.of(new HiveWrittenPartitions(commitTasks.stream() + return Optional.of(new HiveOutputMetadata(new HiveOutputInfo(commitTasks.stream() .map(CommitTaskData::getPath) - .collect(toImmutableList()))); + .collect(toImmutableList()), icebergTable.location()))); } private void handleInsertTask(CommitTaskData task, Table icebergTable, AppendFiles appendFiles, ImmutableSet.Builder writtenFiles) @@ -687,9 +800,81 @@ private void handleFinishData(CommitTaskData task, Table icebergTable, Partition } @Override - public ColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + public Optional getDeleteRowIdColumn(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return Optional.of(IcebergColumnHandle.create(ROW_POSITION, typeManager, REGULAR)); + } + + /** + * Return the row change paradigm supported by the connector on the table. + */ + @Override + public RowChangeParadigm getRowChangeParadigm(ConnectorSession session, ConnectorTableHandle tableHandle) { - return IcebergColumnHandle.create(ROW_POSITION, typeManager, REGULAR); + return DELETE_ROW_AND_INSERT_ROW; + } + + @Override + public ColumnHandle getMergeTargetTableRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Types.StructType type = Types.StructType.of(ImmutableList.builder() + .add(MetadataColumns.FILE_PATH) + .add(ROW_POSITION) + .add(SPEC_ID) + .add(NestedField.required(MERGE_PARTITION_DATA.getId(), MERGE_PARTITION_DATA.getColumnName(), StringType.get())) + .build()); + + NestedField field = NestedField.required(MERGE_TARGET_ROW_ID_DATA.getId(), MERGE_TARGET_ROW_ID_DATA.getColumnName(), type); + return IcebergColumnHandle.create(field, typeManager, SYNTHESIZED); + } + + @Override + public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + verify(icebergTableHandle.getIcebergTableName().getTableType() == DATA, "only the data table can have data merged"); + Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); + int formatVersion = ((BaseTable) icebergTable).operations().current().formatVersion(); + + if (formatVersion < MIN_FORMAT_VERSION_FOR_DELETE || + !Optional.ofNullable(icebergTable.properties().get(TableProperties.UPDATE_MODE)) + .map(mode -> mode.equals(MERGE_ON_READ.modeName())) + .orElse(false)) { + throw new PrestoException(ICEBERG_INVALID_FORMAT_VERSION, + "Iceberg table updates require at least format version 2 and update mode must be merge-on-read"); + } + validateTableMode(session, icebergTable); + transaction = icebergTable.newTransaction(); + + IcebergInsertTableHandle insertHandle = new IcebergInsertTableHandle( + icebergTableHandle.getSchemaName(), + icebergTableHandle.getIcebergTableName(), + toPrestoSchema(icebergTable.schema(), typeManager), + toPrestoPartitionSpec(icebergTable.spec(), typeManager), + getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + icebergTable.location(), + getFileFormat(icebergTable), + getCompressionCodec(session), + icebergTable.properties(), + getSupportedSortFields(icebergTable.schema(), icebergTable.sortOrder()), + Optional.empty()); + + Map partitionSpecs = transformValues(icebergTable.specs(), partitionSpec -> toPrestoPartitionSpec(partitionSpec, typeManager)); + + return new IcebergMergeTableHandle(icebergTableHandle, insertHandle, partitionSpecs); + } + + @Override + public void finishMerge( + ConnectorSession session, + ConnectorMergeTableHandle tableHandle, + Collection fragments, + Collection computedStatistics) + { + IcebergWritableTableHandle insertTableHandle = + ((IcebergMergeTableHandle) tableHandle).getInsertTableHandle(); + + finishWrite(session, insertTableHandle, fragments, UPDATE_AFTER); } @Override @@ -698,27 +883,29 @@ public boolean isLegacyGetLayoutSupported(ConnectorSession session, ConnectorTab return !isPushdownFilterEnabled(session); } - protected List getColumnMetadata(Table table) + protected List getColumnMetadata(ConnectorSession session, Table table) { Map> partitionFields = getPartitionFields(table.spec(), ALL); return table.schema().columns().stream() .map(column -> ColumnMetadata.builder() - .setName(column.name()) + .setName(normalizeIdentifier(session, column.name())) .setType(toPrestoType(column.type(), typeManager)) + .setNullable(column.isOptional()) .setComment(column.doc()) .setHidden(false) .setExtraInfo(partitionFields.containsKey(column.name()) ? - columnExtraInfo(partitionFields.get(column.name())) : - null) + columnExtraInfo(partitionFields.get(column.name())) : + null) + .setNullable(column.isOptional()) .build()) .collect(toImmutableList()); } - protected List getColumnMetadata(View view) + protected List getColumnMetadata(ConnectorSession session, View view) { return view.schema().columns().stream() .map(column -> ColumnMetadata.builder() - .setName(column.name()) + .setName(normalizeIdentifier(session, column.name())) .setType(toPrestoType(column.type(), typeManager)) .setComment(column.doc()) .setHidden(false) @@ -843,6 +1030,38 @@ public void rollback() // TODO: cleanup open transaction } + @Override + public void dropBranch(ConnectorSession session, ConnectorTableHandle tableHandle, String branchName, boolean branchExists) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + verify(icebergTableHandle.getIcebergTableName().getTableType() == DATA, "only the data table can have branch dropped"); + Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); + if (icebergTable.refs().containsKey(branchName) && icebergTable.refs().get(branchName).isBranch()) { + icebergTable.manageSnapshots().removeBranch(branchName).commit(); + } + else { + if (!branchExists) { + throw new PrestoException(NOT_FOUND, format("Branch %s doesn't exist in table %s", branchName, icebergTableHandle.getSchemaTableName().getTableName())); + } + } + } + + @Override + public void dropTag(ConnectorSession session, ConnectorTableHandle tableHandle, String tagName, boolean tagExists) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + verify(icebergTableHandle.getIcebergTableName().getTableType() == DATA, "only the data table can have tag dropped"); + Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); + if (icebergTable.refs().containsKey(tagName) && icebergTable.refs().get(tagName).isTag()) { + icebergTable.manageSnapshots().removeTag(tagName).commit(); + } + else { + if (!tagExists) { + throw new PrestoException(NOT_FOUND, format("Tag %s doesn't exist in table %s", tagName, icebergTableHandle.getSchemaTableName().getTableName())); + } + } + } + @Override public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnMetadata column) { @@ -858,9 +1077,13 @@ public void addColumn(ConnectorSession session, ConnectorTableHandle tableHandle Transaction transaction = icebergTable.newTransaction(); transaction.updateSchema().addColumn(column.getName(), columnType, column.getComment().orElse(null)).commit(); if (column.getProperties().containsKey(PARTITIONING_PROPERTY)) { - String transform = (String) column.getProperties().get(PARTITIONING_PROPERTY); - transaction.updateSpec().addField(getPartitionColumnName(column.getName(), transform), - getTransformTerm(column.getName(), transform)).commit(); + List partitioningTransform = (List) column.getProperties().get(PARTITIONING_PROPERTY); + UpdatePartitionSpec updatePartitionSpec = transaction.updateSpec(); + for (String transform : partitioningTransform) { + updatePartitionSpec = updatePartitionSpec.addField(getPartitionColumnName(column.getName(), transform), + getTransformTerm(column.getName(), transform)); + } + updatePartitionSpec.commit(); } transaction.commitTransaction(); } @@ -895,9 +1118,12 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan Table icebergTable = getIcebergTable(session, icebergTableHandle.getSchemaTableName()); Transaction transaction = icebergTable.newTransaction(); transaction.updateSchema().renameColumn(columnHandle.getName(), target).commit(); - if (icebergTable.spec().fields().stream().map(PartitionField::sourceId).anyMatch(sourceId -> sourceId == columnHandle.getId())) { - transaction.updateSpec().renameField(columnHandle.getName(), target).commit(); - } + icebergTable.spec().fields().stream() + .filter(field -> field.sourceId() == columnHandle.getId()) + .forEach(field -> { + String transform = field.transform().toString(); + transaction.updateSpec().renameField(field.name(), getPartitionColumnName(target, transform)).commit(); + }); transaction.commitTransaction(); } @@ -919,7 +1145,7 @@ public Map getColumnHandles(ConnectorSession session, Conn Table icebergTable = getIcebergTable(session, table.getSchemaTableName()); Schema schema; if (table.getIcebergTableName().getTableType() == CHANGELOG) { - schema = ChangelogUtil.changelogTableSchema(getRowTypeFromColumnMeta(getColumnMetadata(icebergTable))); + schema = ChangelogUtil.changelogTableSchema(getRowTypeFromColumnMeta(getColumnMetadata(session, icebergTable))); } else { schema = icebergTable.schema(); @@ -932,6 +1158,8 @@ public Map getColumnHandles(ConnectorSession session, Conn if (table.getIcebergTableName().getTableType() != CHANGELOG) { columnHandles.put(FILE_PATH.getColumnName(), PATH_COLUMN_HANDLE); columnHandles.put(DATA_SEQUENCE_NUMBER.getColumnName(), DATA_SEQUENCE_NUMBER_COLUMN_HANDLE); + columnHandles.put(IS_DELETED.getColumnName(), IS_DELETED_COLUMN_HANDLE); + columnHandles.put(DELETE_FILE_PATH.getColumnName(), DELETE_FILE_PATH_COLUMN_HANDLE); } return columnHandles.build(); } @@ -956,11 +1184,35 @@ public IcebergTableHandle getTableHandle(ConnectorSession session, SchemaTableNa verify(name.getTableType() == DATA || name.getTableType() == CHANGELOG || name.getTableType() == EQUALITY_DELETES, "Wrong table type: " + name.getTableType()); if (!tableExists(session, tableName)) { - return null; + // If table doesn't exist, check if it's a materialized view + return getMaterializedView(session, tableName).map(definition -> { + SchemaTableName storageTableName = new SchemaTableName(definition.getSchema(), definition.getTable()); + Table storageTable = getIcebergTable(session, storageTableName); + + // Time travel on the materialized view itself is not supported + if (tableVersion.isPresent() || name.getSnapshotId().isPresent()) { + throw new PrestoException(NOT_SUPPORTED, "Time travel queries on materialized views are not supported"); + } + + return new IcebergTableHandle( + storageTableName.getSchemaName(), + new IcebergTableName(storageTableName.getTableName(), name.getTableType(), Optional.empty(), Optional.empty()), + name.getSnapshotId().isPresent(), + tryGetLocation(storageTable), + tryGetProperties(storageTable), + tryGetSchema(storageTable).map(SchemaParser::toJson), + Optional.empty(), + Optional.empty(), + getSortFields(storageTable), + ImmutableList.of(), + Optional.of(tableName)); + }) + // null indicates table not found + .orElse(null); } - // use a new schema table name that omits the table type - Table table = getIcebergTable(session, new SchemaTableName(tableName.getSchemaName(), name.getTableName())); + SchemaTableName tableNameToLoad = new SchemaTableName(tableName.getSchemaName(), name.getTableName()); + Table table = getIcebergTable(session, tableNameToLoad); Optional tableSnapshotId = tableVersion .map(version -> { @@ -976,8 +1228,8 @@ public IcebergTableHandle getTableHandle(ConnectorSession session, SchemaTableNa Optional tableSchemaJson = tableSchema.map(SchemaParser::toJson); return new IcebergTableHandle( - tableName.getSchemaName(), - new IcebergTableName(name.getTableName(), name.getTableType(), tableSnapshotId, name.getChangelogEndSnapshot()), + tableNameToLoad.getSchemaName(), + new IcebergTableName(tableNameToLoad.getTableName(), name.getTableType(), tableSnapshotId, name.getChangelogEndSnapshot()), name.getSnapshotId().isPresent(), tryGetLocation(table), tryGetProperties(table), @@ -985,7 +1237,8 @@ public IcebergTableHandle getTableHandle(ConnectorSession session, SchemaTableNa Optional.empty(), Optional.empty(), getSortFields(table), - ImmutableList.of()); + ImmutableList.of(), + Optional.empty()); } @Override @@ -1017,6 +1270,46 @@ public void truncateTable(ConnectorSession session, ConnectorTableHandle tableHa removeScanFiles(icebergTable, TupleDomain.all()); } + @Override + public ConnectorDistributedProcedureHandle beginCallDistributedProcedure( + ConnectorSession session, + QualifiedObjectName procedureName, + ConnectorTableLayoutHandle tableLayoutHandle, + Object[] arguments) + { + IcebergTableHandle handle = ((IcebergTableLayoutHandle) tableLayoutHandle).getTable(); + Table icebergTable = getIcebergTable(session, handle.getSchemaTableName()); + + if (handle.isSnapshotSpecified()) { + throw new PrestoException(NOT_SUPPORTED, "This connector do not allow table execute at specified snapshot"); + } + + transaction = icebergTable.newTransaction(); + BaseProcedure procedure = procedureRegistry.resolve( + new ConnectorId(procedureName.getCatalogName()), + new SchemaTableName( + procedureName.getSchemaName(), + procedureName.getObjectName())); + verify(procedure instanceof DistributedProcedure, "procedure must be DistributedProcedure"); + procedureContext = Optional.of((IcebergProcedureContext) ((DistributedProcedure) procedure).createContext(icebergTable, transaction)); + return ((DistributedProcedure) procedure).begin(session, procedureContext.get(), tableLayoutHandle, arguments); + } + + @Override + public void finishCallDistributedProcedure(ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + BaseProcedure procedure = procedureRegistry.resolve( + new ConnectorId(procedureName.getCatalogName()), + new SchemaTableName( + procedureName.getSchemaName(), + procedureName.getObjectName())); + verify(procedure instanceof DistributedProcedure, "procedure must be DistributedProcedure"); + verify(procedureContext.isPresent(), "procedure context must be present"); + ((DistributedProcedure) procedure).finish(session, procedureContext.get(), procedureHandle, fragments); + transaction.commitTransaction(); + procedureContext = Optional.empty(); + } + @Override public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, ConnectorTableHandle tableHandle) { @@ -1041,7 +1334,7 @@ public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, Connecto } @Override - public void finishDelete(ConnectorSession session, ConnectorDeleteTableHandle tableHandle, Collection fragments) + public Optional finishDeleteWithOutput(ConnectorSession session, ConnectorDeleteTableHandle tableHandle, Collection fragments) { IcebergTableHandle handle = (IcebergTableHandle) tableHandle; Table icebergTable = getIcebergTable(session, handle.getSchemaTableName()); @@ -1084,6 +1377,7 @@ public void finishDelete(ConnectorSession session, ConnectorDeleteTableHandle ta rowDelta.commit(); transaction.commitTransaction(); + return Optional.empty(); } @Override @@ -1207,9 +1501,9 @@ private OptionalLong removeScanFiles(Table icebergTable, TupleDomain summary = icebergTable.currentSnapshot().summary(); - long deletedRecords = Long.parseLong(summary.getOrDefault(DELETED_RECORDS_PROP, "0")); - long removedPositionDeletes = Long.parseLong(summary.getOrDefault(REMOVED_POS_DELETES_PROP, "0")); - long removedEqualityDeletes = Long.parseLong(summary.getOrDefault(REMOVED_EQ_DELETES_PROP, "0")); + long deletedRecords = parseLong(summary.getOrDefault(DELETED_RECORDS_PROP, "0")); + long removedPositionDeletes = parseLong(summary.getOrDefault(REMOVED_POS_DELETES_PROP, "0")); + long removedEqualityDeletes = parseLong(summary.getOrDefault(REMOVED_EQ_DELETES_PROP, "0")); // Removed rows count is inaccurate when existing equality delete files return OptionalLong.of(deletedRecords - removedPositionDeletes - removedEqualityDeletes); } @@ -1261,7 +1555,7 @@ else if (tableVersion.getVersionExpressionType() instanceof VarcharType) { * @return A column handle for the Row ID update column. */ @Override - public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle, List updatedColumns) + public Optional getUpdateRowIdColumn(ConnectorSession session, ConnectorTableHandle tableHandle, List updatedColumns) { List unmodifiedColumns = new ArrayList<>(); unmodifiedColumns.add(ROW_POSITION); @@ -1277,7 +1571,7 @@ public ColumnHandle getUpdateRowIdColumnHandle(ConnectorSession session, Connect } } NestedField field = NestedField.required(UPDATE_ROW_DATA.getId(), UPDATE_ROW_DATA.getColumnName(), Types.StructType.of(unmodifiedColumns)); - return IcebergColumnHandle.create(field, typeManager, SYNTHESIZED); + return Optional.of(IcebergColumnHandle.create(field, typeManager, SYNTHESIZED)); } @Override @@ -1323,4 +1617,465 @@ protected Optional getDataLocationBasedOnWarehouseDataDir(SchemaTableNam { return Optional.empty(); } + + @Override + public Optional getInfo(ConnectorTableLayoutHandle tableHandle) + { + IcebergTableLayoutHandle icebergTableHandle = (IcebergTableLayoutHandle) tableHandle; + Optional outputPath = icebergTableHandle.getTable().getOutputPath(); + if (outputPath == null || !outputPath.isPresent()) { + return Optional.empty(); + } + return Optional.of(new IcebergInputInfo( + icebergTableHandle.getTable().getIcebergTableName().getSnapshotId(), + outputPath.get())); + } + + @Override + public void createMaterializedView( + ConnectorSession session, + ConnectorTableMetadata viewMetadata, + MaterializedViewDefinition viewDefinition, + boolean ignoreExisting) + { + try { + SchemaTableName viewName = viewMetadata.getTable(); + Map materializedViewProperties = viewMetadata.getProperties(); + + SchemaTableName storageTableName = getStorageTableName(session, viewName, materializedViewProperties); + + if (viewExists(session, viewMetadata)) { + if (ignoreExisting) { + return; + } + throw new PrestoException(ALREADY_EXISTS, "Materialized view " + viewName + " already exists"); + } + + ConnectorTableMetadata storageTableMetadata = new ConnectorTableMetadata( + storageTableName, + viewMetadata.getColumns(), + materializedViewProperties, + viewMetadata.getComment()); + createTable(session, storageTableMetadata, false); + + try { + Map properties = new HashMap<>(); + properties.put(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties.put(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL, viewDefinition.getOriginalSql()); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA, storageTableName.getSchemaName()); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME, storageTableName.getTableName()); + + String baseTablesStr = serializeSchemaTableNames(viewDefinition.getBaseTables()); + properties.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, baseTablesStr); + properties.put(PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS, serializeColumnMappings(viewDefinition.getColumnMappings())); + properties.put(PRESTO_MATERIALIZED_VIEW_OWNER, viewDefinition.getOwner() + .orElseThrow(() -> new PrestoException(INVALID_VIEW, "Materialized view owner is required"))); + properties.put(PRESTO_MATERIALIZED_VIEW_SECURITY_MODE, viewDefinition.getSecurityMode() + .orElseThrow(() -> new PrestoException(INVALID_VIEW, "Materialized view security mode is required (set legacy_materialized_views=false)")) + .name()); + + getStaleReadBehavior(materializedViewProperties) + .ifPresent(behavior -> properties.put(PRESTO_MATERIALIZED_VIEW_STALE_READ_BEHAVIOR, behavior.name())); + getStalenessWindow(materializedViewProperties) + .ifPresent(window -> properties.put(PRESTO_MATERIALIZED_VIEW_STALENESS_WINDOW, window.toString())); + MaterializedViewRefreshType refreshType = getRefreshType(materializedViewProperties); + properties.put(PRESTO_MATERIALIZED_VIEW_REFRESH_TYPE, refreshType.name()); + + for (SchemaTableName baseTable : viewDefinition.getBaseTables()) { + properties.put(getBaseTableViewPropertyName(baseTable), "0"); + } + + createIcebergView(session, viewName, viewMetadata.getColumns(), viewDefinition.getOriginalSql(), properties); + } + catch (Exception e) { + try { + dropStorageTable(session, storageTableName); + } + catch (Exception cleanupException) { + e.addSuppressed(cleanupException); + } + throw e; + } + } + catch (PrestoException e) { + if (e.getErrorCode() == NOT_SUPPORTED.toErrorCode()) { + throw new PrestoException(NOT_SUPPORTED, "Materialized views are not supported with this catalog type", e); + } + throw e; + } + } + + private void dropStorageTable(ConnectorSession session, SchemaTableName storageTableName) + { + ConnectorTableHandle storageTableHandle = getTableHandle(session, storageTableName); + if (storageTableHandle != null) { + dropTable(session, storageTableHandle); + } + } + + @Override + public List listMaterializedViews(ConnectorSession session, String schemaName) + { + ImmutableList.Builder materializedViews = ImmutableList.builder(); + + List views = listViews(session, Optional.of(schemaName)); + + for (SchemaTableName viewName : views) { + View icebergView = getIcebergView(session, viewName); + Map properties = icebergView.properties(); + if (properties.containsKey(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION)) { + materializedViews.add(viewName); + } + } + + return materializedViews.build(); + } + + @Override + public Optional getMaterializedView(ConnectorSession session, SchemaTableName viewName) + { + try { + View icebergView = getIcebergView(session, viewName); + + Map viewProperties = icebergView.properties(); + String originalSql = viewProperties.get(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL); + + if (originalSql == null) { + return Optional.empty(); + } + + // Validate format version + String formatVersion = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION); + int version; + try { + version = Integer.parseInt(formatVersion); + } + catch (NumberFormatException e) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, + format("Invalid materialized view format version: %s", formatVersion)); + } + + if (version != CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, + format("Materialized view format version %d is not supported by this version of Presto (current version: %d). Please upgrade Presto.", + version, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION)); + } + + String baseTablesStr = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_BASE_TABLES); + List baseTables; + if (baseTablesStr.isEmpty()) { + baseTables = ImmutableList.of(); + } + else { + baseTables = deserializeSchemaTableNames(baseTablesStr); + } + + String columnMappingsJson = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS); + List columnMappings = deserializeColumnMappings(columnMappingsJson); + + String storageSchema = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA); + String storageTableName = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME); + + String owner = getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_OWNER); + ViewSecurity securityMode; + try { + securityMode = ViewSecurity.valueOf(getRequiredMaterializedViewProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_SECURITY_MODE)); + } + catch (IllegalArgumentException | NullPointerException e) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, "Invalid or missing materialized view security mode"); + } + + // Parse staleness config - staleness window defaults to 0s if behavior is set + Optional staleReadBehavior = getOptionalEnumProperty( + viewProperties, PRESTO_MATERIALIZED_VIEW_STALE_READ_BEHAVIOR, MaterializedViewStaleReadBehavior.class); + Optional stalenessWindow = getOptionalDurationProperty(viewProperties, PRESTO_MATERIALIZED_VIEW_STALENESS_WINDOW); + + Optional stalenessConfig = Optional.empty(); + if (staleReadBehavior.isPresent()) { + stalenessConfig = Optional.of(new MaterializedViewStalenessConfig( + staleReadBehavior.get(), + stalenessWindow.orElse(new Duration(0, TimeUnit.SECONDS)))); + } + + Optional refreshType = getOptionalEnumProperty( + viewProperties, PRESTO_MATERIALIZED_VIEW_REFRESH_TYPE, MaterializedViewRefreshType.class); + + return Optional.of(new MaterializedViewDefinition( + originalSql, + storageSchema, + storageTableName, + baseTables, + Optional.of(owner), + Optional.of(securityMode), + columnMappings, + ImmutableList.of(), + Optional.empty(), + stalenessConfig, + refreshType)); + } + catch (NoSuchViewException e) { + return Optional.empty(); + } + catch (PrestoException e) { + if (e.getErrorCode() == NOT_SUPPORTED.toErrorCode()) { + return Optional.empty(); + } + throw e; + } + } + + @Override + public void dropMaterializedView(ConnectorSession session, SchemaTableName viewName) + { + Optional definition = getMaterializedView(session, viewName); + + if (definition.isPresent()) { + dropIcebergView(session, viewName); + SchemaTableName storageTableName = new SchemaTableName( + definition.get().getSchema(), + definition.get().getTable()); + ConnectorTableHandle storageTableHandle = getTableHandle(session, storageTableName); + if (storageTableHandle != null) { + dropTable(session, storageTableHandle); + } + } + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus( + ConnectorSession session, + SchemaTableName materializedViewName, + TupleDomain baseQueryDomain) + { + Optional definition = getMaterializedView(session, materializedViewName); + if (definition.isEmpty()) { + return new MaterializedViewStatus(NOT_MATERIALIZED, ImmutableMap.of()); + } + + View icebergView = getIcebergView(session, materializedViewName); + Map props = icebergView.properties(); + String lastRefreshSnapshotStr = props.get(PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID); + if (lastRefreshSnapshotStr == null) { + return new MaterializedViewStatus(NOT_MATERIALIZED, ImmutableMap.of()); + } + + SchemaTableName storageTableName = new SchemaTableName(definition.get().getSchema(), definition.get().getTable()); + Table storageTable = getIcebergTable(session, storageTableName); + long lastRefreshSnapshotId = parseLong(lastRefreshSnapshotStr); + Snapshot lastRefreshSnapshot = storageTable.snapshot(lastRefreshSnapshotId); + if (lastRefreshSnapshot == null) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, + format("Storage table snapshot %d not found for materialized view %s. " + + "The snapshot may have been expired. Consider refreshing the view.", + lastRefreshSnapshotId, materializedViewName)); + } + Optional lastFreshTime = Optional.of(lastRefreshSnapshot.timestampMillis()); + + boolean isStale = false; + for (SchemaTableName baseTable : definition.get().getBaseTables()) { + Table baseIcebergTable = getIcebergTable(session, baseTable); + long currentSnapshotId = baseIcebergTable.currentSnapshot() != null + ? baseIcebergTable.currentSnapshot().snapshotId() + : 0L; + + String key = getBaseTableViewPropertyName(baseTable); + String recordedSnapshotStr = props.get(key); + if (recordedSnapshotStr == null) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, + format("Missing base table snapshot property for %s in materialized view %s", baseTable, materializedViewName)); + } + long recordedSnapshotId = parseLong(recordedSnapshotStr); + + if (currentSnapshotId != recordedSnapshotId) { + isStale = true; + break; + } + } + + if (isStale) { + return new MaterializedViewStatus( + PARTIALLY_MATERIALIZED, + ImmutableMap.of(), + lastFreshTime); + } + + return new MaterializedViewStatus( + FULLY_MATERIALIZED, + ImmutableMap.of(), + lastFreshTime); + } + + @Override + public ConnectorInsertTableHandle beginRefreshMaterializedView( + ConnectorSession session, + ConnectorTableHandle tableHandle) + { + IcebergTableHandle icebergTableHandle = (IcebergTableHandle) tableHandle; + + if (icebergTableHandle.getMaterializedViewName().isEmpty()) { + throw new IllegalStateException(format( + "beginRefreshMaterializedView called on non-materialized view table: %s", + icebergTableHandle.getSchemaTableName())); + } + + SchemaTableName storageTableName = icebergTableHandle.getSchemaTableName(); + IcebergTableHandle storageTableHandle = getTableHandle(session, storageTableName); + Table storageTable = getIcebergTable(session, storageTableName); + + transaction = storageTable.newTransaction(); + + transaction.newDelete().deleteFromRowFilter(alwaysTrue()).commit(); + + SchemaTableName materializedViewName = icebergTableHandle.getMaterializedViewName().get(); + + return new IcebergInsertTableHandle( + storageTableHandle.getSchemaName(), + storageTableHandle.getIcebergTableName(), + toPrestoSchema(storageTable.schema(), typeManager), + toPrestoPartitionSpec(storageTable.spec(), typeManager), + getColumns(storageTable.schema(), storageTable.spec(), typeManager), + storageTable.location(), + getFileFormat(storageTable), + getCompressionCodec(session), + storageTable.properties(), + getSupportedSortFields(storageTable.schema(), storageTable.sortOrder()), + Optional.of(materializedViewName)); + } + + @Override + public Optional finishRefreshMaterializedView( + ConnectorSession session, + ConnectorInsertTableHandle insertHandle, + Collection fragments, + Collection computedStatistics) + { + Optional result = finishInsert(session, insertHandle, fragments, computedStatistics); + + IcebergInsertTableHandle icebergInsertHandle = (IcebergInsertTableHandle) insertHandle; + + icebergInsertHandle.getMaterializedViewName().ifPresent(materializedViewName -> { + SchemaTableName storageTableName = new SchemaTableName( + icebergInsertHandle.getSchemaName(), + icebergInsertHandle.getTableName().getTableName()); + + Table storageTable = getIcebergTable(session, storageTableName); + long newSnapshotId = storageTable.currentSnapshot() != null + ? storageTable.currentSnapshot().snapshotId() + : 0L; + + Optional definition = getMaterializedView(session, materializedViewName); + Map properties = new HashMap<>(); + properties.put(PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID, String.valueOf(newSnapshotId)); + + if (definition.isPresent()) { + for (SchemaTableName baseTable : definition.get().getBaseTables()) { + try { + Table baseIcebergTable = getIcebergTable(session, baseTable); + long baseSnapshotId = baseIcebergTable.currentSnapshot() != null + ? baseIcebergTable.currentSnapshot().snapshotId() + : 0L; + String key = getBaseTableViewPropertyName(baseTable); + properties.put(key, baseSnapshotId + ""); + } + catch (Exception e) { + log.warn(e, "Failed to capture snapshot for base table %s during refresh of materialized view %s", baseTable, materializedViewName); + } + } + } + + updateIcebergViewProperties(session, materializedViewName, properties); + }); + + return result; + } + + private SchemaTableName getStorageTableName(ConnectorSession session, SchemaTableName viewName, Map properties) + { + String tableName = getStorageTable(properties).orElseGet(() -> { + // Generate default storage table name using prefix + return getMaterializedViewStoragePrefix(session) + viewName.getTableName(); + }); + String schema = getStorageSchema(properties) + .orElse(viewName.getSchemaName()); + return new SchemaTableName(schema, tableName); + } + + private String serializeColumnMappings(List columnMappings) + { + return columnMappingsCodec.toJson(columnMappings); + } + + private List deserializeColumnMappings(String json) + { + return columnMappingsCodec.fromJson(json); + } + + private String serializeSchemaTableNames(List schemaTableNames) + { + return schemaTableNamesCodec.toJson(schemaTableNames); + } + + private List deserializeSchemaTableNames(String json) + { + try { + return schemaTableNamesCodec.fromJson(json); + } + catch (IllegalArgumentException e) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, + format("Invalid base table name format: %s. Cause: %s", json, e.getMessage()), e); + } + } + + private static String getBaseTableViewPropertyName(SchemaTableName baseTable) + { + return format("%s%s.%s", PRESTO_MATERIALIZED_VIEW_BASE_SNAPSHOT_PREFIX, baseTable.getSchemaName(), baseTable.getTableName()); + } + + private static String getRequiredMaterializedViewProperty(Map viewProperties, String propertyKey) + { + String value = viewProperties.get(propertyKey); + if (value == null) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, format("Materialized view missing required property: %s", propertyKey)); + } + return value; + } + + private static > Optional getOptionalEnumProperty(Map viewProperties, String propertyKey, Class enumClass) + { + String value = viewProperties.get(propertyKey); + if (value == null) { + return Optional.empty(); + } + try { + return Optional.of(Enum.valueOf(enumClass, value)); + } + catch (IllegalArgumentException e) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, format("Invalid materialized view property %s: %s", propertyKey, value)); + } + } + + private static Optional getOptionalDurationProperty(Map viewProperties, String propertyKey) + { + String value = viewProperties.get(propertyKey); + if (value == null) { + return Optional.empty(); + } + try { + return Optional.of(Duration.valueOf(value)); + } + catch (IllegalArgumentException e) { + throw new PrestoException(ICEBERG_INVALID_MATERIALIZED_VIEW, format("Invalid materialized view property %s: %s", propertyKey, value)); + } + } + + private boolean viewExists(ConnectorSession session, ConnectorTableMetadata viewMetadata) + { + try { + getIcebergView(session, viewMetadata.getTable()); + return true; + } + catch (NoSuchViewException e) { + return false; + } + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java index fa912f60a3f8c..3afa99d710d72 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergColumnHandle.java @@ -34,7 +34,10 @@ import static com.facebook.presto.iceberg.ColumnIdentity.createColumnIdentity; import static com.facebook.presto.iceberg.ColumnIdentity.primitiveColumnIdentity; import static com.facebook.presto.iceberg.IcebergMetadataColumn.DATA_SEQUENCE_NUMBER; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.DELETE_FILE_PATH; import static com.facebook.presto.iceberg.IcebergMetadataColumn.FILE_PATH; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.IS_DELETED; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_TARGET_ROW_ID_DATA; import static com.facebook.presto.iceberg.IcebergMetadataColumn.UPDATE_ROW_DATA; import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; import static com.google.common.base.Preconditions.checkArgument; @@ -50,6 +53,10 @@ public class IcebergColumnHandle public static final ColumnMetadata PATH_COLUMN_METADATA = getColumnMetadata(FILE_PATH); public static final IcebergColumnHandle DATA_SEQUENCE_NUMBER_COLUMN_HANDLE = getIcebergColumnHandle(DATA_SEQUENCE_NUMBER); public static final ColumnMetadata DATA_SEQUENCE_NUMBER_COLUMN_METADATA = getColumnMetadata(DATA_SEQUENCE_NUMBER); + public static final IcebergColumnHandle IS_DELETED_COLUMN_HANDLE = getIcebergColumnHandle(IS_DELETED); + public static final ColumnMetadata IS_DELETED_COLUMN_METADATA = getColumnMetadata(IS_DELETED); + public static final IcebergColumnHandle DELETE_FILE_PATH_COLUMN_HANDLE = getIcebergColumnHandle(DELETE_FILE_PATH); + public static final ColumnMetadata DELETE_FILE_PATH_COLUMN_METADATA = getColumnMetadata(DELETE_FILE_PATH); private final ColumnIdentity columnIdentity; private final Type type; @@ -103,6 +110,12 @@ public boolean isUpdateRowIdColumn() return columnIdentity.getId() == UPDATE_ROW_DATA.getId(); } + @JsonIgnore + public boolean isMergeTargetTableRowIdColumn() + { + return columnIdentity.getId() == MERGE_TARGET_ROW_ID_DATA.getId(); + } + @Override public ColumnHandle withRequiredSubfields(List subfields) { @@ -180,6 +193,16 @@ public boolean isDataSequenceNumberColumn() return getColumnIdentity().getId() == DATA_SEQUENCE_NUMBER.getId(); } + public boolean isDeletedColumn() + { + return getColumnIdentity().getId() == IS_DELETED.getId(); + } + + public boolean isDeleteFilePathColumn() + { + return getColumnIdentity().getId() == DELETE_FILE_PATH.getId(); + } + public static IcebergColumnHandle primitiveIcebergColumnHandle(int id, String name, Type type, Optional comment) { return new IcebergColumnHandle(primitiveColumnIdentity(id, name), type, comment, REGULAR); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java index 129a3fda539f8..e86b26146da86 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergCommonModule.java @@ -41,7 +41,6 @@ import com.facebook.presto.hive.gcs.GcsConfigurationInitializer; import com.facebook.presto.hive.gcs.HiveGcsConfig; import com.facebook.presto.hive.gcs.HiveGcsConfigurationInitializer; -import com.facebook.presto.hive.metastore.InvalidateMetastoreCacheProcedure; import com.facebook.presto.iceberg.nessie.IcebergNessieConfig; import com.facebook.presto.iceberg.optimizer.IcebergPlanOptimizerProvider; import com.facebook.presto.iceberg.procedure.ExpireSnapshotsProcedure; @@ -49,6 +48,8 @@ import com.facebook.presto.iceberg.procedure.ManifestFileCacheInvalidationProcedure; import com.facebook.presto.iceberg.procedure.RegisterTableProcedure; import com.facebook.presto.iceberg.procedure.RemoveOrphanFiles; +import com.facebook.presto.iceberg.procedure.RewriteDataFilesProcedure; +import com.facebook.presto.iceberg.procedure.RewriteManifestsProcedure; import com.facebook.presto.iceberg.procedure.RollbackToSnapshotProcedure; import com.facebook.presto.iceberg.procedure.RollbackToTimestampProcedure; import com.facebook.presto.iceberg.procedure.SetCurrentSnapshotProcedure; @@ -78,23 +79,25 @@ import com.facebook.presto.parquet.cache.ParquetCacheConfig; import com.facebook.presto.parquet.cache.ParquetFileMetadata; import com.facebook.presto.parquet.cache.ParquetMetadataSource; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; -import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.BaseProcedure; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.inject.Binder; import com.google.inject.Provides; import com.google.inject.Scopes; +import com.google.inject.TypeLiteral; import com.google.inject.multibindings.Multibinder; +import jakarta.inject.Singleton; import org.weakref.jmx.MBeanExporter; -import javax.inject.Singleton; - import java.nio.ByteBuffer; import java.time.Duration; import java.util.Optional; @@ -155,6 +158,7 @@ protected void setup(Binder binder) newOptionalBinder(binder, IcebergNessieConfig.class); // bind optional Nessie config to IcebergSessionProperties binder.bind(IcebergTableProperties.class).in(Scopes.SINGLETON); + binder.bind(IcebergMaterializedViewProperties.class).in(Scopes.SINGLETON); binder.bind(ConnectorSplitManager.class).to(IcebergSplitManager.class).in(Scopes.SINGLETON); newExporter(binder).export(ConnectorSplitManager.class).as(generatedNameOf(IcebergSplitManager.class, connectorId)); @@ -168,6 +172,8 @@ protected void setup(Binder binder) configBinder(binder).bindConfig(ParquetFileWriterConfig.class); jsonCodecBinder(binder).bindJsonCodec(CommitTaskData.class); + jsonCodecBinder(binder).bindListJsonCodec(MaterializedViewDefinition.ColumnMapping.class); + jsonCodecBinder(binder).bindListJsonCodec(SchemaTableName.class); binder.bind(FileFormatDataSourceStats.class).in(Scopes.SINGLETON); newExporter(binder).export(FileFormatDataSourceStats.class).withGeneratedName(); @@ -175,7 +181,7 @@ protected void setup(Binder binder) binder.bind(IcebergFileWriterFactory.class).in(Scopes.SINGLETON); newExporter(binder).export(IcebergFileWriterFactory.class).withGeneratedName(); - Multibinder procedures = newSetBinder(binder, Procedure.class); + Multibinder> procedures = newSetBinder(binder, new TypeLiteral>() {}); procedures.addBinding().toProvider(RollbackToSnapshotProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(RollbackToTimestampProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(RegisterTableProcedure.class).in(Scopes.SINGLETON); @@ -187,10 +193,8 @@ protected void setup(Binder binder) procedures.addBinding().toProvider(SetTablePropertyProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(StatisticsFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); procedures.addBinding().toProvider(ManifestFileCacheInvalidationProcedure.class).in(Scopes.SINGLETON); - - if (buildConfigObject(MetastoreClientConfig.class).isInvalidateMetastoreCacheProcedureEnabled()) { - procedures.addBinding().toProvider(InvalidateMetastoreCacheProcedure.class).in(Scopes.SINGLETON); - } + procedures.addBinding().toProvider(RewriteDataFilesProcedure.class).in(Scopes.SINGLETON); + procedures.addBinding().toProvider(RewriteManifestsProcedure.class).in(Scopes.SINGLETON); // for orc binder.bind(EncryptionLibrary.class).annotatedWith(HiveDwrfEncryptionProvider.ForCryptoService.class).to(UnsupportedEncryptionLibrary.class).in(Scopes.SINGLETON); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java index a90c7c888e6bf..346cc3c890e74 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConfig.java @@ -15,28 +15,28 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.configuration.LegacyConfig; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import org.apache.iceberg.hadoop.HadoopFileIO; -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Max; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; - import java.util.EnumSet; import java.util.List; -import static com.facebook.presto.hive.HiveCompressionCodec.GZIP; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctDataSize; +import static com.facebook.presto.hive.HiveCompressionCodec.ZSTD; import static com.facebook.presto.iceberg.CatalogType.HIVE; import static com.facebook.presto.iceberg.IcebergFileFormat.PARQUET; import static com.facebook.presto.iceberg.util.StatisticsUtil.decodeMergeFlags; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctDataSize; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_EXPIRATION_INTERVAL_MS_DEFAULT; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_MAX_CONTENT_LENGTH_DEFAULT; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_MAX_TOTAL_BYTES_DEFAULT; @@ -47,7 +47,7 @@ public class IcebergConfig { private IcebergFileFormat fileFormat = PARQUET; - private HiveCompressionCodec compressionCodec = GZIP; + private HiveCompressionCodec compressionCodec = ZSTD; private CatalogType catalogType = HIVE; private String catalogWarehouse; private String catalogWarehouseDataDir; @@ -60,6 +60,7 @@ public class IcebergConfig private double statisticSnapshotRecordDifferenceWeight; private boolean pushdownFilterEnabled; private boolean deleteAsJoinRewriteEnabled = true; + private int deleteAsJoinRewriteMaxDeleteColumns = 400; private int rowsForMetadataOptimizationThreshold = 1000; private int metadataPreviousVersionsMax = METADATA_PREVIOUS_VERSIONS_MAX_DEFAULT; private boolean metadataDeleteAfterCommit = METADATA_DELETE_AFTER_COMMIT_ENABLED_DEFAULT; @@ -75,6 +76,7 @@ public class IcebergConfig private DataSize manifestCacheMaxChunkSize = succinctDataSize(2, MEGABYTE); private int splitManagerThreads = Runtime.getRuntime().availableProcessors(); private DataSize maxStatisticsFileCacheSize = succinctDataSize(256, MEGABYTE); + private String materializedViewStoragePrefix = "__mv_storage__"; @NotNull public FileFormat getFileFormat() @@ -267,19 +269,39 @@ public boolean isPushdownFilterEnabled() return pushdownFilterEnabled; } - @Config("iceberg.delete-as-join-rewrite-enabled") - @ConfigDescription("When enabled, equality delete row filtering will be implemented by rewriting the query plan to join with the delete keys.") + @LegacyConfig(value = "iceberg.delete-as-join-rewrite-enabled") + @Config("deprecated.iceberg.delete-as-join-rewrite-enabled") + @ConfigDescription("When enabled, equality delete row filtering will be implemented by rewriting the query plan to join with the delete keys. " + + "Deprecated: Set 'iceberg.delete-as-join-rewrite-max-delete-columns' to 0 to control the enabling of this feature. This will be removed in a future release.") + @Deprecated public IcebergConfig setDeleteAsJoinRewriteEnabled(boolean deleteAsJoinPushdownEnabled) { this.deleteAsJoinRewriteEnabled = deleteAsJoinPushdownEnabled; return this; } + @Deprecated public boolean isDeleteAsJoinRewriteEnabled() { return deleteAsJoinRewriteEnabled; } + @Config("iceberg.delete-as-join-rewrite-max-delete-columns") + @ConfigDescription("The maximum number of columns that can be used in a delete as join rewrite. " + + "If the number of columns exceeds this value, the delete as join rewrite will not be applied.") + @Min(0) + @Max(400) + public IcebergConfig setDeleteAsJoinRewriteMaxDeleteColumns(int deleteAsJoinRewriteMaxDeleteColumns) + { + this.deleteAsJoinRewriteMaxDeleteColumns = deleteAsJoinRewriteMaxDeleteColumns; + return this; + } + + public int getDeleteAsJoinRewriteMaxDeleteColumns() + { + return deleteAsJoinRewriteMaxDeleteColumns; + } + @Config("iceberg.rows-for-metadata-optimization-threshold") @ConfigDescription("The max partitions number to utilize metadata optimization. 0 means skip the metadata optimization directly.") public IcebergConfig setRowsForMetadataOptimizationThreshold(int rowsForMetadataOptimizationThreshold) @@ -458,4 +480,20 @@ public IcebergConfig setStatisticsKllSketchKParameter(int kllSketchKParameter) this.statisticsKllSketchKParameter = kllSketchKParameter; return this; } + + @NotNull + public String getMaterializedViewStoragePrefix() + { + return materializedViewStoragePrefix; + } + + @Config("iceberg.materialized-view-storage-prefix") + @ConfigDescription("Default prefix for generated materialized view storage table names. " + + "This is only used when the storage_table table property is not explicitly set. " + + "When a custom table name is provided, it takes precedence over this prefix.") + public IcebergConfig setMaterializedViewStoragePrefix(String materializedViewStoragePrefix) + { + this.materializedViewStoragePrefix = materializedViewStoragePrefix; + return this; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java index b4976dee3d2d3..c398001e6a75b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergConnector.java @@ -15,6 +15,8 @@ import com.facebook.airlift.bootstrap.LifeCycleManager; import com.facebook.presto.hive.HiveTransactionHandle; +import com.facebook.presto.iceberg.function.IcebergBucketFunction; +import com.facebook.presto.iceberg.function.changelog.ApplyChangelogFunction; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.Connector; @@ -29,6 +31,8 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorMetadata; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; @@ -37,6 +41,7 @@ import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.spi.connector.ConnectorCapabilities.NOT_NULL_COLUMN_CONSTRAINT; import static com.facebook.presto.spi.connector.EmptyConnectorCommitHandle.INSTANCE; @@ -59,9 +64,10 @@ public class IcebergConnector private final List> sessionProperties; private final List> schemaProperties; private final List> tableProperties; + private final List> materializedViewProperties; private final List> columnProperties; private final ConnectorAccessControl accessControl; - private final Set procedures; + private final Set> procedures; private final ConnectorPlanOptimizerProvider planOptimizerProvider; public IcebergConnector( @@ -76,9 +82,10 @@ public IcebergConnector( List> sessionProperties, List> schemaProperties, List> tableProperties, + List> materializedViewProperties, List> columnProperties, ConnectorAccessControl accessControl, - Set procedures, + Set> procedures, ConnectorPlanOptimizerProvider planOptimizerProvider) { this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); @@ -92,6 +99,7 @@ public IcebergConnector( this.sessionProperties = ImmutableList.copyOf(requireNonNull(sessionProperties, "sessionProperties is null")); this.schemaProperties = ImmutableList.copyOf(requireNonNull(schemaProperties, "schemaProperties is null")); this.tableProperties = ImmutableList.copyOf(requireNonNull(tableProperties, "tableProperties is null")); + this.materializedViewProperties = ImmutableList.copyOf(requireNonNull(materializedViewProperties, "materializedViewProperties is null")); this.columnProperties = ImmutableList.copyOf(requireNonNull(columnProperties, "columnProperties is null")); this.accessControl = requireNonNull(accessControl, "accessControl is null"); this.procedures = requireNonNull(procedures, "procedures is null"); @@ -150,7 +158,13 @@ public Set getSystemTables() @Override public Set getProcedures() { - return procedures; + return getProcedures(Procedure.class); + } + + @Override + public Set getDistributedProcedures() + { + return getProcedures(DistributedProcedure.class); } @Override @@ -177,6 +191,12 @@ public List> getColumnProperties() return columnProperties; } + @Override + public List> getMaterializedViewProperties() + { + return materializedViewProperties; + } + @Override public ConnectorAccessControl getAccessControl() { @@ -218,4 +238,21 @@ public ConnectorPlanOptimizerProvider getConnectorPlanOptimizerProvider() { return planOptimizerProvider; } + + @Override + public Set> getSystemFunctions() + { + return ImmutableSet.>builder() + .add(ApplyChangelogFunction.class) + .add(IcebergBucketFunction.class) + .add(IcebergBucketFunction.Bucket.class) + .build(); + } + + private > Set getProcedures(Class clazz) + { + return procedures.stream().filter(clazz::isInstance) + .map(clazz::cast) + .collect(Collectors.toSet()); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java new file mode 100644 index 0000000000000..28eec12680077 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergDistributedProcedureHandle.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.hive.HiveCompressionCodec; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class IcebergDistributedProcedureHandle + extends IcebergWritableTableHandle + implements ConnectorDistributedProcedureHandle +{ + private final IcebergTableLayoutHandle tableLayoutHandle; + private final Map relevantData; + + @JsonCreator + public IcebergDistributedProcedureHandle( + @JsonProperty("schemaName") String schemaName, + @JsonProperty("tableName") IcebergTableName tableName, + @JsonProperty("schema") PrestoIcebergSchema schema, + @JsonProperty("partitionSpec") PrestoIcebergPartitionSpec partitionSpec, + @JsonProperty("inputColumns") List inputColumns, + @JsonProperty("outputPath") String outputPath, + @JsonProperty("fileFormat") FileFormat fileFormat, + @JsonProperty("compressionCodec") HiveCompressionCodec compressionCodec, + @JsonProperty("storageProperties") Map storageProperties, + @JsonProperty("tableLayoutHandle") IcebergTableLayoutHandle tableLayoutHandle, + @JsonProperty("sortOrder") List sortOrder, + @JsonProperty("relevantData") Map relevantData) + { + super( + schemaName, + tableName, + schema, + partitionSpec, + inputColumns, + outputPath, + fileFormat, + compressionCodec, + storageProperties, + sortOrder); + this.tableLayoutHandle = requireNonNull(tableLayoutHandle, "tableLayoutHandle is null"); + this.relevantData = requireNonNull(relevantData, "relevantData is null"); + } + + @JsonProperty + public IcebergTableLayoutHandle getTableLayoutHandle() + { + return tableLayoutHandle; + } + + @JsonProperty + public Map getRelevantData() + { + return relevantData; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java index c833a65eaf62f..3a5f76ab5ebad 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java @@ -24,7 +24,6 @@ public enum IcebergErrorCode implements ErrorCodeSupplier { - ICEBERG_UNKNOWN_TABLE_TYPE(0, EXTERNAL), ICEBERG_INVALID_METADATA(1, EXTERNAL), ICEBERG_TOO_MANY_OPEN_PARTITIONS(2, USER_ERROR), ICEBERG_INVALID_PARTITION_VALUE(3, EXTERNAL), @@ -40,7 +39,10 @@ public enum IcebergErrorCode ICEBERG_INVALID_FORMAT_VERSION(14, USER_ERROR), ICEBERG_UNKNOWN_MANIFEST_TYPE(15, EXTERNAL), ICEBERG_COMMIT_ERROR(16, EXTERNAL), - ICEBERG_MISSING_COLUMN(17, EXTERNAL); + ICEBERG_MISSING_COLUMN(17, EXTERNAL), + ICEBERG_INVALID_MATERIALIZED_VIEW(18, EXTERNAL), + ICEBERG_INVALID_SPEC_ID(19, EXTERNAL), + /**/; private final ErrorCode errorCode; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergFileWriterFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergFileWriterFactory.java index 462deae4f146d..62c02c881a950 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergFileWriterFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergFileWriterFactory.java @@ -33,6 +33,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; @@ -40,8 +41,6 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.types.Types; -import javax.inject.Inject; - import java.io.IOException; import java.util.List; import java.util.Optional; @@ -161,7 +160,7 @@ private IcebergFileWriter createParquetWriter( makeTypeMap(fileColumnTypes, fileColumnNames), parquetWriterOptions, IntStream.range(0, fileColumnNames.size()).toArray(), - getCompressionCodec(session).getParquetCompressionCodec().get(), + getCompressionCodec(session).getParquetCompressionCodec(), outputPath, hdfsEnvironment, hdfsContext, diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java index 199939c6b7985..fbb24b55b577d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHandleResolver.java @@ -16,8 +16,10 @@ import com.facebook.presto.hive.HiveTransactionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -63,12 +65,23 @@ public Class getInsertTableHandleClass() return IcebergInsertTableHandle.class; } + public Class getMergeTableHandleClass() + { + return IcebergMergeTableHandle.class; + } + @Override public Class getDeleteTableHandleClass() { return IcebergTableHandle.class; } + @Override + public Class getDistributedProcedureHandleClass() + { + return IcebergDistributedProcedureHandle.class; + } + @Override public Class getTransactionHandleClass() { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java index 66c026595d2bf..4bf0768bf66bf 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java @@ -24,6 +24,7 @@ import com.facebook.presto.hive.HiveTypeTranslator; import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.hive.TableAlreadyExistsException; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.hive.ViewAlreadyExistsException; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Database; @@ -36,14 +37,17 @@ import com.facebook.presto.hive.metastore.Table; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSystemConfig; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.ConnectorViewDefinition; import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaNotFoundException; import com.facebook.presto.spi.SchemaTableName; @@ -52,6 +56,7 @@ import com.facebook.presto.spi.ViewNotFoundException; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; @@ -117,9 +122,12 @@ import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA; import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; import static com.facebook.presto.iceberg.IcebergSessionProperties.getHiveStatisticsMergeStrategy; +import static com.facebook.presto.iceberg.IcebergTableProperties.getPartitioning; +import static com.facebook.presto.iceberg.IcebergTableProperties.getSortOrder; +import static com.facebook.presto.iceberg.IcebergTableProperties.getTableLocation; import static com.facebook.presto.iceberg.IcebergTableType.DATA; import static com.facebook.presto.iceberg.IcebergUtil.createIcebergViewProperties; -import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getColumnsForWrite; import static com.facebook.presto.iceberg.IcebergUtil.getHiveIcebergTable; import static com.facebook.presto.iceberg.IcebergUtil.isIcebergTable; import static com.facebook.presto.iceberg.IcebergUtil.populateTableProperties; @@ -156,33 +164,43 @@ public class IcebergHiveMetadata { public static final int MAXIMUM_PER_QUERY_TABLE_CACHE_SIZE = 1000; + private final IcebergCatalogName catalogName; private final ExtendedHiveMetastore metastore; private final HdfsEnvironment hdfsEnvironment; private final DateTimeZone timeZone = DateTimeZone.forTimeZone(TimeZone.getTimeZone(ZoneId.of(TimeZone.getDefault().getID()))); - private final IcebergHiveTableOperationsConfig hiveTableOeprationsConfig; + private final IcebergHiveTableOperationsConfig hiveTableOperationsConfig; + private final ConnectorSystemConfig connectorSystemConfig; private final Cache> tableCache; private final ManifestFileCache manifestFileCache; public IcebergHiveMetadata( + IcebergCatalogName catalogName, ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, + JsonCodec> columnMappingsCodec, + JsonCodec> schemaTableNamesCodec, NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, - IcebergHiveTableOperationsConfig hiveTableOeprationsConfig, + IcebergHiveTableOperationsConfig hiveTableOperationsConfig, StatisticsFileCache statisticsFileCache, ManifestFileCache manifestFileCache, - IcebergTableProperties tableProperties) + IcebergTableProperties tableProperties, + ConnectorSystemConfig connectorSystemConfig) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, schemaTableNamesCodec, + nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); - this.hiveTableOeprationsConfig = requireNonNull(hiveTableOeprationsConfig, "hiveTableOperationsConfig is null"); + this.hiveTableOperationsConfig = requireNonNull(hiveTableOperationsConfig, "hiveTableOperationsConfig is null"); this.tableCache = CacheBuilder.newBuilder().maximumSize(MAXIMUM_PER_QUERY_TABLE_CACHE_SIZE).build(); this.manifestFileCache = requireNonNull(manifestFileCache, "manifestFileCache is null"); + this.connectorSystemConfig = requireNonNull(connectorSystemConfig, "connectorSystemConfig is null"); } public ExtendedHiveMetastore getMetastore() @@ -206,7 +224,7 @@ public boolean schemaExists(ConnectorSession session, String schemaName) @Override protected org.apache.iceberg.Table getRawIcebergTable(ConnectorSession session, SchemaTableName schemaTableName) { - return getHiveIcebergTable(metastore, hdfsEnvironment, hiveTableOeprationsConfig, manifestFileCache, session, schemaTableName); + return getHiveIcebergTable(metastore, hdfsEnvironment, hiveTableOperationsConfig, manifestFileCache, session, catalogName, schemaTableName); } @Override @@ -223,7 +241,7 @@ protected boolean tableExists(ConnectorSession session, SchemaTableName schemaTa return false; } if (!isIcebergTable(hiveTable.get())) { - throw new UnknownTableTypeException(schemaTableName); + throw new UnknownTableTypeException("Not an Iceberg table: " + schemaTableName); } return true; } @@ -336,14 +354,14 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con Schema schema = toIcebergSchema(tableMetadata.getColumns()); - PartitionSpec partitionSpec = parsePartitionFields(schema, tableProperties.getPartitioning(tableMetadata.getProperties())); + PartitionSpec partitionSpec = parsePartitionFields(schema, getPartitioning(tableMetadata.getProperties())); MetastoreContext metastoreContext = getMetastoreContext(session); Database database = metastore.getDatabase(metastoreContext, schemaName) .orElseThrow(() -> new SchemaNotFoundException(schemaName)); HdfsContext hdfsContext = new HdfsContext(session, schemaName, tableName); - String targetPath = tableProperties.getTableLocation(tableMetadata.getProperties()); + String targetPath = getTableLocation(tableMetadata.getProperties()); if (targetPath == null) { Optional location = database.getLocation(); if (!location.isPresent() || location.get().isEmpty()) { @@ -360,7 +378,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con getMetastoreContext(session), hdfsEnvironment, hdfsContext, - hiveTableOeprationsConfig, + hiveTableOperationsConfig, manifestFileCache, schemaName, tableName, @@ -369,7 +387,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con if (operations.current() != null) { throw new TableAlreadyExistsException(schemaTableName); } - SortOrder sortOrder = parseSortFields(schema, tableProperties.getSortOrder(tableMetadata.getProperties())); + SortOrder sortOrder = parseSortFields(schema, getSortOrder(tableMetadata.getProperties())); FileFormat fileFormat = tableProperties.getFileFormat(session, tableMetadata.getProperties()); TableMetadata metadata = newTableMetadata(schema, partitionSpec, sortOrder, targetPath, populateTableProperties(this, tableMetadata, tableProperties, fileFormat, session)); transaction = createTableTransaction(tableName, operations, metadata); @@ -379,7 +397,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con new IcebergTableName(tableName, DATA, Optional.empty(), Optional.empty()), toPrestoSchema(metadata.schema(), typeManager), toPrestoPartitionSpec(metadata.spec(), typeManager), - getColumns(metadata.schema(), metadata.spec(), typeManager), + getColumnsForWrite(metadata.schema(), metadata.spec(), typeManager), targetPath, fileFormat, getCompressionCodec(session), @@ -459,18 +477,23 @@ public List listViews(ConnectorSession session, Optional listMaterializedViews(ConnectorSession session, String schemaName) + { + return ImmutableList.of(); + } + @Override public Map getViews(ConnectorSession session, SchemaTablePrefix prefix) { ImmutableMap.Builder views = ImmutableMap.builder(); List tableNames; - if (prefix.getTableName() != null) { + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { tableNames = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } else { - tableNames = listViews(session, Optional.of(prefix.getSchemaName())); + tableNames = listViews(session, Optional.ofNullable(prefix.getSchemaName())); } - MetastoreContext metastoreContext = getMetastoreContext(session); for (SchemaTableName schemaTableName : tableNames) { Optional
table = getHiveTable(session, schemaTableName); if (table.isPresent() && isPrestoView(table.get())) { @@ -549,7 +572,8 @@ public TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession Set supportedStatistics = ImmutableSet.builder() .addAll(hiveColumnStatistics) // iceberg table-supported statistics - .addAll(super.getStatisticsCollectionMetadata(session, tableMetadata).getColumnStatistics()) + .addAll(!connectorSystemConfig.isNativeExecution() ? + super.getStatisticsCollectionMetadata(session, tableMetadata).getColumnStatistics() : ImmutableSet.of()) .build(); Set tableStatistics = ImmutableSet.of(ROW_COUNT); return new TableStatisticsMetadata(supportedStatistics, tableStatistics, emptyList()); @@ -681,4 +705,30 @@ public void unregisterTable(ConnectorSession clientSession, SchemaTableName sche MetastoreContext metastoreContext = getMetastoreContext(clientSession); metastore.dropTableFromMetastore(metastoreContext, schemaTableName.getSchemaName(), schemaTableName.getTableName()); } + + @Override + protected void createIcebergView( + ConnectorSession session, + SchemaTableName viewName, + List columns, + String viewSql, + Map properties) + { + throw new PrestoException(NOT_SUPPORTED, "Iceberg Hive catalog does not support native Iceberg views for materialized views."); + } + + @Override + protected void dropIcebergView(ConnectorSession session, SchemaTableName schemaTableName) + { + throw new PrestoException(NOT_SUPPORTED, "Iceberg Hive catalog does not support native Iceberg views for materialized views."); + } + + @Override + protected void updateIcebergViewProperties( + ConnectorSession session, + SchemaTableName viewName, + Map properties) + { + throw new PrestoException(NOT_SUPPORTED, "Iceberg Hive catalog does not support native Iceberg views for materialized views."); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java index 9747146392fb7..ca37b7910b009 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadataFactory.java @@ -19,22 +19,31 @@ import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; +import com.facebook.presto.spi.ConnectorSystemConfig; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; +import jakarta.inject.Inject; -import javax.inject.Inject; +import java.util.List; +import static com.facebook.presto.spi.MaterializedViewDefinition.ColumnMapping; import static java.util.Objects.requireNonNull; public class IcebergHiveMetadataFactory implements IcebergMetadataFactory { + final IcebergCatalogName catalogName; final ExtendedHiveMetastore metastore; final HdfsEnvironment hdfsEnvironment; final TypeManager typeManager; + final ProcedureRegistry procedureRegistry; final JsonCodec commitTaskCodec; + final JsonCodec> columnMappingsCodec; + final JsonCodec> schemaTableNamesCodec; final StandardFunctionResolution functionResolution; final RowExpressionService rowExpressionService; final NodeVersion nodeVersion; @@ -43,50 +52,66 @@ public class IcebergHiveMetadataFactory final StatisticsFileCache statisticsFileCache; final ManifestFileCache manifestFileCache; final IcebergTableProperties tableProperties; + final ConnectorSystemConfig connectorSystemConfig; @Inject public IcebergHiveMetadataFactory( + IcebergCatalogName catalogName, ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, + JsonCodec> columnMappingsCodec, + JsonCodec> schemaTableNamesCodec, NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, IcebergHiveTableOperationsConfig operationsConfig, StatisticsFileCache statisticsFileCache, ManifestFileCache manifestFileCache, - IcebergTableProperties tableProperties) + IcebergTableProperties tableProperties, + ConnectorSystemConfig connectorSystemConfig) { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); this.metastore = requireNonNull(metastore, "metastore is null"); this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + this.columnMappingsCodec = requireNonNull(columnMappingsCodec, "columnMappingsCodec is null"); + this.schemaTableNamesCodec = requireNonNull(schemaTableNamesCodec, "schemaTableNamesCodec is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.operationsConfig = requireNonNull(operationsConfig, "operationsConfig is null"); this.statisticsFileCache = requireNonNull(statisticsFileCache, "statisticsFileCache is null"); this.manifestFileCache = requireNonNull(manifestFileCache, "manifestFileCache is null"); this.tableProperties = requireNonNull(tableProperties, "icebergTableProperties is null"); + this.connectorSystemConfig = requireNonNull(connectorSystemConfig, "connectorSystemConfig is null"); } public ConnectorMetadata create() { return new IcebergHiveMetadata( + catalogName, metastore, hdfsEnvironment, typeManager, + procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, + columnMappingsCodec, + schemaTableNamesCodec, nodeVersion, filterStatsCalculatorService, operationsConfig, statisticsFileCache, manifestFileCache, - tableProperties); + tableProperties, + connectorSystemConfig); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveModule.java index 3823019e1dc73..38242d8ae782e 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveModule.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveModule.java @@ -16,10 +16,12 @@ import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.PartitionMutator; +import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HiveMetastoreCacheStats; import com.facebook.presto.hive.metastore.HivePartitionMutator; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.MetastoreCacheStats; import com.facebook.presto.hive.metastore.MetastoreConfig; import com.facebook.presto.hive.metastore.thrift.ThriftHiveMetastoreConfig; @@ -30,6 +32,8 @@ import java.util.Optional; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.ALL; +import static com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheType.TABLE; import static com.google.common.base.Preconditions.checkArgument; import static org.weakref.jmx.ObjectNames.generatedNameOf; import static org.weakref.jmx.guice.ExportBinder.newExporter; @@ -50,14 +54,15 @@ public IcebergHiveModule(String connectorId, Optional met public void setup(Binder binder) { install(new IcebergHiveMetastoreModule(this.connectorId, this.metastore)); + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(IcebergHiveTableOperationsConfig.class); configBinder(binder).bindConfig(MetastoreClientConfig.class); configBinder(binder).bindConfig(ThriftHiveMetastoreConfig.class); - long metastoreCacheTtl = buildConfigObject(MetastoreClientConfig.class).getMetastoreCacheTtl().toMillis(); - checkArgument(metastoreCacheTtl == 0, "In-memory hive metastore caching must not be enabled for Iceberg"); + checkArgument(isCachingAllowed(buildConfigObject(MetastoreClientConfig.class)), + "In-memory hive metastore caching for tables must not be enabled for Iceberg"); binder.bind(PartitionMutator.class).to(HivePartitionMutator.class).in(Scopes.SINGLETON); @@ -68,4 +73,18 @@ public void setup(Binder binder) configBinder(binder).bindConfig(MetastoreConfig.class); } + + private boolean isCachingAllowed(MetastoreClientConfig config) + { + if (!config.getEnabledCaches().isEmpty()) { + return !config.getEnabledCaches().contains(ALL) && !config.getEnabledCaches().contains(TABLE); + } + + if (!config.getDisabledCaches().isEmpty()) { + return config.getDisabledCaches().contains(ALL) || config.getDisabledCaches().contains(TABLE); + } + + return config.getMetastoreCacheScope() != MetastoreCacheScope.ALL || + config.getDefaultMetastoreCacheTtl().toMillis() == 0; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveTableOperationsConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveTableOperationsConfig.java index b04d677ce456c..9d9d358044030 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveTableOperationsConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveTableOperationsConfig.java @@ -15,12 +15,11 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; -import javax.validation.constraints.Min; - -import static io.airlift.units.Duration.succinctDuration; +import static com.facebook.airlift.units.Duration.succinctDuration; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -32,6 +31,7 @@ public class IcebergHiveTableOperationsConfig private Duration tableRefreshMaxRetryTime = succinctDuration(1, MINUTES); private double tableRefreshBackoffScaleFactor = 4.0; private int tableRefreshRetries = 20; + private boolean lockingEnabled = true; @MinDuration("1ms") public Duration getTableRefreshBackoffMinSleepTime() @@ -102,4 +102,17 @@ public int getTableRefreshRetries() { return tableRefreshRetries; } + + @Config("iceberg.engine.hive.lock-enabled") + @ConfigDescription("Whether to use HMS locks to ensure atomicity of commits") + public IcebergHiveTableOperationsConfig setLockingEnabled(boolean lockingEnabled) + { + this.lockingEnabled = lockingEnabled; + return this; + } + + public boolean getLockingEnabled() + { + return lockingEnabled; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInputInfo.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInputInfo.java index 048954ad7a289..839346b5c59c7 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInputInfo.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInputInfo.java @@ -22,11 +22,14 @@ public class IcebergInputInfo { private final Optional snapshotId; + private final String tableLocation; public IcebergInputInfo( - @JsonProperty("snapshotId") Optional snapshotId) + @JsonProperty("snapshotId") Optional snapshotId, + @JsonProperty("tableLocation") String tableLocation) { this.snapshotId = requireNonNull(snapshotId, "snapshotId is null"); + this.tableLocation = requireNonNull(tableLocation, "tableLocation is null"); } @JsonProperty @@ -34,4 +37,10 @@ public Optional getSnapshotId() { return snapshotId; } + + @JsonProperty + public String getTableLocation() + { + return tableLocation; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInsertTableHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInsertTableHandle.java index 0a2ad488a88d4..de6a31503c7b8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInsertTableHandle.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergInsertTableHandle.java @@ -15,11 +15,13 @@ import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Map; +import java.util.Optional; public class IcebergInsertTableHandle extends IcebergWritableTableHandle @@ -36,7 +38,8 @@ public IcebergInsertTableHandle( @JsonProperty("fileFormat") FileFormat fileFormat, @JsonProperty("compressionCodec") HiveCompressionCodec compressionCodec, @JsonProperty("storageProperties") Map storageProperties, - @JsonProperty("sortOrder") List sortOrder) + @JsonProperty("sortOrder") List sortOrder, + @JsonProperty("materializedViewName") Optional materializedViewName) { super( schemaName, @@ -48,6 +51,7 @@ public IcebergInsertTableHandle( fileFormat, compressionCodec, storageProperties, - sortOrder); + sortOrder, + materializedViewName); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMaterializedViewProperties.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMaterializedViewProperties.java new file mode 100644 index 0000000000000..a1db2e786e119 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMaterializedViewProperties.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.units.Duration; +import com.facebook.presto.spi.MaterializedViewRefreshType; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; +import com.facebook.presto.spi.session.PropertyMetadata; +import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.spi.MaterializedViewRefreshType.FULL; +import static com.facebook.presto.spi.session.PropertyMetadata.durationProperty; +import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +/** + * Properties specific to Iceberg materialized views. + * This class provides the property definitions and accessor methods + * for materialized view-specific properties. It combines the base table + * properties (needed for the storage table) with MV-specific properties. + */ +public class IcebergMaterializedViewProperties +{ + public static final String STORAGE_SCHEMA = "storage_schema"; + public static final String STORAGE_TABLE = "storage_table"; + public static final String STALE_READ_BEHAVIOR = "stale_read_behavior"; + public static final String STALENESS_WINDOW = "staleness_window"; + public static final String REFRESH_TYPE = "refresh_type"; + + private final List> materializedViewProperties; + + @Inject + public IcebergMaterializedViewProperties(IcebergTableProperties tableProperties) + { + requireNonNull(tableProperties, "tableProperties is null"); + + // MV-specific properties + List> mvOnlyProperties = ImmutableList.>builder() + .add(stringProperty( + STORAGE_SCHEMA, + "Schema for the materialized view storage table (defaults to same schema as the materialized view)", + null, + true)) + .add(stringProperty( + STORAGE_TABLE, + "Custom name for the materialized view storage table (defaults to generated name)", + null, + true)) + .add(new PropertyMetadata<>( + STALE_READ_BEHAVIOR, + "Behavior when reading from a stale materialized view (FAIL or USE_VIEW_QUERY)", + createUnboundedVarcharType(), + MaterializedViewStaleReadBehavior.class, + null, + true, + value -> value == null ? null : MaterializedViewStaleReadBehavior.valueOf(((String) value).toUpperCase(ENGLISH)), + value -> value == null ? null : ((MaterializedViewStaleReadBehavior) value).name())) + .add(durationProperty( + STALENESS_WINDOW, + "Staleness window for materialized view (e.g., '1h', '30m', '0s')", + null, + true)) + .add(new PropertyMetadata<>( + REFRESH_TYPE, + "Refresh type for materialized view", + createUnboundedVarcharType(), + MaterializedViewRefreshType.class, + FULL, + true, + value -> value == null ? FULL : MaterializedViewRefreshType.valueOf(((String) value).toUpperCase(ENGLISH)), + value -> value == null ? null : ((MaterializedViewRefreshType) value).name())) + .build(); + + // Combine table properties (for storage table) with MV-specific properties + materializedViewProperties = ImmutableList.>builder() + .addAll(tableProperties.getTableProperties()) + .addAll(mvOnlyProperties) + .build(); + } + + public List> getMaterializedViewProperties() + { + return materializedViewProperties; + } + + public static Optional getStorageSchema(Map properties) + { + return Optional.ofNullable((String) properties.get(STORAGE_SCHEMA)); + } + + public static Optional getStorageTable(Map properties) + { + return Optional.ofNullable((String) properties.get(STORAGE_TABLE)); + } + + public static Optional getStaleReadBehavior(Map properties) + { + return Optional.ofNullable((MaterializedViewStaleReadBehavior) properties.get(STALE_READ_BEHAVIOR)); + } + + public static Optional getStalenessWindow(Map properties) + { + return Optional.ofNullable((Duration) properties.get(STALENESS_WINDOW)); + } + + public static MaterializedViewRefreshType getRefreshType(Map properties) + { + return (MaterializedViewRefreshType) properties.getOrDefault(REFRESH_TYPE, FULL); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java new file mode 100644 index 0000000000000..a6447da093c08 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeSink.java @@ -0,0 +1,225 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.ColumnarRow; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.hive.HdfsContext; +import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.iceberg.delete.IcebergDeletePageSink; +import com.facebook.presto.spi.ConnectorMergeSink; +import com.facebook.presto.spi.ConnectorPageSink; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.connector.MergePage; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.io.LocationProvider; +import org.roaringbitmap.longlong.ImmutableLongBitmapDataProvider; +import org.roaringbitmap.longlong.LongBitmapDataProvider; +import org.roaringbitmap.longlong.Roaring64Bitmap; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.plugin.base.util.Closables.closeAllSuppress; +import static com.facebook.presto.spi.connector.MergePage.createDeleteAndInsertPages; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.CompletableFuture.failedFuture; + +public class IcebergMergeSink + implements ConnectorMergeSink +{ + private final LocationProvider locationProvider; + private final IcebergFileWriterFactory fileWriterFactory; + private final HdfsEnvironment hdfsEnvironment; + private final JsonCodec jsonCodec; + private final ConnectorSession session; + private final FileFormat fileFormat; + private final Map partitionsSpecs; + private final ConnectorPageSink insertPageSink; + private final int columnCount; + private final Map fileDeletions = new HashMap<>(); + + public IcebergMergeSink( + LocationProvider locationProvider, + IcebergFileWriterFactory fileWriterFactory, + HdfsEnvironment hdfsEnvironment, + JsonCodec jsonCodec, + ConnectorSession session, + FileFormat fileFormat, + Map partitionsSpecs, + ConnectorPageSink insertPageSink, + int columnCount) + { + this.locationProvider = requireNonNull(locationProvider, "locationProvider is null"); + this.fileWriterFactory = requireNonNull(fileWriterFactory, "fileWriterFactory is null"); + this.hdfsEnvironment = requireNonNull(hdfsEnvironment, "hdfsEnvironment is null"); + this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); + this.session = requireNonNull(session, "session is null"); + this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); + this.partitionsSpecs = requireNonNull(partitionsSpecs, "partitionsSpecs is null"); + this.insertPageSink = requireNonNull(insertPageSink, "insertPageSink is null"); + this.columnCount = columnCount; + } + + /** + * @param page It has N + 2 channels/blocks, where N is the number of columns in the source table.
+ * 1: Source table column 1.
+ * 2: Source table column 2.
+ * N: Source table column N.
+ * N + 1: Operation: INSERT(1), DELETE(2), UPDATE(3). More info: {@link ConnectorMergeSink}
+ * N + 2: Target Table Row ID (_file:varchar, _pos:bigint, partition_spec_id:integer, partition_data:varchar). + */ + @Override + public void storeMergedRows(Page page) + { + MergePage mergePage = createDeleteAndInsertPages(page, columnCount); + + mergePage.getInsertionsPage().ifPresent(insertPageSink::appendPage); + + mergePage.getDeletionsPage().ifPresent(deletions -> { + ColumnarRow rowIdRow = toColumnarRow(deletions.getBlock(deletions.getChannelCount() - 1)); + + for (int position = 0; position < rowIdRow.getPositionCount(); position++) { + Slice filePath = VarcharType.VARCHAR.getSlice(rowIdRow.getField(0), position); + long rowPosition = BIGINT.getLong(rowIdRow.getField(1), position); + + int index = position; + FileDeletion deletion = fileDeletions.computeIfAbsent(filePath, ignored -> { + int partitionSpecId = toIntExact(INTEGER.getLong(rowIdRow.getField(2), index)); + String partitionData = VarcharType.VARCHAR.getSlice(rowIdRow.getField(3), index).toStringUtf8(); + return new FileDeletion(partitionSpecId, partitionData); + }); + + deletion.rowsToDelete().addLong(rowPosition); + } + }); + } + + @Override + public CompletableFuture> finish() + { + return insertPageSink.finish().thenCompose(insertFragments -> { + List fragments = new ArrayList<>(insertFragments); + + try { + fileDeletions.forEach((dataFilePath, deletion) -> { + ConnectorPageSink sink = createPositionDeletePageSink( + dataFilePath.toStringUtf8(), + partitionsSpecs.get(deletion.partitionSpecId()), + deletion.partitionDataJson()); + fragments.addAll(writePositionDeletes(sink, deletion.rowsToDelete())); + }); + return completedFuture(fragments); + } + catch (Exception e) { + return failedFuture(e); + } + }); + } + + @Override + public void abort() + { + insertPageSink.abort(); + } + + private ConnectorPageSink createPositionDeletePageSink(String dataFilePath, PartitionSpec partitionSpec, String partitionDataJson) + { + return new IcebergDeletePageSink( + partitionSpec, + Optional.of(partitionDataJson), + locationProvider, + fileWriterFactory, + hdfsEnvironment, + new HdfsContext(session), + jsonCodec, + session, + dataFilePath, + fileFormat); + } + + private static Collection writePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + try { + return doWritePositionDeletes(sink, rowsToDelete); + } + catch (Throwable t) { + closeAllSuppress(t, sink::abort); + throw t; + } + } + + private static Collection doWritePositionDeletes(ConnectorPageSink sink, ImmutableLongBitmapDataProvider rowsToDelete) + { + PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(BIGINT)); + + rowsToDelete.forEach(rowPosition -> { + BIGINT.writeLong(pageBuilder.getBlockBuilder(0), rowPosition); + pageBuilder.declarePosition(); + if (pageBuilder.isFull()) { + sink.appendPage(pageBuilder.build()); + pageBuilder.reset(); + } + }); + + if (!pageBuilder.isEmpty()) { + sink.appendPage(pageBuilder.build()); + } + + return sink.finish().join(); + } + + private static class FileDeletion + { + private final int partitionSpecId; + private final String partitionDataJson; + private final LongBitmapDataProvider rowsToDelete = new Roaring64Bitmap(); + + public FileDeletion(int partitionSpecId, String partitionDataJson) + { + this.partitionSpecId = partitionSpecId; + this.partitionDataJson = requireNonNull(partitionDataJson, "partitionDataJson is null"); + } + + public int partitionSpecId() + { + return partitionSpecId; + } + + public String partitionDataJson() + { + return partitionDataJson; + } + + public LongBitmapDataProvider rowsToDelete() + { + return rowsToDelete; + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java new file mode 100644 index 0000000000000..7d706cb1e2d40 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMergeTableHandle.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +@ThriftStruct +public class IcebergMergeTableHandle + implements ConnectorMergeTableHandle +{ + private final IcebergTableHandle tableHandle; + private final IcebergInsertTableHandle insertTableHandle; + private final Map partitionSpecs; + + @JsonCreator + @ThriftConstructor + public IcebergMergeTableHandle( + @JsonProperty("tableHandle") IcebergTableHandle tableHandle, + @JsonProperty("insertTableHandle") IcebergInsertTableHandle insertTableHandle, + @JsonProperty("partitionSpecs") Map partitionSpecs) + { + this.tableHandle = requireNonNull(tableHandle, "tableHandle is null"); + this.insertTableHandle = requireNonNull(insertTableHandle, "insertTableHandle is null"); + this.partitionSpecs = requireNonNull(partitionSpecs, "partitionSpecs is null"); + } + + @Override + @JsonProperty + @ThriftField(1) + public IcebergTableHandle getTableHandle() + { + return tableHandle; + } + + @JsonProperty + @ThriftField(2) + public IcebergInsertTableHandle getInsertTableHandle() + { + return insertTableHandle; + } + + @JsonProperty + @ThriftField(3) + public Map getPartitionSpecs() + { + return partitionSpecs; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java index 9758f325338c7..f89488911f8cd 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergMetadataColumn.java @@ -22,6 +22,7 @@ import java.util.stream.Stream; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.iceberg.ColumnIdentity.TypeCategory.PRIMITIVE; @@ -32,16 +33,20 @@ public enum IcebergMetadataColumn { FILE_PATH(MetadataColumns.FILE_PATH.fieldId(), "$path", VARCHAR, PRIMITIVE), DATA_SEQUENCE_NUMBER(Integer.MAX_VALUE - 1001, "$data_sequence_number", BIGINT, PRIMITIVE), + IS_DELETED(MetadataColumns.IS_DELETED.fieldId(), "$deleted", BOOLEAN, PRIMITIVE), + DELETE_FILE_PATH(MetadataColumns.DELETE_FILE_PATH.fieldId(), "$delete_file_path", VARCHAR, PRIMITIVE), /** * Iceberg reserved row ids begin at INTEGER.MAX_VALUE and count down. Starting with MIN_VALUE here to avoid conflicts. * Inner type for row is not known until runtime. */ - UPDATE_ROW_DATA(Integer.MIN_VALUE, "$row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT) + UPDATE_ROW_DATA(Integer.MIN_VALUE, "$row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT), + MERGE_TARGET_ROW_ID_DATA(Integer.MIN_VALUE + 1, "$target_table_row_id", RowType.anonymous(ImmutableList.of(UNKNOWN)), STRUCT), + MERGE_PARTITION_DATA(Integer.MIN_VALUE + 2, "partition_data", VARCHAR, PRIMITIVE) /**/; - private static final Set COLUMN_IDS = Stream.of(values()) - .map(IcebergMetadataColumn::getId) - .collect(toImmutableSet()); + private static final Set COLUMN_IDS = Stream.concat( + Stream.of(values()).map(IcebergMetadataColumn::getId), + Stream.of(MetadataColumns.SPEC_ID.fieldId())).collect(toImmutableSet()); private final int id; private final String columnName; private final Type type; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeCatalogFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeCatalogFactory.java index 02f04f64f137b..26f4962334b83 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeCatalogFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeCatalogFactory.java @@ -21,14 +21,13 @@ import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.UncheckedExecutionException; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.iceberg.CatalogUtil; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.SupportsNamespaces; -import javax.inject.Inject; - import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java index b35d006b827f5..9202be8e4d5c4 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadata.java @@ -19,17 +19,20 @@ import com.facebook.presto.hive.TableAlreadyExistsException; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; import com.facebook.presto.iceberg.util.IcebergPrestoModelConverters; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorNewTableLayout; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.ConnectorViewDefinition; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -48,9 +51,11 @@ import org.apache.iceberg.exceptions.NamespaceNotEmptyException; import org.apache.iceberg.exceptions.NoSuchNamespaceException; import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.exceptions.NoSuchViewException; import org.apache.iceberg.view.View; import org.apache.iceberg.view.ViewBuilder; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -61,10 +66,13 @@ import static com.facebook.presto.iceberg.CatalogType.HADOOP; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_COMMIT_ERROR; import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; +import static com.facebook.presto.iceberg.IcebergTableProperties.getPartitioning; +import static com.facebook.presto.iceberg.IcebergTableProperties.getSortOrder; +import static com.facebook.presto.iceberg.IcebergTableProperties.getTableLocation; import static com.facebook.presto.iceberg.IcebergTableType.DATA; import static com.facebook.presto.iceberg.IcebergUtil.VIEW_OWNER; import static com.facebook.presto.iceberg.IcebergUtil.createIcebergViewProperties; -import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getColumnsForWrite; import static com.facebook.presto.iceberg.IcebergUtil.getNativeIcebergTable; import static com.facebook.presto.iceberg.IcebergUtil.getNativeIcebergView; import static com.facebook.presto.iceberg.IcebergUtil.populateTableProperties; @@ -79,6 +87,7 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.StandardErrorCode.SCHEMA_NOT_EMPTY; import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.base.Throwables.getRootCause; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; @@ -101,16 +110,20 @@ public class IcebergNativeMetadata public IcebergNativeMetadata( IcebergNativeCatalogFactory catalogFactory, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, + JsonCodec> columnMappingsCodec, + JsonCodec> schemaTableNamesCodec, CatalogType catalogType, NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, IcebergTableProperties tableProperties) { - super(typeManager, functionResolution, rowExpressionService, commitTaskCodec, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + super(typeManager, procedureRegistry, functionResolution, rowExpressionService, commitTaskCodec, columnMappingsCodec, schemaTableNamesCodec, + nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.catalogType = requireNonNull(catalogType, "catalogType is null"); this.warehouseDataDir = Optional.ofNullable(catalogFactory.getCatalogWarehouseDataDir()); @@ -125,9 +138,18 @@ protected Table getRawIcebergTable(ConnectorSession session, SchemaTableName sch @Override protected View getIcebergView(ConnectorSession session, SchemaTableName schemaTableName) { - return icebergViews.computeIfAbsent( - schemaTableName, - ignored -> getNativeIcebergView(catalogFactory, session, schemaTableName)); + try { + return icebergViews.computeIfAbsent( + schemaTableName, + ignored -> getNativeIcebergView(catalogFactory, session, schemaTableName)); + } + catch (RuntimeException e) { + Throwable rootCause = getRootCause(e); + if (rootCause instanceof NoSuchViewException) { + throw (NoSuchViewException) rootCause; + } + throw e; + } } @Override @@ -240,11 +262,21 @@ public List listViews(ConnectorSession session, Optional tableNames = ImmutableList.builder(); Catalog catalog = catalogFactory.getCatalog(session); if (catalog instanceof ViewCatalog) { + ViewCatalog viewCatalog = (ViewCatalog) catalog; for (String schema : listSchemas(session, schemaName.orElse(null))) { try { - for (TableIdentifier tableIdentifier : ((ViewCatalog) catalog).listViews( + for (TableIdentifier tableIdentifier : viewCatalog.listViews( toIcebergNamespace(Optional.ofNullable(schema), catalogFactory.isNestedNamespaceEnabled()))) { - tableNames.add(new SchemaTableName(schema, tableIdentifier.name())); + // Exclude materialized views from the list of views + try { + View view = viewCatalog.loadView(tableIdentifier); + if (!view.properties().containsKey(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION)) { + tableNames.add(new SchemaTableName(schema, tableIdentifier.name())); + } + } + catch (IllegalArgumentException e) { + // Ignore illegal view names + } } } catch (NoSuchNamespaceException e) { @@ -270,11 +302,11 @@ public Map getViews(ConnectorSession s Catalog catalog = catalogFactory.getCatalog(session); if (catalog instanceof ViewCatalog) { List tableNames; - if (prefix.getTableName() != null) { + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { tableNames = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } else { - tableNames = listViews(session, Optional.of(prefix.getSchemaName())); + tableNames = listViews(session, Optional.ofNullable(prefix.getSchemaName())); } for (SchemaTableName schemaTableName : tableNames) { @@ -282,6 +314,10 @@ public Map getViews(ConnectorSession s TableIdentifier viewIdentifier = toIcebergTableIdentifier(schemaTableName, catalogFactory.isNestedNamespaceEnabled()); if (((ViewCatalog) catalog).viewExists(viewIdentifier)) { View view = ((ViewCatalog) catalog).loadView(viewIdentifier); + // Skip materialized views + if (view.properties().containsKey(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION)) { + continue; + } verifyAndPopulateViews(view, schemaTableName, view.sqlFor(VIEW_DIALECT).sql(), views); } } @@ -293,6 +329,34 @@ public Map getViews(ConnectorSession s return views.build(); } + @Override + public List listMaterializedViews(ConnectorSession session, String schemaName) + { + ImmutableList.Builder materializedViews = ImmutableList.builder(); + Catalog catalog = catalogFactory.getCatalog(session); + if (catalog instanceof ViewCatalog) { + ViewCatalog viewCatalog = (ViewCatalog) catalog; + try { + for (TableIdentifier tableIdentifier : viewCatalog.listViews( + toIcebergNamespace(Optional.ofNullable(schemaName), catalogFactory.isNestedNamespaceEnabled()))) { + try { + View view = viewCatalog.loadView(tableIdentifier); + if (view.properties().containsKey(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION)) { + materializedViews.add(new SchemaTableName(schemaName, tableIdentifier.name())); + } + } + catch (IllegalArgumentException e) { + // Ignore illegal view names + } + } + } + catch (NoSuchNamespaceException e) { + // ignore + } + } + return materializedViews.build(); + } + @Override public void dropView(ConnectorSession session, SchemaTableName viewName) { @@ -303,6 +367,18 @@ public void dropView(ConnectorSession session, SchemaTableName viewName) ((ViewCatalog) catalog).dropView(toIcebergTableIdentifier(viewName, catalogFactory.isNestedNamespaceEnabled())); } + @Override + public void renameView(ConnectorSession session, SchemaTableName source, SchemaTableName target) + { + Catalog catalog = catalogFactory.getCatalog(session); + if (!(catalog instanceof ViewCatalog)) { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support renaming views"); + } + ((ViewCatalog) catalog).renameView( + toIcebergTableIdentifier(source, catalogFactory.isNestedNamespaceEnabled()), + toIcebergTableIdentifier(target, catalogFactory.isNestedNamespaceEnabled())); + } + private void verifyAndPopulateViews(View view, SchemaTableName schemaTableName, String viewData, ImmutableMap.Builder views) { views.put(schemaTableName, new ConnectorViewDefinition( @@ -320,12 +396,12 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con Schema schema = toIcebergSchema(tableMetadata.getColumns()); - PartitionSpec partitionSpec = parsePartitionFields(schema, tableProperties.getPartitioning(tableMetadata.getProperties())); + PartitionSpec partitionSpec = parsePartitionFields(schema, getPartitioning(tableMetadata.getProperties())); FileFormat fileFormat = tableProperties.getFileFormat(session, tableMetadata.getProperties()); try { TableIdentifier tableIdentifier = toIcebergTableIdentifier(schemaTableName, catalogFactory.isNestedNamespaceEnabled()); - String targetPath = tableProperties.getTableLocation(tableMetadata.getProperties()); + String targetPath = getTableLocation(tableMetadata.getProperties()); if (!isNullOrEmpty(targetPath)) { transaction = catalogFactory.getCatalog(session).newCreateTableTransaction( tableIdentifier, @@ -348,7 +424,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con Table icebergTable = transaction.table(); ReplaceSortOrder replaceSortOrder = transaction.replaceSortOrder(); - SortOrder sortOrder = parseSortFields(schema, tableProperties.getSortOrder(tableMetadata.getProperties())); + SortOrder sortOrder = parseSortFields(schema, getSortOrder(tableMetadata.getProperties())); List sortFields = getSupportedSortFields(icebergTable.schema(), sortOrder); for (SortField sortField : sortFields) { if (sortField.getSortOrder().isAscending()) { @@ -371,7 +447,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Con new IcebergTableName(tableName, DATA, Optional.empty(), Optional.empty()), toPrestoSchema(icebergTable.schema(), typeManager), toPrestoPartitionSpec(icebergTable.spec(), typeManager), - getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + getColumnsForWrite(icebergTable.schema(), icebergTable.spec(), typeManager), icebergTable.location(), fileFormat, getCompressionCodec(session), @@ -417,4 +493,77 @@ protected Optional getDataLocationBasedOnWarehouseDataDir(SchemaTableNam } return warehouseDataDir.map(base -> base + schemaTableName.getSchemaName() + "/" + schemaTableName.getTableName()); } + + @Override + protected void createIcebergView( + ConnectorSession session, + SchemaTableName viewName, + List columns, + String viewSql, + Map properties) + { + Catalog catalog = catalogFactory.getCatalog(session); + if (!(catalog instanceof ViewCatalog)) { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support creating Iceberg views for materialized views"); + } + ViewCatalog viewCatalog = (ViewCatalog) catalog; + + Schema schema = toIcebergSchema(columns); + + viewCatalog.buildView(toIcebergTableIdentifier(viewName, catalogFactory.isNestedNamespaceEnabled())) + .withSchema(schema) + .withDefaultNamespace(toIcebergNamespace(Optional.ofNullable(viewName.getSchemaName()), catalogFactory.isNestedNamespaceEnabled())) + .withQuery(VIEW_DIALECT, viewSql) + .withProperties(properties) + .create(); + + icebergViews.remove(viewName); + + if (!viewCatalog.viewExists(toIcebergTableIdentifier(viewName, catalogFactory.isNestedNamespaceEnabled()))) { + throw new PrestoException(ICEBERG_COMMIT_ERROR, "Failed to create Iceberg view for materialized view: " + viewName); + } + } + + @Override + protected void dropIcebergView(ConnectorSession session, SchemaTableName schemaTableName) + { + Catalog catalog = catalogFactory.getCatalog(session); + if (!(catalog instanceof ViewCatalog)) { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support dropping Iceberg views for materialized views"); + } + ViewCatalog viewCatalog = (ViewCatalog) catalog; + viewCatalog.dropView(toIcebergTableIdentifier(schemaTableName, catalogFactory.isNestedNamespaceEnabled())); + + icebergViews.remove(schemaTableName); + } + + @Override + protected void updateIcebergViewProperties( + ConnectorSession session, + SchemaTableName viewName, + Map properties) + { + Catalog catalog = catalogFactory.getCatalog(session); + if (!(catalog instanceof ViewCatalog)) { + throw new PrestoException(NOT_SUPPORTED, "This connector does not support updating Iceberg views for materialized views"); + } + ViewCatalog viewCatalog = (ViewCatalog) catalog; + + TableIdentifier viewIdentifier = toIcebergTableIdentifier(viewName, catalogFactory.isNestedNamespaceEnabled()); + View existingView = viewCatalog.loadView(viewIdentifier); + + Map tempProperties = new HashMap<>(existingView.properties()); + tempProperties.putAll(properties); + Map mergedProperties = ImmutableMap.copyOf(tempProperties); + + ViewBuilder viewBuilder = viewCatalog.buildView(viewIdentifier) + .withSchema(existingView.schema()) + .withDefaultNamespace(toIcebergNamespace(Optional.ofNullable(viewName.getSchemaName()), catalogFactory.isNestedNamespaceEnabled())) + .withQuery(VIEW_DIALECT, existingView.sqlFor(VIEW_DIALECT).sql()) + .withProperties(mergedProperties); + + viewBuilder.createOrReplace(); + + icebergViews.remove(viewName); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java index 44ddf1992a73e..72f11ce078166 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergNativeMetadataFactory.java @@ -17,12 +17,16 @@ import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; +import jakarta.inject.Inject; -import javax.inject.Inject; +import java.util.List; import static java.util.Objects.requireNonNull; @@ -30,7 +34,10 @@ public class IcebergNativeMetadataFactory implements IcebergMetadataFactory { final TypeManager typeManager; + final ProcedureRegistry procedureRegistry; final JsonCodec commitTaskCodec; + final JsonCodec> columnMappingsCodec; + final JsonCodec> schemaTableNamesCodec; final IcebergNativeCatalogFactory catalogFactory; final CatalogType catalogType; final StandardFunctionResolution functionResolution; @@ -45,9 +52,12 @@ public IcebergNativeMetadataFactory( IcebergConfig config, IcebergNativeCatalogFactory catalogFactory, TypeManager typeManager, + ProcedureRegistry procedureRegistry, StandardFunctionResolution functionResolution, RowExpressionService rowExpressionService, JsonCodec commitTaskCodec, + JsonCodec> columnMappingsCodec, + JsonCodec> schemaTableNamesCodec, NodeVersion nodeVersion, FilterStatsCalculatorService filterStatsCalculatorService, StatisticsFileCache statisticsFileCache, @@ -55,11 +65,13 @@ public IcebergNativeMetadataFactory( { this.catalogFactory = requireNonNull(catalogFactory, "catalogFactory is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null"); this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + this.columnMappingsCodec = requireNonNull(columnMappingsCodec, "columnMappingsCodec is null"); + this.schemaTableNamesCodec = requireNonNull(schemaTableNamesCodec, "schemaTableNamesCodec is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); - requireNonNull(config, "config is null"); this.catalogType = config.getCatalogType(); this.filterStatsCalculatorService = requireNonNull(filterStatsCalculatorService, "filterStatsCalculatorService is null"); this.statisticsFileCache = requireNonNull(statisticsFileCache, "statisticsFileCache is null"); @@ -68,6 +80,8 @@ public IcebergNativeMetadataFactory( public ConnectorMetadata create() { - return new IcebergNativeMetadata(catalogFactory, typeManager, functionResolution, rowExpressionService, commitTaskCodec, catalogType, nodeVersion, filterStatsCalculatorService, statisticsFileCache, tableProperties); + return new IcebergNativeMetadata(catalogFactory, typeManager, procedureRegistry, functionResolution, rowExpressionService, + commitTaskCodec, columnMappingsCodec, schemaTableNamesCodec, catalogType, nodeVersion, filterStatsCalculatorService, + statisticsFileCache, tableProperties); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSink.java index e526ff13332a8..1e90a006182b0 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSink.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSink.java @@ -74,7 +74,7 @@ import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_TOO_MANY_OPEN_PARTITIONS; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_WRITER_OPEN_ERROR; -import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getColumnsForWrite; import static com.facebook.presto.iceberg.PartitionTransforms.getColumnTransform; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -157,7 +157,7 @@ public IcebergPageSink( this.sortOrder = requireNonNull(sortOrder, "sortOrder is null"); String tempDirectoryPath = locationProvider.newDataLocation("sort-tmp-files"); this.tempDirectory = new Path(tempDirectoryPath); - this.columnTypes = getColumns(outputSchema, partitionSpec, requireNonNull(sortParameters.getTypeManager(), "typeManager is null")).stream() + this.columnTypes = getColumnsForWrite(outputSchema, partitionSpec, requireNonNull(sortParameters.getTypeManager(), "typeManager is null")).stream() .map(IcebergColumnHandle::getType) .collect(toImmutableList()); this.sortParameters = sortParameters; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java index 0852a9300a0d3..45e8a7164df8f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSinkProvider.java @@ -16,7 +16,10 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeSink; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; @@ -25,19 +28,20 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import jakarta.inject.Inject; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; import org.apache.iceberg.io.LocationProvider; -import javax.inject.Inject; - +import java.util.Map; import java.util.Optional; import static com.facebook.presto.iceberg.IcebergUtil.getLocationProvider; import static com.facebook.presto.iceberg.IcebergUtil.getShallowWrappedIcebergTable; import static com.facebook.presto.iceberg.PartitionSpecConverter.toIcebergPartitionSpec; import static com.facebook.presto.iceberg.SchemaConverter.toIcebergSchema; +import static com.google.common.collect.Maps.transformValues; import static java.util.Objects.requireNonNull; public class IcebergPageSinkProvider @@ -80,6 +84,12 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa return createPageSink(session, (IcebergWritableTableHandle) insertTableHandle); } + @Override + public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorDistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + return createPageSink(session, (IcebergWritableTableHandle) procedureHandle); + } + private ConnectorPageSink createPageSink(ConnectorSession session, IcebergWritableTableHandle tableHandle) { HdfsContext hdfsContext = new HdfsContext(session, tableHandle.getSchemaName(), tableHandle.getTableName().getTableName()); @@ -103,4 +113,30 @@ private ConnectorPageSink createPageSink(ConnectorSession session, IcebergWritab tableHandle.getSortOrder(), sortParameters); } + + @Override + public ConnectorMergeSink createMergeSink(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorMergeTableHandle mergeHandle) + { + IcebergMergeTableHandle merge = (IcebergMergeTableHandle) mergeHandle; + IcebergWritableTableHandle tableHandle = merge.getInsertTableHandle(); + SchemaTableName schemaTableName = new SchemaTableName(tableHandle.getSchemaName(), tableHandle.getTableName().getTableName()); + LocationProvider locationProvider = getLocationProvider(schemaTableName, tableHandle.getOutputPath(), tableHandle.getStorageProperties()); + + Schema schema = toIcebergSchema(tableHandle.getSchema()); + Map partitionSpecs = transformValues(merge.getPartitionSpecs(), + prestoIcebergPartitionSpec -> toIcebergPartitionSpec(prestoIcebergPartitionSpec).toUnbound().bind(schema)); + + ConnectorPageSink pageSink = createPageSink(session, tableHandle); + + return new IcebergMergeSink( + locationProvider, + fileWriterFactory, + hdfsEnvironment, + jsonCodec, + session, + tableHandle.getFileFormat(), + partitionSpecs, + pageSink, + tableHandle.getInputColumns().size()); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java index ca982e29cfc54..2b6bc0b6e9f2f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPageSourceProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.Domain; @@ -78,13 +79,12 @@ import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; -import com.google.common.base.Suppliers; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileStatus; @@ -97,6 +97,7 @@ import org.apache.iceberg.Table; import org.apache.iceberg.io.LocationProvider; import org.apache.iceberg.types.Conversions; +import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.NestedField; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.crypto.InternalFileDecryptor; @@ -110,14 +111,13 @@ import org.roaringbitmap.longlong.LongBitmapDataProvider; import org.roaringbitmap.longlong.Roaring64Bitmap; -import javax.inject.Inject; - import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Set; import java.util.function.Function; @@ -127,7 +127,6 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR; -import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.SYNTHESIZED; import static com.facebook.presto.hive.CacheQuota.NO_CACHE_CONSTRAINTS; import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcLazyReadSmallRanges; import static com.facebook.presto.hive.HiveCommonSessionProperties.getOrcMaxBufferSize; @@ -145,12 +144,14 @@ import static com.facebook.presto.hive.parquet.ParquetPageSourceFactory.createDecryptor; import static com.facebook.presto.iceberg.FileContent.EQUALITY_DELETES; import static com.facebook.presto.iceberg.FileContent.POSITION_DELETES; +import static com.facebook.presto.iceberg.IcebergColumnHandle.DELETE_FILE_PATH_COLUMN_HANDLE; import static com.facebook.presto.iceberg.IcebergColumnHandle.getPushedDownSubfield; import static com.facebook.presto.iceberg.IcebergColumnHandle.isPushedDownSubfield; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_CANNOT_OPEN_SPLIT; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_MISSING_COLUMN; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_MISSING_DATA; +import static com.facebook.presto.iceberg.IcebergMetadataColumn.MERGE_PARTITION_DATA; import static com.facebook.presto.iceberg.IcebergOrcColumn.ROOT_COLUMN_ID; import static com.facebook.presto.iceberg.IcebergUtil.getColumns; import static com.facebook.presto.iceberg.IcebergUtil.getLocationProvider; @@ -173,14 +174,16 @@ import static com.facebook.presto.parquet.predicate.PredicateUtils.buildPredicate; import static com.facebook.presto.parquet.predicate.PredicateUtils.predicateMatches; import static com.facebook.presto.parquet.reader.ColumnIndexFilterUtils.getColumnIndexStore; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Predicates.not; +import static com.google.common.base.Suppliers.memoize; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Maps.uniqueIndex; +import static io.airlift.slice.Slices.EMPTY_SLICE; import static io.airlift.slice.Slices.utf8Slice; import static java.lang.String.format; import static java.time.ZoneOffset.UTC; @@ -188,7 +191,9 @@ import static java.util.Objects.requireNonNull; import static org.apache.iceberg.MetadataColumns.DELETE_FILE_PATH; import static org.apache.iceberg.MetadataColumns.DELETE_FILE_POS; +import static org.apache.iceberg.MetadataColumns.FILE_PATH; import static org.apache.iceberg.MetadataColumns.ROW_POSITION; +import static org.apache.iceberg.MetadataColumns.SPEC_ID; import static org.apache.parquet.io.ColumnIOConverter.constructField; import static org.apache.parquet.io.ColumnIOConverter.findNestedColumnIO; @@ -356,7 +361,8 @@ private static ConnectorPageSourceWithRowPositions createParquetPageSource( Type prestoType = column.getType(); prestoTypes.add(prestoType); - if (column.getColumnType() == IcebergColumnHandle.ColumnType.SYNTHESIZED && !column.isUpdateRowIdColumn()) { + if (column.getColumnType() == IcebergColumnHandle.ColumnType.SYNTHESIZED && + !column.isUpdateRowIdColumn() && !column.isMergeTargetTableRowIdColumn()) { Subfield pushedDownSubfield = getPushedDownSubfield(column); List nestedColumnPath = nestedColumnPath(pushedDownSubfield); Optional columnIO = findNestedColumnIO(lookupColumnByName(messageColumnIO, pushedDownSubfield.getRootName()), nestedColumnPath); @@ -425,18 +431,6 @@ public static Optional getColumnType( return Optional.ofNullable(parquetIdToField.get(column.getId())); } - private static HiveColumnHandle.ColumnType getHiveColumnHandleColumnType(IcebergColumnHandle.ColumnType columnType) - { - switch (columnType) { - case REGULAR: - return REGULAR; - case SYNTHESIZED: - return SYNTHESIZED; - } - - throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unknown ColumnType: " + columnType); - } - private static TupleDomain getParquetTupleDomain(Map, RichColumnDescriptor> descriptorsByPath, TupleDomain effectivePredicate) { if (effectivePredicate.isNone()) { @@ -538,23 +532,17 @@ private static ConnectorPageSourceWithRowPositions createBatchOrcPageSource( Map fileOrcColumnsByName = uniqueIndex(fileOrcColumns, orcColumn -> orcColumn.getColumnName().toLowerCase(ENGLISH)); int nextMissingColumnIndex = fileOrcColumnsByName.size(); - List isRowPositionList = new ArrayList<>(); - for (IcebergColumnHandle column : regularColumns) { + OptionalInt rowPositionColumnIndex = OptionalInt.empty(); + for (int idx = 0; idx < regularColumns.size(); idx++) { + IcebergColumnHandle column = regularColumns.get(idx); IcebergOrcColumn icebergOrcColumn; - boolean isExcludeColumn = false; if (fileOrcColumnByIcebergId.isEmpty()) { + // This is a migrated table icebergOrcColumn = fileOrcColumnsByName.get(column.getName()); } else { icebergOrcColumn = fileOrcColumnByIcebergId.get(column.getId()); - if (icebergOrcColumn == null) { - // Cannot get orc column from 'fileOrcColumnByIcebergId', which means SchemaEvolution may have happened, so we get orc column by column name. - icebergOrcColumn = fileOrcColumnsByName.get(column.getName()); - if (icebergOrcColumn != null) { - isExcludeColumn = true; - } - } } if (icebergOrcColumn != null) { @@ -569,11 +557,8 @@ private static ConnectorPageSourceWithRowPositions createBatchOrcPageSource( Optional.empty()); physicalColumnHandles.add(columnHandle); - // Skip SchemaEvolution column - if (!isExcludeColumn) { - includedColumns.put(columnHandle.getHiveColumnIndex(), typeManager.getType(columnHandle.getTypeSignature())); - columnReferences.add(new TupleDomainOrcPredicate.ColumnReference<>(columnHandle, columnHandle.getHiveColumnIndex(), typeManager.getType(columnHandle.getTypeSignature()))); - } + includedColumns.put(columnHandle.getHiveColumnIndex(), typeManager.getType(columnHandle.getTypeSignature())); + columnReferences.add(new TupleDomainOrcPredicate.ColumnReference<>(columnHandle, columnHandle.getHiveColumnIndex(), typeManager.getType(columnHandle.getTypeSignature()))); } else { physicalColumnHandles.add(new HiveColumnHandle( @@ -581,12 +566,16 @@ private static ConnectorPageSourceWithRowPositions createBatchOrcPageSource( toHiveType(column.getType()), column.getType().getTypeSignature(), nextMissingColumnIndex++, - getHiveColumnHandleColumnType(column.getColumnType()), + column.getColumnType(), column.getComment(), column.getRequiredSubfields(), Optional.empty())); } - isRowPositionList.add(column.isRowPositionColumn()); + + if (column.isRowPositionColumn()) { + checkArgument(rowPositionColumnIndex.isEmpty(), "Requesting more than 1 row number columns is not allowed."); + rowPositionColumnIndex = OptionalInt.of(idx); + } } // Skip the time type columns in predicate, converted on page source level @@ -643,7 +632,7 @@ private static ConnectorPageSourceWithRowPositions createBatchOrcPageSource( systemMemoryUsage, stats, runtimeStats, - isRowPositionList, + rowPositionColumnIndex, // Iceberg doesn't support row IDs new byte[0], ""), @@ -768,10 +757,10 @@ public ConnectorPageSource createPageSource( Map partitionKeys = split.getPartitionKeys(); - // the update row isn't a valid column that can be read from storage. + // The update row id and merge target table row id aren't valid columns that can be read from storage. // Filter it out from columns passed to the storage page source. Set columnsToReadFromStorage = icebergColumns.stream() - .filter(not(IcebergColumnHandle::isUpdateRowIdColumn)) + .filter(not(column -> column.isUpdateRowIdColumn() || column.isMergeTargetTableRowIdColumn())) .collect(Collectors.toSet()); // add any additional columns which may need to be read from storage @@ -782,22 +771,36 @@ public ConnectorPageSource createPageSource( .filter(not(icebergColumns::contains)) .forEach(columnsToReadFromStorage::add); - // finally, add the fields that the update column requires. - Optional updateRow = icebergColumns.stream() - .filter(IcebergColumnHandle::isUpdateRowIdColumn) + // finally, add the fields that the UPDATE and MERGE column requires. + Optional rowIdColumnHandle = icebergColumns.stream() + .filter(column -> column.isUpdateRowIdColumn() || column.isMergeTargetTableRowIdColumn()) .findFirst(); - updateRow.ifPresent(updateRowIdColumn -> { + rowIdColumnHandle.ifPresent(rowIdColumn -> { Set alreadyRequiredColumnIds = columnsToReadFromStorage.stream() .map(IcebergColumnHandle::getId) .collect(toImmutableSet()); - updateRowIdColumn.getColumnIdentity().getChildren() + rowIdColumn.getColumnIdentity().getChildren() .stream() .filter(colId -> !alreadyRequiredColumnIds.contains(colId.getId())) .forEach(colId -> { - if (colId.getId() == ROW_POSITION.fieldId()) { + if (colId.getId() == FILE_PATH.fieldId()) { + IcebergColumnHandle handle = IcebergColumnHandle.create(FILE_PATH, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } + else if (colId.getId() == ROW_POSITION.fieldId()) { IcebergColumnHandle handle = IcebergColumnHandle.create(ROW_POSITION, typeManager, REGULAR); columnsToReadFromStorage.add(handle); } + else if (colId.getId() == SPEC_ID.fieldId()) { + IcebergColumnHandle handle = IcebergColumnHandle.create(SPEC_ID, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } + else if (colId.getId() == MERGE_PARTITION_DATA.getId()) { + NestedField mergePartitionData = NestedField.required(MERGE_PARTITION_DATA.getId(), + MERGE_PARTITION_DATA.getColumnName(), Types.StringType.get()); + IcebergColumnHandle handle = IcebergColumnHandle.create(mergePartitionData, typeManager, REGULAR); + columnsToReadFromStorage.add(handle); + } else { NestedField column = tableSchema.findField(colId.getId()); if (column == null) { @@ -831,6 +834,20 @@ public ConnectorPageSource createPageSource( else if (icebergColumn.isDataSequenceNumberColumn()) { metadataValues.put(icebergColumn.getColumnIdentity().getId(), split.getDataSequenceNumber()); } + else if (icebergColumn.isMergeTargetTableRowIdColumn()) { + for (ColumnIdentity subColumn : icebergColumn.getColumnIdentity().getChildren()) { + if (subColumn.getId() == FILE_PATH.fieldId()) { + metadataValues.put(subColumn.getId(), utf8Slice(split.getPath())); + } + else if (subColumn.getId() == SPEC_ID.fieldId()) { + metadataValues.put(subColumn.getId(), (long) partitionSpec.specId()); + } + else if (subColumn.getId() == MERGE_PARTITION_DATA.getId()) { + Optional partitionDataJson = split.getPartitionDataJson(); + metadataValues.put(subColumn.getId(), partitionDataJson.isPresent() ? utf8Slice(partitionDataJson.get()) : EMPTY_SLICE); + } + } + } } List delegateColumns = columnsToReadFromStorage.stream().collect(toImmutableList()); @@ -847,8 +864,7 @@ else if (icebergColumn.isDataSequenceNumberColumn()) { LocationProvider locationProvider = getLocationProvider(table.getSchemaTableName(), outputPath.get(), storageProperties.get()); Supplier deleteSinkSupplier = () -> new IcebergDeletePageSink( - tableSchema, - split.getPartitionSpecAsJson(), + partitionSpec, split.getPartitionDataJson(), locationProvider, fileWriterFactory, @@ -858,24 +874,26 @@ else if (icebergColumn.isDataSequenceNumberColumn()) { session, split.getPath(), split.getFileFormat()); - Supplier> deletePredicate = Suppliers.memoize(() -> { + boolean storeDeleteFilePath = icebergColumns.contains(DELETE_FILE_PATH_COLUMN_HANDLE); + Supplier> deleteFilters = memoize(() -> { // If equality deletes are optimized into a join they don't need to be applied here List deletesToApply = split .getDeletes() .stream() .filter(deleteFile -> deleteFile.content() == POSITION_DELETES || equalityDeletesRequired) .collect(toImmutableList()); - List deleteFilters = readDeletes( + return readDeletes( session, tableSchema, split.getPath(), deletesToApply, partitionInsertingPageSource.getRowPositionDelegate().getStartRowPosition(), - partitionInsertingPageSource.getRowPositionDelegate().getEndRowPosition()); - return deleteFilters.stream() - .map(filter -> filter.createPredicate(delegateColumns)) - .reduce(RowPredicate::and); + partitionInsertingPageSource.getRowPositionDelegate().getEndRowPosition(), + storeDeleteFilePath); }); + Supplier> deletePredicate = memoize(() -> deleteFilters.get().stream() + .map(filter -> filter.createPredicate(delegateColumns)) + .reduce(RowPredicate::and)); Table icebergTable = getShallowWrappedIcebergTable( tableSchema, partitionSpec, @@ -905,9 +923,10 @@ else if (icebergColumn.isDataSequenceNumberColumn()) { delegateColumns, deleteSinkSupplier, deletePredicate, + deleteFilters, updatedRowPageSinkSupplier, table.getUpdatedColumns(), - updateRow); + rowIdColumnHandle); if (split.getChangelogSplitInfo().isPresent()) { dataSource = new ChangelogPageSource(dataSource, split.getChangelogSplitInfo().get(), (List) (List) desiredColumns, icebergColumns); @@ -940,7 +959,8 @@ private List readDeletes( String dataFilePath, List deleteFiles, Optional startRowPosition, - Optional endRowPosition) + Optional endRowPosition, + boolean storeDeleteFilePath) { verify(startRowPosition.isPresent() == endRowPosition.isPresent(), "startRowPosition and endRowPosition must be specified together"); @@ -981,6 +1001,10 @@ private List readDeletes( catch (IOException e) { throw new PrestoException(ICEBERG_CANNOT_OPEN_SPLIT, format("Cannot open Iceberg delete file: %s", delete.path()), e); } + if (storeDeleteFilePath) { + filters.add(new PositionDeleteFilter(deletedRows, delete.path())); + deletedRows = new Roaring64Bitmap(); // Reset the deleted rows for the next file + } } else if (delete.content() == EQUALITY_DELETES) { List fieldIds = delete.equalityFieldIds(); @@ -990,7 +1014,7 @@ else if (delete.content() == EQUALITY_DELETES) { .collect(toImmutableList()); try (ConnectorPageSource pageSource = openDeletes(session, delete, columns, TupleDomain.all())) { - filters.add(readEqualityDeletes(pageSource, columns, schema)); + filters.add(readEqualityDeletes(pageSource, columns, storeDeleteFilePath ? delete.path() : null)); } catch (IOException e) { throw new PrestoException(ICEBERG_CANNOT_OPEN_SPLIT, format("Cannot open Iceberg delete file: %s", delete.path()), e); @@ -1001,8 +1025,8 @@ else if (delete.content() == EQUALITY_DELETES) { } } - if (!deletedRows.isEmpty()) { - filters.add(new PositionDeleteFilter(deletedRows)); + if (!deletedRows.isEmpty() && !storeDeleteFilePath) { + filters.add(new PositionDeleteFilter(deletedRows, null)); } return filters; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionField.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionField.java new file mode 100644 index 0000000000000..9fd48b4474488 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionField.java @@ -0,0 +1,161 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.iceberg; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; +import java.util.OptionalInt; + +import static java.util.Objects.requireNonNull; + +public class IcebergPartitionField +{ + private final int sourceId; + private final int fieldId; + private final OptionalInt parameter; + private final PartitionTransformType transform; + private final String name; + + @JsonCreator + public IcebergPartitionField( + @JsonProperty("sourceId") int sourceId, + @JsonProperty("fieldId") int fieldId, + @JsonProperty("parameter") OptionalInt parameter, + @JsonProperty("transform") PartitionTransformType transform, + @JsonProperty("name") String name) + { + this.sourceId = sourceId; + this.fieldId = fieldId; + this.parameter = requireNonNull(parameter, "parameter is null"); + this.transform = requireNonNull(transform, "transform is null"); + this.name = requireNonNull(name, "name is null"); + } + + @JsonProperty + public int getSourceId() + { + return sourceId; + } + + @JsonProperty + public int getFieldId() + { + return fieldId; + } + + @JsonProperty + public OptionalInt getParameter() + { + return parameter; + } + + @JsonProperty + public PartitionTransformType getTransform() + { + return transform; + } + + @JsonProperty + public String getName() + { + return name; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IcebergPartitionField that = (IcebergPartitionField) o; + return transform == that.transform && + Objects.equals(name, that.name) && + sourceId == that.sourceId && + fieldId == that.fieldId && + parameter == that.parameter; + } + + @Override + public int hashCode() + { + return Objects.hash(sourceId, fieldId, parameter, transform, name); + } + + @Override + public String toString() + { + return "IcebergPartitionField{" + + "sourceId=" + sourceId + + ", fieldId=" + fieldId + + ", parameter=" + (parameter.isPresent() ? String.valueOf(parameter.getAsInt()) : "null") + + ", transform=" + transform + + ", name='" + name + '\'' + + '}'; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private int sourceId; + private int fieldId; + private OptionalInt parameter; + private PartitionTransformType transform; + private String name; + + public Builder setSourceId(int sourceId) + { + this.sourceId = sourceId; + return this; + } + + public Builder setFieldId(int fieldId) + { + this.fieldId = fieldId; + return this; + } + + public Builder setTransform(PartitionTransformType transform) + { + this.transform = transform; + return this; + } + + public Builder setName(String name) + { + this.name = name; + return this; + } + + public Builder setParameter(OptionalInt parameter) + { + this.parameter = parameter; + return this; + } + + public IcebergPartitionField build() + { + return new IcebergPartitionField(sourceId, fieldId, parameter == null ? OptionalInt.empty() : parameter, transform, name); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionLoader.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionLoader.java new file mode 100644 index 0000000000000..e646a64f473ba --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPartitionLoader.java @@ -0,0 +1,210 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.NullableValue; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.hive.PartitionNameWithVersion; +import com.facebook.presto.hive.PartitionSet.PartitionLoader; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.SchemaTableName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.PartitionField; +import org.apache.iceberg.PartitionSpec; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.common.RuntimeUnit.NANO; +import static com.facebook.presto.common.RuntimeUnit.NONE; +import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression; +import static com.facebook.presto.iceberg.IcebergUtil.getFileFormat; +import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions; +import static com.facebook.presto.iceberg.IcebergUtil.getNonMetadataColumnConstraints; +import static com.facebook.presto.iceberg.IcebergUtil.parsePartitionValue; +import static com.facebook.presto.iceberg.IcebergUtil.resolveSnapshotIdByName; +import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; +import static java.lang.String.format; +import static org.apache.iceberg.types.Type.TypeID.BINARY; +import static org.apache.iceberg.types.Type.TypeID.FIXED; + +public class IcebergPartitionLoader + implements PartitionLoader +{ + public static final String LAZY_LOADING_COUNT_KEY_TEMPLATE = "lazy_loading_count_%s_%s"; + public static final String LAZY_LOADING_TIME_KEY_TEMPLATE = "lazy_loading_time_%s_%s"; + private final TypeManager typeManager; + private final FileFormat fileFormat; + private final ConnectorTableHandle tableHandle; + private final Constraint constraint; + private final List partitionColumns; + private final boolean isEmptyTable; + private final TableScan tableScan; + private final RuntimeStats runtimeStats; + private final String loadingCountKey; + private final String loadingTimeKey; + + public IcebergPartitionLoader( + TypeManager typeManager, + ConnectorTableHandle tableHandle, + Table icebergTable, + Constraint constraint, + List partitionColumns, + RuntimeStats runtimeStats) + { + this.typeManager = typeManager; + this.partitionColumns = partitionColumns; + this.fileFormat = getFileFormat(icebergTable); + this.tableHandle = tableHandle; + this.constraint = constraint; + IcebergTableName name = ((IcebergTableHandle) tableHandle).getIcebergTableName(); + SchemaTableName schemaTableName = ((IcebergTableHandle) tableHandle).getSchemaTableName(); + loadingCountKey = format(LAZY_LOADING_COUNT_KEY_TEMPLATE, schemaTableName.getSchemaName(), schemaTableName.getTableName()); + loadingTimeKey = format(LAZY_LOADING_TIME_KEY_TEMPLATE, schemaTableName.getSchemaName(), schemaTableName.getTableName()); + // Empty iceberg table would cause `snapshotId` not present + Optional snapshotId = resolveSnapshotIdByName(icebergTable, name); + if (!snapshotId.isPresent()) { + this.isEmptyTable = true; + this.tableScan = null; + } + else { + this.isEmptyTable = false; + this.tableScan = icebergTable.newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(runtimeStats)) + .filter(toIcebergExpression(getNonMetadataColumnConstraints(constraint + .getSummary() + .simplify()))) + .useSnapshot(snapshotId.get()); + } + this.runtimeStats = runtimeStats; + } + + @Override + public synchronized List loadPartitions() + { + if (isEmptyTable) { + return ImmutableList.of(); + } + + long startTime = System.nanoTime(); + Set partitions = new HashSet<>(); + try (CloseableIterable fileScanTasks = tableScan.planFiles()) { + for (FileScanTask fileScanTask : fileScanTasks) { + // If exists delete files, skip the metadata optimization based on partition values as they might become incorrect + if (!fileScanTask.deletes().isEmpty()) { + return ImmutableList.of(new HivePartition(((IcebergTableHandle) tableHandle).getSchemaTableName())); + } + + StructLike partition = fileScanTask.file().partition(); + PartitionSpec spec = fileScanTask.spec(); + Map fieldToIndex = getIdentityPartitions(spec); + ImmutableMap.Builder builder = ImmutableMap.builder(); + + fieldToIndex.forEach((field, index) -> { + int id = field.sourceId(); + org.apache.iceberg.types.Type type = spec.schema().findType(id); + Class javaClass = type.typeId().javaClass(); + Object value = partition.get(index, javaClass); + String partitionStringValue; + + if (value == null) { + partitionStringValue = null; + } + else if (type.typeId() == FIXED || type.typeId() == BINARY) { + partitionStringValue = Base64.getEncoder().encodeToString(((ByteBuffer) value).array()); + } + else { + partitionStringValue = value.toString(); + } + + NullableValue partitionValue = parsePartitionValue(fileFormat, partitionStringValue, toPrestoType(type, typeManager), partition.toString()); + Optional column = partitionColumns.stream() + .filter(icebergColumnHandle -> Objects.equals(icebergColumnHandle.getId(), field.sourceId())) + .findAny(); + + if (column.isPresent()) { + builder.put(column.get(), partitionValue); + } + }); + + Map values = builder.build(); + HivePartition newPartition = new HivePartition( + ((IcebergTableHandle) tableHandle).getSchemaTableName(), + new PartitionNameWithVersion(partition.toString(), Optional.empty()), + values); + + boolean isIncludePartition = true; + Map domains = constraint.getSummary().getDomains().get(); + for (IcebergColumnHandle column : partitionColumns) { + NullableValue value = newPartition.getKeys().get(column); + Domain allowedDomain = domains.get(column); + if (allowedDomain != null && !allowedDomain.includesNullableValue(value.getValue())) { + isIncludePartition = false; + break; + } + } + + if (constraint.predicate().isPresent() && !constraint.predicate().get().test(newPartition.getKeys())) { + isIncludePartition = false; + } + + if (isIncludePartition) { + partitions.add(newPartition); + } + } + runtimeStats.addMetricValue(loadingCountKey, NONE, 1); + runtimeStats.addMetricValue(loadingTimeKey, NANO, System.nanoTime() - startTime); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + return ImmutableList.copyOf(partitions); + } + + @Override + public synchronized boolean isEmpty() + { + if (isEmptyTable) { + return true; + } + + try (CloseableIterable fileScanTasks = tableScan.planFiles(); + CloseableIterator iterator = fileScanTasks.iterator()) { + return !iterator.hasNext(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPlugin.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPlugin.java index cd9cdab76bcfb..f1ef5379b36c1 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPlugin.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergPlugin.java @@ -13,16 +13,13 @@ */ package com.facebook.presto.iceberg; -import com.facebook.presto.iceberg.function.changelog.ApplyChangelogFunction; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.connector.ConnectorFactory; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import javax.management.MBeanServer; import java.lang.management.ManagementFactory; -import java.util.Set; public class IcebergPlugin implements Plugin @@ -44,12 +41,4 @@ public Iterable getConnectorFactories() { return ImmutableList.of(new IcebergConnectorFactory(mBeanServer)); } - - @Override - public Set> getFunctions() - { - return ImmutableSet.>builder() - .add(ApplyChangelogFunction.class) - .build(); - } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java new file mode 100644 index 0000000000000..e5f2d04c97e4c --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergProcedureContext.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; + +import static java.util.Objects.requireNonNull; + +public class IcebergProcedureContext + implements ConnectorProcedureContext +{ + final Table table; + final Transaction transaction; + + public IcebergProcedureContext(Table table, Transaction transaction) + { + this.table = requireNonNull(table, "table is null"); + this.transaction = requireNonNull(transaction, "transaction is null"); + } + + public Table getTable() + { + return table; + } + + public Transaction getTransaction() + { + return transaction; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java index a11338b2c5f64..c789da98caa02 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSessionProperties.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.iceberg; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.cache.CacheConfig; import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.hive.OrcFileWriterConfig; @@ -20,32 +21,33 @@ import com.facebook.presto.iceberg.nessie.IcebergNessieConfig; import com.facebook.presto.iceberg.util.StatisticsUtil; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; +import jakarta.inject.Inject; import org.apache.parquet.column.ParquetProperties; -import javax.inject.Inject; - import java.util.EnumSet; import java.util.List; import java.util.Optional; +import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.iceberg.util.StatisticsUtil.SUPPORTED_MERGE_FLAGS; import static com.facebook.presto.iceberg.util.StatisticsUtil.decodeMergeFlags; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static com.facebook.presto.spi.session.PropertyMetadata.booleanProperty; import static com.facebook.presto.spi.session.PropertyMetadata.doubleProperty; import static com.facebook.presto.spi.session.PropertyMetadata.integerProperty; import static com.facebook.presto.spi.session.PropertyMetadata.longProperty; import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; +import static java.lang.String.format; public final class IcebergSessionProperties { - private static final String COMPRESSION_CODEC = "compression_codec"; private static final String PARQUET_WRITER_BLOCK_SIZE = "parquet_writer_block_size"; private static final String PARQUET_WRITER_PAGE_SIZE = "parquet_writer_page_size"; private static final String PARQUET_WRITER_VERSION = "parquet_writer_version"; @@ -58,15 +60,18 @@ public final class IcebergSessionProperties private static final String MINIMUM_ASSIGNED_SPLIT_WEIGHT = "minimum_assigned_split_weight"; private static final String NESSIE_REFERENCE_NAME = "nessie_reference_name"; private static final String NESSIE_REFERENCE_HASH = "nessie_reference_hash"; + static final String COMPRESSION_CODEC = "compression_codec"; public static final String PARQUET_DEREFERENCE_PUSHDOWN_ENABLED = "parquet_dereference_pushdown_enabled"; public static final String MERGE_ON_READ_MODE_ENABLED = "merge_on_read_enabled"; public static final String PUSHDOWN_FILTER_ENABLED = "pushdown_filter_enabled"; public static final String DELETE_AS_JOIN_REWRITE_ENABLED = "delete_as_join_rewrite_enabled"; + public static final String DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS = "delete_as_join_rewrite_max_delete_columns"; public static final String HIVE_METASTORE_STATISTICS_MERGE_STRATEGY = "hive_statistics_merge_strategy"; public static final String STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT = "statistic_snapshot_record_difference_weight"; public static final String ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD = "rows_for_metadata_optimization_threshold"; public static final String STATISTICS_KLL_SKETCH_K_PARAMETER = "statistics_kll_sketch_k_parameter"; public static final String TARGET_SPLIT_SIZE_BYTES = "target_split_size_bytes"; + public static final String MATERIALIZED_VIEW_STORAGE_PREFIX = "materialized_view_storage_prefix"; private final List> sessionProperties; @@ -181,6 +186,23 @@ public IcebergSessionProperties( "When enabled equality delete row filtering will be pushed down into a join.", icebergConfig.isDeleteAsJoinRewriteEnabled(), false)) + .add(new PropertyMetadata<>( + DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS, + "The maximum number of columns that can be used in a delete as join rewrite. " + + "If the number of columns exceeds this value, the delete as join rewrite will not be applied.", + INTEGER, + Integer.class, + icebergConfig.getDeleteAsJoinRewriteMaxDeleteColumns(), + false, + value -> { + int intValue = ((Number) value).intValue(); + if (intValue < 0 || intValue > 400) { + throw new PrestoException(INVALID_SESSION_PROPERTY, + format("Invalid value for %s: %s. It must be between 0 and 400.", DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS, intValue)); + } + return intValue; + }, + integer -> integer)) .add(integerProperty( ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD, "The max partitions number to utilize metadata optimization. When partitions number " + @@ -196,6 +218,13 @@ public IcebergSessionProperties( TARGET_SPLIT_SIZE_BYTES, "The target split size. Set to 0 to use the iceberg table's read.split.target-size property", 0L, + false)) + .add(stringProperty( + MATERIALIZED_VIEW_STORAGE_PREFIX, + "Default prefix for generated materialized view storage table names. " + + "This is only used when the storage_table table property is not explicitly set. " + + "When a custom table name is provided, it takes precedence over this prefix.", + icebergConfig.getMaterializedViewStoragePrefix(), false)); nessieConfig.ifPresent((config) -> propertiesBuilder @@ -311,6 +340,11 @@ public static boolean isDeleteToJoinPushdownEnabled(ConnectorSession session) return session.getProperty(DELETE_AS_JOIN_REWRITE_ENABLED, Boolean.class); } + public static int getDeleteAsJoinRewriteMaxDeleteColumns(ConnectorSession session) + { + return session.getProperty(DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS, Integer.class); + } + public static int getRowsForMetadataOptimizationThreshold(ConnectorSession session) { return session.getProperty(ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD, Integer.class); @@ -335,4 +369,9 @@ public static Long getTargetSplitSize(ConnectorSession session) { return session.getProperty(TARGET_SPLIT_SIZE_BYTES, Long.class); } + + public static String getMaterializedViewStoragePrefix(ConnectorSession session) + { + return session.getProperty(MATERIALIZED_VIEW_STORAGE_PREFIX, String.class); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java index 7b1415a4beced..aeaf5dea8b202 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergSplitManager.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.iceberg.DeleteFile; import org.apache.iceberg.IncrementalChangelogScan; import org.apache.iceberg.Table; @@ -34,8 +35,6 @@ import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; @@ -83,6 +82,7 @@ public ConnectorSplitSource getSplits( TupleDomain predicate = getNonMetadataColumnConstraints(layoutHandle .getValidPredicate()); + Table icebergTable = getIcebergTable(transactionManager.get(transaction), session, table.getSchemaTableName()); if (table.getIcebergTableName().getTableType() == CHANGELOG) { @@ -91,6 +91,7 @@ public ConnectorSplitSource getSplits( long toSnapshot = table.getIcebergTableName().getChangelogEndSnapshot() .orElseGet(icebergTable.currentSnapshot()::snapshotId); IncrementalChangelogScan scan = icebergTable.newIncrementalChangelogScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())) .fromSnapshotExclusive(fromSnapshot) .toSnapshot(toSnapshot); return new ChangelogSplitSource(session, typeManager, icebergTable, scan); @@ -100,12 +101,14 @@ else if (table.getIcebergTableName().getTableType() == EQUALITY_DELETES) { table.getIcebergTableName().getSnapshotId().get(), predicate, table.getPartitionSpecId(), - table.getEqualityFieldIds()); + table.getEqualityFieldIds(), + session.getRuntimeStats()); return new EqualityDeletesSplitSource(session, icebergTable, deleteFiles); } else { TableScan tableScan = icebergTable.newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())) .filter(toIcebergExpression(predicate)) .useSnapshot(table.getIcebergTableName().getSnapshotId().get()) .planWith(executor); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableHandle.java index 633f80d51eb33..fbff48e18d1f4 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableHandle.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableHandle.java @@ -15,6 +15,7 @@ import com.facebook.presto.hive.BaseHiveTableHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.SchemaTableName; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -40,6 +41,7 @@ public class IcebergTableHandle private final Optional> equalityFieldIds; private final List sortOrder; private final List updatedColumns; + private final Optional materializedViewName; @JsonCreator public IcebergTableHandle( @@ -52,7 +54,8 @@ public IcebergTableHandle( @JsonProperty("partitionFieldIds") Optional> partitionFieldIds, @JsonProperty("equalityFieldIds") Optional> equalityFieldIds, @JsonProperty("sortOrder") List sortOrder, - @JsonProperty("updatedColumns") List updatedColumns) + @JsonProperty("updatedColumns") List updatedColumns, + @JsonProperty("materializedViewName") Optional materializedViewName) { super(schemaName, icebergTableName.getTableName()); @@ -65,6 +68,7 @@ public IcebergTableHandle( this.equalityFieldIds = requireNonNull(equalityFieldIds, "equalityFieldIds is null"); this.sortOrder = ImmutableList.copyOf(requireNonNull(sortOrder, "sortOrder is null")); this.updatedColumns = requireNonNull(updatedColumns, "updatedColumns is null"); + this.materializedViewName = requireNonNull(materializedViewName, "materializedViewName is null"); } @JsonProperty @@ -121,6 +125,12 @@ public List getUpdatedColumns() return updatedColumns; } + @JsonProperty + public Optional getMaterializedViewName() + { + return materializedViewName; + } + public IcebergTableHandle withUpdatedColumns(List updatedColumns) { return new IcebergTableHandle( @@ -133,7 +143,8 @@ public IcebergTableHandle withUpdatedColumns(List updatedCo partitionFieldIds, equalityFieldIds, sortOrder, - updatedColumns); + updatedColumns, + materializedViewName); } @Override @@ -152,13 +163,14 @@ public boolean equals(Object o) snapshotSpecified == that.snapshotSpecified && Objects.equals(sortOrder, that.sortOrder) && Objects.equals(tableSchemaJson, that.tableSchemaJson) && - Objects.equals(equalityFieldIds, that.equalityFieldIds); + Objects.equals(equalityFieldIds, that.equalityFieldIds) && + Objects.equals(materializedViewName, that.materializedViewName); } @Override public int hashCode() { - return Objects.hash(getSchemaName(), icebergTableName, sortOrder, snapshotSpecified, tableSchemaJson, equalityFieldIds); + return Objects.hash(getSchemaName(), icebergTableName, sortOrder, snapshotSpecified, tableSchemaJson, equalityFieldIds, materializedViewName); } @Override diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableLayoutHandle.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableLayoutHandle.java index 0b98757367486..ba4a18e63fd86 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableLayoutHandle.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableLayoutHandle.java @@ -17,7 +17,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.hive.BaseHiveColumnHandle; import com.facebook.presto.hive.BaseHiveTableLayoutHandle; -import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.hive.PartitionSet; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.relation.RowExpression; @@ -77,7 +77,7 @@ public IcebergTableLayoutHandle( Optional> requestedColumns, boolean pushdownFilterEnabled, TupleDomain partitionColumnPredicate, - Optional> partitions, + Optional partitions, IcebergTableHandle table) { super( @@ -171,7 +171,7 @@ public static class Builder private Optional> requestedColumns; private boolean pushdownFilterEnabled; private TupleDomain partitionColumnPredicate; - private Optional> partitions; + private Optional partitions; private IcebergTableHandle table; public Builder setPartitionColumns(List partitionColumns) @@ -222,7 +222,7 @@ public Builder setPartitionColumnPredicate(TupleDomain partitionCo return this; } - public Builder setPartitions(Optional> partitions) + public Builder setPartitions(Optional partitions) { this.partitions = partitions; return this; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableName.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableName.java index b199a9569ea25..e619cdbcc2ef0 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableName.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableName.java @@ -28,6 +28,7 @@ import static com.facebook.presto.iceberg.IcebergTableType.FILES; import static com.facebook.presto.iceberg.IcebergTableType.HISTORY; import static com.facebook.presto.iceberg.IcebergTableType.MANIFESTS; +import static com.facebook.presto.iceberg.IcebergTableType.METADATA_LOG_ENTRIES; import static com.facebook.presto.iceberg.IcebergTableType.PARTITIONS; import static com.facebook.presto.iceberg.IcebergTableType.PROPERTIES; import static com.facebook.presto.iceberg.IcebergTableType.REFS; @@ -51,7 +52,7 @@ public class IcebergTableName private final Optional changelogEndSnapshot; - private static final Set SYSTEM_TABLES = Sets.immutableEnumSet(FILES, MANIFESTS, PARTITIONS, HISTORY, SNAPSHOTS, PROPERTIES, REFS); + private static final Set SYSTEM_TABLES = Sets.immutableEnumSet(FILES, MANIFESTS, PARTITIONS, HISTORY, SNAPSHOTS, PROPERTIES, REFS, METADATA_LOG_ENTRIES); @JsonCreator public IcebergTableName( diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableProperties.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableProperties.java index 5d537100c5fef..450dd3e5f786a 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableProperties.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableProperties.java @@ -22,14 +22,14 @@ import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.iceberg.RowLevelOperationMode; import org.apache.iceberg.TableProperties; -import javax.inject.Inject; - import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -45,6 +45,11 @@ import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static org.apache.iceberg.TableProperties.COMMIT_NUM_RETRIES; +import static org.apache.iceberg.TableProperties.HIVE_LOCK_ENABLED; +import static org.apache.iceberg.TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED; +import static org.apache.iceberg.TableProperties.METRICS_MAX_INFERRED_COLUMN_DEFAULTS; +import static org.apache.iceberg.TableProperties.ORC_COMPRESSION; +import static org.apache.iceberg.TableProperties.PARQUET_COMPRESSION; import static org.apache.iceberg.TableProperties.UPDATE_MODE; import static org.apache.iceberg.TableProperties.WRITE_DATA_LOCATION; @@ -100,14 +105,19 @@ public class IcebergTableProperties .put(COMMIT_RETRIES, TableProperties.COMMIT_NUM_RETRIES) .put(DELETE_MODE, TableProperties.DELETE_MODE) .put(METADATA_PREVIOUS_VERSIONS_MAX, TableProperties.METADATA_PREVIOUS_VERSIONS_MAX) - .put(METADATA_DELETE_AFTER_COMMIT, TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED) - .put(METRICS_MAX_INFERRED_COLUMN, TableProperties.METRICS_MAX_INFERRED_COLUMN_DEFAULTS) + .put(METADATA_DELETE_AFTER_COMMIT, METADATA_DELETE_AFTER_COMMIT_ENABLED) + .put(METRICS_MAX_INFERRED_COLUMN, METRICS_MAX_INFERRED_COLUMN_DEFAULTS) .build(); private static final Set UPDATABLE_PROPERTIES = ImmutableSet.builder() .add(COMMIT_RETRIES) .add(COMMIT_NUM_RETRIES) .add(TARGET_SPLIT_SIZE) + .add(METADATA_DELETE_AFTER_COMMIT) + .add(METADATA_DELETE_AFTER_COMMIT_ENABLED) + .add(METADATA_PREVIOUS_VERSIONS_MAX) + .add(HIVE_LOCK_ENABLED) + .add(TableProperties.METADATA_PREVIOUS_VERSIONS_MAX) .build(); private static final String DEFAULT_FORMAT_VERSION = "2"; @@ -119,7 +129,7 @@ public class IcebergTableProperties @Inject public IcebergTableProperties(IcebergConfig icebergConfig) { - List> properties = ImmutableList.>builder() + List> baseTableProperties = ImmutableList.>builder() .add(new PropertyMetadata<>( PARTITIONING_PROPERTY, "Partition transforms", @@ -184,15 +194,20 @@ public IcebergTableProperties(IcebergConfig icebergConfig) icebergConfig.getMetadataPreviousVersionsMax(), false)) .add(booleanProperty( - TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED, + METADATA_DELETE_AFTER_COMMIT_ENABLED, "Whether enables to delete the oldest metadata file after commit", icebergConfig.isMetadataDeleteAfterCommit(), false)) .add(integerProperty( - TableProperties.METRICS_MAX_INFERRED_COLUMN_DEFAULTS, + METRICS_MAX_INFERRED_COLUMN_DEFAULTS, "The maximum number of columns for which metrics are collected", icebergConfig.getMetricsMaxInferredColumn(), false)) + .add(booleanProperty( + HIVE_LOCK_ENABLED, + "Whether to enable hive locks", + null, + false)) .add(new PropertyMetadata<>( UPDATE_MODE, "Update mode for the table", @@ -206,9 +221,19 @@ public IcebergTableProperties(IcebergConfig icebergConfig) "Desired size of split to generate during query scan planning", TableProperties.SPLIT_SIZE_DEFAULT, false)) + .add(stringProperty( + PARQUET_COMPRESSION, + "Compression codec for Parquet format", + null, + false)) + .add(stringProperty( + ORC_COMPRESSION, + "Compression codec for ORC format", + null, + false)) .build(); - deprecatedPropertyMetadata = properties.stream() + deprecatedPropertyMetadata = baseTableProperties.stream() .filter(prop -> DEPRECATED_PROPERTIES.inverse().containsKey(prop.getName())) .map(prop -> new PropertyMetadata<>( DEPRECATED_PROPERTIES.inverse().get(prop.getName()), @@ -222,15 +247,23 @@ public IcebergTableProperties(IcebergConfig icebergConfig) .collect(toImmutableMap(property -> property.getName(), property -> property)); tableProperties = ImmutableList.>builder() - .addAll(properties) + .addAll(baseTableProperties) .addAll(deprecatedPropertyMetadata.values().iterator()) .build(); - columnProperties = ImmutableList.of(stringProperty( - PARTITIONING_PROPERTY, - "This column's partition transform", - null, - false)); + columnProperties = ImmutableList.of( + new PropertyMetadata<>( + PARTITIONING_PROPERTY, + "This column's partition transforms, supports both string expressions (e.g., 'bucket(4)') and array expressions (e.g. ARRAY['bucket(4)', 'identity'])", + new ArrayType(VARCHAR), + List.class, + ImmutableList.of(), + false, + value -> ((Collection) value).stream() + .map(name -> ((String) name).toLowerCase(ENGLISH)) + .collect(toImmutableList()), + value -> value) + .withAdditionalTypeHandler(VARCHAR, ImmutableList::of)); } public List> getTableProperties() @@ -243,7 +276,7 @@ public List> getColumnProperties() return columnProperties; } - public Set getUpdatableProperties() + public static Set getUpdatableProperties() { return UPDATABLE_PROPERTIES; } @@ -252,31 +285,38 @@ public Set getUpdatableProperties() * @return a map of deprecated property name to new property name, or null if the property is * removed entirely. */ - public Map getDeprecatedProperties() + public static Map getDeprecatedProperties() { return DEPRECATED_PROPERTIES; } + public boolean isTablePropertySupported(String propertyName) + { + return tableProperties.stream() + .map(PropertyMetadata::getName) + .anyMatch(name -> name.equalsIgnoreCase(propertyName)); + } + public FileFormat getFileFormat(ConnectorSession session, Map tableProperties) { return (FileFormat) getTablePropertyWithDeprecationWarning(session, tableProperties, TableProperties.DEFAULT_FILE_FORMAT); } @SuppressWarnings("unchecked") - public List getPartitioning(Map tableProperties) + public static List getPartitioning(Map tableProperties) { List partitioning = (List) tableProperties.get(PARTITIONING_PROPERTY); return partitioning == null ? ImmutableList.of() : ImmutableList.copyOf(partitioning); } @SuppressWarnings("unchecked") - public List getSortOrder(Map tableProperties) + public static List getSortOrder(Map tableProperties) { List sortedBy = (List) tableProperties.get(SORTED_BY_PROPERTY); return sortedBy == null ? ImmutableList.of() : ImmutableList.copyOf(sortedBy); } - public String getTableLocation(Map tableProperties) + public static String getTableLocation(Map tableProperties) { return (String) tableProperties.get(LOCATION_PROPERTY); } @@ -285,6 +325,10 @@ public static String getWriteDataLocation(Map tableProperties) { return (String) tableProperties.get(WRITE_DATA_LOCATION); } + public static Optional isHiveLocksEnabled(Map tableProperties) + { + return tableProperties.containsKey(HIVE_LOCK_ENABLED) ? Optional.of(String.valueOf(tableProperties.get(HIVE_LOCK_ENABLED))) : Optional.empty(); + } public String getFormatVersion(ConnectorSession session, Map tableProperties) { @@ -308,12 +352,12 @@ public Integer getMetadataPreviousVersionsMax(ConnectorSession session, Map tableProperties) { - return (Boolean) getTablePropertyWithDeprecationWarning(session, tableProperties, TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED); + return (Boolean) getTablePropertyWithDeprecationWarning(session, tableProperties, METADATA_DELETE_AFTER_COMMIT_ENABLED); } public Integer getMetricsMaxInferredColumn(ConnectorSession session, Map tableProperties) { - return (Integer) getTablePropertyWithDeprecationWarning(session, tableProperties, TableProperties.METRICS_MAX_INFERRED_COLUMN_DEFAULTS); + return (Integer) getTablePropertyWithDeprecationWarning(session, tableProperties, METRICS_MAX_INFERRED_COLUMN_DEFAULTS); } public RowLevelOperationMode getUpdateMode(Map tableProperties) diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableType.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableType.java index 7c3c5d73c3b7f..68a113ca62477 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableType.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergTableType.java @@ -22,6 +22,7 @@ public enum IcebergTableType PARTITIONS(true), FILES(true), REFS(true), + METADATA_LOG_ENTRIES(true), PROPERTIES(true), CHANGELOG(true), EQUALITY_DELETES(true), diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java index 16cc207683a9b..8a0bbdd1b16e8 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUpdateablePageSource.java @@ -15,9 +15,12 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.ColumnarRow; import com.facebook.presto.common.block.RowBlock; +import com.facebook.presto.common.block.RunLengthEncodedBlock; import com.facebook.presto.hive.HivePartitionKey; +import com.facebook.presto.iceberg.delete.DeleteFilter; import com.facebook.presto.iceberg.delete.IcebergDeletePageSink; import com.facebook.presto.iceberg.delete.RowPredicate; import com.facebook.presto.spi.ConnectorPageSource; @@ -25,8 +28,10 @@ import com.facebook.presto.spi.UpdatablePageSource; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import org.apache.iceberg.Schema; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.Pair; import java.io.IOException; import java.io.UncheckedIOException; @@ -34,14 +39,19 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; +import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_BAD_DATA; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_MISSING_COLUMN; import static com.google.common.base.Throwables.throwIfInstanceOf; @@ -64,12 +74,14 @@ public class IcebergUpdateablePageSource private final Supplier deleteSinkSupplier; private IcebergDeletePageSink positionDeleteSink; private final Supplier> deletePredicate; + private final Supplier> deleteFilters; private final List columns; /** * Columns actually updated in the query */ private final List updatedColumns; + private final List delegateColumns; private final Schema tableSchema; private final Supplier updatedRowPageSinkSupplier; private IcebergPageSink updatedRowPageSink; @@ -83,6 +95,8 @@ public class IcebergUpdateablePageSource // Maps the Iceberg field ids of modified columns to their indexes in the updatedColumns columnValueAndRowIdChannels array private final Map columnIdentityToUpdatedColumnIndex = new HashMap<>(); private final int[] outputColumnToDelegateMapping; + private final int isDeletedColumnId; + private final int deleteFilePathColumnId; public IcebergUpdateablePageSource( Schema tableSchema, @@ -95,30 +109,33 @@ public IcebergUpdateablePageSource( List delegateColumns, Supplier deleteSinkSupplier, Supplier> deletePredicate, + Supplier> deleteFilters, Supplier updatedRowPageSinkSupplier, // the columns that this page source is supposed to update List updatedColumns, - Optional updateRowIdColumn) + Optional rowIdColumn) { requireNonNull(partitionKeys, "partitionKeys is null"); this.tableSchema = requireNonNull(tableSchema, "tableSchema is null"); this.columns = requireNonNull(outputColumns, "columns is null"); this.delegate = requireNonNull(delegate, "delegate is null"); + this.delegateColumns = requireNonNull(delegateColumns, "delegateColumns is null"); // information for deletes this.deleteSinkSupplier = deleteSinkSupplier; this.deletePredicate = requireNonNull(deletePredicate, "deletePredicate is null"); + this.deleteFilters = requireNonNull(deleteFilters, "deleteFilters is null"); // information for updates this.updatedRowPageSinkSupplier = requireNonNull(updatedRowPageSinkSupplier, "updatedRowPageSinkSupplier is null"); this.updatedColumns = requireNonNull(updatedColumns, "updatedColumns is null"); this.outputColumnToDelegateMapping = new int[columns.size()]; - this.updateRowIdColumnIndex = updateRowIdColumn.map(columns::indexOf).orElse(-1); - this.updateRowIdChildColumnIndexes = updateRowIdColumn + this.updateRowIdColumnIndex = rowIdColumn.map(columns::indexOf).orElse(-1); + this.updateRowIdChildColumnIndexes = rowIdColumn .map(column -> new int[column.getColumnIdentity().getChildren().size()]) .orElse(new int[0]); Map columnToIndex = IntStream.range(0, delegateColumns.size()) .boxed() .collect(toImmutableMap(index -> delegateColumns.get(index).getColumnIdentity(), identity())); - updateRowIdColumn.ifPresent(column -> { + rowIdColumn.ifPresent(column -> { List rowIdFields = column.getColumnIdentity().getChildren(); for (int i = 0; i < rowIdFields.size(); i++) { ColumnIdentity columnIdentity = rowIdFields.get(i); @@ -134,17 +151,20 @@ public IcebergUpdateablePageSource( } } for (int i = 0; i < outputColumnToDelegateMapping.length; i++) { - if (outputColumns.get(i).isUpdateRowIdColumn()) { + IcebergColumnHandle outputColumn = outputColumns.get(i); + if (outputColumn.isUpdateRowIdColumn() || outputColumn.isMergeTargetTableRowIdColumn()) { continue; } - if (!columnToIndex.containsKey(outputColumns.get(i).getColumnIdentity())) { - throw new PrestoException(ICEBERG_MISSING_COLUMN, format("Column %s not found in delegate column map", outputColumns.get(i))); + if (!columnToIndex.containsKey(outputColumn.getColumnIdentity())) { + throw new PrestoException(ICEBERG_MISSING_COLUMN, format("Column %s not found in delegate column map", outputColumn)); } else { - outputColumnToDelegateMapping[i] = columnToIndex.get(outputColumns.get(i).getColumnIdentity()); + outputColumnToDelegateMapping[i] = columnToIndex.get(outputColumn.getColumnIdentity()); } } + this.isDeletedColumnId = getDelegateColumnId(IcebergColumnHandle::isDeletedColumn); + this.deleteFilePathColumnId = getDelegateColumnId(IcebergColumnHandle::isDeleteFilePathColumn); } @Override @@ -179,7 +199,7 @@ public boolean isFinished() * {@link IcebergPartitionInsertingPageSource}. * 2. Using the newly retrieved page, apply any necessary delete filters. * 3. Finally, take the necessary channels from the page with the delete filters applied and - * nest them into the updateRowId channel in {@link #setUpdateRowIdBlock(Page)} + * nest them into the updateRowId channel in {@link #setRowIdBlock(Page)} */ @Override public Page getNextPage() @@ -191,11 +211,26 @@ public Page getNextPage() } Optional deleteFilterPredicate = deletePredicate.get(); - if (deleteFilterPredicate.isPresent()) { + if (isDeletedColumnId != -1 || deleteFilePathColumnId != -1) { + if (isDeletedColumnId != -1) { + if (deleteFilterPredicate.isPresent()) { + // Instead of filtering rows, we mark whether the row is deleted in the $deleted column + dataPage = deleteFilterPredicate.get().markDeleted(dataPage, isDeletedColumnId); + } + else { + Block allFalseBlock = RunLengthEncodedBlock.create(BOOLEAN, false, dataPage.getPositionCount()); + dataPage = dataPage.replaceColumn(isDeletedColumnId, allFalseBlock); + } + } + if (deleteFilePathColumnId != -1) { + dataPage = markDeleteFilePath(dataPage, deleteFilePathColumnId); + } + } + else if (deleteFilterPredicate.isPresent()) { dataPage = deleteFilterPredicate.get().filterPage(dataPage); } - return setUpdateRowIdBlock(dataPage); + return setRowIdBlock(dataPage); } catch (RuntimeException e) { closeWithSuppression(e); @@ -213,6 +248,14 @@ public void deleteRows(Block rowIds) positionDeleteSink.appendPage(new Page(rowIds)); } + /** + * @param page This page contains the following channels: + *
    + *
  • One channel for the row ID, which includes the position number of this row within the file and the values of the unmodified columns.
  • + *
  • One additional channel for each updated column. These channels contain the new values for the updated columns.
  • + *
+ * @param columnValueAndRowIdChannels Channel numbers of the column values and the row ID's channel number at the end of the list. + */ @Override public void updateRows(Page page, List columnValueAndRowIdChannels) { @@ -234,6 +277,7 @@ public void updateRows(Page page, List columnValueAndRowIdChannels) Set updatedColumnFieldIds = columnIdentityToUpdatedColumnIndex.keySet(); List tableColumns = tableSchema.columns(); Block[] fullPage = new Block[tableColumns.size()]; + // Build a page that will contain the values of the updated rows. The rows stored in the "fullPage" include both updated and non-updated field values. for (int targetChannel = 0; targetChannel < tableColumns.size(); targetChannel++) { Types.NestedField column = tableColumns.get(targetChannel); ColumnIdentity columnIdentity = ColumnIdentity.createColumnIdentity(column); @@ -275,18 +319,20 @@ public void abort() } /** - * The $row_id column used for updates is a composite column of at least one other column in the Page. + * The $row_id column used for updates and merge is a composite column of at least one other column in the Page. * The indexes of the columns needed for the $row_id are in the updateRowIdChildColumnIndexes array. * * @param page The raw Page from the Parquet/ORC reader. * @return A Page where the $row_id channel has been populated. */ - private Page setUpdateRowIdBlock(Page page) + private Page setRowIdBlock(Page page) { Block[] fullPage = new Block[columns.size()]; Block[] rowIdFields; Consumer loopFunc; - if (updateRowIdColumnIndex == -1 || updatedColumns.isEmpty()) { + boolean isMergeTargetTable = columns.stream().anyMatch(IcebergColumnHandle::isMergeTargetTableRowIdColumn); + + if ((updateRowIdColumnIndex == -1 || updatedColumns.isEmpty()) && !isMergeTargetTable) { loopFunc = (channel) -> fullPage[channel] = page.getBlock(outputColumnToDelegateMapping[channel]); } else { @@ -311,6 +357,87 @@ private Page setUpdateRowIdBlock(Page page) return new Page(page.getPositionCount(), fullPage); } + private int getDelegateColumnId(Predicate columnPredicate) + { + int targetColumnId = -1; + for (int i = 0; i < columns.size(); i++) { + if (columnPredicate.test(columns.get(i))) { + targetColumnId = i; + break; + } + } + if (targetColumnId == -1) { + return -1; + } + return outputColumnToDelegateMapping[targetColumnId]; + } + + private Page markDeleteFilePath(Page page, int deleteFilePathDelegateColumnId) + { + List> filterPredicates = deleteFilters.get().stream() + .map(filter -> Pair.of(filter, filter.createPredicate(delegateColumns))) + .collect(Collectors.toList()); + + int positionCount = page.getPositionCount(); + if (positionCount == 0) { + return page; + } + + boolean allSameValues = true; + Optional firstValue = getDeleteFilePath(page, 0, filterPredicates); + BlockBuilder blockBuilder = null; + // Build the varchar block with the deleted file path or null if the row isn't deleted + for (int position = 1; position < positionCount; position++) { + Optional deleteFilePath = getDeleteFilePath(page, position, filterPredicates); + if (allSameValues && !Objects.equals(firstValue.orElse(null), deleteFilePath.orElse(null))) { + blockBuilder = VARCHAR.createBlockBuilder(null, positionCount); + for (int idx = 0; idx < position; idx++) { + writeStringOrNull(blockBuilder, firstValue); + } + writeStringOrNull(blockBuilder, deleteFilePath); + allSameValues = false; + } + else if (!allSameValues) { + writeStringOrNull(blockBuilder, deleteFilePath); + } + } + + Block block; + if (blockBuilder != null) { + block = blockBuilder.build(); + } + else { + Slice slice = firstValue.map(Slices::utf8Slice).orElse(null); + block = RunLengthEncodedBlock.create(VARCHAR, slice, positionCount); + } + + return page.replaceColumn(deleteFilePathDelegateColumnId, block); + } + + private void writeStringOrNull(BlockBuilder blockBuilder, Optional toWrite) + { + if (toWrite.isPresent()) { + VARCHAR.writeString(blockBuilder, toWrite.get()); + } + else { + blockBuilder.appendNull(); + } + } + + private Optional getDeleteFilePath(Page page, int position, List> filterPredicates) + { + for (Pair pair : filterPredicates) { + boolean deleted = !pair.second().test(page, position); + if (deleted) { + String path = pair.first().getDeleteFilePath().orElse(null); + if (path != null) { + return Optional.of(path); + } + } + } + return Optional.empty(); + } + @Override public void close() { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUtil.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUtil.java index d925a37db40bd..6ec40b607a022 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUtil.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergUtil.java @@ -14,7 +14,9 @@ package com.facebook.presto.iceberg; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.GenericInternalException; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; @@ -28,10 +30,11 @@ import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveColumnConverterProvider; -import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.hive.HivePartitionKey; +import com.facebook.presto.hive.HiveStorageFormat; import com.facebook.presto.hive.HiveType; -import com.facebook.presto.hive.PartitionNameWithVersion; +import com.facebook.presto.hive.PartitionSet; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.MetastoreContext; @@ -49,7 +52,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.airlift.units.DataSize; import org.apache.iceberg.BaseTable; import org.apache.iceberg.ContentFile; import org.apache.iceberg.ContentScanTask; @@ -71,6 +73,7 @@ import org.apache.iceberg.TableOperations; import org.apache.iceberg.TableScan; import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; import org.apache.iceberg.catalog.ViewCatalog; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.hive.HiveSchemaUtil; @@ -84,7 +87,6 @@ import org.apache.iceberg.view.View; import java.io.IOException; -import java.io.UncheckedIOException; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -101,10 +103,10 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.Chars.isCharType; @@ -142,8 +144,8 @@ import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; import static com.facebook.presto.iceberg.IcebergSessionProperties.isMergeOnReadModeEnabled; import static com.facebook.presto.iceberg.IcebergTableProperties.getWriteDataLocation; +import static com.facebook.presto.iceberg.IcebergTableProperties.isHiveLocksEnabled; import static com.facebook.presto.iceberg.TypeConverter.toIcebergType; -import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; import static com.facebook.presto.iceberg.util.IcebergPrestoModelConverters.toIcebergTableIdentifier; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; @@ -158,7 +160,6 @@ import static com.google.common.collect.Streams.stream; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.doubleToRawLongBits; import static java.lang.Double.longBitsToDouble; import static java.lang.Double.parseDouble; @@ -188,6 +189,7 @@ import static org.apache.iceberg.TableProperties.DELETE_MODE; import static org.apache.iceberg.TableProperties.DELETE_MODE_DEFAULT; import static org.apache.iceberg.TableProperties.FORMAT_VERSION; +import static org.apache.iceberg.TableProperties.HIVE_LOCK_ENABLED; import static org.apache.iceberg.TableProperties.MERGE_MODE; import static org.apache.iceberg.TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED; import static org.apache.iceberg.TableProperties.METADATA_DELETE_AFTER_COMMIT_ENABLED_DEFAULT; @@ -211,7 +213,6 @@ public final class IcebergUtil { - private static final Pattern SIMPLE_NAME = Pattern.compile("[a-z][a-z0-9]*"); private static final Logger log = Logger.get(IcebergUtil.class); public static final int MIN_FORMAT_VERSION_FOR_DELETE = 2; @@ -246,7 +247,7 @@ public static Table getShallowWrappedIcebergTable(Schema schema, PartitionSpec s return new PrestoIcebergTableForMetricsConfig(schema, spec, properties, sortOrder); } - public static Table getHiveIcebergTable(ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, IcebergHiveTableOperationsConfig config, ManifestFileCache manifestFileCache, ConnectorSession session, SchemaTableName table) + public static Table getHiveIcebergTable(ExtendedHiveMetastore metastore, HdfsEnvironment hdfsEnvironment, IcebergHiveTableOperationsConfig config, ManifestFileCache manifestFileCache, ConnectorSession session, IcebergCatalogName catalogName, SchemaTableName table) { HdfsContext hdfsContext = new HdfsContext(session, table.getSchemaName(), table.getTableName()); TableOperations operations = new HiveTableOperations( @@ -258,7 +259,7 @@ public static Table getHiveIcebergTable(ExtendedHiveMetastore metastore, HdfsEnv manifestFileCache, table.getSchemaName(), table.getTableName()); - return new BaseTable(operations, quotedTableName(table)); + return new BaseTable(operations, fullTableName(catalogName.getCatalogName(), TableIdentifier.of(table.getSchemaName(), table.getTableName()))); } public static Table getNativeIcebergTable(IcebergNativeCatalogFactory catalogFactory, ConnectorSession session, SchemaTableName table) @@ -272,7 +273,16 @@ public static View getNativeIcebergView(IcebergNativeCatalogFactory catalogFacto if (!(catalog instanceof ViewCatalog)) { throw new PrestoException(NOT_SUPPORTED, "This connector does not support get views"); } - return ((ViewCatalog) catalog).loadView(toIcebergTableIdentifier(table, catalogFactory.isNestedNamespaceEnabled())); + return ((ViewCatalog) catalog).loadView(toIcebergTableIdentifier(getBaseSchemaTableName(table), catalogFactory.isNestedNamespaceEnabled())); + } + + /** + * Removes Iceberg-specific suffixes from the table name + */ + private static SchemaTableName getBaseSchemaTableName(SchemaTableName table) + { + IcebergTableName icebergTableName = IcebergTableName.from(table.getTableName()); + return new SchemaTableName(table.getSchemaName(), icebergTableName.getTableName()); } public static List getPartitionKeyColumnHandles(IcebergTableHandle tableHandle, Table table, TypeManager typeManager) @@ -361,6 +371,25 @@ public static List getColumns(Stream fields, Schem .collect(toImmutableList()); } + public static List getColumnsForWrite(Schema schema, PartitionSpec partitionSpec, TypeManager typeManager) + { + return getColumnsForWrite(schema.columns().stream().map(NestedField::fieldId), schema, partitionSpec, typeManager); + } + + private static List getColumnsForWrite(Stream fields, Schema schema, PartitionSpec partitionSpec, TypeManager typeManager) + { + Set partitionSourceIds = partitionSpec.fields().stream() + .map(PartitionField::sourceId) + .collect(toImmutableSet()); + + return fields + .map(schema::findField) + .map(column -> partitionSourceIds.contains(column.fieldId()) ? + IcebergColumnHandle.create(column, typeManager, PARTITION_KEY) : + IcebergColumnHandle.create(column, typeManager, REGULAR)) + .collect(toImmutableList()); + } + public static Map getIdentityPartitions(PartitionSpec partitionSpec) { // TODO: expose transform information in Iceberg library @@ -388,12 +417,22 @@ public static List toHiveColumns(List columns) return columns.stream() .map(column -> new Column( column.name(), - HiveType.toHiveType(HiveSchemaUtil.convert(column.type())), + icebergTypeToHiveType(column.type()), Optional.empty(), Optional.empty())) .collect(toImmutableList()); } + private static HiveType icebergTypeToHiveType(org.apache.iceberg.types.Type icebergType) + { + // Special handling for TIME type: use bigint instead of 'string' + if (icebergType.typeId() == org.apache.iceberg.types.Type.TypeID.TIME) { + return HiveType.HIVE_LONG; + } + + return HiveType.toHiveType(HiveSchemaUtil.convert(icebergType)); + } + public static FileFormat getFileFormat(Table table) { return FileFormat.valueOf(table.properties() @@ -411,23 +450,13 @@ public static Optional getViewComment(View view) return Optional.ofNullable(view.properties().get(TABLE_COMMENT)); } - private static String quotedTableName(SchemaTableName name) - { - return quotedName(name.getSchemaName()) + "." + quotedName(name.getTableName()); - } - - private static String quotedName(String name) - { - if (SIMPLE_NAME.matcher(name).matches()) { - return name; - } - return '"' + name.replace("\"", "\"\"") + '"'; - } - - public static TableScan getTableScan(TupleDomain predicates, Optional snapshotId, Table icebergTable) + public static TableScan getTableScan(TupleDomain predicates, Optional snapshotId, Table icebergTable, RuntimeStats runtimeStats) { Expression expression = ExpressionConverter.toIcebergExpression(predicates); - TableScan tableScan = icebergTable.newScan().filter(expression); + TableScan tableScan = icebergTable + .newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(runtimeStats)) + .filter(expression); return snapshotId .map(id -> isSnapshot(icebergTable, id) ? tableScan.useSnapshot(id) : tableScan.asOfTime(id)) .orElse(tableScan); @@ -448,9 +477,11 @@ public static LocationProvider getLocationProvider(SchemaTableName schemaTableNa return locationsFor(tableLocation, storageProperties); } - public static TableScan buildTableScan(Table icebergTable, MetadataTableType metadataTableType) + public static TableScan buildTableScan(Table icebergTable, MetadataTableType metadataTableType, RuntimeStats runtimeStats) { - return createMetadataTableInstance(icebergTable, metadataTableType).newScan(); + return createMetadataTableInstance(icebergTable, metadataTableType) + .newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(runtimeStats)); } public static Map columnNameToPositionInSchema(Schema schema) @@ -558,7 +589,7 @@ private static void verifyPartitionTypeSupported(String partitionName, Type type } } - private static NullableValue parsePartitionValue( + protected static NullableValue parsePartitionValue( FileFormat fileFormat, String partitionStringValue, Type prestoType, @@ -602,100 +633,22 @@ else if (constraint.getKey() == DATA_SEQUENCE_NUMBER_COLUMN_HANDLE) { return matches; } - public static List getPartitions( + public static PartitionSet getPartitions( TypeManager typeManager, ConnectorTableHandle tableHandle, Table icebergTable, Constraint constraint, - List partitionColumns) + List partitionColumns, + RuntimeStats runtimeStats) { - IcebergTableName name = ((IcebergTableHandle) tableHandle).getIcebergTableName(); - FileFormat fileFormat = getFileFormat(icebergTable); - // Empty iceberg table would cause `snapshotId` not present - Optional snapshotId = resolveSnapshotIdByName(icebergTable, name); - if (!snapshotId.isPresent()) { - return ImmutableList.of(); - } - - TableScan tableScan = icebergTable.newScan() - .filter(toIcebergExpression(getNonMetadataColumnConstraints(constraint - .getSummary() - .simplify()))) - .useSnapshot(snapshotId.get()); - - Set partitions = new HashSet<>(); - - try (CloseableIterable fileScanTasks = tableScan.planFiles()) { - for (FileScanTask fileScanTask : fileScanTasks) { - // If exists delete files, skip the metadata optimization based on partition values as they might become incorrect - if (!fileScanTask.deletes().isEmpty()) { - return ImmutableList.of(new HivePartition(((IcebergTableHandle) tableHandle).getSchemaTableName())); - } - StructLike partition = fileScanTask.file().partition(); - PartitionSpec spec = fileScanTask.spec(); - Map fieldToIndex = getIdentityPartitions(spec); - ImmutableMap.Builder builder = ImmutableMap.builder(); - - fieldToIndex.forEach((field, index) -> { - int id = field.sourceId(); - org.apache.iceberg.types.Type type = spec.schema().findType(id); - Class javaClass = type.typeId().javaClass(); - Object value = partition.get(index, javaClass); - String partitionStringValue; - - if (value == null) { - partitionStringValue = null; - } - else { - if (type.typeId() == FIXED || type.typeId() == BINARY) { - partitionStringValue = Base64.getEncoder().encodeToString(((ByteBuffer) value).array()); - } - else { - partitionStringValue = value.toString(); - } - } - - NullableValue partitionValue = parsePartitionValue(fileFormat, partitionStringValue, toPrestoType(type, typeManager), partition.toString()); - Optional column = partitionColumns.stream() - .filter(icebergColumnHandle -> Objects.equals(icebergColumnHandle.getId(), field.sourceId())) - .findAny(); - - if (column.isPresent()) { - builder.put(column.get(), partitionValue); - } - }); - - Map values = builder.build(); - HivePartition newPartition = new HivePartition( - ((IcebergTableHandle) tableHandle).getSchemaTableName(), - new PartitionNameWithVersion(partition.toString(), Optional.empty()), - values); - - boolean isIncludePartition = true; - Map domains = constraint.getSummary().getDomains().get(); - for (IcebergColumnHandle column : partitionColumns) { - NullableValue value = newPartition.getKeys().get(column); - Domain allowedDomain = domains.get(column); - if (allowedDomain != null && !allowedDomain.includesNullableValue(value.getValue())) { - isIncludePartition = false; - break; - } - } - - if (constraint.predicate().isPresent() && !constraint.predicate().get().test(newPartition.getKeys())) { - isIncludePartition = false; - } - - if (isIncludePartition) { - partitions.add(newPartition); - } - } - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - - return new ArrayList<>(partitions); + return new PartitionSet( + new IcebergPartitionLoader( + typeManager, + tableHandle, + icebergTable, + constraint, + partitionColumns, + runtimeStats)); } public static Optional tryGetSchema(Table table) @@ -892,10 +845,16 @@ public static CloseableIterable getDeleteFiles(Table table, long snapshot, TupleDomain filter, Optional> requestedPartitionSpec, - Optional> requestedSchema) + Optional> requestedSchema, + RuntimeStats runtimeStats) { Expression filterExpression = toIcebergExpression(filter); - CloseableIterable fileTasks = table.newScan().useSnapshot(snapshot).filter(filterExpression).planFiles(); + CloseableIterable fileTasks = table + .newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(runtimeStats)) + .useSnapshot(snapshot) + .filter(filterExpression) + .planFiles(); return new CloseableIterable() { @@ -1162,12 +1121,19 @@ public static Map populateTableProperties(IcebergAbstractMetadat Integer commitRetries = tableProperties.getCommitRetries(session, tableMetadata.getProperties()); propertiesBuilder.put(DEFAULT_FILE_FORMAT, fileFormat.toString()); propertiesBuilder.put(COMMIT_NUM_RETRIES, String.valueOf(commitRetries)); + HiveCompressionCodec compressionCodec = getCompressionCodec(session); switch (fileFormat) { case PARQUET: - propertiesBuilder.put(PARQUET_COMPRESSION, getCompressionCodec(session).getParquetCompressionCodec().get().toString()); + if (!compressionCodec.isSupportedStorageFormat(HiveStorageFormat.PARQUET)) { + throw new PrestoException(NOT_SUPPORTED, format("Compression codec %s is not supported for Parquet format", compressionCodec)); + } + propertiesBuilder.put(PARQUET_COMPRESSION, compressionCodec.getParquetCompressionCodec().name()); break; case ORC: - propertiesBuilder.put(ORC_COMPRESSION, getCompressionCodec(session).getOrcCompressionKind().name()); + if (!compressionCodec.isSupportedStorageFormat(HiveStorageFormat.ORC)) { + throw new PrestoException(NOT_SUPPORTED, format("Compression codec %s is not supported for ORC format", compressionCodec)); + } + propertiesBuilder.put(ORC_COMPRESSION, compressionCodec.getOrcCompressionKind().name()); break; } if (tableMetadata.getComment().isPresent()) { @@ -1200,6 +1166,8 @@ public static Map populateTableProperties(IcebergAbstractMetadat propertiesBuilder.put(SPLIT_SIZE, String.valueOf(IcebergTableProperties.getTargetSplitSize(tableMetadata.getProperties()))); + isHiveLocksEnabled(tableMetadata.getProperties()).ifPresent(value -> propertiesBuilder.put(HIVE_LOCK_ENABLED, value)); + return propertiesBuilder.build(); } @@ -1329,4 +1297,30 @@ public static DataSize getTargetSplitSize(ConnectorSession session, Scan storageProperties; private final List sortOrder; + private final Optional materializedViewName; public IcebergWritableTableHandle( String schemaName, @@ -46,6 +49,22 @@ public IcebergWritableTableHandle( HiveCompressionCodec compressionCodec, Map storageProperties, List sortOrder) + { + this(schemaName, tableName, schema, partitionSpec, inputColumns, outputPath, fileFormat, compressionCodec, storageProperties, sortOrder, Optional.empty()); + } + + public IcebergWritableTableHandle( + String schemaName, + IcebergTableName tableName, + PrestoIcebergSchema schema, + PrestoIcebergPartitionSpec partitionSpec, + List inputColumns, + String outputPath, + FileFormat fileFormat, + HiveCompressionCodec compressionCodec, + Map storageProperties, + List sortOrder, + Optional materializedViewName) { this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.tableName = requireNonNull(tableName, "tableName is null"); @@ -57,6 +76,7 @@ public IcebergWritableTableHandle( this.compressionCodec = requireNonNull(compressionCodec, "compressionCodec is null"); this.storageProperties = requireNonNull(storageProperties, "storageProperties is null"); this.sortOrder = ImmutableList.copyOf(requireNonNull(sortOrder, "sortOrder is null")); + this.materializedViewName = requireNonNull(materializedViewName, "materializedViewName is null"); } @JsonProperty @@ -124,4 +144,10 @@ public List getSortOrder() { return sortOrder; } + + @JsonProperty + public Optional getMaterializedViewName() + { + return materializedViewName; + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java index fe421aa78fb7b..325e529c1b01b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/InternalIcebergConnectorFactory.java @@ -28,12 +28,15 @@ import com.facebook.presto.hive.gcs.HiveGcsModule; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.s3.HiveS3Module; -import com.facebook.presto.plugin.base.security.AllowAllAccessControl; +import com.facebook.presto.hive.security.SystemTableAwareAccessControl; +import com.facebook.presto.iceberg.security.IcebergSecurityModule; +import com.facebook.presto.spi.ConnectorSystemConfig; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorAccessControl; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; @@ -47,13 +50,14 @@ import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; -import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableSet; import com.google.inject.Injector; import com.google.inject.Key; -import com.google.inject.util.Types; +import com.google.inject.TypeLiteral; import org.weakref.jmx.guice.MBeanModule; import javax.management.MBeanServer; @@ -86,6 +90,7 @@ public static Connector createConnector( new HiveS3Module(catalogName), new HiveGcsModule(), new HiveAuthenticationModule(), + new IcebergSecurityModule(), new CachingModule(), new HiveCommonModule(), binder -> { @@ -93,12 +98,14 @@ public static Connector createConnector( binder.bind(NodeVersion.class).toInstance(new NodeVersion(context.getNodeManager().getCurrentNode().getVersion())); binder.bind(NodeManager.class).toInstance(context.getNodeManager()); binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(ProcedureRegistry.class).toInstance(context.getProcedureRegistry()); binder.bind(PageIndexerFactory.class).toInstance(context.getPageIndexerFactory()); binder.bind(PageSorter.class).toInstance(context.getPageSorter()); binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); binder.bind(FilterStatsCalculatorService.class).toInstance(context.getFilterStatsCalculatorService()); + binder.bind(ConnectorSystemConfig.class).toInstance(context.getConnectorSystemConfig()); }); Injector injector = app @@ -116,11 +123,14 @@ public static Connector createConnector( IcebergSessionProperties icebergSessionProperties = injector.getInstance(IcebergSessionProperties.class); HiveCommonSessionProperties hiveCommonSessionProperties = injector.getInstance(HiveCommonSessionProperties.class); IcebergTableProperties icebergTableProperties = injector.getInstance(IcebergTableProperties.class); - Set procedures = injector.getInstance((Key>) Key.get(Types.setOf(Procedure.class))); + IcebergMaterializedViewProperties icebergMaterializedViewProperties = injector.getInstance(IcebergMaterializedViewProperties.class); + Set> procedures = + injector.getInstance(Key.get(new TypeLiteral>>() {})); ConnectorPlanOptimizerProvider planOptimizerProvider = injector.getInstance(ConnectorPlanOptimizerProvider.class); List> allSessionProperties = new ArrayList<>(icebergSessionProperties.getSessionProperties()); allSessionProperties.addAll(hiveCommonSessionProperties.getSessionProperties()); + ConnectorAccessControl accessControl = new SystemTableAwareAccessControl(injector.getInstance(ConnectorAccessControl.class)); return new IcebergConnector( lifeCycleManager, @@ -134,8 +144,9 @@ public static Connector createConnector( allSessionProperties, SchemaProperties.SCHEMA_PROPERTIES, icebergTableProperties.getTableProperties(), + icebergMaterializedViewProperties.getMaterializedViewProperties(), icebergTableProperties.getColumnProperties(), - new AllowAllAccessControl(), + accessControl, procedures, planOptimizerProvider); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/MetadataLogTable.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/MetadataLogTable.java new file mode 100644 index 0000000000000..96276ed0c3f00 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/MetadataLogTable.java @@ -0,0 +1,145 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.InMemoryRecordSet; +import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SystemTable; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.TableMetadata.MetadataLogEntry; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DateTimeEncoding.packDateTimeWithZone; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static java.util.Objects.requireNonNull; +import static org.apache.iceberg.util.SnapshotUtil.snapshotIdAsOfTime; + +public class MetadataLogTable + implements SystemTable +{ + private final ConnectorTableMetadata tableMetadata; + private final Table icebergTable; + + private static final List COLUMN_DEFINITIONS = ImmutableList.builder() + .add(ColumnMetadata.builder().setName("timestamp").setType(TIMESTAMP_WITH_TIME_ZONE).build()) + .add(ColumnMetadata.builder().setName("file").setType(VARCHAR).build()) + .add(ColumnMetadata.builder().setName("latest_snapshot_id").setType(BIGINT).build()) + .add(ColumnMetadata.builder().setName("latest_schema_id").setType(INTEGER).build()) + .add(ColumnMetadata.builder().setName("latest_sequence_number").setType(BIGINT).build()) + .build(); + + private static final List COLUMN_TYPES = COLUMN_DEFINITIONS.stream() + .map(ColumnMetadata::getType) + .collect(Collectors.toList()); + + public MetadataLogTable(SchemaTableName tableName, Table icebergTable) + { + tableMetadata = new ConnectorTableMetadata(requireNonNull(tableName, "tableName is null"), COLUMN_DEFINITIONS); + this.icebergTable = requireNonNull(icebergTable, "icebergTable is null"); + } + + @Override + public Distribution getDistribution() + { + return Distribution.SINGLE_COORDINATOR; + } + + @Override + public ConnectorTableMetadata getTableMetadata() + { + return tableMetadata; + } + + @Override + public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, ConnectorSession session, TupleDomain constraint) + { + Iterable> rowIterable = () -> new Iterator>() + { + private final Iterator metadataLogEntriesIterator = ((BaseTable) icebergTable).operations().current().previousFiles().iterator(); + private boolean addedLatestEntry; + + @Override + public boolean hasNext() + { + return metadataLogEntriesIterator.hasNext() || !addedLatestEntry; + } + + @Override + public List next() + { + if (metadataLogEntriesIterator.hasNext()) { + return processMetadataLogEntries(session, metadataLogEntriesIterator.next()); + } + if (!addedLatestEntry) { + addedLatestEntry = true; + TableMetadata currentMetadata = ((BaseTable) icebergTable).operations().current(); + return buildLatestMetadataRow(session, currentMetadata); + } + throw new NoSuchElementException(); + } + }; + return new InMemoryRecordSet(COLUMN_TYPES, rowIterable).cursor(); + } + + private List processMetadataLogEntries(ConnectorSession session, MetadataLogEntry metadataLogEntry) + { + Long snapshotId = null; + Snapshot snapshot = null; + try { + snapshotId = snapshotIdAsOfTime(icebergTable, metadataLogEntry.timestampMillis()); + snapshot = icebergTable.snapshot(snapshotId); + } + catch (IllegalArgumentException ignored) { + // Implies this metadata file was created during table creation + } + return addRow(session, metadataLogEntry.timestampMillis(), metadataLogEntry.file(), snapshotId, snapshot); + } + + private List buildLatestMetadataRow(ConnectorSession session, TableMetadata metadata) + { + Snapshot latestSnapshot = icebergTable.currentSnapshot(); + Long latestSnapshotId = (latestSnapshot != null) ? latestSnapshot.snapshotId() : null; + + return addRow(session, metadata.lastUpdatedMillis(), metadata.metadataFileLocation(), latestSnapshotId, latestSnapshot); + } + + private List addRow(ConnectorSession session, long timestampMillis, String fileLocation, Long snapshotId, Snapshot snapshot) + { + return Arrays.asList( + packDateTimeWithZone(timestampMillis, session.getSqlFunctionProperties().getTimeZoneKey()), + fileLocation, + snapshotId, + snapshot != null ? snapshot.schemaId() : null, + snapshot != null ? snapshot.sequenceNumber() : null); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java index 3a2727ffde858..015972ac949a5 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.io.StringWriter; import java.io.UncheckedIOException; +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.util.Arrays; @@ -184,7 +185,18 @@ public static Object getValue(JsonNode partitionValue, Type type) throw new UncheckedIOException("Failed during JSON conversion of " + partitionValue, e); } case DECIMAL: - return partitionValue.decimalValue().setScale(((DecimalType) type).scale()); + if (partitionValue.isLong()) { + return BigDecimal.valueOf(partitionValue.asLong(), ((DecimalType) type).scale()); + } + else if (partitionValue.isInt()) { + return BigDecimal.valueOf(partitionValue.asInt(), ((DecimalType) type).scale()); + } + else if (partitionValue.isBigInteger()) { + return new BigDecimal(partitionValue.bigIntegerValue(), ((DecimalType) type).scale()); + } + else { + return partitionValue.decimalValue().setScale(((DecimalType) type).scale()); + } } throw new UnsupportedOperationException("Type not supported as partition column: " + type); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionFields.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionFields.java index eb6a2463a50ff..d29275e88cbfe 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionFields.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionFields.java @@ -13,15 +13,16 @@ */ package com.facebook.presto.iceberg; +import com.google.common.annotations.VisibleForTesting; +import jakarta.annotation.Nullable; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.expressions.Term; -import javax.annotation.Nullable; - import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.function.Consumer; import java.util.regex.MatchResult; import java.util.regex.Matcher; @@ -78,12 +79,25 @@ public static PartitionSpec parsePartitionFields(Schema schema, List fie .orElseGet(() -> PartitionSpec.builderFor(schema)); for (String field : fields) { - parsePartitionField(builder, field); + buildPartitionField(builder, field); + } + return builder.build(); + } + + public static PartitionSpec parseIcebergPartitionFields(Schema schema, List fields, @Nullable Integer specId) + { + PartitionSpec.Builder builder = Optional.ofNullable(specId) + .map(id -> PartitionSpec.builderFor(schema).withSpecId(id)) + .orElseGet(() -> PartitionSpec.builderFor(schema)); + + for (IcebergPartitionField field : fields) { + buildPartitionSpec(builder, field); } return builder.build(); } - public static void parsePartitionField(PartitionSpec.Builder builder, String field) + @VisibleForTesting + static void buildPartitionField(PartitionSpec.Builder builder, String field) { @SuppressWarnings("PointlessBooleanExpression") boolean matched = false || @@ -99,6 +113,35 @@ public static void parsePartitionField(PartitionSpec.Builder builder, String fie } } + private static void buildPartitionSpec(PartitionSpec.Builder builder, IcebergPartitionField partitionField) + { + String field = partitionField.getName(); + PartitionTransformType type = partitionField.getTransform(); + OptionalInt parameter = partitionField.getParameter(); + switch (type) { + case IDENTITY: + builder.identity(field); + break; + case YEAR: + builder.year(field); + break; + case MONTH: + builder.month(field); + break; + case DAY: + builder.day(field); + break; + case HOUR: + builder.hour(field); + break; + case BUCKET: + builder.bucket(field, parameter.getAsInt()); + break; + case TRUNCATE: + builder.truncate(field, parameter.getAsInt()); + } + } + private static boolean tryMatch(CharSequence value, Pattern pattern, Consumer match) { Matcher matcher = pattern.matcher(value); @@ -116,6 +159,41 @@ public static List toPartitionFields(PartitionSpec spec) .collect(toImmutableList()); } + private static String toPartitionField(PartitionSpec spec, PartitionField field) + { + String name = spec.schema().findColumnName(field.sourceId()); + String transform = field.transform().toString(); + + switch (transform) { + case "identity": + return name; + case "year": + case "month": + case "day": + case "hour": + return format("%s(%s)", transform, name); + } + + Matcher matcher = ICEBERG_BUCKET_PATTERN.matcher(transform); + if (matcher.matches()) { + return format("bucket(%s, %s)", name, matcher.group(1)); + } + + matcher = ICEBERG_TRUNCATE_PATTERN.matcher(transform); + if (matcher.matches()) { + return format("truncate(%s, %s)", name, matcher.group(1)); + } + + throw new UnsupportedOperationException("Unsupported partition transform: " + field); + } + + public static List toIcebergPartitionFields(PartitionSpec spec) + { + return spec.fields().stream() + .map(field -> toIcebergPartitionField(spec, field)) + .collect(toImmutableList()); + } + // Keep consistency with PartitionSpec.Builder protected static String getPartitionColumnName(String columnName, String transform) { @@ -134,11 +212,21 @@ protected static String getPartitionColumnName(String columnName, String transfo return columnName + "_bucket"; } + matcher = ICEBERG_BUCKET_PATTERN.matcher(transform); + if (matcher.matches()) { + return columnName + "_bucket"; + } + matcher = COLUMN_TRUNCATE_PATTERN.matcher(transform); if (matcher.matches()) { return columnName + "_trunc"; } + matcher = ICEBERG_TRUNCATE_PATTERN.matcher(transform); + if (matcher.matches()) { + return columnName + "_trunc"; + } + throw new UnsupportedOperationException("Unknown partition transform: " + transform); } @@ -170,32 +258,23 @@ protected static Term getTransformTerm(String columnName, String transform) throw new UnsupportedOperationException("Unknown partition transform: " + transform); } - private static String toPartitionField(PartitionSpec spec, PartitionField field) + private static IcebergPartitionField toIcebergPartitionField(PartitionSpec spec, PartitionField field) { String name = spec.schema().findColumnName(field.sourceId()); String transform = field.transform().toString(); - - switch (transform) { - case "identity": - return name; - case "year": - case "month": - case "day": - case "hour": - return format("%s(%s)", transform, name); - } - + IcebergPartitionField.Builder builder = IcebergPartitionField.builder(); + builder.setTransform(PartitionTransformType.fromStringOrFail(transform)).setFieldId(field.fieldId()).setSourceId(field.sourceId()).setName(name); Matcher matcher = ICEBERG_BUCKET_PATTERN.matcher(transform); if (matcher.matches()) { - return format("bucket(%s, %s)", name, matcher.group(1)); + builder.setParameter(OptionalInt.of(Integer.parseInt(matcher.group(1)))); + return builder.build(); } - matcher = ICEBERG_TRUNCATE_PATTERN.matcher(transform); if (matcher.matches()) { - return format("truncate(%s, %s)", name, matcher.group(1)); + builder.setParameter(OptionalInt.of(Integer.parseInt(matcher.group(1)))); + return builder.build(); } - - throw new UnsupportedOperationException("Unsupported partition transform: " + field); + return builder.build(); } public static String quotedName(String name) diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionSpecConverter.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionSpecConverter.java index 921168870dc9e..cc191ca6f1be0 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionSpecConverter.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionSpecConverter.java @@ -16,8 +16,8 @@ import com.facebook.presto.common.type.TypeManager; import org.apache.iceberg.PartitionSpec; -import static com.facebook.presto.iceberg.PartitionFields.parsePartitionFields; -import static com.facebook.presto.iceberg.PartitionFields.toPartitionFields; +import static com.facebook.presto.iceberg.PartitionFields.parseIcebergPartitionFields; +import static com.facebook.presto.iceberg.PartitionFields.toIcebergPartitionFields; import static com.facebook.presto.iceberg.SchemaConverter.toIcebergSchema; import static com.facebook.presto.iceberg.SchemaConverter.toPrestoSchema; @@ -30,12 +30,12 @@ public static PrestoIcebergPartitionSpec toPrestoPartitionSpec(PartitionSpec spe return new PrestoIcebergPartitionSpec( spec.specId(), toPrestoSchema(spec.schema(), typeManager), - toPartitionFields(spec)); + toIcebergPartitionFields(spec)); } public static PartitionSpec toIcebergPartitionSpec(PrestoIcebergPartitionSpec spec) { - return parsePartitionFields( + return parseIcebergPartitionFields( toIcebergSchema(spec.getSchema()), spec.getFields(), spec.getSpecId()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTable.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTable.java index 77ec7d4ba1443..08a5887bed551 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTable.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTable.java @@ -162,6 +162,7 @@ public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, Connect return new InMemoryRecordSet(resultTypes, ImmutableList.of()).cursor(); } TableScan tableScan = icebergTable.newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())) .useSnapshot(snapshotId.get()) .includeColumnStats(); return buildRecordCursor(getPartitions(tableScan), icebergTable.spec().fields()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransformType.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransformType.java new file mode 100644 index 0000000000000..819830d282fff --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransformType.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +public enum PartitionTransformType +{ + IDENTITY("identity", 0), + HOUR("hour", 1), + DAY("day", 2), + MONTH("month", 3), + YEAR("year", 4), + BUCKET("bucket", 5), + TRUNCATE("truncate", 6); + + private static final Map TRANSFORM_MAP = new HashMap<>(); + private static final Map CODE_MAP = new HashMap<>(); + + static { + for (PartitionTransformType type : values()) { + TRANSFORM_MAP.put(type.transform, type); + CODE_MAP.put(type.code, type); + } + } + + private final String transform; + private final int code; + + PartitionTransformType(String transform, int code) + { + this.transform = transform; + this.code = code; + } + + public String getTransform() + { + return transform; + } + + public int getCode() + { + return code; + } + + public static Optional fromString(String transform) + { + if (transform == null) { + return Optional.empty(); + } + + PartitionTransformType type = TRANSFORM_MAP.get(transform); + if (type != null) { + return Optional.of(type); + } + + // Handle bucket and truncate transforms with parameters + if (transform.startsWith(BUCKET.transform + "[")) { + return Optional.of(BUCKET); + } + if (transform.startsWith(TRUNCATE.transform + "[")) { + return Optional.of(TRUNCATE); + } + + return Optional.empty(); + } + + public static PartitionTransformType fromCode(int code) + { + PartitionTransformType type = CODE_MAP.get(code); + if (type == null) { + throw new IllegalArgumentException("Unknown transform code: " + code); + } + return type; + } + + public static PartitionTransformType fromStringOrFail(String transform) + { + return fromString(transform) + .orElseThrow(() -> new IllegalArgumentException("Unsupported transform type: " + transform)); + } + + @Override + public String toString() + { + return transform; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransforms.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransforms.java index e24c0464fad35..1eba3d13e4389 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransforms.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionTransforms.java @@ -21,11 +21,10 @@ import io.airlift.slice.Murmur3Hash32; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import jakarta.annotation.Nullable; import org.apache.iceberg.PartitionField; import org.joda.time.DateTimeField; -import javax.annotation.Nullable; - import java.math.BigDecimal; import java.math.BigInteger; import java.util.function.Function; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PrestoIcebergPartitionSpec.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PrestoIcebergPartitionSpec.java index 119a819868f2f..38039b2643568 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PrestoIcebergPartitionSpec.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PrestoIcebergPartitionSpec.java @@ -26,13 +26,13 @@ public class PrestoIcebergPartitionSpec { private final int specId; private final PrestoIcebergSchema schema; - private final List fields; + private final List fields; @JsonCreator public PrestoIcebergPartitionSpec( @JsonProperty("specId") int specId, @JsonProperty("schema") PrestoIcebergSchema schema, - @JsonProperty("fields") List fields) + @JsonProperty("fields") List fields) { this.specId = specId; this.schema = requireNonNull(schema, "schema is null"); @@ -52,7 +52,7 @@ public PrestoIcebergSchema getSchema() } @JsonProperty - public List getFields() + public List getFields() { return fields; } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PropertiesTable.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PropertiesTable.java index 3381eeef21445..2ba009f7ec017 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PropertiesTable.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PropertiesTable.java @@ -37,8 +37,9 @@ public class PropertiesTable { private final ConnectorTableMetadata tableMetadata; private final Table icebergTable; + private final IcebergTableProperties tableProperties; - public PropertiesTable(SchemaTableName tableName, Table icebergTable) + public PropertiesTable(SchemaTableName tableName, Table icebergTable, IcebergTableProperties tableProperties) { this.icebergTable = requireNonNull(icebergTable, "icebergTable is null"); @@ -46,7 +47,10 @@ public PropertiesTable(SchemaTableName tableName, Table icebergTable) ImmutableList.builder() .add(ColumnMetadata.builder().setName("key").setType(VARCHAR).build()) .add(ColumnMetadata.builder().setName("value").setType(VARCHAR).build()) + .add(ColumnMetadata.builder().setName("is_supported_by_presto").setType(VARCHAR).build()) .build()); + + this.tableProperties = requireNonNull(tableProperties, "tableProperties is null"); } @Override @@ -64,17 +68,20 @@ public ConnectorTableMetadata getTableMetadata() @Override public ConnectorPageSource pageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, TupleDomain constraint) { - return new FixedPageSource(buildPages(tableMetadata, icebergTable)); + return new FixedPageSource(buildPages(tableMetadata, icebergTable, tableProperties)); } - private static List buildPages(ConnectorTableMetadata tableMetadata, Table icebergTable) + private static List buildPages(ConnectorTableMetadata tableMetadata, Table icebergTable, IcebergTableProperties tableProperties) { PageListBuilder pagesBuilder = PageListBuilder.forTable(tableMetadata); icebergTable.properties().forEach((key, value) -> { + boolean isSupported = tableProperties.isTablePropertySupported(key); + pagesBuilder.beginRow(); pagesBuilder.appendVarchar(key); pagesBuilder.appendVarchar(value); + pagesBuilder.appendVarchar(Boolean.toString(isSupported)); pagesBuilder.endRow(); }); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RuntimeStatsMetricsReporter.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RuntimeStatsMetricsReporter.java new file mode 100644 index 0000000000000..f017254753b62 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/RuntimeStatsMetricsReporter.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.common.RuntimeUnit; +import org.apache.iceberg.metrics.MetricsReport; +import org.apache.iceberg.metrics.MetricsReporter; +import org.apache.iceberg.metrics.ScanReport; + +/** + * A MetricsReporter implementation for reporting + * Iceberg scan metrics to Presto's RuntimeStats. + */ + +public final class RuntimeStatsMetricsReporter + implements MetricsReporter +{ + /** + * RuntimeStats variable used for storing scan metrics from Iceberg reports. + */ + private final RuntimeStats runtimeStats; + + /** + * Constructs a RuntimeStatsMetricsReporter. + * + * @param runtimeStat the RuntimeStats instance to report metrics to + */ + public RuntimeStatsMetricsReporter(final RuntimeStats runtimeStat) + { + this.runtimeStats = runtimeStat; + } + + @Override + public void report(final MetricsReport report) + { + if (!(report instanceof ScanReport)) { + return; + } + + ScanReport scanReport = (ScanReport) report; + String tableName = scanReport.tableName(); + + if (scanReport.scanMetrics().totalPlanningDuration() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "totalPlanningDuration"), + RuntimeUnit.NANO, + scanReport.scanMetrics().totalPlanningDuration() + .totalDuration().toNanos()); + } + + if (scanReport.scanMetrics().resultDataFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "resultDataFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().resultDataFiles().value()); + } + + if (scanReport.scanMetrics().resultDeleteFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "resultDeleteFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().resultDeleteFiles().value()); + } + + if (scanReport.scanMetrics().totalDataManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "totalDataManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().totalDataManifests().value()); + } + + if (scanReport.scanMetrics().totalDeleteManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "totalDeleteManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().totalDeleteManifests().value()); + } + + if (scanReport.scanMetrics().scannedDataManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "scannedDataManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().scannedDataManifests().value()); + } + + if (scanReport.scanMetrics().skippedDataManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "skippedDataManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().skippedDataManifests().value()); + } + + if (scanReport.scanMetrics().totalFileSizeInBytes() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "totalFileSizeInBytes"), + RuntimeUnit.BYTE, + scanReport.scanMetrics().totalFileSizeInBytes() + .value()); + } + + if (scanReport.scanMetrics().totalDeleteFileSizeInBytes() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "totalDeleteFileSizeInBytes"), + RuntimeUnit.BYTE, + scanReport.scanMetrics().totalDeleteFileSizeInBytes() + .value()); + } + + if (scanReport.scanMetrics().skippedDataFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "skippedDataFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().skippedDataFiles() + .value()); + } + + if (scanReport.scanMetrics().skippedDeleteFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "skippedDeleteFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().skippedDeleteFiles().value()); + } + + if (scanReport.scanMetrics().scannedDeleteManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "scannedDeleteManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().scannedDeleteManifests().value()); + } + + if (scanReport.scanMetrics().skippedDeleteManifests() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "skippedDeleteManifests"), + RuntimeUnit.NONE, + scanReport.scanMetrics().skippedDeleteManifests().value()); + } + + if (scanReport.scanMetrics().indexedDeleteFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "indexedDeleteFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().indexedDeleteFiles().value()); + } + + if (scanReport.scanMetrics().equalityDeleteFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "equalityDeleteFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().equalityDeleteFiles().value()); + } + + if (scanReport.scanMetrics().positionalDeleteFiles() != null) { + runtimeStats.addMetricValue( + tableScanString(tableName, "positionalDeleteFiles"), + RuntimeUnit.NONE, + scanReport.scanMetrics().positionalDeleteFiles().value()); + } + } + + /** + * Helper method to construct the full metric name for a table scan. + * + * @param tableName the name of the table + * @param metricName the name of the metric + * @return the composed metric name in the format: table.scan.metric + */ + private static String tableScanString(final String tableName, final String metricName) + { + return tableName + ".scan." + metricName; + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SnapshotsTable.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SnapshotsTable.java index 2874a020ded92..a876f826506dd 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SnapshotsTable.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SnapshotsTable.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg; import com.facebook.presto.common.Page; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.TimeZoneKey; @@ -101,7 +102,8 @@ public ConnectorPageSource pageSource(ConnectorTransactionHandle transactionHand private static List buildPages(ConnectorTableMetadata tableMetadata, ConnectorSession session, Table icebergTable) { PageListBuilder pagesBuilder = PageListBuilder.forTable(tableMetadata); - TableScan tableScan = buildTableScan(icebergTable, SNAPSHOTS); + RuntimeStats runtimeStats = session.getRuntimeStats(); + TableScan tableScan = buildTableScan(icebergTable, SNAPSHOTS, runtimeStats); TimeZoneKey timeZoneKey = session.getTimeZoneKey(); Map columnNameToPosition = columnNameToPositionInSchema(tableScan.schema()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SortParameters.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SortParameters.java index 05377d8199a3e..dac3e27531d1d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SortParameters.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/SortParameters.java @@ -18,8 +18,7 @@ import com.facebook.presto.hive.OrcFileWriterFactory; import com.facebook.presto.hive.SortingFileWriterConfig; import com.facebook.presto.spi.PageSorter; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class SortParameters { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java index 25b1095577384..a379a0492675b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java @@ -42,6 +42,7 @@ import com.google.common.collect.Maps; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import jakarta.annotation.Nullable; import org.apache.datasketches.memory.Memory; import org.apache.datasketches.theta.CompactSketch; import org.apache.iceberg.ContentFile; @@ -69,8 +70,6 @@ import org.apache.iceberg.types.Types; import org.apache.iceberg.util.Pair; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.UncheckedIOException; import java.nio.ByteBuffer; @@ -248,26 +247,29 @@ private TableStatistics makeTableStatistics(StatisticsFileCache statisticsFileCa for (IcebergColumnHandle columnHandle : selectedColumns) { int fieldId = columnHandle.getId(); ColumnStatistics.Builder columnBuilder = tableStats.getOrDefault(fieldId, ColumnStatistics.builder()); - Long nullCount = summary.getNullCounts().get(fieldId); - if (nullCount != null) { - columnBuilder.setNullsFraction(Estimate.of(nullCount / recordCount)); - } - Object min = summary.getMinValues().get(fieldId); - Object max = summary.getMaxValues().get(fieldId); - if (min instanceof Number && max instanceof Number) { - DoubleRange range = new DoubleRange(((Number) min).doubleValue(), ((Number) max).doubleValue()); - columnBuilder.setRange(Optional.of(range)); - - // the histogram is generated by scanning the entire dataset. It is possible that - // the constraint prevents scanning portions of the table. Given that we know the - // range that the scan provides for a particular column, bound the histogram to the - // scanned range. - - final DoubleRange histRange = range; - columnBuilder.setHistogram(columnBuilder.getHistogram() - .map(histogram -> DisjointRangeDomainHistogram - .addConjunction(histogram, Range.range(DOUBLE, histRange.getMin(), true, histRange.getMax(), true)))); + if (summary.hasValidColumnMetrics()) { + Long nullCount = summary.getNullCounts().get(fieldId); + if (nullCount != null) { + columnBuilder.setNullsFraction(Estimate.of(nullCount / recordCount)); + } + + Object min = summary.getMinValues().get(fieldId); + Object max = summary.getMaxValues().get(fieldId); + if (min instanceof Number && max instanceof Number) { + DoubleRange range = new DoubleRange(((Number) min).doubleValue(), ((Number) max).doubleValue()); + columnBuilder.setRange(Optional.of(range)); + + // the histogram is generated by scanning the entire dataset. It is possible that + // the constraint prevents scanning portions of the table. Given that we know the + // range that the scan provides for a particular column, bound the histogram to the + // scanned range. + + final DoubleRange histRange = range; + columnBuilder.setHistogram(columnBuilder.getHistogram() + .map(histogram -> DisjointRangeDomainHistogram + .addConjunction(histogram, Range.range(DOUBLE, histRange.getMin(), true, histRange.getMax(), true)))); + } } result.setColumnStatistics(columnHandle, columnBuilder.build()); } @@ -282,6 +284,7 @@ private Partition getDataTableSummary(IcebergTableHandle tableHandle, List partitionFields) { TableScan tableScan = icebergTable.newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())) .filter(toIcebergExpression(intersection)) .select(selectedColumns.stream().map(IcebergColumnHandle::getName).collect(Collectors.toList())) .useSnapshot(tableHandle.getIcebergTableName().getSnapshotId().get()) @@ -301,7 +304,8 @@ private Partition getEqualityDeleteTableSummary(IcebergTableHandle tableHandle, tableHandle.getIcebergTableName().getSnapshotId().get(), intersection, tableHandle.getPartitionSpecId(), - tableHandle.getEqualityFieldIds()); + tableHandle.getEqualityFieldIds(), + session.getRuntimeStats()); CloseableIterable> files = CloseableIterable.transform(deleteFiles, deleteFile -> deleteFile); return getSummaryFromFiles(files, idToTypeMapping, nonPartitionPrimitiveColumns, partitionFields); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/DeleteFilter.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/DeleteFilter.java index 42122d8103920..eb3f77bbf1f7b 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/DeleteFilter.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/DeleteFilter.java @@ -16,8 +16,11 @@ import com.facebook.presto.iceberg.IcebergColumnHandle; import java.util.List; +import java.util.Optional; public interface DeleteFilter { RowPredicate createPredicate(List columns); + + Optional getDeleteFilePath(); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/EqualityDeleteFilter.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/EqualityDeleteFilter.java index c86f0dae9a28d..2441ba9fcea76 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/EqualityDeleteFilter.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/EqualityDeleteFilter.java @@ -17,12 +17,14 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.iceberg.IcebergColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; +import jakarta.annotation.Nullable; import org.apache.iceberg.Schema; import org.apache.iceberg.StructLike; import org.apache.iceberg.util.StructLikeSet; import org.apache.iceberg.util.StructProjection; import java.util.List; +import java.util.Optional; import static com.facebook.presto.iceberg.IcebergUtil.schemaFromHandles; import static java.util.Objects.requireNonNull; @@ -32,11 +34,14 @@ public final class EqualityDeleteFilter { private final Schema schema; private final StructLikeSet deleteSet; + @Nullable + private final String deleteFilePath; - private EqualityDeleteFilter(Schema schema, StructLikeSet deleteSet) + private EqualityDeleteFilter(Schema schema, StructLikeSet deleteSet, @Nullable String deleteFilePath) { this.schema = requireNonNull(schema, "schema is null"); this.deleteSet = requireNonNull(deleteSet, "deleteSet is null"); + this.deleteFilePath = deleteFilePath; } @Override @@ -55,7 +60,13 @@ public RowPredicate createPredicate(List columns) }; } - public static DeleteFilter readEqualityDeletes(ConnectorPageSource pageSource, List columns, Schema tableSchema) + @Override + public Optional getDeleteFilePath() + { + return Optional.ofNullable(deleteFilePath); + } + + public static DeleteFilter readEqualityDeletes(ConnectorPageSource pageSource, List columns, String deleteFilePath) { Type[] types = columns.stream() .map(IcebergColumnHandle::getType) @@ -75,6 +86,6 @@ public static DeleteFilter readEqualityDeletes(ConnectorPageSource pageSource, L } } - return new EqualityDeleteFilter(deleteSchema, deleteSet); + return new EqualityDeleteFilter(deleteSchema, deleteSet, deleteFilePath); } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java index ac7ee1aa9bbcf..9d735f40d409a 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/IcebergDeletePageSink.java @@ -36,7 +36,6 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.iceberg.MetricsConfig; import org.apache.iceberg.PartitionSpec; -import org.apache.iceberg.PartitionSpecParser; import org.apache.iceberg.Schema; import org.apache.iceberg.io.LocationProvider; @@ -81,8 +80,7 @@ public class IcebergDeletePageSink private static final MetricsConfig FULL_METRICS_CONFIG = MetricsConfig.fromProperties(ImmutableMap.of(DEFAULT_WRITE_METRICS_MODE, "full")); public IcebergDeletePageSink( - Schema outputSchema, - String partitionSpecAsJson, + PartitionSpec partitionSpec, Optional partitionDataAsJson, LocationProvider locationProvider, IcebergFileWriterFactory fileWriterFactory, @@ -101,7 +99,7 @@ public IcebergDeletePageSink( this.session = requireNonNull(session, "session is null"); this.dataFile = requireNonNull(dataFile, "dataFile is null"); this.fileFormat = requireNonNull(fileFormat, "fileFormat is null"); - this.partitionSpec = PartitionSpecParser.fromJson(outputSchema, partitionSpecAsJson); + this.partitionSpec = requireNonNull(partitionSpec, "partitionSpec is null"); this.partitionData = partitionDataFromJson(partitionSpec, partitionDataAsJson); String fileName = fileFormat.addExtension(String.format("delete_file_%s", randomUUID().toString())); this.outputPath = partitionData.map(partition -> new Path(locationProvider.newDataLocation(partitionSpec, partition, fileName))) @@ -182,6 +180,9 @@ public IcebergPositionDeleteWriter() this.writer = createWriter(); } + /** + * @param page Only one channel. It contains the list of row positions to delete. + */ public void appendPage(Page page) { if (page.getChannelCount() == 1) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/PositionDeleteFilter.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/PositionDeleteFilter.java index f9fca85d170b2..bbe44a5367223 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/PositionDeleteFilter.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/PositionDeleteFilter.java @@ -18,10 +18,12 @@ import com.facebook.presto.iceberg.IcebergColumnHandle; import com.facebook.presto.spi.ConnectorPageSource; import io.airlift.slice.Slice; +import jakarta.annotation.Nullable; import org.roaringbitmap.longlong.ImmutableLongBitmapDataProvider; import org.roaringbitmap.longlong.LongBitmapDataProvider; import java.util.List; +import java.util.Optional; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; @@ -32,10 +34,13 @@ public final class PositionDeleteFilter implements DeleteFilter { private final ImmutableLongBitmapDataProvider deletedRows; + @Nullable + private final String deleteFilePath; - public PositionDeleteFilter(ImmutableLongBitmapDataProvider deletedRows) + public PositionDeleteFilter(ImmutableLongBitmapDataProvider deletedRows, @Nullable String deleteFilePath) { this.deletedRows = requireNonNull(deletedRows, "deletedRows is null"); + this.deleteFilePath = deleteFilePath; } @Override @@ -48,6 +53,11 @@ public RowPredicate createPredicate(List columns) }; } + public Optional getDeleteFilePath() + { + return Optional.ofNullable(deleteFilePath); + } + private static int rowPositionChannel(List columns) { for (int i = 0; i < columns.size(); i++) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/RowPredicate.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/RowPredicate.java index 450eccb3f8314..e61bc49294c26 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/RowPredicate.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/delete/RowPredicate.java @@ -14,7 +14,11 @@ package com.facebook.presto.iceberg.delete; import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static java.util.Objects.requireNonNull; public interface RowPredicate @@ -43,4 +47,40 @@ default Page filterPage(Page page) } return page.getPositions(retained, 0, retainedCount); } + + default Page markDeleted(Page page, int deletedDelegateColumnId) + { + int positionCount = page.getPositionCount(); + if (positionCount == 0) { + return page; + } + + boolean allSameValues = true; + boolean firstValue = !test(page, 0); + BlockBuilder blockBuilder = null; + for (int position = 1; position < positionCount; position++) { + boolean deleted = !test(page, position); + if (allSameValues && deleted != firstValue) { + blockBuilder = BOOLEAN.createFixedSizeBlockBuilder(positionCount); + for (int idx = 0; idx < position; idx++) { + BOOLEAN.writeBoolean(blockBuilder, firstValue); + } + BOOLEAN.writeBoolean(blockBuilder, deleted); + allSameValues = false; + } + else if (!allSameValues) { + BOOLEAN.writeBoolean(blockBuilder, deleted); + } + } + + Block block; + if (blockBuilder != null) { + block = blockBuilder.build(); + } + else { + block = RunLengthEncodedBlock.create(BOOLEAN, firstValue, positionCount); + } + + return page.replaceColumn(deletedDelegateColumnId, block); + } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/IcebergBucketFunction.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/IcebergBucketFunction.java new file mode 100644 index 0000000000000..429b296f79d8e --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/IcebergBucketFunction.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.function; + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.LiteralParameter; +import com.facebook.presto.spi.function.LiteralParameters; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import io.airlift.slice.Slice; +import org.apache.iceberg.transforms.Transforms; +import org.apache.iceberg.types.Types; + +import java.math.BigDecimal; +import java.math.MathContext; + +import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc; +import static com.facebook.presto.common.type.Decimals.decodeUnscaledValue; +import static com.facebook.presto.common.type.SqlTimestamp.MICROSECONDS_PER_MILLISECOND; + +public final class IcebergBucketFunction +{ + private IcebergBucketFunction() {} + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketInteger(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.LongType.get()) + .apply(value); + } + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketVarchar(@SqlType(StandardTypes.VARCHAR) Slice value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return (long) Transforms.bucket((int) numberOfBuckets) + .bind(Types.StringType.get()) + .apply(value.toStringUtf8()); + } + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketVarbinary(@SqlType(StandardTypes.VARBINARY) Slice value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return (long) Transforms.bucket((int) numberOfBuckets) + .bind(Types.BinaryType.get()) + .apply(value.toByteBuffer()); + } + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketDate(@SqlType(StandardTypes.DATE) long value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.DateType.get()) + .apply((int) value); + } + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketTimestamp(@SqlType(StandardTypes.TIMESTAMP) long value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.TimestampType.withoutZone()) + .apply(value); + } + + @ScalarFunction("bucket") + @SqlType(StandardTypes.BIGINT) + public static long bucketTimestampWithTimeZone(@SqlType(StandardTypes.TIMESTAMP_WITH_TIME_ZONE) long value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.TimestampType.withZone()) + .apply(unpackMillisUtc(value) * MICROSECONDS_PER_MILLISECOND); + } + + @ScalarFunction("bucket") + public static final class Bucket + { + @LiteralParameters({"p", "s"}) + @SqlType(StandardTypes.BIGINT) + public static long bucketShortDecimal(@LiteralParameter("p") long numPrecision, @LiteralParameter("s") long numScale, @SqlType("decimal(p, s)") long value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.DecimalType.of((int) numPrecision, (int) numScale)) + .apply(BigDecimal.valueOf(value)); + } + + @LiteralParameters({"p", "s"}) + @SqlType(StandardTypes.BIGINT) + public static long bucketLongDecimal(@LiteralParameter("p") long numPrecision, @LiteralParameter("s") long numScale, @SqlType("decimal(p, s)") Slice value, @SqlType(StandardTypes.INTEGER) long numberOfBuckets) + { + return Transforms.bucket((int) numberOfBuckets) + .bind(Types.DecimalType.of((int) numPrecision, (int) numScale)) + .apply(new BigDecimal(decodeUnscaledValue(value), (int) numScale, new MathContext((int) numPrecision))); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/changelog/ApplyChangelogState.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/changelog/ApplyChangelogState.java index 024088fe4292f..156833899a714 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/changelog/ApplyChangelogState.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/function/changelog/ApplyChangelogState.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.function.AccumulatorState; import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.facebook.presto.spi.function.GroupedAccumulatorState; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Optional; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergFileHiveMetastore.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergFileHiveMetastore.java index 1822a92b0bd19..8ba5f88cd8fbf 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergFileHiveMetastore.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergFileHiveMetastore.java @@ -19,13 +19,12 @@ import com.facebook.presto.hive.metastore.file.FileHiveMetastore; import com.facebook.presto.hive.metastore.file.FileHiveMetastoreConfig; import com.facebook.presto.spi.PrestoException; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileUtil; import org.apache.hadoop.fs.Path; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.io.IOException; import java.util.Optional; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergHiveFileMetastoreModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergHiveFileMetastoreModule.java index f06a2fba1542e..162675bfa1e75 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergHiveFileMetastoreModule.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/hive/IcebergHiveFileMetastoreModule.java @@ -16,6 +16,7 @@ import com.facebook.presto.hive.ForCachingHiveMetastore; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreCacheSpecProvider; import com.facebook.presto.hive.metastore.file.FileHiveMetastoreConfig; import com.google.inject.Binder; import com.google.inject.Module; @@ -40,6 +41,7 @@ public IcebergHiveFileMetastoreModule(String connectorId) public void configure(Binder binder) { configBinder(binder).bindConfig(FileHiveMetastoreConfig.class); + binder.bind(MetastoreCacheSpecProvider.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).annotatedWith(ForCachingHiveMetastore.class).to(IcebergFileHiveMetastore.class).in(Scopes.SINGLETON); binder.bind(ExtendedHiveMetastore.class).to(InMemoryCachingHiveMetastore.class).in(Scopes.SINGLETON); newExporter(binder).export(ExtendedHiveMetastore.class) diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieCatalogFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieCatalogFactory.java index 90b332ce291e7..b33aa15301279 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieCatalogFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieCatalogFactory.java @@ -20,8 +20,7 @@ import com.facebook.presto.iceberg.IcebergNativeCatalogFactory; import com.facebook.presto.spi.ConnectorSession; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Map; import java.util.Optional; @@ -60,6 +59,7 @@ protected Map getCatalogProperties(ConnectorSession session) if (hash != null) { properties.put("ref.hash", hash); } + catalogConfig.getAuthenticationType().ifPresent(val -> properties.put("nessie.authentication.type", val.toString())); catalogConfig.getReadTimeoutMillis().ifPresent(val -> properties.put("transport.read-timeout", val.toString())); catalogConfig.getConnectTimeoutMillis().ifPresent(val -> properties.put("transport.connect-timeout", val.toString())); catalogConfig.getClientBuilderImpl().ifPresent(val -> properties.put("client-builder-impl", val)); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieConfig.java index 8e68bc5f5953b..13f8c021aa077 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/nessie/IcebergNessieConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotEmpty; import java.util.Optional; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java index 18635bb4ce806..ebef2d465de58 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergEqualityDeleteAsJoin.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.iceberg.optimizer; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.BigintType; @@ -83,13 +84,17 @@ import static com.facebook.presto.iceberg.FileContent.EQUALITY_DELETES; import static com.facebook.presto.iceberg.FileContent.fromIcebergFileContent; import static com.facebook.presto.iceberg.IcebergColumnHandle.DATA_SEQUENCE_NUMBER_COLUMN_HANDLE; +import static com.facebook.presto.iceberg.IcebergColumnHandle.DELETE_FILE_PATH_COLUMN_HANDLE; +import static com.facebook.presto.iceberg.IcebergColumnHandle.IS_DELETED_COLUMN_HANDLE; import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_FILESYSTEM_ERROR; import static com.facebook.presto.iceberg.IcebergMetadataColumn.DATA_SEQUENCE_NUMBER; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getDeleteAsJoinRewriteMaxDeleteColumns; import static com.facebook.presto.iceberg.IcebergSessionProperties.isDeleteToJoinPushdownEnabled; import static com.facebook.presto.iceberg.IcebergUtil.getDeleteFiles; import static com.facebook.presto.iceberg.IcebergUtil.getIcebergTable; import static com.facebook.presto.iceberg.TypeConverter.toPrestoType; import static com.facebook.presto.spi.ConnectorPlanRewriter.rewriteWith; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; /** @@ -130,7 +135,9 @@ public class IcebergEqualityDeleteAsJoin @Override public PlanNode optimize(PlanNode maxSubplan, ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) { - if (!isDeleteToJoinPushdownEnabled(session)) { + int maxDeleteColumns = getDeleteAsJoinRewriteMaxDeleteColumns(session); + checkArgument(maxDeleteColumns >= 0, "maxDeleteColumns must be non-negative, got %s", maxDeleteColumns); + if (!isDeleteToJoinPushdownEnabled(session) || maxDeleteColumns == 0) { return maxSubplan; } return rewriteWith(new DeleteAsJoinRewriter(functionResolution, @@ -175,6 +182,16 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) return node; } + if (node.getAssignments().containsValue(IS_DELETED_COLUMN_HANDLE) || node.getAssignments().containsValue(DELETE_FILE_PATH_COLUMN_HANDLE)) { + // Skip this optimization if metadata columns `$deleted` or `$delete_file_path` exist + return node; + } + + if (icebergTableHandle.getMaterializedViewName().isPresent()) { + // Materialized views should not have delete files + return node; + } + IcebergAbstractMetadata metadata = (IcebergAbstractMetadata) transactionManager.get(table.getTransaction()); Table icebergTable = getIcebergTable(metadata, session, icebergTableHandle.getSchemaTableName()); @@ -184,12 +201,16 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) .orElseGet(TupleDomain::all); // Collect info about each unique delete schema to join by - ImmutableMap, DeleteSetInfo> deleteSchemas = collectDeleteInformation(icebergTable, predicate, tableName.getSnapshotId().get()); + ImmutableMap, DeleteSetInfo> deleteSchemas = collectDeleteInformation(icebergTable, predicate, tableName.getSnapshotId().get(), session); if (deleteSchemas.isEmpty()) { // no equality deletes return node; } + if (deleteSchemas.keySet().stream().anyMatch(equalityIds -> equalityIds.size() > getDeleteAsJoinRewriteMaxDeleteColumns(session))) { + // Too many fields in the delete schema, don't rewrite + return node; + } // Add all the fields required by the join that were not added by the user's query ImmutableMap unselectedAssignments = createAssignmentsForUnselectedFields(node, deleteSchemas, icebergTable); @@ -266,21 +287,25 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext context) new SpecialFormExpression(SpecialFormExpression.Form.IS_NULL, BooleanType.BOOLEAN, new SpecialFormExpression(SpecialFormExpression.Form.COALESCE, BigintType.BIGINT, deleteVersionColumns))); + boolean hasExplicitDataSequenceNumberCol = node.getAssignments().containsValue(DATA_SEQUENCE_NUMBER_COLUMN_HANDLE); Assignments.Builder assignmentsBuilder = Assignments.builder(); filter.getOutputVariables().stream() - .filter(variableReferenceExpression -> !variableReferenceExpression.getName().startsWith(DATA_SEQUENCE_NUMBER_COLUMN_HANDLE.getName())) + .filter(variableReferenceExpression -> hasExplicitDataSequenceNumberCol || !variableReferenceExpression.getName().startsWith(DATA_SEQUENCE_NUMBER_COLUMN_HANDLE.getName())) .forEach(variableReferenceExpression -> assignmentsBuilder.put(variableReferenceExpression, variableReferenceExpression)); return new ProjectNode(Optional.empty(), idAllocator.getNextId(), filter, assignmentsBuilder.build(), ProjectNode.Locality.LOCAL); } private static ImmutableMap, DeleteSetInfo> collectDeleteInformation(Table icebergTable, TupleDomain predicate, - long snapshotId) + long snapshotId, + ConnectorSession session) + { // Delete schemas can repeat, so using a normal hashmap to dedup, will be converted to immutable at the end of the function. HashMap, DeleteSetInfo> deleteInformations = new HashMap<>(); + RuntimeStats runtimeStats = session.getRuntimeStats(); try (CloseableIterator files = - getDeleteFiles(icebergTable, snapshotId, predicate, Optional.empty(), Optional.empty()).iterator()) { + getDeleteFiles(icebergTable, snapshotId, predicate, Optional.empty(), Optional.empty(), runtimeStats).iterator()) { files.forEachRemaining(delete -> { if (fromIcebergFileContent(delete.content()) == EQUALITY_DELETES) { ImmutableMap.Builder partitionFieldsBuilder = new ImmutableMap.Builder<>(); @@ -334,7 +359,8 @@ private TableScanNode createDeletesTableScan(ImmutableMap assignmentsBuilder = ImmutableMap.builder() - .put(dataSequenceNumberVariableReference, DATA_SEQUENCE_NUMBER_COLUMN_HANDLE) .putAll(unselectedAssignments) .putAll(node.getAssignments()); ImmutableList.Builder outputsBuilder = ImmutableList.builder(); outputsBuilder.addAll(node.getOutputVariables()); - if (!node.getAssignments().containsKey(dataSequenceNumberVariableReference)) { + if (!node.getAssignments().containsValue(DATA_SEQUENCE_NUMBER_COLUMN_HANDLE)) { + assignmentsBuilder.put(dataSequenceNumberVariableReference, DATA_SEQUENCE_NUMBER_COLUMN_HANDLE); outputsBuilder.add(dataSequenceNumberVariableReference); } outputsBuilder.addAll(unselectedAssignments.keySet()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergFilterPushdown.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergFilterPushdown.java index f60e149c56750..64f45cecef2fb 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergFilterPushdown.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/optimizer/IcebergFilterPushdown.java @@ -13,10 +13,11 @@ */ package com.facebook.presto.iceberg.optimizer; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.TypeManager; -import com.facebook.presto.hive.HivePartition; +import com.facebook.presto.hive.PartitionSet; import com.facebook.presto.hive.rule.BaseSubfieldExtractionRewriter; import com.facebook.presto.iceberg.IcebergAbstractMetadata; import com.facebook.presto.iceberg.IcebergColumnHandle; @@ -155,12 +156,14 @@ protected ConnectorPushdownFilterResult getConnectorPushdownFilterResult( TupleDomain partitionColumnPredicate = TupleDomain.withColumnDomains(Maps.filterKeys( constraint.getSummary().getDomains().get(), Predicates.in(partitionColumns))); - List partitions = getPartitions( + RuntimeStats runtimeStats = session.getRuntimeStats(); + PartitionSet partitions = getPartitions( typeManager, tableHandle, icebergTable, constraint, - partitionColumns); + partitionColumns, + runtimeStats); return new ConnectorPushdownFilterResult( metadata.getTableLayout( @@ -174,7 +177,7 @@ protected ConnectorPushdownFilterResult getConnectorPushdownFilterResult( .setRequestedColumns(requestedColumns) .setPushdownFilterEnabled(true) .setPartitionColumnPredicate(partitionColumnPredicate) - .setPartitions(Optional.ofNullable(partitions.size() == 0 ? null : partitions)) + .setPartitions(Optional.ofNullable(partitions.isEmpty() ? null : partitions)) .setTable((IcebergTableHandle) tableHandle) .build()), remainingExpressions.getDynamicFilterExpression()); diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java index afbc4793b48bc..aea6060a571f6 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ExpireSnapshotsProcedure.java @@ -21,11 +21,12 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.iceberg.ExpireSnapshots; import org.apache.iceberg.Table; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -65,11 +66,11 @@ public Procedure get() "system", "expire_snapshots", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("older_than", TIMESTAMP, false, null), - new Procedure.Argument("retain_last", INTEGER, false, null), - new Procedure.Argument("snapshot_ids", "array(bigint)", false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("older_than", TIMESTAMP, false, null), + new Argument("retain_last", INTEGER, false, null), + new Argument("snapshot_ids", "array(bigint)", false, null)), EXPIRE_SNAPSHOTS.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java index 96898ad485397..08b43a84640a5 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/FastForwardBranchProcedure.java @@ -18,10 +18,11 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.iceberg.Table; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -58,10 +59,10 @@ public Procedure get() "system", "fast_forward", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("branch", VARCHAR), - new Procedure.Argument("to", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("branch", VARCHAR), + new Argument("to", VARCHAR)), FAST_FORWARD.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java index 6779af23be3e8..1f3eb2f708eb6 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/ManifestFileCacheInvalidationProcedure.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java index 28fc304ce8fdc..72a1edadee4f6 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RegisterTableProcedure.java @@ -24,12 +24,13 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import javax.inject.Inject; import javax.inject.Provider; import java.io.IOException; @@ -83,10 +84,10 @@ public Procedure get() "system", "register_table", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("metadata_location", VARCHAR), - new Procedure.Argument("metadata_file", VARCHAR, false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("metadata_location", VARCHAR), + new Argument("metadata_file", VARCHAR, false, null)), REGISTER_TABLE.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java index 1cd87cee0b6fd..2d88d01c14501 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RemoveOrphanFiles.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.LocatedFileStatus; import org.apache.hadoop.fs.Path; @@ -39,7 +40,6 @@ import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; -import javax.inject.Inject; import javax.inject.Provider; import java.io.IOException; @@ -225,7 +225,7 @@ private static ManifestReader> readerForManifest(Table case DATA: return ManifestFiles.read(manifest, table.io()); case DELETES: - ManifestFiles.readDeleteManifest(manifest, table.io(), table.specs()); + return ManifestFiles.readDeleteManifest(manifest, table.io(), table.specs()); default: throw new PrestoException(ICEBERG_UNKNOWN_MANIFEST_TYPE, "Unknown manifest file content: " + manifest.content()); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteDataFilesProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteDataFilesProcedure.java new file mode 100644 index 0000000000000..17a72c8bf03ec --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteDataFilesProcedure.java @@ -0,0 +1,266 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.procedure; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.iceberg.CommitTaskData; +import com.facebook.presto.iceberg.IcebergColumnHandle; +import com.facebook.presto.iceberg.IcebergDistributedProcedureHandle; +import com.facebook.presto.iceberg.IcebergProcedureContext; +import com.facebook.presto.iceberg.IcebergTableHandle; +import com.facebook.presto.iceberg.IcebergTableLayoutHandle; +import com.facebook.presto.iceberg.PartitionData; +import com.facebook.presto.iceberg.RuntimeStatsMetricsReporter; +import com.facebook.presto.iceberg.SortField; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; +import org.apache.iceberg.DeleteFile; +import org.apache.iceberg.FileContent; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.RewriteFiles; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.types.Type; + +import javax.inject.Inject; +import javax.inject.Provider; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.OptionalInt; +import java.util.Set; +import java.util.function.Consumer; + +import static com.facebook.presto.common.Utils.checkArgument; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.getSupportedSortFields; +import static com.facebook.presto.iceberg.IcebergSessionProperties.getCompressionCodec; +import static com.facebook.presto.iceberg.IcebergUtil.getColumns; +import static com.facebook.presto.iceberg.IcebergUtil.getFileFormat; +import static com.facebook.presto.iceberg.PartitionSpecConverter.toPrestoPartitionSpec; +import static com.facebook.presto.iceberg.SchemaConverter.toPrestoSchema; +import static com.facebook.presto.iceberg.SortFieldUtils.parseSortFields; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RewriteDataFilesProcedure + implements Provider +{ + TypeManager typeManager; + JsonCodec commitTaskCodec; + + @Inject + public RewriteDataFilesProcedure( + TypeManager typeManager, + JsonCodec commitTaskCodec) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.commitTaskCodec = requireNonNull(commitTaskCodec, "commitTaskCodec is null"); + } + + @Override + public DistributedProcedure get() + { + return new TableDataRewriteDistributedProcedure( + "system", + "rewrite_data_files", + ImmutableList.of( + new Argument(SCHEMA, VARCHAR), + new Argument(TABLE_NAME, VARCHAR), + new Argument("filter", VARCHAR, false, "TRUE"), + new Argument("sorted_by", "array(varchar)", false, null), + new Argument("options", "map(varchar, varchar)", false, null)), + (session, procedureContext, tableLayoutHandle, arguments, sortOrderIndex) -> beginCallDistributedProcedure(session, (IcebergProcedureContext) procedureContext, (IcebergTableLayoutHandle) tableLayoutHandle, arguments, sortOrderIndex), + ((session, procedureContext, tableHandle, fragments) -> finishCallDistributedProcedure(session, (IcebergProcedureContext) procedureContext, tableHandle, fragments)), + arguments -> { + checkArgument(arguments.length == 2, format("invalid number of arguments: %s (should have %s)", arguments.length, 2)); + checkArgument(arguments[0] instanceof Table && arguments[1] instanceof Transaction, "Invalid arguments, required: [Table, Transaction]"); + return new IcebergProcedureContext((Table) arguments[0], (Transaction) arguments[1]); + }); + } + + private ConnectorDistributedProcedureHandle beginCallDistributedProcedure(ConnectorSession session, IcebergProcedureContext procedureContext, IcebergTableLayoutHandle layoutHandle, Object[] arguments, OptionalInt sortOrderIndex) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + Table icebergTable = procedureContext.getTable(); + IcebergTableHandle tableHandle = layoutHandle.getTable(); + + SortOrder sortOrder = icebergTable.sortOrder(); + List sortFieldStrings = ImmutableList.of(); + if (sortOrderIndex.isPresent()) { + Object value = arguments[sortOrderIndex.getAsInt()]; + if (value == null) { + sortFieldStrings = ImmutableList.of(); + } + else if (value instanceof List) { + sortFieldStrings = ((List) value).stream() + .map(String.class::cast) + .collect(toImmutableList()); + } + else { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "sorted_by must be an array(varchar)"); + } + } + if (sortFieldStrings != null && !sortFieldStrings.isEmpty()) { + SortOrder specifiedSortOrder = parseSortFields(icebergTable.schema(), sortFieldStrings); + if (specifiedSortOrder.satisfies(sortOrder)) { + // If the specified sort order satisfies the target table's internal sort order, use the specified sort order + sortOrder = specifiedSortOrder; + } + else { + throw new PrestoException(NOT_SUPPORTED, "Specified sort order is incompatible with the target table's internal sort order"); + } + } + + List sortFields = getSupportedSortFields(icebergTable.schema(), sortOrder); + return new IcebergDistributedProcedureHandle( + tableHandle.getSchemaName(), + tableHandle.getIcebergTableName(), + toPrestoSchema(icebergTable.schema(), typeManager), + toPrestoPartitionSpec(icebergTable.spec(), typeManager), + getColumns(icebergTable.schema(), icebergTable.spec(), typeManager), + icebergTable.location(), + getFileFormat(icebergTable), + getCompressionCodec(session), + icebergTable.properties(), + layoutHandle, + sortFields, + ImmutableMap.of()); + } + } + + private void finishCallDistributedProcedure(ConnectorSession session, IcebergProcedureContext procedureContext, ConnectorDistributedProcedureHandle procedureHandle, Collection fragments) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + IcebergDistributedProcedureHandle handle = (IcebergDistributedProcedureHandle) procedureHandle; + Table icebergTable = procedureContext.getTransaction().table(); + + List commitTasks = fragments.stream() + .map(slice -> commitTaskCodec.fromJson(slice.getBytes())) + .collect(toImmutableList()); + + org.apache.iceberg.types.Type[] partitionColumnTypes = icebergTable.spec().fields().stream() + .map(field -> field.transform().getResultType( + icebergTable.schema().findType(field.sourceId()))) + .toArray(Type[]::new); + + Set newFiles = new HashSet<>(); + for (CommitTaskData task : commitTasks) { + DataFiles.Builder builder = DataFiles.builder(icebergTable.spec()) + .withPath(task.getPath()) + .withFileSizeInBytes(task.getFileSizeInBytes()) + .withFormat(handle.getFileFormat().name()) + .withMetrics(task.getMetrics().metrics()); + + if (!icebergTable.spec().fields().isEmpty()) { + String partitionDataJson = task.getPartitionDataJson() + .orElseThrow(() -> new VerifyException("No partition data for partitioned table")); + builder.withPartition(PartitionData.fromJson(partitionDataJson, partitionColumnTypes)); + } + newFiles.add(builder.build()); + } + + IcebergTableLayoutHandle layoutHandle = handle.getTableLayoutHandle(); + IcebergTableHandle tableHandle = layoutHandle.getTable(); + final Set scannedDataFiles = new HashSet<>(); + final Set fullyAppliedDeleteFiles = new HashSet<>(); + if (tableHandle.getIcebergTableName().getSnapshotId().isPresent()) { + TupleDomain predicate = layoutHandle.getValidPredicate(); + + Consumer fileScanTaskConsumer = (task) -> { + scannedDataFiles.add(task.file()); + if (!task.deletes().isEmpty()) { + task.deletes().forEach(deleteFile -> { + if (deleteFile.content() == FileContent.EQUALITY_DELETES && + !icebergTable.specs().get(deleteFile.specId()).isPartitioned() && + !predicate.isAll()) { + // Equality files with an unpartitioned spec are applied as global deletes + // So they should not be cleaned up unless the whole table is optimized + return; + } + fullyAppliedDeleteFiles.add(deleteFile); + }); + } + }; + + TableScan tableScan = procedureContext.getTable().newScan() + .metricsReporter(new RuntimeStatsMetricsReporter(session.getRuntimeStats())) + .filter(toIcebergExpression(predicate)) + .useSnapshot(tableHandle.getIcebergTableName().getSnapshotId().get()); + CloseableIterable fileScanTaskIterable = tableScan.planFiles(); + CloseableIterator fileScanTaskIterator = fileScanTaskIterable.iterator(); + fileScanTaskIterator.forEachRemaining(fileScanTaskConsumer); + try { + fileScanTaskIterable.close(); + fileScanTaskIterator.close(); + // TODO: remove this after org.apache.iceberg.io.CloseableIterator'withClose + // correct release resources holds by iterator. + fileScanTaskIterator = CloseableIterator.empty(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + if (fragments.isEmpty() && + scannedDataFiles.isEmpty() && + fullyAppliedDeleteFiles.isEmpty()) { + return; + } + + RewriteFiles rewriteFiles = procedureContext.getTransaction().newRewrite() + .rewriteFiles(scannedDataFiles, fullyAppliedDeleteFiles, newFiles, ImmutableSet.of()); + + // Table.snapshot method returns null if there is no matching snapshot + Snapshot snapshot = requireNonNull( + handle.getTableName() + .getSnapshotId() + .map(icebergTable::snapshot) + .orElse(null), + "snapshot is null"); + if (icebergTable.currentSnapshot() != null) { + rewriteFiles.validateFromSnapshot(snapshot.snapshotId()); + } + rewriteFiles.commit(); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteManifestsProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteManifestsProcedure.java new file mode 100644 index 0000000000000..b82084cc70926 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RewriteManifestsProcedure.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.procedure; + +import com.facebook.presto.iceberg.IcebergMetadataFactory; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; +import org.apache.iceberg.RewriteManifests; +import org.apache.iceberg.Table; + +import javax.inject.Provider; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; +import static com.facebook.presto.common.type.StandardTypes.INTEGER; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_SPEC_ID; +import static com.facebook.presto.iceberg.IcebergUtil.getIcebergTable; +import static java.util.Objects.requireNonNull; + +public class RewriteManifestsProcedure + implements Provider +{ + private static final MethodHandle REWRITE_MANIFESTS = methodHandle( + RewriteManifestsProcedure.class, + "rewriteManifests", + ConnectorSession.class, + String.class, + String.class, + Integer.class); + + private final IcebergMetadataFactory metadataFactory; + + @Inject + public RewriteManifestsProcedure(IcebergMetadataFactory metadataFactory) + { + this.metadataFactory = requireNonNull(metadataFactory, "metadataFactory is null"); + } + + @Override + public Procedure get() + { + return new Procedure( + "system", + "rewrite_manifests", + ImmutableList.of( + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("spec_id", INTEGER, false, null)), + REWRITE_MANIFESTS.bindTo(this)); + } + + public void rewriteManifests(ConnectorSession clientSession, String schemaName, String tableName, Integer specId) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + SchemaTableName schemaTableName = new SchemaTableName(schemaName, tableName); + ConnectorMetadata metadata = metadataFactory.create(); + Table icebergTable = getIcebergTable(metadata, clientSession, schemaTableName); + RewriteManifests rewriteManifests = icebergTable.rewriteManifests().clusterBy(file -> "file"); + int targetSpecId; + if (specId != null) { + if (!icebergTable.specs().containsKey(specId)) { + throw new PrestoException(ICEBERG_INVALID_SPEC_ID, "Given spec id does not exist: " + specId); + } + targetSpecId = specId; + } + else { + targetSpecId = icebergTable.spec().specId(); + } + rewriteManifests.rewriteIf(manifest -> manifest.partitionSpecId() == targetSpecId).commit(); + } + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java index df9feca8fc69e..c50c2c8b1e6c9 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToSnapshotProcedure.java @@ -18,9 +18,10 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -57,9 +58,9 @@ public Procedure get() "system", "rollback_to_snapshot", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("snapshot_id", BIGINT)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("snapshot_id", BIGINT)), ROLLBACK_TO_SNAPSHOT.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java index 4195e2ca099c8..513a4f4ae57d7 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/RollbackToTimestampProcedure.java @@ -20,9 +20,10 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -59,9 +60,9 @@ public Procedure get() "system", "rollback_to_timestamp", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("timestamp", TIMESTAMP)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("timestamp", TIMESTAMP)), ROLLBACK_TO_TIMESTAMP.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java index 5359af02f44c0..8fdfa2310fac9 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetCurrentSnapshotProcedure.java @@ -18,11 +18,12 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.iceberg.SnapshotRef; import org.apache.iceberg.Table; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -61,10 +62,10 @@ public Procedure get() "system", "set_current_snapshot", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("snapshot_id", BIGINT, false, null), - new Procedure.Argument("ref", VARCHAR, false, null)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("snapshot_id", BIGINT, false, null), + new Argument("ref", VARCHAR, false, null)), SET_CURRENT_SNAPSHOT.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java index deb5dd961d066..d4f5457bc3a80 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/SetTablePropertyProcedure.java @@ -13,32 +13,34 @@ */ package com.facebook.presto.iceberg.procedure; -import com.facebook.airlift.log.Logger; -import com.facebook.presto.hive.HdfsEnvironment; -import com.facebook.presto.iceberg.IcebergConfig; import com.facebook.presto.iceberg.IcebergMetadataFactory; import com.facebook.presto.iceberg.IcebergTableName; +import com.facebook.presto.iceberg.IcebergTableProperties; import com.facebook.presto.iceberg.IcebergUtil; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.iceberg.Table; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; import static com.facebook.presto.common.block.MethodHandleUtil.methodHandle; import static com.facebook.presto.common.type.StandardTypes.VARCHAR; +import static com.facebook.presto.iceberg.IcebergWarningCode.ICEBERG_UNSUPPORTED_TABLE_PROPERTY; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class SetTablePropertyProcedure implements Provider { - private static final Logger LOG = Logger.get(SetTablePropertyProcedure.class); private static final MethodHandle SET_TABLE_PROPERTY = methodHandle( SetTablePropertyProcedure.class, "setTableProperty", @@ -48,19 +50,14 @@ public class SetTablePropertyProcedure String.class, String.class); - private final IcebergConfig config; private final IcebergMetadataFactory metadataFactory; - private final HdfsEnvironment hdfsEnvironment; + private final IcebergTableProperties tableProperties; @Inject - public SetTablePropertyProcedure( - IcebergConfig config, - IcebergMetadataFactory metadataFactory, - HdfsEnvironment hdfsEnvironment) + public SetTablePropertyProcedure(IcebergMetadataFactory metadataFactory, IcebergTableProperties tableProperties) { - this.config = requireNonNull(config); this.metadataFactory = requireNonNull(metadataFactory); - this.hdfsEnvironment = requireNonNull(hdfsEnvironment); + this.tableProperties = requireNonNull(tableProperties, "tableProperties is null"); } @Override @@ -70,10 +67,10 @@ public Procedure get() "system", "set_table_property", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR), - new Procedure.Argument("key", VARCHAR), - new Procedure.Argument("value", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR), + new Argument("key", VARCHAR), + new Argument("value", VARCHAR)), SET_TABLE_PROPERTY.bindTo(this)); } @@ -86,13 +83,24 @@ public Procedure get() */ public void setTableProperty(ConnectorSession session, String schema, String table, String key, String value) { - ConnectorMetadata metadata = metadataFactory.create(); - IcebergTableName tableName = IcebergTableName.from(table); - SchemaTableName schemaTableName = new SchemaTableName(schema, tableName.getTableName()); - Table icebergTable = IcebergUtil.getIcebergTable(metadata, session, schemaTableName); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(getClass().getClassLoader())) { + // Warn if property is not recognized by Presto + if (!tableProperties.isTablePropertySupported(key)) { + PrestoWarning warning = new PrestoWarning(ICEBERG_UNSUPPORTED_TABLE_PROPERTY, format( + "Iceberg table property '%s' is not recognized by Presto. " + + "It will be stored in Iceberg metadata but ignored by the Presto engine.", + key)); + session.getWarningCollector().add(warning); + } - icebergTable.updateProperties() - .set(key, value) - .commit(); + ConnectorMetadata metadata = metadataFactory.create(); + IcebergTableName tableName = IcebergTableName.from(table); + SchemaTableName schemaTableName = new SchemaTableName(schema, tableName.getTableName()); + Table icebergTable = IcebergUtil.getIcebergTable(metadata, session, schemaTableName); + + icebergTable.updateProperties() + .set(key, value) + .commit(); + } } } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java index 0ddc4fd2c98f4..6e91c57b53cf3 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/StatisticsFileCacheInvalidationProcedure.java @@ -17,8 +17,8 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java index 898257e6b85b8..15c1b801632a5 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/procedure/UnregisterTableProcedure.java @@ -20,9 +20,10 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.lang.invoke.MethodHandle; @@ -56,8 +57,8 @@ public Procedure get() "system", "unregister_table", ImmutableList.of( - new Procedure.Argument("schema", VARCHAR), - new Procedure.Argument("table_name", VARCHAR)), + new Argument("schema", VARCHAR), + new Argument("table_name", VARCHAR)), UNREGISTER_TABLE.bindTo(this)); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestCatalogFactory.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestCatalogFactory.java index 0c588d71e2ef3..fd13589ea7440 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestCatalogFactory.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestCatalogFactory.java @@ -25,14 +25,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.UncheckedExecutionException; import io.jsonwebtoken.Jwts; +import jakarta.inject.Inject; import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.catalog.SessionCatalog.SessionContext; import org.apache.iceberg.rest.HTTPClient; import org.apache.iceberg.rest.RESTCatalog; -import javax.inject.Inject; - import java.util.Date; import java.util.Map; import java.util.Optional; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestConfig.java index 613a48d02f13a..fed989889b475 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestConfig.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/rest/IcebergRestConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.util.Optional; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/IcebergSecurityModule.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/IcebergSecurityModule.java new file mode 100644 index 0000000000000..243ce9dc5d5d7 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/IcebergSecurityModule.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.security; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.presto.plugin.base.security.AllowAllAccessControlModule; +import com.facebook.presto.plugin.base.security.FileBasedAccessControlModule; +import com.google.inject.Binder; +import com.google.inject.Module; + +import static com.facebook.airlift.configuration.ConditionalModule.installModuleIf; + +public class IcebergSecurityModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + bindSecurityModule("allow-all", new AllowAllAccessControlModule()); + bindSecurityModule("file", new FileBasedAccessControlModule()); + } + + private void bindSecurityModule(String name, Module module) + { + install(installModuleIf( + SecurityConfig.class, + security -> name.equalsIgnoreCase(security.getSecuritySystem()), + module)); + } +} diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/SecurityConfig.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/SecurityConfig.java new file mode 100644 index 0000000000000..40332ebca28d3 --- /dev/null +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/security/SecurityConfig.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.security; + +import com.facebook.airlift.configuration.Config; +import jakarta.validation.constraints.NotNull; + +public class SecurityConfig +{ + private String securitySystem = "allow-all"; + + @NotNull + public String getSecuritySystem() + { + return securitySystem; + } + + @Config("iceberg.security") + public SecurityConfig setSecuritySystem(String securitySystem) + { + this.securitySystem = securitySystem; + return this; + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/BenchmarkIcebergLazyLoading.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/BenchmarkIcebergLazyLoading.java new file mode 100644 index 0000000000000..51a5f5d9fd2ae --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/BenchmarkIcebergLazyLoading.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.intellij.lang.annotations.Language; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; +import org.openjdk.jmh.runner.options.WarmupMode; + +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; + +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Fork(1) +@Warmup(iterations = 5, time = 2, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 10, time = 2, timeUnit = TimeUnit.SECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkIcebergLazyLoading +{ + @Language("SQL") final String optimizationQuery = "select min(a), max(a), min(b), max(b) from iceberg_partition"; + + @Param({"300 * 2", "1600 * 4"}) + private String recordCount = "300 * 2"; + + DistributedQueryRunner queryRunner; + int batchCount; + int countPerBatch; + @Language("SQL") String normalQuery; + + @Setup + public void setup() throws Exception + { + queryRunner = IcebergQueryRunner.builder() + .build() + .getQueryRunner(); + queryRunner.execute("create table iceberg_partition(a int, b int, c double) with (partitioning = ARRAY['a', 'b'])"); + + String[] batchAndPerBatch = recordCount.split("\\*"); + batchCount = Integer.parseInt(batchAndPerBatch[0].trim()); + countPerBatch = Integer.parseInt(batchAndPerBatch[1].trim()); + normalQuery = String.format("select a, c from iceberg_partition where b >= %s", countPerBatch - 1); + + for (int b = 0; b < batchCount; b++) { + StringBuilder sqlBuilder = new StringBuilder(String.format("values (%d, %d, %d)", b, b, b * b)); + for (int i = 1; i < countPerBatch; i++) { + sqlBuilder.append(String.format(", (%d, %d, %d)", b, i, b * i)); + } + String valuesSql = sqlBuilder.toString(); + queryRunner.execute("insert into iceberg_partition " + valuesSql); + } + } + + @Benchmark + public void testFurtherOptimize(Blackhole bh) + { + MaterializedResult result = queryRunner.execute(optimizationQuery); + bh.consume(result.getRowCount()); + } + + @Benchmark + public void testNormalQuery(Blackhole bh) + { + MaterializedResult result = queryRunner.execute(normalQuery); + bh.consume(result.getRowCount()); + } + + @TearDown + public void finish() + { + queryRunner.execute("drop table iceberg_partition"); + closeAllRuntimeException(queryRunner); + queryRunner = null; + } + + public static void main(String[] args) throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .warmupMode(WarmupMode.INDI) + .include(".*" + BenchmarkIcebergLazyLoading.class.getSimpleName() + ".*") + .build(); + new Runner(options).run(); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java index 70f771882736e..a9723243deeec 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.Session.SessionBuilder; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveClientConfig; @@ -25,6 +26,9 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.planPrinter.IOPlanPrinter.ColumnConstraint; +import com.facebook.presto.sql.planner.planPrinter.IOPlanPrinter.IOPlan; +import com.facebook.presto.sql.planner.planPrinter.IOPlanPrinter.IOPlan.TableColumnInfo; import com.facebook.presto.sql.tree.AstVisitor; import com.facebook.presto.sql.tree.ColumnDefinition; import com.facebook.presto.sql.tree.CreateTable; @@ -34,6 +38,7 @@ import com.facebook.presto.testing.assertions.Assert; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; import com.facebook.presto.tests.DistributedQueryRunner; +import com.facebook.presto.tests.ResultWithQueryId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.hadoop.fs.FileSystem; @@ -51,7 +56,9 @@ import java.util.Optional; import java.util.function.BiConsumer; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SystemSessionProperties.LEGACY_TIMESTAMP; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.iceberg.CatalogType.HADOOP; @@ -70,6 +77,7 @@ import static com.facebook.presto.iceberg.procedure.RegisterTableProcedure.getFileSystem; import static com.facebook.presto.iceberg.procedure.RegisterTableProcedure.resolveLatestMetadataLocation; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.tests.sql.TestTable.randomTableSuffix; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.String.format; @@ -133,16 +141,16 @@ public void testTime() @Override public void testDescribeTable() { - MaterializedResult expectedColumns = resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("orderkey", "bigint", "", "") - .row("custkey", "bigint", "", "") - .row("orderstatus", "varchar", "", "") - .row("totalprice", "double", "", "") - .row("orderdate", "date", "", "") - .row("orderpriority", "varchar", "", "") - .row("clerk", "varchar", "", "") - .row("shippriority", "integer", "", "") - .row("comment", "varchar", "", "") + MaterializedResult expectedColumns = resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar", "", "", null, null, 2147483647L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar", "", "", null, null, 2147483647L) + .row("clerk", "varchar", "", "", null, null, 2147483647L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar", "", "", null, null, 2147483647L) .build(); MaterializedResult actualColumns = computeActual("DESCRIBE orders"); Assert.assertEquals(actualColumns, expectedColumns); @@ -668,12 +676,12 @@ public void testColumnComments() assertUpdate(session, "CREATE TABLE test_column_comments (_bigint BIGINT COMMENT 'test column comment')"); assertQuery(session, "SHOW COLUMNS FROM test_column_comments", - "VALUES ('_bigint', 'bigint', '', 'test column comment')"); + "VALUES ('_bigint', 'bigint', '', 'test column comment', 19L, null, null)"); assertUpdate("ALTER TABLE test_column_comments ADD COLUMN _varchar VARCHAR COMMENT 'test new column comment'"); assertQuery( "SHOW COLUMNS FROM test_column_comments", - "VALUES ('_bigint', 'bigint', '', 'test column comment'), ('_varchar', 'varchar', '', 'test new column comment')"); + "VALUES ('_bigint', 'bigint', '', 'test column comment', 19L, null, null), ('_varchar', 'varchar', '', 'test new column comment', null, null, 2147483647L)"); dropTable(session, "test_column_comments"); } @@ -739,7 +747,118 @@ private long getLatestSnapshotId() @Test public void testInsertIntoNotNullColumn() { - // TODO: To support non-null column. (NOT_NULL_COLUMN_CONSTRAINT) + assertUpdate("CREATE TABLE test_not_null_table (c1 INTEGER, c2 INTEGER NOT NULL)"); + assertUpdate("INSERT INTO test_not_null_table (c2) VALUES (2)", 1); + assertQuery("SELECT * FROM test_not_null_table", "VALUES (NULL, 2)"); + assertQueryFails("INSERT INTO test_not_null_table (c1) VALUES (1)", "NULL value not allowed for NOT NULL column: c2"); + assertUpdate("DROP TABLE IF EXISTS test_not_null_table"); + + assertUpdate("CREATE TABLE test_commuted_not_null_table (a BIGINT, b BIGINT NOT NULL)"); + assertUpdate("INSERT INTO test_commuted_not_null_table (b) VALUES (2)", 1); + assertQuery("SELECT * FROM test_commuted_not_null_table", "VALUES (NULL, 2)"); + assertQueryFails("INSERT INTO test_commuted_not_null_table (b, a) VALUES (NULL, 3),(4, NULL),(NULL, NULL)", "NULL value not allowed for NOT NULL column: b"); + assertUpdate("DROP TABLE IF EXISTS test_commuted_not_null_table"); + } + + @Test + public void testAddColumnWithMultiplePartitionTransforms() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_multiple_partition_column(a int)"); + assertUpdate("insert into add_multiple_partition_column values 1", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'"))); + + // Add a varchar column with partition transforms `ARRAY['bucket(4)', 'truncate(2)', 'identity']` + assertQuerySucceeds("alter table add_multiple_partition_column add column b varchar with(partitioning = ARRAY['bucket(4)', 'truncate(2)', 'identity'])"); + assertUpdate("insert into add_multiple_partition_column values(2, '1002')", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of( + columnDefinition("a", "integer"), + columnDefinition("b", "varchar")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'", + "partitioning", "ARRAY['bucket(b, 4)','truncate(b, 2)','b']"))); + + // Add a date column with partition transforms `ARRAY['year', 'bucket(8)', 'identity']` + assertQuerySucceeds("alter table add_multiple_partition_column add column c date with(partitioning = ARRAY['year', 'bucket(8)', 'identity'])"); + assertUpdate("insert into add_multiple_partition_column values(3, '1003', date '1984-12-08')", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of( + columnDefinition("a", "integer"), + columnDefinition("b", "varchar"), + columnDefinition("c", "date")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'", + "partitioning", "ARRAY['bucket(b, 4)','truncate(b, 2)','b','year(c)','bucket(c, 8)','c']"))); + + assertQuery("select * from add_multiple_partition_column", + "values(1, null, null), (2, '1002', null), (3, '1003', date '1984-12-08')"); + dropTable(getSession(), "add_multiple_partition_column"); + } + + @Test + public void testAddColumnWithRedundantOrDuplicatedPartitionTransforms() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_redundant_partition_column(a int)"); + + // Specify duplicated transforms would fail + assertQueryFails("alter table add_redundant_partition_column add column b varchar with(partitioning = ARRAY['bucket(4)', 'truncate(2)', 'bucket(4)'])", + "Cannot add duplicate partition field: .*"); + assertQueryFails("alter table add_redundant_partition_column add column b varchar with(partitioning = ARRAY['identity', 'identity'])", + "Cannot add duplicate partition field: .*"); + + // Specify redundant transforms would fail + assertQueryFails("alter table add_redundant_partition_column add column c date with(partitioning = ARRAY['year', 'month'])", + "Cannot add redundant partition field: .*"); + assertQueryFails("alter table add_redundant_partition_column add column c timestamp with(partitioning = ARRAY['day', 'hour'])", + "Cannot add redundant partition field: .*"); + + validateShowCreateTable(catalog, schema, "add_redundant_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_redundant_partition_column") + "'"))); + + dropTable(getSession(), "add_redundant_partition_column"); + } + + @Test + public void testAddColumnWithUnsupportedPropertyValueTypes() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_invalid_partition_column(a int)"); + + assertQueryFails("alter table add_invalid_partition_column add column b varchar with(partitioning = 123)", + "Invalid value for column property 'partitioning': Cannot convert '123' to array\\(varchar\\) or any of \\[varchar]"); + assertQueryFails("alter table add_invalid_partition_column add column b varchar with(partitioning = ARRAY[123, 234])", + "Invalid value for column property 'partitioning': Cannot convert 'ARRAY\\[123,234]' to array\\(varchar\\) or any of \\[varchar]"); + + validateShowCreateTable(catalog, schema, "add_invalid_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_invalid_partition_column") + "'"))); + + dropTable(getSession(), "add_invalid_partition_column"); } @Test @@ -798,7 +917,7 @@ protected void testCreateTableLike() "\"" + schemaName + "\"", "test_create_table_like_copy1", getCustomizedTableProperties(ImmutableMap.of( - "location", "'" + getLocation(schemaName, "test_create_table_like_copy1") + "'"))); + "location", "'" + getLocation(schemaName, "test_create_table_like_copy1") + "'"))); dropTable(session, "test_create_table_like_copy1"); assertUpdate(session, "CREATE TABLE test_create_table_like_copy2 (LIKE test_create_table_like_original EXCLUDING PROPERTIES)"); @@ -992,6 +1111,41 @@ protected void unregisterTable(String schemaName, String newTableName) assertUpdate("CALL system.unregister_table('" + schemaName + "', '" + newTableName + "')"); } + @DataProvider + public Object[][] compressionCodecTestData() + { + return new Object[][] { + // codec, format, shouldSucceed, expectedErrorMessage + {"ZSTD", "PARQUET", true, null}, + {"LZ4", "PARQUET", false, "Compression codec LZ4 is not supported for Parquet format"}, + {"LZ4", "ORC", true, null}, + {"ZSTD", "ORC", true, null}, + {"SNAPPY", "ORC", true, null}, + {"SNAPPY", "PARQUET", true, null}, + {"GZIP", "PARQUET", true, null}, + {"NONE", "PARQUET", true, null}, + }; + } + + @Test(dataProvider = "compressionCodecTestData") + public void testCompressionCodecValidation(String codec, String format, boolean shouldSucceed, String expectedErrorMessage) + { + Session session = Session.builder(getSession()) + .setCatalogSessionProperty("iceberg", "compression_codec", codec) + .build(); + + String tableName = "test_compression_" + codec.toLowerCase() + "_" + format.toLowerCase(); + String createTableSql = "CREATE TABLE " + tableName + " (x int) WITH (format = '" + format + "')"; + + if (shouldSucceed) { + assertUpdate(session, createTableSql); + dropTable(session, tableName); + } + else { + assertQueryFails(session, createTableSql, expectedErrorMessage); + } + } + @Test public void testCreateNestedPartitionedTable() { @@ -1588,6 +1742,9 @@ public void testRegisterTable() { String schemaName = getSession().getSchema().get(); String tableName = "register"; + // Create a `noise` table in the same schema to test that the `getLocation` method finds and returns the right metadata location. + String noiseTableName = "register1"; + assertUpdate("CREATE TABLE " + noiseTableName + " (id integer, value integer)"); assertUpdate("CREATE TABLE " + tableName + " (id integer, value integer)"); assertUpdate("INSERT INTO " + tableName + " VALUES(1, 1)", 1); @@ -1599,6 +1756,7 @@ public void testRegisterTable() unregisterTable(schemaName, newTableName); dropTable(getSession(), tableName); + dropTable(getSession(), noiseTableName); } @Test @@ -1860,6 +2018,66 @@ public void testMetadataDeleteOnTableWithUnsupportedSpecsWhoseDataAllDeleted(Str } } + @Test(dataProvider = "version_and_mode") + public void testMetadataDeleteOnTableAfterWholeRewriteDataFiles(String version, String mode) + { + String errorMessage = "This connector only supports delete where one or more partitions are deleted entirely.*"; + String schemaName = getSession().getSchema().get(); + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '" + version + "', delete_mode = '" + mode + "')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + + // Then evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + // Do not support metadata delete with filter on column `c`, because we have data with old partition spec + assertQueryFails("DELETE FROM " + tableName + " WHERE c > 3", errorMessage); + + // Call procedure rewrite_data_files without filter to rewrite all data files + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => '" + schemaName + "')", 5); + + // Then we can do metadata delete on column `c`, because all data files are rewritten under new partition spec + assertUpdate("DELETE FROM " + tableName + " WHERE c > 3", 2); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, '1001', NULL), (2, '1002', NULL), (3, '1003', 3)"); + } + finally { + dropTable(getSession(), tableName); + } + } + + @Test(dataProvider = "version_and_mode") + public void testMetadataDeleteOnTableAfterPartialRewriteDataFiles(String version, String mode) + { + String errorMessage = "This connector only supports delete where one or more partitions are deleted entirely.*"; + String schemaName = getSession().getSchema().get(); + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '" + version + "', delete_mode = '" + mode + "', partitioning = ARRAY['a'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + + // Then evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + // Do not support metadata delete with filter on column `c`, because we have data with old partition spec + assertQueryFails("DELETE FROM " + tableName + " WHERE c > 3", errorMessage); + + // Call procedure rewrite_data_files with filter to rewrite data files under the prior partition spec + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => '" + schemaName + "', filter => 'a in (1, 2)')", 2); + + // Then we can do metadata delete on column `c`, because all data files are now under new partition spec + assertUpdate("DELETE FROM " + tableName + " WHERE c > 3", 2); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, '1001', NULL), (2, '1002', NULL), (3, '1003', 3)"); + } + finally { + dropTable(getSession(), tableName); + } + } + @DataProvider(name = "version_and_mode") public Object[][] versionAndMode() { @@ -2089,6 +2307,50 @@ public void testDeprecatedTablePropertiesAlterTable() }); } + @Test + public void testRuntimeMetricsReporter() + { + ResultWithQueryId result = getDistributedQueryRunner() + .executeWithQueryId(getSession(), "SELECT * FROM orders WHERE orderkey < 100"); + + DistributedQueryRunner distributedQueryRunner = (DistributedQueryRunner) getQueryRunner(); + + RuntimeStats runtimestats = distributedQueryRunner.getCoordinator() + .getQueryManager() + .getFullQueryInfo(result.getQueryId()) + .getQueryStats() + .getRuntimeStats(); + + String catalog = getSession().getCatalog().get(); + String schema = getSession().getSchema().get(); + String tableName = catalog + "." + schema + ".orders"; + + assertTrue(runtimestats + .getMetrics() + .get(tableName + ".scan.totalPlanningDuration") + .getSum() > 0); + + assertTrue(runtimestats + .getMetrics() + .get(tableName + ".scan.resultDataFiles") + .getCount() > 1); + + assertTrue(runtimestats + .getMetrics() + .get(tableName + ".scan.totalDeleteManifests") + .getCount() > 0); + + assertTrue(runtimestats + .getMetrics() + .get(tableName + ".scan.totalFileSizeInBytes") + .getCount() > 0); + + assertTrue(runtimestats + .getMetrics() + .get(tableName + ".scan.totalFileSizeInBytes") + .getSum() > 0); + } + protected HdfsEnvironment getHdfsEnvironment() { HiveClientConfig hiveClientConfig = new HiveClientConfig(); @@ -2132,8 +2394,8 @@ protected void validatePropertiesForShowCreateTable(String catalog, String schem } protected void validateShowCreateTable(String table, - List columnDefinitions, - Map propertyDescriptions) + List columnDefinitions, + Map propertyDescriptions) { String catalog = getSession().getCatalog().get(); String schema = getSession().getSchema().get(); @@ -2142,9 +2404,9 @@ protected void validateShowCreateTable(String table, } protected void validateShowCreateTable(String catalog, String schema, String table, - List columnDefinitions, - String comment, - Map propertyDescriptions) + List columnDefinitions, + String comment, + Map propertyDescriptions) { validateShowCreateTableInner(catalog, schema, table, Optional.ofNullable(columnDefinitions), Optional.ofNullable(comment), propertyDescriptions); @@ -2156,21 +2418,23 @@ protected ColumnDefinition columnDefinition(String name, String type) } private void validateShowCreateTableInner(String catalog, String schema, String table, - Optional> columnDefinitions, - Optional commentDescription, - Map propertyDescriptions) + Optional> columnDefinitions, + Optional commentDescription, + Map propertyDescriptions) { MaterializedResult showCreateTable = computeActual(format("SHOW CREATE TABLE %s.%s.%s", catalog, schema, table)); String createTableSql = (String) getOnlyElement(showCreateTable.getOnlyColumnAsSet()); SqlParser parser = new SqlParser(); - parser.createStatement(createTableSql).accept(new AstVisitor() { + parser.createStatement(createTableSql).accept(new AstVisitor() + { @Override protected Void visitCreateTable(CreateTable node, Void context) { columnDefinitions.ifPresent(columnDefinitionList -> { ImmutableList.Builder columnDefinitionsBuilder = ImmutableList.builder(); - node.getElements().forEach(element -> element.accept(new AstVisitor() { + node.getElements().forEach(element -> element.accept(new AstVisitor() + { @Override protected Void visitColumnDefinition(ColumnDefinition node, Void context) { @@ -2206,4 +2470,30 @@ private static String getMetadataFileLocation(ConnectorSession session, HdfsEnvi fileSystem, metadataDir).getName(); } + + @Test + public void testIOExplainWithTimestampWithTimeZone() + { + assertUpdate("CREATE TABLE test_tstz_io (id BIGINT, tstz TIMESTAMP WITH TIME ZONE)"); + try { + assertUpdate("INSERT INTO test_tstz_io VALUES (1, TIMESTAMP '2020-01-15 10:30:45.000 UTC')", 1); + + MaterializedResult result = computeActual("EXPLAIN (TYPE IO, FORMAT JSON) SELECT * FROM test_tstz_io " + + "WHERE tstz = TIMESTAMP '2020-01-15 10:30:45.000 UTC'"); + IOPlan ioPlan = jsonCodec(IOPlan.class).fromJson((String) getOnlyElement(result.getOnlyColumnAsSet())); + assertEquals(ioPlan.getInputTableColumnInfos().size(), 1); + TableColumnInfo tableInfo = ioPlan.getInputTableColumnInfos().iterator().next(); + + Optional tstzConstraint = tableInfo.getColumnConstraints().stream() + .filter(c -> c.getColumnName().equals("tstz")) + .findFirst(); + assertTrue(tstzConstraint.isPresent(), "Expected timestamp with time zone column constraint"); + String tstzValue = tstzConstraint.get().getDomain().getRanges().iterator().next().getLow().getValue().get(); + assertTrue(tstzValue.matches("^\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3} .+$"), + "Timestamp with time zone should be formatted as yyyy-MM-dd HH:mm:ss.SSS TZ but was: " + tstzValue); + } + finally { + assertUpdate("DROP TABLE test_tstz_io"); + } + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java index e8910cc6b276b..6af8175848750 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedTestBase.java @@ -20,6 +20,7 @@ import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.FixedWidthType; +import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeParameter; @@ -29,9 +30,13 @@ import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; import com.facebook.presto.hive.HiveClientConfig; +import com.facebook.presto.hive.HiveCompressionCodec; import com.facebook.presto.hive.HiveHdfsConfiguration; +import com.facebook.presto.hive.HiveStorageFormat; +import com.facebook.presto.hive.HiveType; import com.facebook.presto.hive.MetastoreClientConfig; import com.facebook.presto.hive.authentication.NoHdfsAuthentication; +import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.s3.HiveS3Config; import com.facebook.presto.hive.s3.PrestoS3ConfigurationUpdater; import com.facebook.presto.hive.s3.S3ConfigurationUpdater; @@ -45,6 +50,7 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorMetadata; +import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.spi.statistics.ColumnStatistics; import com.facebook.presto.spi.statistics.ConnectorHistogram; @@ -69,7 +75,10 @@ import org.apache.hadoop.fs.Path; import org.apache.iceberg.BaseTable; import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.DataFile; +import org.apache.iceberg.DataFiles; import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Metrics; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Snapshot; @@ -106,13 +115,17 @@ import java.lang.reflect.Field; import java.net.URI; import java.nio.ByteBuffer; +import java.time.Instant; import java.time.LocalDateTime; import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -135,24 +148,34 @@ import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.SYNTHESIZED; import static com.facebook.presto.hive.HiveCommonSessionProperties.PARQUET_BATCH_READ_OPTIMIZATION_ENABLED; +import static com.facebook.presto.iceberg.CatalogType.HADOOP; import static com.facebook.presto.iceberg.FileContent.EQUALITY_DELETES; import static com.facebook.presto.iceberg.FileContent.POSITION_DELETES; +import static com.facebook.presto.iceberg.FileFormat.ORC; +import static com.facebook.presto.iceberg.FileFormat.PARQUET; import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; import static com.facebook.presto.iceberg.IcebergQueryRunner.getIcebergDataDirectoryPath; +import static com.facebook.presto.iceberg.IcebergSessionProperties.COMPRESSION_CODEC; import static com.facebook.presto.iceberg.IcebergSessionProperties.DELETE_AS_JOIN_REWRITE_ENABLED; +import static com.facebook.presto.iceberg.IcebergSessionProperties.DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS; import static com.facebook.presto.iceberg.IcebergSessionProperties.PUSHDOWN_FILTER_ENABLED; import static com.facebook.presto.iceberg.IcebergSessionProperties.STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.UPDATE_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.privilege; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; import static com.facebook.presto.testing.assertions.Assert.assertEquals; @@ -161,13 +184,17 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.nio.file.Files.createTempDirectory; +import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; import static java.util.function.Function.identity; import static org.apache.iceberg.SnapshotSummary.TOTAL_DATA_FILES_PROP; import static org.apache.iceberg.SnapshotSummary.TOTAL_DELETE_FILES_PROP; +import static org.apache.iceberg.types.Type.TypeID.TIME; import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_1_0; import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_2_0; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotNull; @@ -292,7 +319,7 @@ public void testDeleteWithPartitionSpecEvolution() } @Test - public void testRenamePartitionColumn() + public void testRenameIdentityPartitionColumn() { assertQuerySucceeds("create table test_partitioned_table(a int, b varchar) with (partitioning = ARRAY['a'])"); assertQuerySucceeds("insert into test_partitioned_table values(1, '1001'), (2, '1002')"); @@ -309,6 +336,120 @@ public void testRenamePartitionColumn() assertQuerySucceeds("DROP TABLE test_partitioned_table"); } + @Test(dataProvider = "fileFormat") + public void testQueryOnSchemaEvolution(String fileFormat) + { + String tableName = "test_query_on_schema_evolution_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + "(a int, b varchar) with (\"write.format.default\" = '" + fileFormat + "')"); + assertUpdate("INSERT INTO " + tableName + " VALUES(1, '1001'), (2, '1002')", 2); + assertUpdate("ALTER TABLE " + tableName + " RENAME COLUMN a to a2"); + assertQuery("SELECT * FROM " + tableName, "VALUES(1, '1001'), (2, '1002')"); + + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN a varchar"); + assertQuery("SELECT * FROM " + tableName, "VALUES(1, '1001', NULL), (2, '1002', NULL)"); + + assertUpdate("ALTER TABLE " + tableName + " DROP COLUMN a"); + assertQuery("SELECT * FROM " + tableName, "VALUES(1, '1001'), (2, '1002')"); + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN a int"); + assertQuery("SELECT * FROM " + tableName, "VALUES(1, '1001', NULL), (2, '1002', NULL)"); + + assertUpdate("DROP TABLE " + tableName); + } + + @DataProvider(name = "transforms") + public String[][] transforms() + { + return new String[][] { + {"a int", "a"}, + {"a int", "bucket(a, 3)"}, + {"a int", "bucket(a, 3)', 'a"}, + {"a int", "truncate(a, 2)"}, + {"a int", "truncate(a, 2)', 'a', 'bucket(a, 3)"} + }; + } + + @DataProvider(name = "dateTimeTransforms") + public String[][] dateTimeTransforms() + { + return new String[][] { + {"a timestamp", "year(a)"}, + {"a timestamp", "month(a)"}, + {"a timestamp", "day(a)"}, + {"a timestamp", "hour(a)"}, + {"a timestamp", "a', 'month(a)"} + }; + } + + @Test(dataProvider = "transforms") + public void testRenamePartitionColumn(String[] transform) + { + assertQuerySucceeds("DROP TABLE IF EXISTS test_partitioned_table"); + assertQuerySucceeds(format("create table test_partitioned_table(%s) with (partitioning = ARRAY['%s'])", transform[0], transform[1])); + assertQuerySucceeds("insert into test_partitioned_table values(1), (2)"); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 1").getOnlyValue(), 1L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 2").getOnlyValue(), 1L); + + assertQuerySucceeds("alter table test_partitioned_table rename column a to d"); + assertQuerySucceeds("insert into test_partitioned_table values(1), (2), (3)"); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where d = 1").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where d = 2").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where d = 3").getOnlyValue(), 1L); + assertQueryFails("select a from test_partitioned_table", "line 1:8: Column 'a' cannot be resolved"); + + assertQuerySucceeds("alter table test_partitioned_table rename column d to e"); + assertQuerySucceeds("insert into test_partitioned_table values (3)"); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where e = 1").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where e = 2").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where e = 3").getOnlyValue(), 2L); + + assertQuerySucceeds("alter table test_partitioned_table rename column e to a"); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 1").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 2").getOnlyValue(), 2L); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 3").getOnlyValue(), 2L); + assertQuerySucceeds("insert into test_partitioned_table values (3)"); + assertEquals(getQueryRunner().execute("SELECT count(*) FROM test_partitioned_table where a = 3").getOnlyValue(), 3L); + assertQueryFails("select d from test_partitioned_table", "line 1:8: Column 'd' cannot be resolved"); + assertQueryFails("select e from test_partitioned_table", "line 1:8: Column 'e' cannot be resolved"); + assertQuerySucceeds("DROP TABLE test_partitioned_table"); + } + + @Test(dataProvider = "dateTimeTransforms") + public void testRenameDatetimePartitionColumn(String[] transform) + { + Session session = Session.builder(getSession()) + .setSystemProperty(LEGACY_TIMESTAMP, "false") + .build(); + assertQuerySucceeds("DROP TABLE IF EXISTS test_partitioned_table"); + assertQuerySucceeds(format("create table test_partitioned_table(%s) with (partitioning = ARRAY['%s'])", transform[0], transform[1])); + assertQuerySucceeds("insert into test_partitioned_table values(localtimestamp), (localtimestamp)"); + assertEquals(getQueryRunner().execute( + session, + "SELECT count(*) FROM test_partitioned_table where a <= localtimestamp").getOnlyValue(), 2L); + + assertQuerySucceeds("alter table test_partitioned_table rename column a to d"); + assertQuerySucceeds("insert into test_partitioned_table values(localtimestamp), (localtimestamp), (localtimestamp)"); + assertEquals(getQueryRunner().execute( + session, + "SELECT count(*) FROM test_partitioned_table where d <= localtimestamp").getOnlyValue(), 5L); + assertQueryFails("select a from test_partitioned_table", "line 1:8: Column 'a' cannot be resolved"); + + assertQuerySucceeds("alter table test_partitioned_table rename column d to e"); + assertQuerySucceeds("insert into test_partitioned_table values (localtimestamp)"); + assertEquals(getQueryRunner().execute( + session, + "SELECT count(*) FROM test_partitioned_table where e < localtimestamp").getOnlyValue(), 6L); + + assertQuerySucceeds("alter table test_partitioned_table rename column e to a"); + assertEquals(getQueryRunner().execute( + session, + "SELECT count(*) FROM test_partitioned_table where a < localtimestamp").getOnlyValue(), 6L); + assertQuerySucceeds("insert into test_partitioned_table values (localtimestamp)"); + assertEquals(getQueryRunner().execute(session, "SELECT count(*) FROM test_partitioned_table").getOnlyValue(), 7L); + assertQueryFails("select d from test_partitioned_table", "line 1:8: Column 'd' cannot be resolved"); + assertQueryFails("select e from test_partitioned_table", "line 1:8: Column 'e' cannot be resolved"); + assertQuerySucceeds("DROP TABLE test_partitioned_table"); + } + @Test public void testAddPartitionColumn() { @@ -535,10 +676,10 @@ public void testShowColumnsForPartitionedTable() MaterializedResult actual = computeActual("SHOW COLUMNS FROM show_columns_only_identity_partition"); - MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("id", "integer", "", "") - .row("name", "varchar", "", "") - .row("team", "varchar", "partition key", "") + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("id", "integer", "", "", 10L, null, null) + .row("name", "varchar", "", "", null, null, 2147483647L) + .row("team", "varchar", "partition key", "", null, null, 2147483647L) .build(); assertEquals(actual, expectedParametrizedVarchar); @@ -551,10 +692,10 @@ public void testShowColumnsForPartitionedTable() actual = computeActual("SHOW COLUMNS FROM show_columns_with_non_identity_partition"); - expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("id", "integer", "", "") - .row("name", "varchar", "", "") - .row("team", "varchar", "partition by truncate[1], identity", "") + expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("id", "integer", "", "", 10L, null, null) + .row("name", "varchar", "", "", null, null, 2147483647L) + .row("team", "varchar", "partition by truncate[1], identity", "", null, null, 2147483647L) .build(); assertEquals(actual, expectedParametrizedVarchar); @@ -635,6 +776,21 @@ public void testCreateTableWithCustomLocation() } } + @Test + protected void testCreateTableAndValidateIcebergTableName() + { + String tableName = "test_create_table_for_validate_name"; + Session session = getSession(); + assertUpdate(session, format("CREATE TABLE %s (col1 INTEGER, aDate DATE)", tableName)); + Table icebergTable = loadTable(tableName); + + String catalog = session.getCatalog().get(); + String schemaName = session.getSchema().get(); + assertEquals(icebergTable.name(), catalog + "." + schemaName + "." + tableName); + + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + @Test public void testPartitionedByTimeType() { @@ -1408,6 +1564,36 @@ public void testEqualityDeletesWithHiddenPartitionsEvolution(String fileFormat, assertQuery(session, "SELECT * FROM " + tableName, "VALUES (1, '1001', NULL, NULL), (3, '1003', NULL, NULL), (6, '1004', 1, NULL), (6, '1006', 2, 'th002')"); } + @Test(dataProvider = "equalityDeleteOptions") + public void testEqualityDeletesWithDataSequenceNumber(String fileFormat, boolean joinRewriteEnabled) + throws Exception + { + Session session = deleteAsJoinEnabled(joinRewriteEnabled); + String tableName = "test_v2_row_delete_" + randomTableSuffix(); + String tableName2 = "test_v2_row_delete_2_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + "(id int, data varchar) WITH (\"write.format.default\" = '" + fileFormat + "')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + + assertUpdate("CREATE TABLE " + tableName2 + "(id int, data varchar) WITH (\"write.format.default\" = '" + fileFormat + "')"); + assertUpdate("INSERT INTO " + tableName2 + " VALUES (1, 'a')", 1); + + Table icebergTable = updateTable(tableName); + writeEqualityDeleteToNationTable(icebergTable, ImmutableMap.of("id", 1)); + + Table icebergTable2 = updateTable(tableName2); + writeEqualityDeleteToNationTable(icebergTable2, ImmutableMap.of("id", 1)); + + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'b'), (2, 'a'), (3, 'a')", 3); + assertUpdate("INSERT INTO " + tableName2 + " VALUES (1, 'b'), (2, 'a'), (3, 'a')", 3); + + assertQuery(session, "SELECT * FROM " + tableName, "VALUES (1, 'b'), (2, 'a'), (3, 'a')"); + + assertQuery(session, "SELECT \"$data_sequence_number\", * FROM " + tableName, "VALUES (3, 1, 'b'), (3, 2, 'a'), (3, 3, 'a')"); + + assertQuery(session, "SELECT a.\"$data_sequence_number\", b.\"$data_sequence_number\" from " + tableName + " as a, " + tableName2 + " as b where a.id = b.id", + "VALUES (3, 3), (3, 3), (3, 3)"); + } + @Test public void testPartShowStatsWithFilters() { @@ -1501,6 +1687,216 @@ public void testWithoutSortOrder() } } + @Test + public void testRewriteDataFilesWithSortOrder() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'EEEE'), (3, 'CCCC'), (1, 'AAAA')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'DDDD')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'FFFF')", 2); + + assertUpdate(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id'])", schema, tableName), 7); + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + assertEquals(result.getOnlyColumnAsSet().size(), 1); + String filePath = String.valueOf(result.getOnlyValue()); + assertTrue(isFileSorted(filePath, "id", "ASC")); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testRewriteDataFilesWithSortOrderOnPartitionedTables() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar) with (partitioning = ARRAY['emp_name'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'AAAA'), (3, 'CCCC'), (1, 'BBBB')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'AAAA')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'BBBB')", 2); + + assertUpdate(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id'])", schema, tableName), 7); + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + assertEquals(result.getOnlyColumnAsSet().size(), 3); + for (Object filePath : result.getOnlyColumnAsSet()) { + assertTrue(isFileSorted(String.valueOf(filePath), "id", "ASC")); + } + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testRewriteDataFilesWithDescSortOrder() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'EEEE'), (3, 'CCCC'), (1, 'AAAA')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'DDDD')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'FFFF')", 2); + + assertUpdate(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id DESC'])", schema, tableName), 7); + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + assertEquals(result.getOnlyColumnAsSet().size(), 1); + String filePath = String.valueOf(result.getOnlyValue()); + assertTrue(isFileSorted(filePath, "id", "DESC")); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testRewriteDataFilesWithDescSortOrderOnPartitionedTables() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar) with (partitioning = ARRAY['emp_name'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'AAAA'), (3, 'CCCC'), (1, 'BBBB')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'AAAA')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'BBBB')", 2); + + assertUpdate(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id DESC'])", schema, tableName), 7); + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + assertEquals(result.getOnlyColumnAsSet().size(), 3); + for (Object filePath : result.getOnlyColumnAsSet()) { + assertTrue(isFileSorted(String.valueOf(filePath), "id", "DESC")); + } + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testRewriteDataFilesWithCompatibleSortOrderForSortedTable() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar) with (sorted_by = ARRAY['id DESC'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'EEEE'), (3, 'CCCC'), (1, 'AAAA')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'DDDD')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'FFFF')", 2); + for (Object filePath : computeActual("SELECT file_path from \"" + tableName + "$files\"").getOnlyColumnAsSet()) { + assertTrue(isFileSorted(String.valueOf(filePath), "id", "DESC")); + } + + assertUpdate(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id DESC', 'emp_name ASC'])", schema, tableName), 7); + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + assertEquals(result.getOnlyColumnAsSet().size(), 1); + String filePath = String.valueOf(result.getOnlyValue()); + assertTrue(isFileSorted(filePath, "id", "DESC")); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testNotAllRewriteDataFilesWithIncompatibleSortOrderForSortedTable() + throws IOException + { + String tableName = "test_rewrite_data_with_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + "(id int, emp_name varchar) with (sorted_by = ARRAY['id'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (5, 'EEEE'), (3, 'CCCC'), (1, 'AAAA')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'BBBB'), (4,'DDDD')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (9, 'CCCC'), (11,'FFFF')", 2); + for (Object filePath : computeActual("SELECT file_path from \"" + tableName + "$files\"").getOnlyColumnAsSet()) { + assertTrue(isFileSorted(String.valueOf(filePath), "id", "ASC")); + } + + assertQueryFails(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['id DESC'])", schema, tableName), + "Specified sort order is incompatible with the target table's internal sort order"); + + assertQueryFails(format("CALL system.rewrite_data_files(schema => '%s', table_name => '%s', sorted_by => ARRAY['emp_name ASC', 'id ASC'])", schema, tableName), + "Specified sort order is incompatible with the target table's internal sort order"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testRewriteDataFilesWithFilterAndSortOrder() + throws IOException + { + String tableName = "test_rewrite_data_with_filter_and_sort_order_" + randomTableSuffix(); + String schema = getSession().getSchema().get(); + try { + assertUpdate("CREATE TABLE " + tableName + " (id int, emp_name varchar) with (partitioning = ARRAY['emp_name'])"); + + // Create multiple data files with mixed id values so that only a subset is rewritten + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'AAAAA'), (2, 'BBBBB'), (4, 'AAAAA')", 3); + assertUpdate("INSERT INTO " + tableName + " VALUES (4, 'BBBBB'), (0, 'BBBBB')", 2); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, 'AAAAA'), (3, 'BBBBB')", 2); + + // Rewrite only rows with `emp_name = 'AAAAA'` and sort the rewritten data files by `id desc` + assertUpdate(format( + "CALL system.rewrite_data_files(" + + "schema => '%s', " + + "table_name => '%s', " + + "filter => 'emp_name = ''AAAAA''', " + + "sorted_by => ARRAY['id desc'])", + schema, tableName), 3); + + // Rewrite only rows with `emp_name = 'BBBBB'` and sort the rewritten data files by `id asc` + assertUpdate(format( + "CALL system.rewrite_data_files(" + + "schema => '%s', " + + "table_name => '%s', " + + "filter => 'emp_name = ''BBBBB''', " + + "sorted_by => ARRAY['id asc'])", + schema, tableName), 4); + + // All data is still present + assertQuery( + "SELECT id, emp_name FROM " + tableName, + "VALUES " + + "(1, 'AAAAA'), " + + "(2, 'BBBBB'), " + + "(4, 'AAAAA'), " + + "(4, 'BBBBB'), " + + "(0, 'BBBBB'), " + + "(3, 'AAAAA'), " + + "(3, 'BBBBB')"); + + // There are 2 data files after the rewriting + MaterializedResult result = computeActual("SELECT file_path from \"" + tableName + "$files\""); + List paths = result.getOnlyColumn().map(String::valueOf).distinct().toList(); + assertEquals(paths.size(), 2); + + // The data file under partition `emp_name = 'AAAAA'` is sorted by `id DESC` + List dataFileA = paths.stream().filter(str -> str.contains("AAAAA")).toList(); + assertEquals(dataFileA.size(), 1); + assertTrue(isFileSorted(String.valueOf(dataFileA.get(0)), "id", "DESC")); + + // The data file under partition `emp_name = 'BBBBB'` is sorted by `id ASC` + List dataFileB = paths.stream().filter(str -> str.contains("BBBBB")).toList(); + assertEquals(dataFileB.size(), 1); + assertTrue(isFileSorted(String.valueOf(dataFileB.get(0)), "id", "ASC")); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + public boolean isFileSorted(String path, String sortColumnName, String sortOrder) throws IOException { @@ -1783,6 +2179,52 @@ public void testMetadataVersionsMaintainingProperties() } } + @Test + public void testAlteringMetadataVersionsMaintainingProperties() + throws Exception + { + String alteringTableName = "test_table_with_altering_properties"; + try { + // Create a table with default table properties that maintain 100 previous metadata versions in current metadata, + // and do not automatically delete any metadata files + assertUpdate("CREATE TABLE " + alteringTableName + " (a INTEGER, b VARCHAR)"); + + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (1, '1001'), (2, '1002')", 2); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (3, '1003'), (4, '1004')", 2); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (5, '1005'), (6, '1006')", 2); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (7, '1007'), (8, '1008')", 2); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (9, '1009'), (10, '1010')", 2); + + Table targetTable = loadTable(alteringTableName); + TableMetadata currentTableMetadata = ((BaseTable) targetTable).operations().current(); + // Target table's current metadata record all 5 previous metadata files + assertEquals(currentTableMetadata.previousFiles().size(), 5); + + FileSystem fileSystem = getHdfsEnvironment().getFileSystem(new HdfsContext(SESSION), new Path(targetTable.location())); + // Target table's all existing metadata files count is 6 + FileStatus[] settingTableFiles = fileSystem.listStatus(new Path(targetTable.location(), "metadata"), name -> name.getName().contains(METADATA_FILE_EXTENSION)); + assertEquals(settingTableFiles.length, 6); + + // Alter the table to set properties that maintain only 1 previous metadata version in current metadata, + // and delete unuseful metadata files after each commit + assertUpdate("ALTER TABLE " + alteringTableName + " SET PROPERTIES(\"write.metadata.previous-versions-max\" = 1, \"write.metadata.delete-after-commit.enabled\" = true)"); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (11, '1011'), (12, '1012')", 2); + assertUpdate("INSERT INTO " + alteringTableName + " VALUES (13, '1013'), (14, '1014')", 2); + + targetTable = loadTable(alteringTableName); + currentTableMetadata = ((BaseTable) targetTable).operations().current(); + // Table `test_table_with_setting_properties`'s current metadata only record 1 previous metadata file + assertEquals(currentTableMetadata.previousFiles().size(), 1); + + // Target table's all existing metadata files count is 2 + FileStatus[] defaultTableFiles = fileSystem.listStatus(new Path(targetTable.location(), "metadata"), name -> name.getName().contains(METADATA_FILE_EXTENSION)); + assertEquals(defaultTableFiles.length, 2); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + alteringTableName); + } + } + @DataProvider(name = "batchReadEnabled") public Object[] batchReadEnabledReader() { @@ -1820,6 +2262,116 @@ public void testDecimal(boolean decimalVectorReaderEnabled) } } + public void testMetadataDeleteOnV2MorTableWithRewriteDataFiles() + { + String tableName = "test_rewrite_data_files_table_" + randomTableSuffix(); + try { + // Create a table with partition column `a`, and insert some data under this partition spec + assertUpdate("CREATE TABLE " + tableName + " (a INTEGER, b VARCHAR) WITH (format_version = '2', delete_mode = 'merge-on-read')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, '1001'), (2, '1002')", 2); + assertUpdate("DELETE FROM " + tableName + " WHERE a = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002')"); + + Table icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 1); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 1); + + // Evaluate the partition spec by adding a partition column `c`, and insert some data under the new partition spec + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN c INTEGER WITH (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, '1003', 3), (4, '1004', 4), (5, '1005', 5)", 3); + + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 4); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 1); + + // Execute row level delete with filter on column `b` + assertUpdate("DELETE FROM " + tableName + " WHERE b = '1004'", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3), (5, '1005', 5)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 4); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 2); + + assertQueryFails("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'a > 3')", ".*"); + assertQueryFails("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'c > 3')", ".*"); + + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch')", 3); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3), (5, '1005', 5)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 3); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + + // Do metadata delete on column `a`, because all partition specs contains partition column `a` + assertUpdate("DELETE FROM " + tableName + " WHERE c = 5", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 2); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + + assertUpdate("call system.rewrite_data_files(table_name => '" + tableName + "', schema => 'tpch', filter => 'c > 2')", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (2, '1002', NULL), (3, '1003', 3)"); + icebergTable = loadTable(tableName); + assertHasDataFiles(icebergTable.currentSnapshot(), 2); + assertHasDeleteFiles(icebergTable.currentSnapshot(), 0); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + @Test + public void testDropBranch() + { + assertUpdate("CREATE TABLE test_table_branch (id1 BIGINT, id2 BIGINT)"); + assertUpdate("INSERT INTO test_table_branch VALUES (0, 00), (1, 10)", 2); + + Table icebergTable = loadTable("test_table_branch"); + icebergTable.manageSnapshots().createBranch("testBranch1").commit(); + + assertUpdate("INSERT INTO test_table_branch VALUES (2, 30), (3, 30)", 2); + icebergTable.manageSnapshots().createBranch("testBranch2").commit(); + assertUpdate("INSERT INTO test_table_branch VALUES (4, 40), (5, 50)", 2); + assertEquals(icebergTable.refs().size(), 3); + + assertQuery("SELECT count(*) FROM test_table_branch FOR SYSTEM_VERSION AS OF 'testBranch1'", "VALUES 2"); + assertQuery("SELECT count(*) FROM test_table_branch FOR SYSTEM_VERSION AS OF 'testBranch2'", "VALUES 4"); + assertQuery("SELECT count(*) FROM test_table_branch FOR SYSTEM_VERSION AS OF 'main'", "VALUES 6"); + + assertQuerySucceeds("ALTER TABLE test_table_branch DROP BRANCH 'testBranch1'"); + icebergTable = loadTable("test_table_branch"); + assertEquals(icebergTable.refs().size(), 2); + assertQueryFails("ALTER TABLE test_table_branch DROP BRANCH 'testBranchNotExist'", "Branch testBranchNotExist doesn't exist in table test_table_branch"); + assertQuerySucceeds("ALTER TABLE test_table_branch DROP BRANCH IF EXISTS 'testBranch2'"); + assertQuerySucceeds("ALTER TABLE test_table_branch DROP BRANCH IF EXISTS 'testBranchNotExist'"); + assertQuerySucceeds("DROP TABLE test_table_branch"); + } + + @Test + public void testDropTag() + { + assertUpdate("CREATE TABLE test_table_tag (id1 BIGINT, id2 BIGINT)"); + assertUpdate("INSERT INTO test_table_tag VALUES (0, 00), (1, 10)", 2); + + Table icebergTable = loadTable("test_table_tag"); + icebergTable.manageSnapshots().createTag("testTag1", icebergTable.currentSnapshot().snapshotId()).commit(); + + assertUpdate("INSERT INTO test_table_tag VALUES (2, 30), (3, 30)", 2); + icebergTable.manageSnapshots().createTag("testTag2", icebergTable.currentSnapshot().snapshotId()).commit(); + assertUpdate("INSERT INTO test_table_tag VALUES (4, 40), (5, 50)", 2); + assertEquals(icebergTable.refs().size(), 3); + + assertQuery("SELECT count(*) FROM test_table_tag FOR SYSTEM_VERSION AS OF 'testTag1'", "VALUES 2"); + assertQuery("SELECT count(*) FROM test_table_tag FOR SYSTEM_VERSION AS OF 'testTag2'", "VALUES 4"); + assertQuery("SELECT count(*) FROM test_table_tag FOR SYSTEM_VERSION AS OF 'main'", "VALUES 6"); + + assertQuerySucceeds("ALTER TABLE test_table_tag DROP TAG 'testTag1'"); + icebergTable = loadTable("test_table_tag"); + assertEquals(icebergTable.refs().size(), 2); + assertQueryFails("ALTER TABLE test_table_tag DROP TAG 'testTagNotExist'", "Tag testTagNotExist doesn't exist in table test_table_tag"); + assertQuerySucceeds("ALTER TABLE test_table_tag DROP TAG IF EXISTS 'testTag2'"); + assertQuerySucceeds("ALTER TABLE test_table_tag DROP TAG IF EXISTS 'testTagNotExist'"); + assertQuerySucceeds("DROP TABLE test_table_tag"); + } + @Test public void testRefsTable() { @@ -1872,6 +2424,75 @@ public void testRefsTable() assertQuery("SELECT * FROM test_table_references FOR SYSTEM_VERSION AS OF 'testTag' where id1=1", "VALUES(1, NULL)"); } + @Test + public void testMetadataLogTable() + { + try { + assertUpdate("CREATE TABLE test_table_metadatalog (id1 BIGINT, id2 BIGINT)"); + assertQuery("SELECT count(*) FROM \"test_table_metadatalog$metadata_log_entries\"", "VALUES 1"); + //metadata file created at table creation + assertQuery("SELECT latest_snapshot_id FROM \"test_table_metadatalog$metadata_log_entries\"", "VALUES NULL"); + + assertUpdate("INSERT INTO test_table_metadatalog VALUES (0, 00), (1, 10), (2, 20)", 3); + Table icebergTable = loadTable("test_table_metadatalog"); + Snapshot latestSnapshot = icebergTable.currentSnapshot(); + assertQuery("SELECT count(*) FROM \"test_table_metadatalog$metadata_log_entries\"", "VALUES 2"); + assertQuery("SELECT latest_snapshot_id FROM \"test_table_metadatalog$metadata_log_entries\" order by timestamp DESC limit 1", "values " + latestSnapshot.snapshotId()); + } + finally { + assertUpdate("DROP TABLE IF EXISTS test_table_metadatalog"); + } + } + + @DataProvider(name = "timezoneId") + public Object[][] getTimezonesId() + { + return new Object[][]{{"UTC"}, {"America/Los_Angeles"}, {"Asia/Shanghai"}, {"Asia/Kolkata"}, {"America/Bahia_Banderas"}, {"Europe/Brussels"}}; + } + + @Test(dataProvider = "timezoneId") + public void testMetadataLogTableWithTimeZoneId(String zoneId) + { + try { + Session sessionForTimeZone = Session.builder(getSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(zoneId)).build(); + + assertUpdate(sessionForTimeZone, "CREATE TABLE test_table_metadatalog_tz_id (id1 BIGINT, id2 BIGINT)"); + assertQuery(sessionForTimeZone, "SELECT count(*) FROM \"test_table_metadatalog_tz_id$metadata_log_entries\"", "VALUES 1"); + assertQuery(sessionForTimeZone, "SELECT latest_snapshot_id FROM \"test_table_metadatalog_tz_id$metadata_log_entries\"", "VALUES NULL"); + Table icebergTable = loadTable("test_table_metadatalog_tz_id"); + TableMetadata tableMetadata = ((BaseTable) icebergTable).operations().current(); + ZonedDateTime zonedDateTime1 = Instant.ofEpochMilli(tableMetadata.lastUpdatedMillis()) + .atZone(ZoneId.of(zoneId)); + //metadata file created at table creation + String metadataFileLocation1 = tableMetadata.metadataFileLocation(); + + assertUpdate("INSERT INTO test_table_metadatalog_tz_id VALUES (0, 00), (1, 10), (2, 20)", 3); + icebergTable = loadTable("test_table_metadatalog_tz_id"); + tableMetadata = ((BaseTable) icebergTable).operations().current(); + ZonedDateTime zonedDateTime2 = Instant.ofEpochMilli(tableMetadata.lastUpdatedMillis()) + .atZone(ZoneId.of(zoneId)); + //metadata file created after table insertion + String metadataFileLocation2 = tableMetadata.metadataFileLocation(); + + Snapshot latestSnapshot = icebergTable.currentSnapshot(); + assertQuery("SELECT count(*) FROM \"test_table_metadatalog_tz_id$metadata_log_entries\"", "VALUES 2"); + assertQuery("SELECT latest_snapshot_id FROM \"test_table_metadatalog_tz_id$metadata_log_entries\" order by timestamp DESC limit 1", "values " + latestSnapshot.snapshotId()); + + MaterializedResult actual = getQueryRunner().execute(sessionForTimeZone, "SELECT * FROM \"test_table_metadatalog_tz_id$metadata_log_entries\""); + assertThat(actual).hasSize(2); + MaterializedResult expected = resultBuilder(getSession(), TIMESTAMP_WITH_TIME_ZONE, VARCHAR, BIGINT, INTEGER, BIGINT) + .row(zonedDateTime1, metadataFileLocation1, null, null, null) + .row(zonedDateTime2, metadataFileLocation2, latestSnapshot.snapshotId(), latestSnapshot.schemaId(), latestSnapshot.sequenceNumber()) + .build(); + + assertEquals(actual, expected); + } + finally { + assertUpdate("DROP TABLE IF EXISTS test_table_metadatalog_tz_id"); + } + } + @Test public void testAllIcebergType() { @@ -2001,6 +2622,69 @@ public void testHiddenColumns() testDataSequenceNumberHiddenColumn(); } + @Test + public void testDeleteWithSpecialCharacterColumnName() + { + assertUpdate("CREATE TABLE test_special_character_column_name (\"\" int, name varchar)"); + assertUpdate("INSERT INTO test_special_character_column_name VALUES (1, 'abc'), (2, 'def'), (3, 'ghi')", 3); + assertUpdate("DELETE FROM test_special_character_column_name where \"\" = 2", 1); + assertUpdate("DROP TABLE IF EXISTS test_special_character_column_name"); + } + + @Test + public void testDeletedHiddenColumn() + { + assertUpdate("DROP TABLE IF EXISTS test_deleted"); + assertUpdate("CREATE TABLE test_deleted AS SELECT * FROM tpch.tiny.region WHERE regionkey=0", 1); + assertUpdate("INSERT INTO test_deleted SELECT * FROM tpch.tiny.region WHERE regionkey=1", 1); + + assertQuery("SELECT \"$deleted\" FROM test_deleted", format("VALUES %s, %s", "false", "false")); + + assertUpdate("DELETE FROM test_deleted WHERE regionkey=1", 1); + assertEquals(computeActual("SELECT * FROM test_deleted").getRowCount(), 1); + assertQuery("SELECT \"$deleted\" FROM test_deleted ORDER BY \"$deleted\"", format("VALUES %s, %s", "false", "true")); + } + + @Test + public void testDeleteFilePathHiddenColumn() + { + assertUpdate("DROP TABLE IF EXISTS test_delete_file_path"); + assertUpdate("CREATE TABLE test_delete_file_path AS SELECT * FROM tpch.tiny.region WHERE regionkey=0", 1); + assertUpdate("INSERT INTO test_delete_file_path SELECT * FROM tpch.tiny.region WHERE regionkey=1", 1); + + assertQuery("SELECT \"$delete_file_path\" FROM test_delete_file_path", format("VALUES %s, %s", "NULL", "NULL")); + + assertUpdate("DELETE FROM test_delete_file_path WHERE regionkey=1", 1); + assertEquals(computeActual("SELECT * FROM test_delete_file_path").getRowCount(), 1); + assertEquals(computeActual("SELECT \"$delete_file_path\" FROM test_delete_file_path").getRowCount(), 2); + + assertUpdate("DELETE FROM test_delete_file_path WHERE regionkey=0", 1); + computeActual("SELECT \"$delete_file_path\" FROM test_delete_file_path").getMaterializedRows().forEach(row -> { + assertEquals(row.getFieldCount(), 1); + assertNotNull(row.getField(0)); + }); + } + + @Test(dataProvider = "equalityDeleteOptions") + public void testEqualityDeletesWithDeletedHiddenColumn(String fileFormat, boolean joinRewriteEnabled) + throws Exception + { + Session session = deleteAsJoinEnabled(joinRewriteEnabled); + String tableName = "test_v2_row_delete_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + "(id int, data varchar) WITH (\"write.format.default\" = '" + fileFormat + "')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + + Table icebergTable = updateTable(tableName); + writeEqualityDeleteToNationTable(icebergTable, ImmutableMap.of("id", 1)); + + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'b'), (2, 'a'), (3, 'a')", 3); + + assertQuery(session, "SELECT * FROM " + tableName, "VALUES (1, 'b'), (2, 'a'), (3, 'a')"); + + assertQuery(session, "SELECT \"$deleted\", * FROM " + tableName, + "VALUES (true, 1, 'a'), (false, 1, 'b'), (false, 2, 'a'), (false, 3, 'a')"); + } + @DataProvider(name = "pushdownFilterEnabled") public Object[][] pushdownFilterEnabledProvider() { @@ -2336,12 +3020,44 @@ public void testInformationSchemaQueries() } @Test - public void testUpdateWithPredicates() + public void testUpdateWithDuplicateValues() { - String tableName = "test_update_predicates_" + randomTableSuffix(); - assertUpdate("CREATE TABLE " + tableName + "(id int, full_name varchar(20))"); - assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'aaaa'), (2, 'bbbb'), (3, 'cccc')", 3); - // update single row on id + String tableName = "test_update_duplicate_values_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + "(id int, column1 varchar(10), column2 varchar(10), column3 int)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a', 'a', 1), (2, 'b', 'b', 1), (3, 'c', 'c', 1)", 3); + + // update single row with duplicate values + assertUpdate("UPDATE " + tableName + " SET column1 = CAST(1 as varchar), column2 = CAST(1 as varchar), column3 = 11 WHERE id = 1", 1); + assertQuery("SELECT id, column1, column2, column3 FROM " + tableName, "VALUES (1, '1', '1', 11), (2, 'b', 'b', 1), (3, 'c', 'c', 1)"); + } + + @Test + public void testUpdateWithDifferentCaseColumnNames() + { + String tableName = "test_update_case_" + randomTableSuffix(); + try { + assertUpdate("CREATE TABLE " + tableName + " (id INT, str1 VARCHAR)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + + assertUpdate("UPDATE " + tableName + " SET id = 11 WHERE id = 1", 1); + assertQuery("SELECT id FROM " + tableName, "VALUES 11"); + + assertUpdate("UPDATE " + tableName + " SET ID = 111 WHERE ID = 11", 1); + assertQuery("SELECT ID FROM " + tableName, "VALUES 111"); + assertQuery("SELECT id FROM " + tableName, "VALUES 111"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + public void testUpdateWithPredicates() + { + String tableName = "test_update_predicates_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + tableName + "(id int, full_name varchar(20))"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'aaaa'), (2, 'bbbb'), (3, 'cccc')", 3); + // update single row on id assertUpdate("UPDATE " + tableName + " SET full_name = 'aaaa AAAA' WHERE id = 1", 1); assertQuery("SELECT id, full_name FROM " + tableName, "VALUES (1, 'aaaa AAAA'), (2, 'bbbb'), (3, 'cccc')"); @@ -2426,6 +3142,1012 @@ public void testUpdateOnPartitionTable() assertQuery("SELECT a, b FROM " + tableName, "VALUES (3,'first'), (4,'4th'), (3,'third')"); } + @DataProvider + public Object[][] partitionedProvider() + { + return new Object[][] { + {""}, // Without partitions. + {"WITH (partitioning = ARRAY['address'])"} + }; + } + + @Test(dataProvider = "partitionedProvider") + public void testMergeSimpleQuery(String partitioning) + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", targetTable, partitioning)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeSimpleQueryPartitioned() + { + String targetTable = "merge_simple_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(SELECT * FROM (VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville'))) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithoutTablesAliases() + { + String targetTable = "test_without_aliases_target_" + randomTableSuffix(); + String sourceTable = "test_without_aliases_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s USING %s ", targetTable, sourceTable) + + format("ON (%s.customer = %s.customer) ", targetTable, sourceTable) + + format("WHEN MATCHED THEN" + + " UPDATE SET purchases = %s.purchases + %s.purchases, address = %s.address ", sourceTable, targetTable, sourceTable) + + format("WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(%s.customer, %s.purchases, %s.address)", sourceTable, sourceTable, sourceTable); + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeUsingUpdateAndInsert() + { + String targetTable = "merge_simple_target_" + randomTableSuffix(); + String sourceTable = "merge_simple_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Ed', 7, 'Etherville'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] mergeIncludeWhenAndWhenNotMatchedProvider() + { + return new Object[][] { + {true}, + {false}, + }; + } + + @Test(dataProvider = "mergeIncludeWhenAndWhenNotMatchedProvider") + public void testMergeOnlyInsertNewRows(boolean includeWhenMatched) + { + // This test verifies that the MERGE command works correctly when no rows in the source table meet the MERGE condition. + // It means that the MERGE command will behave as an INSERT command. + String targetTable = "merge_inserts_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + (includeWhenMatched ? + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " : "") + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 2); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test(dataProvider = "mergeIncludeWhenAndWhenNotMatchedProvider") + public void testMergeOnlyUpdateExistingRows(boolean includeWhenNotMatched) + { + // This test verifies that the MERGE command works correctly when all rows in the source table meet the MERGE condition. + // It means that the MERGE command will behave as an UPDATE command. + String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); + String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville')", sourceTable), 3); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + (includeWhenNotMatched ? + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)" : ""); + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave_updated', 22, 'Darbyshire'), ('Aaron_updated', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol_updated', 12, 'Centreville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeEmptyTargetTable() + { + String targetTable = "merge_inserts_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')) AS s(customer, purchases, address)" + + "ON (t.customer = s.customer)" + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 2); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeEmptySourceTable() + { + String targetTable = "merge_all_columns_updated_target_" + randomTableSuffix(); + String sourceTable = "merge_all_columns_updated_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_updated'), purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 0); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] partitionedAndBucketedProvider() + { + return new Object[][] { + {""}, // Without partitions. + {"WITH (partitioning = ARRAY['customer'])"}, + {"WITH (partitioning = ARRAY['purchases'])"}, + {"WITH (partitioning = ARRAY['bucket(customer, 3)'])"}, + {"WITH (partitioning = ARRAY['bucket(purchases, 4)'])"}, + }; + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMergeUsingSelectQuery(String partitioning) + { + String targetTable = "merge_various_target_" + randomTableSuffix(); + String sourceTable = "merge_various_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases VARCHAR) %s", targetTable, partitioning)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); + assertUpdate(format("INSERT INTO %s (customer, purchases) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (SELECT customer, purchases FROM %s) s ", targetTable, sourceTable) + + "ON (t.purchases = s.purchases) " + + "WHEN MATCHED THEN" + + " UPDATE SET customer = CONCAT(t.customer, '_', s.customer) " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases) VALUES(s.customer, s.purchases)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Dave', 'dates'), ('Carol_Craig', 'candles'), ('Lou_Len', 'limes'), ('Joe', 'jellybeans')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test(dataProvider = "partitionedAndBucketedProvider") + public void testMultipleMergeCommands(String partitioning) + { + int targetCustomerCount = 32; + String targetTable = "merge_multiple_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR) %s", targetTable, partitioning)); + + // joe_1, 1000, 91000, jan_1, 1 Poe Ct + // ... + // joe_15, 1000, 91000, jan_15, 15 Poe Ct + String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // joe_16, 2000, 92000, jan_16, 16 Poe Ct + // ... + // joe_32, 2000, 92000, jan_32, 32 Poe Ct + String originalInsertSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 2000, 92000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) " + + "VALUES %s, %s", targetTable, originalInsertFirstHalf, originalInsertSecondHalf), targetCustomerCount - 1); + + // joe_16, 3000, 83000, jan_16, 16 Eop Ct + // ... + // joe_32, 3000, 83000, jan_32, 32 Eop Ct + String firstMergeSource = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jill_%s', '%s Eop Ct')", intValue, 3000, 83000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, firstMergeSource) + + "ON t.customer = s.customer " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases, zipcode = s.zipcode, spouse = s.spouse, address = s.address"; + + assertUpdate(sqlMergeCommand, targetCustomerCount / 2); + + assertQuery( + format("SELECT customer, purchases, zipcode, spouse, address FROM %s", targetTable), + format("VALUES %s, %s", originalInsertFirstHalf, firstMergeSource)); + + // jack_32, 4000, 74000, jan_32, 32 Poe Ct + // ... + // jack_48, 4000, 74000, jan_48, 48 Poe Ct + String nextInsert = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('jack_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 4000, 74000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertUpdate(format("INSERT INTO %s (customer, purchases, zipcode, spouse, address) VALUES %s", targetTable, nextInsert), targetCustomerCount / 2); + + // joe_1, 5000, 85000, jen_32, 32 Poe Ct + // ... + // joe_48, 5000, 85000, jen_48, 48 Poe Ct + String secondMergeSource = IntStream.range(1, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // Note that the following MERGE INTO does not update the "purchases" column. + sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES %s) AS s(customer, purchases, zipcode, spouse, address)", targetTable, secondMergeSource) + + "ON t.customer = s.customer " + + "WHEN MATCHED THEN" + + " UPDATE SET zipcode = s.zipcode, spouse = s.spouse, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, zipcode, spouse, address) VALUES(s.customer, s.purchases, s.zipcode, s.spouse, s.address)"; + + assertUpdate(sqlMergeCommand, targetCustomerCount * 3 / 2 - 1); + + // joe_1, 1000, 85000, jen_1, 1 Poe Ct + // ... + // joe_15, 1000, 85000, jen_15, 15 Poe Ct + String updatedFirstHalf = IntStream.range(1, targetCustomerCount / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 1000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // joe_16, 3000, 85000, jen_16, 16 Poe Ct + // ... + // joe_32, 3000, 85000, jen_32, 32 Poe Ct + String updatedSecondHalf = IntStream.range(targetCustomerCount / 2, targetCustomerCount) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 3000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + // jack_32, 4000, 74000, jan_32, 32 Poe Ct + // ... + // jack_48, 4000, 74000, jan_48, 48 Poe Ct + String nonUpdatedRows = nextInsert; + + // joe_32, 5000, 85000, jen_32, 32 Poe Ct + // ... + // joe_48, 5000, 85000, jen_48, 48 Poe Ct + String insertedRows = IntStream.range(targetCustomerCount, targetCustomerCount * 3 / 2) + .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jen_%s', '%s Poe Ct')", intValue, 5000, 85000, intValue, intValue)) + .collect(Collectors.joining(", ")); + + assertQuery( + format("SELECT customer, purchases, zipcode, spouse, address FROM %s", targetTable), + format("VALUES %s, %s, %s, %s", updatedFirstHalf, updatedSecondHalf, nonUpdatedRows, insertedRows)); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeMillionRows() + { + String tableName = "test_merge_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (orderkey BIGINT, custkey BIGINT, totalprice DOUBLE)", tableName)); + + // Initialize the merge target table with data: + // When "mod(orderkey, 3) = 0" -> copy rows, when "mod(orderkey, 3) = 1" -> double price, when "mod(orderkey, 3) = 2" -> rows with new orderkey + assertUpdate( + format("INSERT INTO %s " + + "SELECT orderkey, custkey, totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0 " + // rows copied + "UNION ALL " + + "SELECT orderkey, custkey, 2*totalprice as totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1 " + // rows with updated price + "UNION ALL " + + "SELECT orderkey + 100000002 as orderkey, custkey, totalprice as totalprice FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2", // rows with new orderkey + tableName), + (long) computeActual("SELECT count(*) FROM tpch.sf1.orders").getOnlyValue()); + + // verify copied rows: same total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 0", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0"); + + // verify rows will be updated: double total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 1", + "SELECT count(*), round(2*sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1"); + + // verify rows will be inserted: same total price and different orderkey. + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + + // MERGE INTO command to update the price of the existing orders and insert new orders, multiplying the original price by 3. + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (SELECT * FROM tpch.sf1.orders) s ", tableName) + + "ON (t.orderkey = s.orderkey) " + + "WHEN MATCHED THEN" + + " UPDATE SET totalprice = s.totalprice " + + "WHEN NOT MATCHED THEN" + + " INSERT (orderkey, custkey, totalprice) VALUES (s.orderkey, s.custkey, 3*s.totalprice)"; + + assertUpdate(sqlMergeCommand, 1_500_000); + + // verify unmodified rows: same total price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 0", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 0"); + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2 AND orderkey > 100000002", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + + // verify updated rows: same total price (these rows originally had double total price in the target table) + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 1", + "SELECT count(*), round(sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 1"); + + // verify inserted rows: triple original price + assertQueryWithSameQueryRunner( + "SELECT count(*), round(sum(totalprice)) FROM " + tableName + " WHERE mod(orderkey, 3) = 2 AND orderkey < 100000002", + "SELECT count(*), round(3*sum(totalprice)) FROM tpch.sf1.orders WHERE mod(orderkey, 3) = 2"); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + + @Test + public void testMergeQueryWithWeirdColumnsCapitalization() + { + String targetTable = "merge_weird_capitalization_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable.toUpperCase(ENGLISH)) + + "(VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')) AS s(customer, purchases, address) " + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purCHases = s.PurchaseS + t.pUrchases, aDDress = s.addrESs " + + "WHEN NOT MATCHED THEN" + + " INSERT (CUSTOMER, purchases, addRESS) VALUES(s.custoMer, s.Purchases, s.ADDress)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithMultipleConditions() + { + String targetTable = "merge_predicates_target_" + randomTableSuffix(); + String sourceTable = "merge_predicates_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Dave', 10, 'Devon'), (2, 'Dave', 20, 'Darbyshire')", targetTable), 2); + assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (3, 'Dave', 2, 'Madrid'), (4, 'Dave', 15, 'Barcelona')", sourceTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON t.customer = s.customer AND s.purchases < 6 " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address) " + + "WHEN NOT MATCHED THEN" + + " INSERT (id, customer, purchases, address) VALUES (s.id, s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, 'Dave', 12, 'Devon/Madrid'), (2, 'Dave', 22, 'Darbyshire/Madrid'), (4, 'Dave', 15, 'Barcelona')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeCasts() + { + String targetTable = "merge_cast_target_" + randomTableSuffix(); + String sourceTable = "merge_cast_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (col1 INT, col2 BIGINT, col3 REAL, col4 DOUBLE, col5 DOUBLE)", targetTable)); + assertUpdate(format("CREATE TABLE %s (col1 INT, col2 INT, col3 INT, col4 INT, col5 REAL)", sourceTable)); + + assertUpdate(format("INSERT INTO %s VALUES (1, 2, 3, 4, 5)", targetTable), 1); + assertUpdate(format("INSERT INTO %s VALUES (2, 3, 4, 5, 6)", sourceTable), 1); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.col1 + 1 = s.col1) " + // Note that the merge condition contains a sum. + "WHEN MATCHED THEN" + + " UPDATE SET col1 = s.col1, col2 = s.col2, col3 = s.col3, col4 = s.col4, col5 = s.col5"; + + assertUpdate(sqlMergeCommand, 1); + + assertQuery("SELECT * FROM " + targetTable, "VALUES (2, 3, 4.0, 5.0, 6.0)"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeSubqueries() + { + String targetTable = "merge_nation_target_" + randomTableSuffix(); + String sourceTable = "merge_nation_source_" + randomTableSuffix(); + + try { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('GERMANY', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE')", targetTable), 3); + assertUpdate(format("INSERT INTO %s VALUES ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE'), ('EGYPT', 'MIDDLE EAST'), ('RUSSIA', 'EUROPE')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.nation_name = s.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = (SELECT CONCAT(name, '_UPDATED') FROM tpch.tiny.region WHERE name = t.region_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT VALUES(s.nation_name, (SELECT CONCAT(name, '_INSERTED') FROM tpch.tiny.region WHERE name = s.region_name))"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('GERMANY', 'EUROPE'), " + + "('ALGERIA', 'AFRICA_UPDATED'), ('FRANCE', 'EUROPE_UPDATED'), " + + "('EGYPT', 'MIDDLE EAST_INSERTED'), ('RUSSIA', 'EUROPE_INSERTED')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] partitionedBucketedFailure() + { + return new Object[][] { + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['customer'])"}, + {"CREATE TABLE %s (customer VARCHAR, address VARCHAR, purchases INT) WITH (partitioning = ARRAY['address'])"}, + {"CREATE TABLE %s (purchases INT, customer VARCHAR, address VARCHAR) WITH (partitioning = ARRAY['customer', 'address'])"}, + {"CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) WITH (partitioning = ARRAY['bucket(customer, 3)'])"} + }; + } + + @Test(dataProvider = "partitionedBucketedFailure") + public void testMergeMultipleRowsMatchMustFails(String createTableSql) + { + String targetTable = "merge_multiple_rows_match_target_" + randomTableSuffix(); + String sourceTable = "merge_multiple_rows_match_source_" + randomTableSuffix(); + + try { + assertUpdate(format(createTableSql, targetTable)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Adelphi'), ('Aaron', 8, 'Ashland')", sourceTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET address = s.address"; + + assertQueryFails(sqlMergeCommand, ".*The MERGE INTO command requires each target row to match at most one source row.*"); + + assertUpdate(format("DELETE FROM %s WHERE purchases = 8", sourceTable), 1); + + assertUpdate(sqlMergeCommand, 1); + + assertQuery("SELECT customer, purchases, address FROM " + targetTable, + "VALUES ('Aaron', 5, 'Adelphi'), ('Bill', 7, 'Antioch')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + private void createNationRegionTable(String targetTable) + { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR NOT NULL)", targetTable)); + } + + @Test + public void testMergeNonNullableColumns() + { + String targetTable = "merge_non_nullable_target_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable), 3); + + List sqlMergeCommands = Arrays.asList( + // Command to check that updating using a null value fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name)\n" + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = NULL", + + // Command to check that inserting using a null value fails. + format("MERGE INTO %s t ", targetTable) + + " USING (VALUES ('ANGOLA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, region_name) VALUES (s.nation_name, NULL)", + + // Command to check that inserting using an implicit null value fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ANGOLA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name) VALUES ('CANADA')", + + // Command to check that if the updated value is provided by a function unpredictably computing null, the merge fails. + format("MERGE INTO %s t ", targetTable) + + "USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) " + + "ON (t.nation_name = s.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = CAST(TRY(5/0) AS VARCHAR)"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, "NULL value not allowed for NOT NULL column. Table: merge_non_nullable_target_.* Column: region_name"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @DataProvider + public Object[][] targetAndSourceWithDifferentPartitioning() + { + return new Object[][] { + { + "target_flat_source_flat", + "", + "" + }, + { + "target_partitioned_source_flat", + "WITH (partitioning = ARRAY['customer'])", + "" + }, + { + "target_bucketed_source_flat", + "WITH (partitioning = ARRAY['bucket(customer, 3)'])", + "" + }, + { + "target_partitioned_and_bucketed_source_flat", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "" + }, + { + "target_partitioned_and_bucketed_source_partitioned", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "WITH (partitioning = ARRAY['customer'])" + }, + { + "target_and_source_partitioned_and_bucketed", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])", + "WITH (partitioning = ARRAY['address', 'bucket(customer, 3)'])" + } + }; + } + + @Test(dataProvider = "targetAndSourceWithDifferentPartitioning") + public void testMergeWithDifferentPartitioning(String testDescription, String targetTablePartitioning, String sourceTablePartitioning) + { + String targetTable = format("%s_target_%s", testDescription, randomTableSuffix()); + String sourceTable = format("%s_source_%s", testDescription, randomTableSuffix()); + + try { + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", targetTable, targetTablePartitioning)); + assertUpdate(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR) %s", sourceTable, sourceTablePartitioning)); + + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); + assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.customer = s.customer) " + + "WHEN MATCHED THEN" + + " UPDATE SET purchases = s.purchases + t.purchases, address = s.address " + + "WHEN NOT MATCHED THEN" + + " INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Carol', 12, 'Centreville'), ('Dave', 22, 'Darbyshire'), ('Ed', 7, 'Etherville')"); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeAccessControl() + { + String catalogName = getSession().getCatalog().get(); + String schemaName = getSession().getSchema().get(); + + String targetTable = "merge_nation_target_" + randomTableSuffix(); + String targetName = format("%s.%s.%s", catalogName, schemaName, targetTable); + + String sourceTable = "merge_nation_source_" + randomTableSuffix(); + String sourceName = format("%s.%s.%s", catalogName, schemaName, sourceTable); + + try { + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable)); + assertUpdate(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable)); + + String baseMergeSql = format("MERGE INTO %s t USING %s s ", targetTable, sourceTable) + + "ON (t.nation_name = s.nation_name) "; + String updateCase = + "WHEN MATCHED THEN" + + " UPDATE SET nation_name = concat(s.nation_name, '_foo')"; + String insertCase = + "WHEN NOT MATCHED THEN" + + " INSERT VALUES(s.nation_name, (SELECT 'EUROPE'))"; + + ImmutableList mergeCases = ImmutableList.of(updateCase, insertCase); + for (String mergeCase : mergeCases) { + // Show that without SELECT privilege on the source table, the MERGE fails regardless of which case is included + assertAccessDenied(baseMergeSql + mergeCase, "Cannot select from columns .* in table or view " + sourceName, privilege(sourceTable, SELECT_COLUMN)); + + // Show that without SELECT privilege on the target table, the MERGE fails regardless of which case is included + assertAccessDenied(baseMergeSql + mergeCase, "Cannot select from columns .* in table or view " + targetName, privilege(targetTable, SELECT_COLUMN)); + } + + // Show that without INSERT privilege on the target table, the MERGE fails + assertAccessDenied(baseMergeSql + insertCase, "Cannot insert into table " + targetName, privilege(targetTable, INSERT_TABLE)); + + // Show that without UPDATE privilege on the target table, the MERGE fails + assertAccessDenied(baseMergeSql + updateCase, "Cannot update columns \\[\\[nation_name\\]\\] in table " + targetName, privilege(targetTable, UPDATE_TABLE)); + } + finally { + assertUpdate("DROP TABLE " + sourceTable); + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testInvalidMergePredicate() + { + String targetTable = "merge_invalid_predicate_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name"; + + assertQueryFails(sqlMergeCommand, ".*The MERGE predicate must evaluate to a boolean: actual type varchar"); + + sqlMergeCommand = + format("MERGE INTO %s t USING (VALUES (1, 'ALGERIA', 'AFRICA')) s(nation_id, nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_id) " + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name"; + + assertQueryFails(sqlMergeCommand, ".*'=' cannot be applied to varchar, integer"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeUnknownColumnName() + { + String targetTable = "merge_unknown_column_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Unknown column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET unknown_column = s.region_name", + + // Unknown column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, unknown_column) VALUES(s.nation_name, (SELECT 'EUROPE'))"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, ".*Merge column name does not exist in target table: unknown_column"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeDuplicateColumnName() + { + String targetTable = "merge_duplicate_column_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Duplicate column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = s.region_name, region_name = 'AFRICA'", + + // Duplicate column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (nation_name, region_name, region_name) VALUES(s.nation_name, (SELECT 'EUROPE'), 'AFRICA')"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, ".*Merge column name is specified more than once: region_name"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeMismatchedColumnDataTypes() + { + String targetTable = "merge_mismatched_column_data_types_" + randomTableSuffix(); + + try { + createNationRegionTable(targetTable); + + String baseMergeSql = format("MERGE INTO %s t USING (VALUES ('ALGERIA', 'AFRICA')) s(nation_name, region_name) ", targetTable) + + "ON (t.nation_name = s.nation_name) "; + + List sqlMergeCommands = Arrays.asList( + // Mismatched column in the UPDATE statement. + baseMergeSql + + "WHEN MATCHED THEN" + + " UPDATE SET region_name = 1", + + // Mismatched column in the INSERT statement. + baseMergeSql + + "WHEN NOT MATCHED THEN" + + " INSERT (region_name) VALUES(1)"); + + for (@Language("SQL") String sqlMergeCommand : sqlMergeCommands) { + assertQueryFails(sqlMergeCommand, + ".*MERGE table column types don't match for MERGE case 0, SET expressions: Table: \\[varchar\\], Expressions: \\[integer\\]"); + } + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithPartitionSpecEvolutionAddPartitionedField() + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (a int, b varchar)", targetTable)); + + assertUpdate(format("INSERT INTO %s VALUES (1, '1001'), (2, '1002')", targetTable), 2); + assertUpdate(format("INSERT INTO %s VALUES (3, '1003'), (4, '1004')", targetTable), 2); + + // Add a partition field to the target iceberg table. + assertUpdate(format("ALTER TABLE %s ADD COLUMN c int WITH(partitioning = 'identity')", targetTable)); + + assertUpdate(format("INSERT INTO %s VALUES (5, '1005', 5), (6, '1006', 6)", targetTable), 2); + + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES (1, 11), (3, 33), (5, 55), (7, 77)) AS s(a, c) " + + "ON (t.a = s.a) " + + "WHEN MATCHED THEN" + + " UPDATE SET c = s.c " + + "WHEN NOT MATCHED THEN" + + " INSERT (a, b, c) VALUES(s.a, 'NEW_LINE', s.c)"; + + assertUpdate(sqlMergeCommand, 4); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, '1001', 11), (2, '1002', NULL), (3, '1003', 33), (4, '1004', NULL), (5, '1005', 55), (6, '1006', 6), (7, 'NEW_LINE', 77)"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + + @Test + public void testMergeWithPartitionSpecEvolutionRemovePartitionedField() + { + String targetTable = "merge_query_" + randomTableSuffix(); + try { + assertUpdate(format("CREATE TABLE %s (a int, b varchar, c int) with(partitioning = ARRAY['a', 'c'])", targetTable)); + assertUpdate(format("INSERT INTO %s VALUES (1, '1001', 11), (2, '1002', 12)", targetTable), 2); + + // Remove a partitioned field from the target iceberg table. + Table icebergTable = loadTable(targetTable); + String partitionFieldName = icebergTable.spec().fields().get(0).name(); + icebergTable.updateSpec().removeField(partitionFieldName).commit(); + + assertUpdate(format("INSERT INTO %s VALUES (3, '1003', 13), (4, '1004', 14)", targetTable), 2); + @Language("SQL") String sqlMergeCommand = + format("MERGE INTO %s t USING ", targetTable) + + "(VALUES (1, 111), (3, 333), (5, 555)) AS s(a, c) " + + "ON (t.a = s.a) " + + "WHEN MATCHED THEN" + + " UPDATE SET c = s.c " + + "WHEN NOT MATCHED THEN" + + " INSERT (a, b, c) VALUES(s.a, 'NEW_LINE', s.c)"; + + assertUpdate(sqlMergeCommand, 3); + + assertQuery("SELECT * FROM " + targetTable, + "VALUES (1, '1001', 111), (2, '1002', 12), (3, '1003', 333), (4, '1004', 14), (5, 'NEW_LINE', 555)"); + } + finally { + assertUpdate("DROP TABLE " + targetTable); + } + } + private void testCheckDeleteFiles(Table icebergTable, int expectedSize, List expectedFileContent) { // check delete file list @@ -2453,7 +4175,7 @@ private void writePositionDeleteToNationTable(Table icebergTable, String dataFil FileSystem fs = getHdfsEnvironment().getFileSystem(new HdfsContext(SESSION), metadataDir); Path path = new Path(metadataDir, deleteFileName); PositionDeleteWriter writer = Parquet.writeDeletes(HadoopOutputFile.fromPath(path, fs)) - .createWriterFunc(GenericParquetWriter::buildWriter) + .createWriterFunc(GenericParquetWriter::create) .forTable(icebergTable) .overwrite() .rowSchema(icebergTable.schema()) @@ -2487,7 +4209,7 @@ private void writeEqualityDeleteToNationTable(Table icebergTable, Map test) { - test.accept(session, FileFormat.PARQUET); - test.accept(session, FileFormat.ORC); + test.accept(session, PARQUET); + test.accept(session, ORC); } - private void assertHasDataFiles(Snapshot snapshot, int dataFilesCount) + protected void assertHasDataFiles(Snapshot snapshot, int dataFilesCount) { Map map = snapshot.summary(); int totalDataFiles = Integer.valueOf(map.get(TOTAL_DATA_FILES_PROP)); assertEquals(totalDataFiles, dataFilesCount); } - private void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) + protected void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) { Map map = snapshot.summary(); int totalDeleteFiles = Integer.valueOf(map.get(TOTAL_DELETE_FILES_PROP)); @@ -2787,6 +4510,36 @@ public void testStatisticsFileCacheInvalidationProcedure() getQueryRunner().execute("DROP TABLE test_statistics_file_cache_procedure"); } + @DataProvider(name = "testFormatAndCompressionCodecs") + public Object[][] compressionCodecs() + { + return Stream.of(PARQUET, ORC) + .flatMap(format -> Arrays.stream(HiveCompressionCodec.values()) + .map(codec -> new Object[] {codec, format})) + .toArray(Object[][]::new); + } + + @Test(dataProvider = "testFormatAndCompressionCodecs") + public void testFormatAndCompressionCodecs(HiveCompressionCodec codec, FileFormat format) + { + String tableName = "test_" + format.name().toLowerCase(ROOT) + "_compression_codec_" + codec.name().toLowerCase(ROOT); + Session session = Session.builder(getSession()) + .setCatalogSessionProperty("iceberg", COMPRESSION_CODEC, codec.name()).build(); + if (codec.isSupportedStorageFormat(format == PARQUET ? HiveStorageFormat.PARQUET : HiveStorageFormat.ORC)) { + String codecName = format == PARQUET ? codec.getParquetCompressionCodec().name() : codec.getOrcCompressionKind().name(); + assertQuerySucceeds(session, format("CREATE TABLE %s WITH (\"write.format.default\" = '%s') as select * from lineitem with no data", tableName, format.name())); + assertQuery(session, format("SELECT value FROM \"%s$properties\" WHERE key = 'write.%s.compression-codec'", tableName, format.name().toLowerCase(ROOT)), format("VALUES '%s'", codecName)); + assertQuery(session, format("SELECT value FROM \"%s$properties\" WHERE key = 'write.format.default'", tableName), format("VALUES '%s'", format.name())); + assertUpdate(session, format("INSERT INTO %s SELECT * from lineitem", tableName), "select count(*) from lineitem"); + assertQuery(session, format("SELECT * FROM %s", tableName), "select * from lineitem"); + assertQuerySucceeds(format("DROP TABLE %s", tableName)); + } + else { + assertQueryFails(session, format("CREATE TABLE %s WITH (\"write.format.default\" = '%s') as select * from lineitem with no data", tableName, format.name()), + format("Compression codec %s is not supported for .*", codec)); + } + } + @DataProvider(name = "sortedTableWithSortTransform") public static Object[][] sortedTableWithSortTransform() { @@ -2806,4 +4559,225 @@ protected void dropTable(Session session, String table) assertUpdate(session, "DROP TABLE " + table); assertFalse(getQueryRunner().tableExists(session, table)); } + + @Test + public void testEqualityDeleteAsJoinWithMaximumFieldsLimitUnderLimit() + throws Exception + { + int maxColumns = 10; + String tableName = "test_eq_delete_under_max_cols_" + randomTableSuffix(); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty(ICEBERG_CATALOG, DELETE_AS_JOIN_REWRITE_ENABLED, "true") + // Make sure the max columns is set to one more than the number of columns in the table + .setCatalogSessionProperty(ICEBERG_CATALOG, DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS, "" + (maxColumns + 1)) + .build(); + + try { + // Test with exactly max columns - should work fine + // Create table with specified number of columns + List columnDefinitions = IntStream.range(0, maxColumns) + .mapToObj(i -> "col_" + i + " varchar") + .collect(Collectors.toList()); + columnDefinitions.add(0, "id bigint"); + + String createTableSql = "CREATE TABLE " + tableName + " (" + + String.join(", ", columnDefinitions) + ")"; + assertUpdate(session, createTableSql); + + // Insert test rows + for (int row = 1; row <= 3; row++) { + final int currentRow = row; + List values = IntStream.range(0, maxColumns) + .mapToObj(i -> "'val_" + currentRow + "_" + i + "'") + .collect(Collectors.toList()); + values.add(0, String.valueOf(currentRow)); + + String insertSql = "INSERT INTO " + tableName + " VALUES (" + + String.join(", ", values) + ")"; + assertUpdate(session, insertSql, 1); + } + + // Verify all rows exist + assertQuery(session, "SELECT count(*) FROM " + tableName, "VALUES (3)"); + + // Update table to format version 2 and create equality delete files + Table icebergTable = updateTable(tableName); + + // Create equality delete using ALL columns + Map deleteRow = new HashMap<>(); + deleteRow.put("id", 2L); + for (int i = 0; i < maxColumns; i++) { + deleteRow.put("col_" + i, "val_2_" + i); + } + + // Write equality delete with ALL columns + writeEqualityDeleteToNationTable(icebergTable, deleteRow); + + // Query should work correctly regardless of optimization + assertQuery(session, "SELECT count(*) FROM " + tableName, "VALUES (2)"); + assertQuery(session, "SELECT id FROM " + tableName + " ORDER BY id", "VALUES (1), (3)"); + + // With <= max columns, query plan should use JOIN (optimization enabled) + assertPlan(session, "SELECT * FROM " + tableName, + anyTree( + node(JoinNode.class, + anyTree(tableScan(tableName)), + anyTree(tableScan(tableName))))); + } + finally { + dropTable(session, tableName); + } + } + + @Test + public void testEqualityDeleteAsJoinWithMaximumFieldsLimitOverLimit() + throws Exception + { + int maxColumns = 10; + String tableName = "test_eq_delete_max_cols_" + randomTableSuffix(); + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setCatalogSessionProperty(ICEBERG_CATALOG, DELETE_AS_JOIN_REWRITE_ENABLED, "true") + .setCatalogSessionProperty(ICEBERG_CATALOG, DELETE_AS_JOIN_REWRITE_MAX_DELETE_COLUMNS, "" + maxColumns) + .build(); + + try { + // Test with max columns - optimization should be disabled to prevent stack overflow + // Create table with specified number of columns + List columnDefinitions = IntStream.range(0, maxColumns) + .mapToObj(i -> "col_" + i + " varchar") + .collect(Collectors.toList()); + columnDefinitions.add(0, "id bigint"); + + String createTableSql = "CREATE TABLE " + tableName + " (" + + String.join(", ", columnDefinitions) + ")"; + assertUpdate(session, createTableSql); + + // Insert test rows + for (int row = 1; row <= 3; row++) { + final int currentRow = row; + List values = IntStream.range(0, maxColumns) + .mapToObj(i -> "'val_" + currentRow + "_" + i + "'") + .collect(Collectors.toList()); + values.add(0, String.valueOf(currentRow)); + + String insertSql = "INSERT INTO " + tableName + " VALUES (" + + String.join(", ", values) + ")"; + assertUpdate(session, insertSql, 1); + } + + // Verify all rows exist + assertQuery(session, "SELECT count(*) FROM " + tableName, "VALUES (3)"); + + // Update table to format version 2 and create equality delete files + Table icebergTable = updateTable(tableName); + + // Create equality delete using ALL columns + Map deleteRow = new HashMap<>(); + deleteRow.put("id", 2L); + for (int i = 0; i < maxColumns; i++) { + deleteRow.put("col_" + i, "val_2_" + i); + } + + // Write equality delete with ALL columns + writeEqualityDeleteToNationTable(icebergTable, deleteRow); + + // Query should work correctly regardless of optimization + assertQuery(session, "SELECT count(*) FROM " + tableName, "VALUES (2)"); + assertQuery(session, "SELECT id FROM " + tableName + " ORDER BY id", "VALUES (1), (3)"); + + // With > max columns, optimization is disabled - no JOIN in plan + // Verify the query works but doesn't contain a join node + assertQuery(session, "SELECT * FROM " + tableName + " WHERE id = 1", + "VALUES (" + Stream.concat(Stream.of("1"), + IntStream.range(0, maxColumns).mapToObj(i -> "'val_1_" + i + "'")) + .collect(Collectors.joining(", ")) + ")"); + + // To verify no join is present, we can check that the plan only contains table scan + assertPlan(session, "SELECT * FROM " + tableName, + anyTree( + anyNot(JoinNode.class, + tableScan(tableName)))); + } + finally { + dropTable(session, tableName); + } + } + + @Test + public void testTableWithNullColumnStats() + { + String tableName1 = "test_null_stats1"; + String tableName2 = "test_null_stats2"; + try { + assertUpdate(String.format("CREATE TABLE %s (id int, name varchar) WITH (\"write.format.default\" = 'PARQUET')", tableName1)); + assertUpdate(String.format("INSERT INTO %s VALUES(1, '1001'), (2, '1002'), (3, '1003')", tableName1), 3); + Table icebergTable1 = loadTable(tableName1); + String dataFilePath = (String) computeActual(String.format("SELECT file_path FROM \"%s$files\" LIMIT 1", tableName1)).getOnlyValue(); + + assertUpdate(String.format("CREATE TABLE %s (id int, name varchar) WITH (\"write.format.default\" = 'PARQUET')", tableName2)); + Table icebergTable2 = loadTable(tableName2); + Metrics newMetrics = new Metrics(3L, null, null, null, null); + DataFile dataFile = DataFiles.builder(icebergTable1.spec()) + .withPath(dataFilePath) + .withFormat("PARQUET") + .withFileSizeInBytes(1234L) + .withMetrics(newMetrics) + .build(); + icebergTable2.newAppend().appendFile(dataFile).commit(); + + TableStatistics stats = getTableStats(tableName2); + assertEquals(stats.getRowCount(), Estimate.of(3.0)); + + // Assert that column statistics are present (even if they don't have detailed metrics) + assertFalse(stats.getColumnStatistics().isEmpty()); + + for (Map.Entry entry : stats.getColumnStatistics().entrySet()) { + ColumnStatistics columnStats = entry.getValue(); + assertNotNull(columnStats); + } + + assertQuery(String.format("SELECT t1.id, t2.name FROM %s t1 INNER JOIN %s t2 ON t1.id = t2.id ORDER BY t1.id", tableName1, tableName2), + "VALUES(1, '1001'), (2, '1002'), (3, '1003')"); + } + finally { + assertUpdate(String.format("DROP TABLE IF EXISTS %s", tableName2)); + assertUpdate(String.format("DROP TABLE IF EXISTS %s", tableName1)); + } + } + + @Test + public void testTimeColumnPhysicalType() + { + String tableName = "test_time_type"; + + try { + assertUpdate("CREATE TABLE " + tableName + " (id BIGINT, time TIME, name VARCHAR)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, TIME '12:34:56', 'test')", 1); + MaterializedResult result = computeActual("SELECT * FROM " + tableName); + List types = result.getTypes(); + assertEquals(types.size(), 3); + assertTrue(types.get(1) instanceof TimeType, "Expected TIME type but got " + types.get(1)); + + Table icebergTable = loadTable(tableName); + Schema schema = icebergTable.schema(); + Types.NestedField timeField = schema.findField("time"); + + assertEquals(timeField.type().typeId(), TIME, + "Iceberg schema should have TIME type, not STRING type"); + + List hiveColumns = IcebergUtil.toHiveColumns(schema.columns()); + Column timeColumn = hiveColumns.stream() + .filter(col -> col.getName().equals("time")) + .findFirst() + .orElseThrow(() -> new AssertionError("time not found in Hive columns")); + + assertEquals(timeColumn.getType(), HiveType.HIVE_LONG, + "TIME column should be converted to HIVE_LONG"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java index 217522ea721a6..c27ed00acbc81 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergQueryRunner.java @@ -255,8 +255,8 @@ public IcebergQueryRunner build() } } else { - queryRunner.execute("CREATE SCHEMA tpch"); - queryRunner.execute("CREATE SCHEMA tpcds"); + queryRunner.execute("CREATE SCHEMA IF NOT EXISTS tpch"); + queryRunner.execute("CREATE SCHEMA IF NOT EXISTS tpcds"); } if (createTpchTables) { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java index 43f06897ad3c6..bf25891c38d57 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConfig.java @@ -22,16 +22,16 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static com.facebook.presto.hive.HiveCompressionCodec.GZIP; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctDataSize; import static com.facebook.presto.hive.HiveCompressionCodec.NONE; +import static com.facebook.presto.hive.HiveCompressionCodec.ZSTD; import static com.facebook.presto.iceberg.CatalogType.HADOOP; import static com.facebook.presto.iceberg.CatalogType.HIVE; import static com.facebook.presto.iceberg.IcebergFileFormat.ORC; import static com.facebook.presto.iceberg.IcebergFileFormat.PARQUET; import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES; import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctDataSize; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_EXPIRATION_INTERVAL_MS_DEFAULT; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_MAX_CONTENT_LENGTH_DEFAULT; import static org.apache.iceberg.CatalogProperties.IO_MANIFEST_CACHE_MAX_TOTAL_BYTES_DEFAULT; @@ -46,7 +46,7 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(IcebergConfig.class) .setFileFormat(PARQUET) - .setCompressionCodec(GZIP) + .setCompressionCodec(ZSTD) .setCatalogType(HIVE) .setCatalogWarehouse(null) .setCatalogWarehouseDataDir(null) @@ -60,6 +60,7 @@ public void testDefaults() .setMergeOnReadModeEnabled(true) .setPushdownFilterEnabled(false) .setDeleteAsJoinRewriteEnabled(true) + .setDeleteAsJoinRewriteMaxDeleteColumns(400) .setRowsForMetadataOptimizationThreshold(1000) .setManifestCachingEnabled(true) .setFileIOImpl(HadoopFileIO.class.getName()) @@ -72,7 +73,8 @@ public void testDefaults() .setMetricsMaxInferredColumn(METRICS_MAX_INFERRED_COLUMN_DEFAULTS_DEFAULT) .setManifestCacheMaxChunkSize(succinctDataSize(2, MEGABYTE)) .setMaxStatisticsFileCacheSize(succinctDataSize(256, MEGABYTE)) - .setStatisticsKllSketchKParameter(1024)); + .setStatisticsKllSketchKParameter(1024) + .setMaterializedViewStoragePrefix("__mv_storage__")); } @Test @@ -93,7 +95,8 @@ public void testExplicitPropertyMappings() .put("iceberg.statistic-snapshot-record-difference-weight", "1.0") .put("iceberg.hive-statistics-merge-strategy", NUMBER_OF_DISTINCT_VALUES.name() + "," + TOTAL_SIZE_IN_BYTES.name()) .put("iceberg.pushdown-filter-enabled", "true") - .put("iceberg.delete-as-join-rewrite-enabled", "false") + .put("deprecated.iceberg.delete-as-join-rewrite-enabled", "false") + .put("iceberg.delete-as-join-rewrite-max-delete-columns", "1") .put("iceberg.rows-for-metadata-optimization-threshold", "500") .put("iceberg.io.manifest.cache-enabled", "false") .put("iceberg.io-impl", "com.facebook.presto.iceberg.HdfsFileIO") @@ -107,6 +110,7 @@ public void testExplicitPropertyMappings() .put("iceberg.metrics-max-inferred-column", "16") .put("iceberg.max-statistics-file-cache-size", "512MB") .put("iceberg.statistics-kll-sketch-k-parameter", "4096") + .put("iceberg.materialized-view-storage-prefix", "custom_mv_prefix") .build(); IcebergConfig expected = new IcebergConfig() @@ -125,6 +129,7 @@ public void testExplicitPropertyMappings() .setHiveStatisticsMergeFlags("NUMBER_OF_DISTINCT_VALUES,TOTAL_SIZE_IN_BYTES") .setPushdownFilterEnabled(true) .setDeleteAsJoinRewriteEnabled(false) + .setDeleteAsJoinRewriteMaxDeleteColumns(1) .setRowsForMetadataOptimizationThreshold(500) .setManifestCachingEnabled(false) .setFileIOImpl("com.facebook.presto.iceberg.HdfsFileIO") @@ -137,7 +142,8 @@ public void testExplicitPropertyMappings() .setMetadataDeleteAfterCommit(true) .setMetricsMaxInferredColumn(16) .setMaxStatisticsFileCacheSize(succinctDataSize(512, MEGABYTE)) - .setStatisticsKllSketchKParameter(4096); + .setStatisticsKllSketchKParameter(4096) + .setMaterializedViewStoragePrefix("custom_mv_prefix"); assertFullMapping(properties, expected); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConnectorFactory.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConnectorFactory.java index c9572b3499e34..e043b8279cd4a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConnectorFactory.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergConnectorFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.iceberg; +import com.facebook.presto.hive.metastore.AbstractCachingHiveMetastore.MetastoreCacheScope; import com.facebook.presto.spi.connector.ConnectorFactory; import com.facebook.presto.testing.TestingConnectorContext; import com.google.common.collect.ImmutableMap; @@ -30,11 +31,53 @@ public void testCachingHiveMetastore() { Map config = ImmutableMap.builder() .put("hive.metastore.uri", "thrift://localhost:9083") - .put("hive.metastore-cache-ttl", "10m") + .put("hive.metastore.cache.ttl.default", "10m") .buildOrThrow(); assertThatThrownBy(() -> createConnector(config)) - .hasMessageContaining("In-memory hive metastore caching must not be enabled for Iceberg"); + .hasMessageContaining("In-memory hive metastore caching for tables must not be enabled for Iceberg"); + } + + @Test + public void testMetastoreCachingDisallowedWhenTableCacheEnabledViaEnabledCachesAll() + { + Map config = ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://localhost:9083") + .put("hive.metastore.cache.ttl.default", "10m") + // Enabling all caches implicitly enables table cache + .put("hive.metastore.cache.enabled-caches", "ALL") + .buildOrThrow(); + + assertThatThrownBy(() -> createConnector(config)) + .hasMessageContaining("In-memory hive metastore caching for tables must not be enabled for Iceberg"); + } + + @Test + public void testMetastoreCachingDisallowedWhenTableCacheExplicitlyEnabledViaEnabledCachesTable() + { + Map config = ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://localhost:9083") + .put("hive.metastore.cache.ttl.default", "10m") + // Explicitly enable table cache + .put("hive.metastore.cache.enabled-caches", "TABLE") + .buildOrThrow(); + + assertThatThrownBy(() -> createConnector(config)) + .hasMessageContaining("In-memory hive metastore caching for tables must not be enabled for Iceberg"); + } + + @Test + public void testLegacyMetastoreCacheScopeAllWithNonZeroTtlDisallowed() + { + Map config = ImmutableMap.builder() + .put("hive.metastore.uri", "thrift://localhost:9083") + // Non-zero default TTL combined with ALL scope should be disallowed + .put("hive.metastore.cache.ttl.default", "10m") + .put("hive.metastore.cache.scope", MetastoreCacheScope.ALL.name()) + .buildOrThrow(); + + assertThatThrownBy(() -> createConnector(config)) + .hasMessageContaining("In-memory hive metastore caching for tables must not be enabled for Iceberg"); } private static void createConnector(Map config) diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergDistributedQueries.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergDistributedQueries.java index 4e14bc5be829c..57dd62d142e3a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergDistributedQueries.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergDistributedQueries.java @@ -115,6 +115,18 @@ public void testDescribeOutputNamedAndUnnamed() assertEqualsIgnoreOrder(actual, expected); } + @Override + public void testNonAutoCommitTransactionWithRollback() + { + // Catalog iceberg only supports writes using autocommit + } + + @Override + public void testNonAutoCommitTransactionWithCommit() + { + // Catalog iceberg only supports writes using autocommit + } + /** * Increased the optimizer timeout from 15000ms to 25000ms */ @@ -184,4 +196,31 @@ public void testStringFilters() assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'AIR '", "VALUES (0)"); assertQuery("SELECT count(*) FROM test_varcharn_filter WHERE shipmode = 'NONEXIST'", "VALUES (0)"); } + + @Test + public void testRenameView() + { + skipTestUnless(supportsViews()); + assertQuerySucceeds("CREATE TABLE iceberg_test_table (_string VARCHAR, _integer INTEGER)"); + assertUpdate("CREATE VIEW test_view_to_be_renamed AS SELECT * FROM iceberg_test_table"); + assertUpdate("ALTER VIEW IF EXISTS test_view_to_be_renamed RENAME TO test_view_renamed"); + assertUpdate("CREATE VIEW test_view2_to_be_renamed AS SELECT * FROM iceberg_test_table"); + assertUpdate("ALTER VIEW test_view2_to_be_renamed RENAME TO test_view2_renamed"); + assertQuerySucceeds("SELECT * FROM test_view_renamed"); + assertQuerySucceeds("SELECT * FROM test_view2_renamed"); + assertUpdate("DROP VIEW test_view_renamed"); + assertUpdate("DROP VIEW test_view2_renamed"); + assertUpdate("DROP TABLE iceberg_test_table"); + } + + @Test + public void testRenameViewIfNotExists() + { + String catalog = getSession().getCatalog().get(); + String schema = getSession().getSchema().get(); + skipTestUnless(supportsViews()); + assertQueryFails("ALTER VIEW test_rename_view_not_exist RENAME TO test_renamed_view_not_exist", + format("line 1:1: View '%s.%s.test_rename_view_not_exist' does not exist", catalog, schema)); + assertQuerySucceeds("ALTER VIEW IF EXISTS test_rename_view_not_exist RENAME TO test_renamed_view_not_exist"); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileBasedSecurity.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileBasedSecurity.java new file mode 100644 index 0000000000000..009a866f647fc --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileBasedSecurity.java @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.sql.query.QueryAssertions; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.nio.file.Path; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.iceberg.CatalogType.HIVE; +import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static com.facebook.presto.iceberg.IcebergQueryRunner.getIcebergDataDirectoryPath; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.ThrowingRunnable; + +@Test(singleThreaded = true) +public class TestIcebergFileBasedSecurity + extends AbstractTestQueryFramework +{ + private QueryAssertions assertions; + private TestingAccessControlManager accessControl; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog(ICEBERG_CATALOG) + .build(); + String path = this.getClass().getResource("security.json").getPath(); + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) + .build(); + + Path dataDirectory = queryRunner.getCoordinator().getDataDirectory(); + Path catalogDirectory = getIcebergDataDirectoryPath(dataDirectory, HIVE.name(), new IcebergConfig().getFileFormat(), false); + + queryRunner.installPlugin(new IcebergPlugin()); + Map icebergProperties = ImmutableMap.builder() + .put("hive.metastore", "file") + .put("hive.metastore.catalog.dir", catalogDirectory.toFile().toURI().toString()) + .put("iceberg.security", "file") + .put("security.config-file", path) + .build(); + + queryRunner.createCatalog(ICEBERG_CATALOG, "iceberg", icebergProperties); + + assertions = new QueryAssertions(queryRunner); + accessControl = assertions.getQueryRunner().getAccessControl(); + + return queryRunner; + } + + @Test + public void testCallNormalProcedures() + { + Session icebergAdmin = Session.builder(getSession("iceberg")) + .build(); + Session alice = Session.builder(getSession("alice")) + .build(); + Session bob = Session.builder(getSession("bob")) + .build(); + Session joe = Session.builder(getSession("joe")) + .build(); + + // `icebergAdmin`, `alice`, and `bob` have the permission to execute `iceberg.system.invalidate_statistics_file_cache` + assertUpdate(icebergAdmin, "call system.invalidate_statistics_file_cache()"); + assertUpdate(alice, "call system.invalidate_statistics_file_cache()"); + assertUpdate(bob, "call system.invalidate_statistics_file_cache()"); + // `joe` do not have the permission to execute `iceberg.system.invalidate_statistics_file_cache` + assertDenied(() -> assertUpdate(joe, "call iceberg.system.invalidate_statistics_file_cache()"), + "Access Denied: Cannot call procedure system.invalidate_statistics_file_cache"); + } + + @Test + public void testCallDistributedProceduresWithInsertDeletePermission() + { + Session icebergAdmin = Session.builder(getSession("iceberg")) + .build(); + Session alice = Session.builder(getSession("alice")) + .build(); + Session bob = Session.builder(getSession("bob")) + .build(); + Session joe = Session.builder(getSession("joe")) + .build(); + + String schema = getSession().getSchema().get(); + String tableName = "test_rewrite_table"; + + try { + assertUpdate(icebergAdmin, "create schema if not exists " + schema); + assertUpdate(icebergAdmin, "create table " + tableName + " (a int, b varchar)"); + assertUpdate(icebergAdmin, "insert into " + tableName + " values(1, '1001')", 1); + assertUpdate(icebergAdmin, "insert into " + tableName + " values(2, '1002')", 1); + + // `icebergAdmin` has permission to execute `iceberg.system.rewrite_data_files` and + // perform INSERT/DELETE operations on the target table involved in the procedure + assertUpdate(icebergAdmin, format("call system.rewrite_data_files('%s', '%s')", schema, tableName), 2); + // `alice` and `bob` have the permission to execute `iceberg.system.rewrite_data_files`, + // but they lack the necessary permission to perform INSERT or DELETE on the target table + assertDenied(() -> assertUpdate(alice, format("call system.rewrite_data_files('%s', '%s')", schema, tableName)), + format("Access Denied: Cannot delete from table %s.%s", schema, tableName)); + assertDenied(() -> assertUpdate(bob, format("call system.rewrite_data_files('%s', '%s')", schema, tableName)), + format("Access Denied: Cannot insert into table %s.%s", schema, tableName)); + // `joe` do not have the permission to execute `iceberg.system.rewrite_data_files` + assertDenied(() -> assertUpdate(joe, format("call system.rewrite_data_files('%s', '%s')", schema, tableName)), + "Access Denied: Cannot call procedure system.rewrite_data_files"); + } + finally { + assertUpdate(icebergAdmin, "drop table if exists " + tableName); + assertUpdate(icebergAdmin, "drop schema if exists " + schema); + } + } + + @Test + public void testCallDistributedProceduresWithRowFiltersAndColumnMasks() + { + Session icebergAdmin = Session.builder(getSession("iceberg")) + .build(); + + String schema = getSession().getSchema().get(); + String tableName = "test_rewrite_table"; + + try { + assertUpdate(icebergAdmin, "create schema if not exists " + schema); + assertUpdate(icebergAdmin, "create table " + tableName + " (a int, b varchar)"); + assertUpdate(icebergAdmin, "insert into " + tableName + " values(1, '1001')", 1); + assertUpdate(icebergAdmin, "insert into " + tableName + " values(2, '1002')", 1); + + // `icebergAdmin` has permission to execute `iceberg.system.rewrite_data_files` and + // perform INSERT/DELETE operations on the target table involved in the procedure + assertUpdate(icebergAdmin, format("call system.rewrite_data_files('%s', '%s')", schema, tableName), 2); + + QualifiedObjectName qualifiedTableName = new QualifiedObjectName("iceberg", schema, tableName); + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.rowFilter( + qualifiedTableName, + "iceberg", + new ViewExpression("iceberg", Optional.empty(), Optional.empty(), "a < 2")); + assertions.assertQuery(icebergAdmin, "SELECT count(*) FROM " + tableName, "VALUES BIGINT '1'"); + assertions.assertFails(icebergAdmin, format("call system.rewrite_data_files('%s', '%s')", schema, tableName), + "Access Denied: Full data access is restricted by row filters and column masks for table: " + qualifiedTableName); + }); + + assertions.executeExclusively(() -> { + accessControl.reset(); + accessControl.columnMask(qualifiedTableName, "b", "iceberg", + new ViewExpression("iceberg", Optional.empty(), Optional.empty(), "'noop'")); + assertions.assertFails(icebergAdmin, format("call system.rewrite_data_files('%s', '%s')", schema, tableName), + "Access Denied: Full data access is restricted by row filters and column masks for table: " + qualifiedTableName); + }); + } + finally { + assertUpdate(icebergAdmin, "drop table if exists " + tableName); + assertUpdate(icebergAdmin, "drop schema if exists " + schema); + } + } + + private Session getSession(String user) + { + return testSessionBuilder() + .setCatalog(getSession().getCatalog().get()) + .setSchema(getSession().getSchema().get()) + .setIdentity(new Identity(user, Optional.empty())).build(); + } + + private static void assertDenied(ThrowingRunnable runnable, String message) + { + assertThatThrownBy(runnable::run) + .isInstanceOf(RuntimeException.class) + .hasMessageMatching(message); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileWriter.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileWriter.java index b87e793547de7..31f0729907da2 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileWriter.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergFileWriter.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BooleanType; @@ -40,7 +41,6 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.session.PropertyMetadata; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; import org.apache.iceberg.MetricsConfig; @@ -189,5 +189,11 @@ public List getTypes() { return ImmutableList.of(BooleanType.BOOLEAN, INTEGER, BIGINT, DoubleType.DOUBLE, VARCHAR, VARBINARY, TIMESTAMP, DATE, HYPER_LOG_LOG); } + + @Override + public boolean hasType(TypeSignature signature) + { + return getType(signature) != null; + } } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java index ed274b55b7bf5..0c30c2b91105a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergLogicalPlanner.java @@ -15,6 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.Session.SessionBuilder; +import com.facebook.presto.common.RuntimeMetric; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.predicate.TupleDomain; @@ -86,9 +88,11 @@ import static com.facebook.presto.hive.MetadataUtils.isEntireColumn; import static com.facebook.presto.iceberg.IcebergColumnHandle.getSynthesizedIcebergColumnHandle; import static com.facebook.presto.iceberg.IcebergColumnHandle.isPushedDownSubfield; +import static com.facebook.presto.iceberg.IcebergPartitionLoader.LAZY_LOADING_COUNT_KEY_TEMPLATE; import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; import static com.facebook.presto.iceberg.IcebergSessionProperties.PARQUET_DEREFERENCE_PUSHDOWN_ENABLED; import static com.facebook.presto.iceberg.IcebergSessionProperties.PUSHDOWN_FILTER_ENABLED; +import static com.facebook.presto.iceberg.IcebergSessionProperties.ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD; import static com.facebook.presto.iceberg.IcebergSessionProperties.isPushdownFilterEnabled; import static com.facebook.presto.parquet.ParquetTypeUtils.pushdownColumnNameForSubfield; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND; @@ -98,6 +102,7 @@ import static com.facebook.presto.sql.planner.assertions.MatchResult.match; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.callDistributedProcedure; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; @@ -107,8 +112,14 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFinish; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.tests.sql.TestTable.randomTableSuffix; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -118,6 +129,8 @@ import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; public class TestIcebergLogicalPlanner @@ -201,6 +214,90 @@ public void testMetadataQueryOptimizer(boolean enabled) } } + @Test(dataProvider = "push_down_filter_enabled") + public void testMetadataQueryOptimizerWithMaxPartitionThreshold(boolean enabled) + { + QueryRunner queryRunner = getQueryRunner(); + String schemaName = getSession().getSchema().get(); + String tableName = "metadata_optimize_with_threshold_" + randomTableSuffix(); + try { + queryRunner.execute("create table " + tableName + "(v1 int, v2 varchar, a int, b varchar)" + + " with(partitioning = ARRAY['a', 'b'])"); + // insert data into 4 different partitions + queryRunner.execute("insert into " + tableName + " values" + + " (1, '1001', 1, '1001')," + + " (2, '1002', 2, '1001')," + + " (3, '1003', 3, '1002')," + + " (4, '1004', 4, '1002')"); + + // Perform metadata optimization when the number of partitions does not exceed the threshold + Session sessionWithThresholdBigger = getSessionWithOptimizeMetadataQueriesAndThreshold(enabled, 4); + assertQuery(sessionWithThresholdBigger, "select b, max(a), min(a) from " + tableName + " group by b", + "values('1001', 2, 1), ('1002', 4, 3)"); + + assertPlan(sessionWithThresholdBigger, "select b, max(a), min(a) from " + tableName + " group by b", + anyTree(values( + ImmutableList.of("a", "b"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new StringLiteral("1001")), + ImmutableList.of(new LongLiteral("2"), new StringLiteral("1001")), + ImmutableList.of(new LongLiteral("3"), new StringLiteral("1002")), + ImmutableList.of(new LongLiteral("4"), new StringLiteral("1002")))))); + + String countKey = format(LAZY_LOADING_COUNT_KEY_TEMPLATE, schemaName, tableName); + RuntimeMetric partitionsLazyLoadingCountMetric = sessionWithThresholdBigger.getRuntimeStats().getMetrics().get(countKey); + assertNotNull(partitionsLazyLoadingCountMetric); + assertEquals(partitionsLazyLoadingCountMetric.getCount(), 1); + + assertQuery(sessionWithThresholdBigger, "select distinct a, b from " + tableName, + "values(1, '1001'), (2, '1001'), (3, '1002'), (4, '1002')"); + assertPlan(sessionWithThresholdBigger, "select distinct a, b from " + tableName, + anyTree(values( + ImmutableList.of("a", "b"), + ImmutableList.of( + ImmutableList.of(new LongLiteral("1"), new StringLiteral("1001")), + ImmutableList.of(new LongLiteral("2"), new StringLiteral("1001")), + ImmutableList.of(new LongLiteral("3"), new StringLiteral("1002")), + ImmutableList.of(new LongLiteral("4"), new StringLiteral("1002")))))); + assertEquals(partitionsLazyLoadingCountMetric.getCount(), 2); + + // Do not perform metadata optimization when the number of partitions exceeds the threshold + Session sessionWithThresholdSmaller = getSessionWithOptimizeMetadataQueriesAndThreshold(false, 3); + assertQuery(sessionWithThresholdSmaller, "select b, max(a), min(a) from " + tableName + " group by b", + "values('1001', 2, 1), ('1002', 4, 3)"); + assertPlan(sessionWithThresholdSmaller, "select b, max(a), min(a) from " + tableName + " group by b", + anyTree(strictTableScan(tableName, identityMap("a", "b")))); + + RuntimeMetric partitionsLazyLoadingCountMetric2 = sessionWithThresholdSmaller.getRuntimeStats().getMetrics().get(countKey); + assertNotNull(partitionsLazyLoadingCountMetric2); + assertEquals(partitionsLazyLoadingCountMetric2.getCount(), 1); + + assertQuery(sessionWithThresholdSmaller, "select distinct a, b from " + tableName, + "values(1, '1001'), (2, '1001'), (3, '1002'), (4, '1002')"); + assertPlan(sessionWithThresholdSmaller, "select distinct a, b from " + tableName, + anyTree(strictTableScan(tableName, identityMap("a", "b")))); + assertEquals(partitionsLazyLoadingCountMetric2.getCount(), 2); + + // Perform further reducible optimization regardless of whether the number of partitions exceeds the threshold + assertQuery(sessionWithThresholdBigger, "select min(a), max(b) from " + tableName, "values(1, '1002')"); + assertPlan(sessionWithThresholdBigger, "select min(a), max(b) from " + tableName, + anyNot(AggregationNode.class, strictProject( + ImmutableMap.of("a", expression("1"), "b", expression("1002")), + anyTree(values())))); + assertEquals(partitionsLazyLoadingCountMetric.getCount(), 3); + + assertQuery(sessionWithThresholdSmaller, "select min(a), max(b) from " + tableName, "values(1, '1002')"); + assertPlan(sessionWithThresholdSmaller, "select min(a), max(b) from " + tableName, + anyNot(AggregationNode.class, strictProject( + ImmutableMap.of("a", expression("1"), "b", expression("1002")), + anyTree(values())))); + assertEquals(partitionsLazyLoadingCountMetric2.getCount(), 3); + } + finally { + queryRunner.execute("DROP TABLE IF EXISTS " + tableName); + } + } + @Test(dataProvider = "push_down_filter_enabled") public void testMetadataQueryOptimizerOnPartitionEvolution(boolean enabled) { @@ -517,16 +614,21 @@ public void testMetadataQueryOptimizerOnMetadataDelete(boolean enabled) public void testFilterByUnmatchedValue(boolean enabled) { Session session = getSessionWithOptimizeMetadataQueries(enabled); - String tableName = "test_filter_by_unmatched_value"; + String schemaName = session.getSchema().get(); + String tableName = "test_filter_by_unmatched_value_" + randomTableSuffix(); assertUpdate("CREATE TABLE " + tableName + " (a varchar, b integer, r row(c int, d varchar)) WITH(partitioning = ARRAY['a'])"); // query with normal column filter on empty table assertPlan(session, "select a, r from " + tableName + " where b = 1001", output(values("a", "r"))); + String countKey = format(LAZY_LOADING_COUNT_KEY_TEMPLATE, schemaName, tableName); + assertNull(session.getRuntimeStats().getMetrics().get(countKey)); + // query with partition column filter on empty table assertPlan(session, "select b, r from " + tableName + " where a = 'var3'", output(values("b", "r"))); + assertNull(session.getRuntimeStats().getMetrics().get(countKey)); assertUpdate("INSERT INTO " + tableName + " VALUES ('var1', 1, (1001, 't1')), ('var1', 3, (1003, 't3'))", 2); assertUpdate("INSERT INTO " + tableName + " VALUES ('var2', 8, (1008, 't8')), ('var2', 10, (1010, 't10'))", 2); @@ -535,10 +637,12 @@ public void testFilterByUnmatchedValue(boolean enabled) // query with unmatched normal column filter assertPlan(session, "select a, r from " + tableName + " where b = 1001", output(values("a", "r"))); + assertNull(session.getRuntimeStats().getMetrics().get(countKey)); // query with unmatched partition column filter assertPlan(session, "select b, r from " + tableName + " where a = 'var3'", output(values("b", "r"))); + assertNull(session.getRuntimeStats().getMetrics().get(countKey)); assertUpdate("DROP TABLE " + tableName); } @@ -547,18 +651,20 @@ public void testFilterByUnmatchedValue(boolean enabled) public void testFiltersWithPushdownDisable() { // The filter pushdown session property is disabled by default - Session sessionWithoutFilterPushdown = getQueryRunner().getDefaultSession(); + Session sessionWithoutFilterPushdown = getSessionWithNewRuntimeStats(getQueryRunner().getDefaultSession()); + String schemaName = sessionWithoutFilterPushdown.getSchema().get(); + String tableName = "test_filters_with_pushdown_disable_" + randomTableSuffix(); - assertUpdate("CREATE TABLE test_filters_with_pushdown_disable(id int, name varchar, r row(a int, b varchar)) with (partitioning = ARRAY['id'])"); - assertUpdate("INSERT INTO test_filters_with_pushdown_disable VALUES(10, 'adam', (10, 'adam')), (11, 'hd001', (11, 'hd001'))", 2); + assertUpdate("CREATE TABLE " + tableName + "(id int, name varchar, r row(a int, b varchar)) with (partitioning = ARRAY['id'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES(10, 'adam', (10, 'adam')), (11, 'hd001', (11, 'hd001'))", 2); // Only identity partition column predicates, would be enforced totally by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT name, r FROM test_filters_with_pushdown_disable WHERE id = 10", + assertPlan(sessionWithoutFilterPushdown, "SELECT name, r FROM " + tableName + " WHERE id = 10", output(exchange( - strictTableScan("test_filters_with_pushdown_disable", identityMap("name", "r")))), + strictTableScan(tableName, identityMap("name", "r")))), plan -> assertTableLayout( plan, - "test_filters_with_pushdown_disable", + tableName, withColumnDomains(ImmutableMap.of(new Subfield( "id", ImmutableList.of()), @@ -566,78 +672,91 @@ public void testFiltersWithPushdownDisable() TRUE_CONSTANT, ImmutableSet.of("id"))); + String countKey = format(LAZY_LOADING_COUNT_KEY_TEMPLATE, schemaName, tableName); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); + // Only normal column predicates, would not be enforced by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM test_filters_with_pushdown_disable WHERE name = 'adam'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM " + tableName + " WHERE name = 'adam'", output(exchange(project( filter("name='adam'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("id", "name", "r"))))))); + strictTableScan(tableName, identityMap("id", "name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Only subfield column predicates, would not be enforced by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM test_filters_with_pushdown_disable WHERE r.a = 10", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM " + tableName + " WHERE r.a = 10", output(exchange(project( filter("r.a=10", - strictTableScan("test_filters_with_pushdown_disable", identityMap("id", "name", "r"))))))); + strictTableScan(tableName, identityMap("id", "name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Predicates with identity partition column and normal column // The predicate was enforced partially by tableScan, so the filterNode drop it's filter condition `id=10` - assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM test_filters_with_pushdown_disable WHERE id = 10 and name = 'adam'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM " + tableName + " WHERE id = 10 and name = 'adam'", output(exchange(project( ImmutableMap.of("id", expression("10")), filter("name='adam'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("name", "r"))))))); + strictTableScan(tableName, identityMap("name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Predicates with identity partition column and subfield column // The predicate was enforced partially by tableScan, so the filterNode drop it's filter condition `id=10` - assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM test_filters_with_pushdown_disable WHERE id = 10 and r.b = 'adam'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM " + tableName + " WHERE id = 10 and r.b = 'adam'", output(exchange(project( ImmutableMap.of("id", expression("10")), filter("r.b='adam'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("name", "r"))))))); + strictTableScan(tableName, identityMap("name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Predicates expression `in` for identity partition columns could be enforced by iceberg table as well - assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM test_filters_with_pushdown_disable WHERE id in (1, 3, 5, 7, 9, 10) and r.b = 'adam'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM " + tableName + " WHERE id in (1, 3, 5, 7, 9, 10) and r.b = 'adam'", output(exchange(project( filter("r.b='adam'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("id", "name", "r"))))))); + strictTableScan(tableName, identityMap("id", "name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // When predicate simplification causing changes in the predicate, it could not be enforced by iceberg table String params = "(" + Joiner.on(", ").join(IntStream.rangeClosed(1, 50).mapToObj(i -> String.valueOf(2 * i + 1)).toArray()) + ")"; - assertPlan(sessionWithoutFilterPushdown, "SELECT name FROM test_filters_with_pushdown_disable WHERE id in " + params + " and r.b = 'adam'", + assertPlan(sessionWithoutFilterPushdown, "SELECT name FROM " + tableName + " WHERE id in " + params + " and r.b = 'adam'", output(exchange(project( filter("r.b='adam' AND id in " + params, - strictTableScan("test_filters_with_pushdown_disable", identityMap("id", "name", "r"))))))); + strictTableScan(tableName, identityMap("id", "name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Add a new identity partitioned column for iceberg table - assertUpdate("ALTER TABLE test_filters_with_pushdown_disable add column newpart bigint with (partitioning = 'identity')"); - assertUpdate("INSERT INTO test_filters_with_pushdown_disable VALUES(10, 'newman', (10, 'newman'), 1001)", 1); + assertUpdate("ALTER TABLE " + tableName + " add column newpart bigint with (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES(10, 'newman', (10, 'newman'), 1001)", 1); // Predicates with originally present identity partition column and newly added identity partition column // Only the predicate on originally present identity partition column could be enforced by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM test_filters_with_pushdown_disable WHERE id = 10 and newpart = 1001", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, name FROM " + tableName + " WHERE id = 10 and newpart = 1001", output(exchange(project( ImmutableMap.of("id", expression("10")), filter("newpart=1001", - strictTableScan("test_filters_with_pushdown_disable", identityMap("name", "newpart"))))))); + strictTableScan(tableName, identityMap("name", "newpart"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); - assertUpdate("DROP TABLE test_filters_with_pushdown_disable"); + assertUpdate("DROP TABLE " + tableName); - assertUpdate("CREATE TABLE test_filters_with_pushdown_disable(id int, name varchar, r row(a int, b varchar)) with (partitioning = ARRAY['id', 'truncate(name, 2)'])"); - assertUpdate("INSERT INTO test_filters_with_pushdown_disable VALUES (10, 'hd001', (10, 'newman'))", 1); + assertUpdate("CREATE TABLE " + tableName + "(id int, name varchar, r row(a int, b varchar)) with (partitioning = ARRAY['id', 'truncate(name, 2)'])"); + assertUpdate("INSERT INTO " + tableName + " VALUES (10, 'hd001', (10, 'newman'))", 1); // Predicates with non-identity partitioned column could not be enforced by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT id FROM test_filters_with_pushdown_disable WHERE name = 'hd001'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id FROM " + tableName + " WHERE name = 'hd001'", output(exchange(project( filter("name='hd001'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("id", "name"))))))); + strictTableScan(tableName, identityMap("id", "name"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); // Predicates with identity partition column and non-identity partitioned column // Only the predicate on identity partition column could be enforced by tableScan - assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM test_filters_with_pushdown_disable WHERE id = 10 and name = 'hd001'", + assertPlan(sessionWithoutFilterPushdown, "SELECT id, r FROM " + tableName + " WHERE id = 10 and name = 'hd001'", output(exchange(project( ImmutableMap.of("id", expression("10")), filter("name='hd001'", - strictTableScan("test_filters_with_pushdown_disable", identityMap("name", "r"))))))); - assertUpdate("DROP TABLE test_filters_with_pushdown_disable"); + strictTableScan(tableName, identityMap("name", "r"))))))); + assertNull(sessionWithoutFilterPushdown.getRuntimeStats().getMetrics().get(countKey)); + + assertUpdate("DROP TABLE " + tableName); } @Test @@ -730,6 +849,59 @@ public void testThoroughlyPushdownForTableWithUnsupportedSpecsWhoseDataAllDelete } } + @Test + public void testCallDistributedProcedureOnPartitionedTable() + { + String tableName = "partition_table_for_call_distributed_procedure"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c1'])"); + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + + assertPlan(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, getSession().getSchema().get()), + output(tableFinish(exchange(REMOTE_STREAMING, GATHER, + callDistributedProcedure( + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + strictTableScan(tableName, identityMap("c1", "c2"))))))))); + + // Do not support the filter that couldn't be enforced totally by tableScan + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c2 > ''bar''')", tableName, getSession().getSchema().get()), + "Unexpected FilterNode found in plan; probably connector was not able to handle provided WHERE expression"); + + // Support the filter that could be enforced totally by tableScan + assertPlan(getSession(), format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, getSession().getSchema().get()), + output(tableFinish(exchange(REMOTE_STREAMING, GATHER, + callDistributedProcedure( + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + strictTableScan(tableName, identityMap("c1", "c2")))))))), + plan -> assertTableLayout( + plan, + tableName, + withColumnDomains(ImmutableMap.of( + new Subfield( + "c1", + ImmutableList.of()), + Domain.create(ValueSet.ofRanges(greaterThan(INTEGER, 3L)), false))), + TRUE_CONSTANT, + ImmutableSet.of("c1"))); + + // Support filter conditions that are always false, which cause the underlying TableScanNode to be optimized into an empty ValuesNode + assertPlan(getSession(), format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1 > 2')", tableName, getSession().getSchema().get()), + output(tableFinish(exchange(REMOTE_STREAMING, GATHER, + callDistributedProcedure( + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + values(ImmutableList.of("c1", "c2"), + ImmutableList.of())))))))); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + @DataProvider(name = "timezones") public Object[][] timezones() { @@ -2304,7 +2476,24 @@ private Session sessionForTimezone(String zoneId, boolean legacyTimestamp) protected Session getSessionWithOptimizeMetadataQueries(boolean enabled) { return Session.builder(super.getSession()) + .setRuntimeStats(new RuntimeStats()) + .setCatalogSessionProperty(ICEBERG_CATALOG, PUSHDOWN_FILTER_ENABLED, String.valueOf(enabled)) + .build(); + } + + protected Session getSessionWithNewRuntimeStats(Session session) + { + return Session.builder(session) + .setRuntimeStats(new RuntimeStats()) + .build(); + } + + protected Session getSessionWithOptimizeMetadataQueriesAndThreshold(boolean enabled, int rowsForMetadataOptimizationThreshold) + { + return Session.builder(super.getSession()) + .setRuntimeStats(new RuntimeStats()) .setCatalogSessionProperty(ICEBERG_CATALOG, PUSHDOWN_FILTER_ENABLED, String.valueOf(enabled)) + .setCatalogSessionProperty(ICEBERG_CATALOG, ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD, String.valueOf(rowsForMetadataOptimizationThreshold)) .build(); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewMetadata.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewMetadata.java new file mode 100644 index 0000000000000..c441132c7654a --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewMetadata.java @@ -0,0 +1,802 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.rest.RESTCatalog; +import org.apache.iceberg.types.Types; +import org.apache.iceberg.view.View; +import org.assertj.core.util.Files; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.iceberg.CatalogType.REST; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_BASE_TABLES; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_OWNER; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_SECURITY_MODE; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.getRestServer; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.restConnectorProperties; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestIcebergMaterializedViewMetadata + extends AbstractTestQueryFramework +{ + private File warehouseLocation; + private TestingHttpServer restServer; + private String serverUri; + + @BeforeClass + @Override + public void init() + throws Exception + { + warehouseLocation = Files.newTemporaryFolder(); + + restServer = getRestServer(warehouseLocation.getAbsolutePath()); + restServer.start(); + + serverUri = restServer.getBaseUrl().toString(); + super.init(); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + throws Exception + { + if (restServer != null) { + restServer.stop(); + } + deleteRecursively(warehouseLocation.toPath(), ALLOW_INSECURE); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(REST) + .setExtraConnectorProperties(restConnectorProperties(serverUri)) + .setDataDirectory(Optional.of(warehouseLocation.toPath())) + .setSchemaName("test_schema") + .setCreateTpchTables(false) + .setExtraProperties(ImmutableMap.of( + "experimental.legacy-materialized-views", "false", + "experimental.allow-legacy-materialized-views-toggle", "true")) + .build().getQueryRunner(); + } + + @Test + public void testMaterializedViewSnapshotTracking() + throws Exception + { + assertUpdate("CREATE TABLE test_snapshot_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_snapshot_base VALUES (1, 100)", 1); + + assertUpdate("CREATE MATERIALIZED VIEW test_snapshot_mv AS SELECT id, value FROM test_snapshot_base"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId = TableIdentifier.of(Namespace.of("test_schema"), "test_snapshot_mv"); + + View viewBeforeRefresh = catalog.loadView(viewId); + String lastRefreshBefore = viewBeforeRefresh.properties().get(PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID); + assertNull(lastRefreshBefore, "Expected last_refresh_snapshot_id to be null before refresh"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_snapshot_mv", 1); + + View viewAfterRefresh = catalog.loadView(viewId); + String lastRefreshAfter = viewAfterRefresh.properties().get(PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID); + assertNotNull(lastRefreshAfter, "Expected last_refresh_snapshot_id to be set after refresh"); + + String baseSnapshot = viewAfterRefresh.properties().get("presto.materialized_view.base_snapshot.test_schema.test_snapshot_base"); + assertNotNull(baseSnapshot, "Expected base table snapshot ID to be tracked"); + + assertUpdate("INSERT INTO test_snapshot_base VALUES (2, 200)", 1); + assertUpdate("REFRESH MATERIALIZED VIEW test_snapshot_mv", 2); + + View viewAfterSecondRefresh = catalog.loadView(viewId); + String lastRefreshAfter2 = viewAfterSecondRefresh.properties().get(PRESTO_MATERIALIZED_VIEW_LAST_REFRESH_SNAPSHOT_ID); + String baseSnapshot2 = viewAfterSecondRefresh.properties().get("presto.materialized_view.base_snapshot.test_schema.test_snapshot_base"); + + assertNotEquals(lastRefreshAfter2, lastRefreshAfter, "Expected last_refresh_snapshot_id to change after second refresh"); + assertNotEquals(baseSnapshot2, baseSnapshot, "Expected base table snapshot ID to change after INSERT"); + + assertUpdate("CREATE TABLE test_empty_snapshot (id BIGINT)"); + assertUpdate("CREATE MATERIALIZED VIEW test_empty_mv AS SELECT id FROM test_empty_snapshot"); + assertUpdate("REFRESH MATERIALIZED VIEW test_empty_mv", 0); + + TableIdentifier emptyViewId = TableIdentifier.of(Namespace.of("test_schema"), "test_empty_mv"); + View emptyView = catalog.loadView(emptyViewId); + String emptySnapshot = emptyView.properties().get("presto.materialized_view.base_snapshot.test_schema.test_empty_snapshot"); + + assertEquals(emptySnapshot, "0", "Expected empty base table snapshot ID to be 0"); + + assertUpdate("DROP MATERIALIZED VIEW test_empty_mv"); + assertUpdate("DROP TABLE test_empty_snapshot"); + } + finally { + catalog.close(); + } + + assertUpdate("DROP MATERIALIZED VIEW test_snapshot_mv"); + assertUpdate("DROP TABLE test_snapshot_base"); + } + + @Test + public void testMaterializedViewWithComplexColumnMappings() + throws Exception + { + assertUpdate("CREATE TABLE test_mapping_orders (order_id BIGINT, customer_id BIGINT, product_id BIGINT, order_date DATE, amount BIGINT)"); + assertUpdate("CREATE TABLE test_mapping_customers (customer_id BIGINT, customer_name VARCHAR, region VARCHAR)"); + assertUpdate("CREATE TABLE test_mapping_products (product_id BIGINT, product_name VARCHAR, category VARCHAR)"); + + assertUpdate("INSERT INTO test_mapping_orders VALUES (1, 100, 1, DATE '2024-01-01', 500), (2, 200, 2, DATE '2024-01-02', 750)", 2); + assertUpdate("INSERT INTO test_mapping_customers VALUES (100, 'Alice', 'US'), (200, 'Bob', 'EU')", 2); + assertUpdate("INSERT INTO test_mapping_products VALUES (1, 'Widget', 'Electronics'), (2, 'Gadget', 'Electronics')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mapping_mv AS " + + "SELECT o.order_id, c.customer_name, c.region, o.order_date, o.amount, p.product_name, p.category " + + "FROM test_mapping_orders o " + + "JOIN test_mapping_customers c ON o.customer_id = c.customer_id " + + "JOIN test_mapping_products p ON o.product_id = p.product_id"); + + assertQuery("SELECT COUNT(*) FROM test_mapping_mv", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mapping_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mapping_mv\"", "SELECT 2"); + + assertQuery("SELECT order_id, customer_name, region FROM test_mapping_mv WHERE order_id = 1", + "VALUES (1, 'Alice', 'US')"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId = TableIdentifier.of(Namespace.of("test_schema"), "test_mapping_mv"); + View view = catalog.loadView(viewId); + + String columnMappings = view.properties().get("presto.materialized_view.column_mappings"); + assertNotNull(columnMappings, "Expected column_mappings property to be set"); + + assertFalse(columnMappings.isEmpty() || columnMappings.equals("[]"), "Expected non-empty column mappings for multi-table join"); + } + finally { + catalog.close(); + } + + assertUpdate("INSERT INTO test_mapping_orders VALUES (3, 100, 1, DATE '2024-01-03', 1000)", 1); + assertUpdate("REFRESH MATERIALIZED VIEW test_mapping_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mapping_mv\"", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_mapping_mv"); + assertUpdate("DROP TABLE test_mapping_products"); + assertUpdate("DROP TABLE test_mapping_customers"); + assertUpdate("DROP TABLE test_mapping_orders"); + } + + @Test + public void testMaterializedViewWithSpecialCharactersInTableNames() + throws Exception + { + assertUpdate("CREATE TABLE test_base_123 (id BIGINT, value_1 BIGINT)"); + assertUpdate("CREATE TABLE test_base_456_special (id BIGINT, value_2 BIGINT)"); + + assertUpdate("INSERT INTO test_base_123 VALUES (1, 100), (2, 200)", 2); + assertUpdate("INSERT INTO test_base_456_special VALUES (1, 300), (2, 400)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_special_chars_mv AS " + + "SELECT a.id, a.value_1, b.value_2 " + + "FROM test_base_123 a " + + "JOIN test_base_456_special b ON a.id = b.id"); + + assertQuery("SELECT COUNT(*) FROM test_special_chars_mv", "SELECT 2"); + assertQuery("SELECT * FROM test_special_chars_mv ORDER BY id", + "VALUES (1, 100, 300), (2, 200, 400)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_special_chars_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_special_chars_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_special_chars_mv\" ORDER BY id", + "VALUES (1, 100, 300), (2, 200, 400)"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId = TableIdentifier.of(Namespace.of("test_schema"), "test_special_chars_mv"); + View view = catalog.loadView(viewId); + + String baseTables = view.properties().get("presto.materialized_view.base_tables"); + assertNotNull(baseTables, "Expected base_tables property to be set"); + + assertTrue(baseTables.contains("test_base_123"), "Expected base_tables to contain 'test_base_123', got: " + baseTables); + assertTrue(baseTables.contains("test_base_456_special"), "Expected base_tables to contain 'test_base_456_special', got: " + baseTables); + + String snapshot1 = view.properties().get("presto.materialized_view.base_snapshot.test_schema.test_base_123"); + String snapshot2 = view.properties().get("presto.materialized_view.base_snapshot.test_schema.test_base_456_special"); + + assertNotNull(snapshot1, "Expected snapshot for test_base_123 to be tracked"); + assertNotNull(snapshot2, "Expected snapshot for test_base_456_special to be tracked"); + } + finally { + catalog.close(); + } + + assertUpdate("INSERT INTO test_base_123 VALUES (3, 500)", 1); + assertUpdate("INSERT INTO test_base_456_special VALUES (3, 600)", 1); + + assertQuery("SELECT COUNT(*) FROM test_special_chars_mv", "SELECT 3"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_special_chars_mv", 3); + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_special_chars_mv\"", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_special_chars_mv"); + assertUpdate("DROP TABLE test_base_456_special"); + assertUpdate("DROP TABLE test_base_123"); + } + + @Test + public void testMaterializedViewOtherValidationErrors() + throws Exception + { + assertUpdate("CREATE TABLE test_other_validation_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_other_validation_base VALUES (1, 'Alice')", 1); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId1 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_empty_base_tables"); + Map properties = new HashMap<>(); + properties.put(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties.put(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL, "SELECT 1 as id"); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA, "test_schema"); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME, "storage_empty_base"); + properties.put(PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS, "[]"); + properties.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, ""); + properties.put(PRESTO_MATERIALIZED_VIEW_OWNER, "test_user"); + properties.put(PRESTO_MATERIALIZED_VIEW_SECURITY_MODE, "DEFINER"); + + assertUpdate("CREATE TABLE storage_empty_base (id BIGINT)"); + + catalog.buildView(viewId1) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT 1 as id") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties) + .create(); + + assertQuery("SELECT * FROM test_mv_empty_base_tables", "SELECT 1"); + + catalog.dropView(viewId1); + assertUpdate("DROP TABLE storage_empty_base"); + + TableIdentifier viewId2 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_invalid_json"); + Map properties2 = new HashMap<>(); + properties2.put(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties2.put(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL, "SELECT id FROM test_other_validation_base"); + properties2.put(PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA, "test_schema"); + properties2.put(PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME, "storage_invalid_json"); + properties2.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "[{\"schema\":\"test_schema\", \"table\": \"test_other_validation_base\"}]"); + properties2.put(PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS, "{invalid json here"); + properties2.put(PRESTO_MATERIALIZED_VIEW_OWNER, "test_user"); + properties2.put(PRESTO_MATERIALIZED_VIEW_SECURITY_MODE, "DEFINER"); + + assertUpdate("CREATE TABLE storage_invalid_json (id BIGINT)"); + + catalog.buildView(viewId2) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM test_other_validation_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties2) + .create(); + + assertQueryFails("SELECT * FROM test_mv_invalid_json", + ".*Invalid JSON string.*"); + + catalog.dropView(viewId2); + assertUpdate("DROP TABLE storage_invalid_json"); + + TableIdentifier viewId3 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_nonexistent_base"); + Map properties3 = new HashMap<>(); + properties3.put(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties3.put(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL, "SELECT id FROM nonexistent_table"); + properties3.put(PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA, "test_schema"); + properties3.put(PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME, "storage_nonexistent_base"); + properties3.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "[{\"schema\":\"test_schema\", \"table\": \"nonexistent_table\"}]"); + properties3.put(PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS, "[]"); + properties3.put(PRESTO_MATERIALIZED_VIEW_OWNER, "test_user"); + properties3.put(PRESTO_MATERIALIZED_VIEW_SECURITY_MODE, "DEFINER"); + + assertUpdate("CREATE TABLE storage_nonexistent_base (id BIGINT)"); + + catalog.buildView(viewId3) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM nonexistent_table") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties3) + .create(); + + assertQueryFails("SELECT * FROM test_mv_nonexistent_base", + ".*(does not exist|not found).*"); + + catalog.dropView(viewId3); + assertUpdate("DROP TABLE storage_nonexistent_base"); + } + finally { + catalog.close(); + } + + assertUpdate("CREATE TABLE existing_storage_table (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO existing_storage_table VALUES (1, 100)", 1); + + assertQueryFails("CREATE MATERIALIZED VIEW test_mv_duplicate_storage " + + "WITH (storage_table = 'existing_storage_table') " + + "AS SELECT id, name FROM test_other_validation_base", + ".*already exists.*"); + + assertUpdate("DROP TABLE existing_storage_table"); + assertUpdate("DROP TABLE test_other_validation_base"); + } + + @Test + public void testMaterializedViewInvalidBaseTableNameFormat() + throws Exception + { + assertUpdate("CREATE TABLE test_format_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_format_base VALUES (1, 'Alice')", 1); + + assertUpdate("CREATE TABLE storage_table_1 (id BIGINT)"); + assertUpdate("CREATE TABLE storage_table_2 (id BIGINT)"); + assertUpdate("CREATE TABLE storage_table_3 (id BIGINT)"); + assertUpdate("CREATE TABLE storage_table_4 (id BIGINT)"); + assertUpdate("CREATE TABLE storage_table_5 (id BIGINT)"); + assertUpdate("CREATE TABLE storage_table_6 (id BIGINT)"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId1 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_no_schema"); + Map properties1 = createValidMvProperties("storage_table_1"); + properties1.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "table_only"); + + catalog.buildView(viewId1) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM table_only") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties1) + .create(); + + assertQueryFails("SELECT * FROM test_mv_no_schema", + ".*Invalid base table name format: table_only.*"); + + catalog.dropView(viewId1); + + TableIdentifier viewId2 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_empty_schema"); + Map properties2 = createValidMvProperties("storage_table_2"); + properties2.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "schema."); + + catalog.buildView(viewId2) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM schema.") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties2) + .create(); + + assertQueryFails("SELECT * FROM test_mv_empty_schema", + ".*Invalid base table name format: schema\\..*"); + + catalog.dropView(viewId2); + + TableIdentifier viewId3 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_empty_table"); + Map properties3 = createValidMvProperties("storage_table_3"); + properties3.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, ".table"); + + catalog.buildView(viewId3) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM .table") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties3) + .create(); + + assertQueryFails("SELECT * FROM test_mv_empty_table", + ".*Invalid base table name format: \\.table.*"); + + catalog.dropView(viewId3); + + TableIdentifier viewId4 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_double_dots"); + Map properties4 = createValidMvProperties("storage_table_4"); + properties4.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "schema..table"); + + catalog.buildView(viewId4) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM schema..table") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties4) + .create(); + + assertQueryFails("SELECT * FROM test_mv_double_dots", + ".*Invalid base table name format: schema\\.\\.table.*"); + + catalog.dropView(viewId4); + + TableIdentifier viewId5 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_too_many_parts"); + Map properties5 = createValidMvProperties("storage_table_5"); + properties5.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "a.b.c"); + + catalog.buildView(viewId5) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM a.b.c") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties5) + .create(); + + assertQueryFails("SELECT * FROM test_mv_too_many_parts", + ".*Invalid base table name format: a\\.b\\.c.*"); + + catalog.dropView(viewId5); + + TableIdentifier viewId6 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_no_separator"); + Map properties6 = createValidMvProperties("storage_table_6"); + properties6.put(PRESTO_MATERIALIZED_VIEW_BASE_TABLES, "schema_table"); + + catalog.buildView(viewId6) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()))) + .withQuery("spark", "SELECT id FROM test_schema.test_format_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties6) + .create(); + + assertQueryFails("SELECT * FROM test_mv_no_separator", + ".*Invalid base table name format: schema_table.*"); + + catalog.dropView(viewId6); + } + finally { + catalog.close(); + } + + assertUpdate("DROP TABLE test_format_base"); + } + + @Test + public void testMaterializedViewMissingRequiredProperties() + throws Exception + { + assertUpdate("CREATE TABLE test_validation_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_validation_base VALUES (1, 'Alice')", 1); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId1 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_missing_base_tables"); + Map properties1 = new HashMap<>(); + properties1.put("presto.materialized_view.format_version", CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties1.put("presto.materialized_view.original_sql", "SELECT id, name FROM test_validation_base"); + properties1.put("presto.materialized_view.storage_schema", "test_schema"); + properties1.put("presto.materialized_view.storage_table_name", "storage1"); + properties1.put("presto.materialized_view.column_mappings", "[]"); + + catalog.buildView(viewId1) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "name", Types.StringType.get()))) + .withQuery("spark", "SELECT id, name FROM test_validation_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties1) + .create(); + + assertQueryFails("SELECT * FROM test_mv_missing_base_tables", + ".*Materialized view missing required property: presto.materialized_view.base_tables.*"); + + catalog.dropView(viewId1); + + TableIdentifier viewId2 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_missing_column_mappings"); + Map properties2 = new HashMap<>(); + properties2.put("presto.materialized_view.format_version", CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties2.put("presto.materialized_view.original_sql", "SELECT id, name FROM test_validation_base"); + properties2.put("presto.materialized_view.storage_schema", "test_schema"); + properties2.put("presto.materialized_view.storage_table_name", "storage2"); + properties2.put("presto.materialized_view.base_tables", "[{\"schema\":\"test_schema\", \"table\": \"test_validation_base\"}]"); + + catalog.buildView(viewId2) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "name", Types.StringType.get()))) + .withQuery("spark", "SELECT id, name FROM test_validation_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties2) + .create(); + + assertQueryFails("SELECT * FROM test_mv_missing_column_mappings", + ".*Materialized view missing required property: presto.materialized_view.column_mappings.*"); + + catalog.dropView(viewId2); + + TableIdentifier viewId3 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_missing_storage_schema"); + Map properties3 = new HashMap<>(); + properties3.put("presto.materialized_view.format_version", CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties3.put("presto.materialized_view.original_sql", "SELECT id, name FROM test_validation_base"); + properties3.put("presto.materialized_view.storage_table_name", "storage3"); + properties3.put("presto.materialized_view.base_tables", "[{\"schema\":\"test_schema\", \"table\": \"test_validation_base\"}]"); + properties3.put("presto.materialized_view.column_mappings", "[]"); + + catalog.buildView(viewId3) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "name", Types.StringType.get()))) + .withQuery("spark", "SELECT id, name FROM test_validation_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties3) + .create(); + + assertQueryFails("SELECT * FROM test_mv_missing_storage_schema", + ".*Materialized view missing required property: presto.materialized_view.storage_schema.*"); + + catalog.dropView(viewId3); + + TableIdentifier viewId4 = TableIdentifier.of(Namespace.of("test_schema"), "test_mv_missing_storage_table_name"); + Map properties4 = new HashMap<>(); + properties4.put("presto.materialized_view.format_version", CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties4.put("presto.materialized_view.original_sql", "SELECT id, name FROM test_validation_base"); + properties4.put("presto.materialized_view.storage_schema", "test_schema"); + properties4.put("presto.materialized_view.base_tables", "[{\"schema\":\"test_schema\", \"table\": \"test_validation_base\"}]"); + properties4.put("presto.materialized_view.column_mappings", "[]"); + + catalog.buildView(viewId4) + .withSchema(new org.apache.iceberg.Schema( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.required(2, "name", Types.StringType.get()))) + .withQuery("spark", "SELECT id, name FROM test_validation_base") + .withDefaultNamespace(Namespace.of("test_schema")) + .withProperties(properties4) + .create(); + + assertQueryFails("SELECT * FROM test_mv_missing_storage_table_name", + ".*Materialized view missing required property: presto.materialized_view.storage_table_name.*"); + + catalog.dropView(viewId4); + } + finally { + catalog.close(); + } + + assertUpdate("DROP TABLE test_validation_base"); + } + + private Map createValidMvProperties(String storageTableName) + { + Map properties = new HashMap<>(); + properties.put(PRESTO_MATERIALIZED_VIEW_FORMAT_VERSION, CURRENT_MATERIALIZED_VIEW_FORMAT_VERSION + ""); + properties.put(PRESTO_MATERIALIZED_VIEW_ORIGINAL_SQL, "SELECT id FROM test_format_base"); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_SCHEMA, "test_schema"); + properties.put(PRESTO_MATERIALIZED_VIEW_STORAGE_TABLE_NAME, storageTableName); + properties.put(PRESTO_MATERIALIZED_VIEW_COLUMN_MAPPINGS, "[]"); + properties.put(PRESTO_MATERIALIZED_VIEW_OWNER, "test_user"); + properties.put(PRESTO_MATERIALIZED_VIEW_SECURITY_MODE, "DEFINER"); + return properties; + } + + @Test + public void testBaseTableSnapshotTracking() + throws Exception + { + assertUpdate("CREATE TABLE test_freshtime_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_freshtime_base VALUES (1, 100)", 1); + + assertUpdate("CREATE MATERIALIZED VIEW test_freshtime_mv AS SELECT id, value FROM test_freshtime_base"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + assertUpdate("REFRESH MATERIALIZED VIEW test_freshtime_mv", 1); + + Table baseTable = catalog.loadTable( + TableIdentifier.of(Namespace.of("test_schema"), "test_freshtime_base")); + long recordedSnapshotId = baseTable.currentSnapshot().snapshotId(); + long recordedSnapshotTimestamp = baseTable.currentSnapshot().timestampMillis(); + + TableIdentifier viewId = TableIdentifier.of(Namespace.of("test_schema"), "test_freshtime_mv"); + View view = catalog.loadView(viewId); + String storedSnapshotId = view.properties().get("presto.materialized_view.base_snapshot.test_schema.test_freshtime_base"); + assertEquals(Long.parseLong(storedSnapshotId), recordedSnapshotId, + "MV should store the base table snapshot ID at refresh time"); + + Thread.sleep(100); + + assertUpdate("INSERT INTO test_freshtime_base VALUES (2, 200)", 1); + + baseTable.refresh(); + long newSnapshotId = baseTable.currentSnapshot().snapshotId(); + long newSnapshotTimestamp = baseTable.currentSnapshot().timestampMillis(); + + assertNotEquals(newSnapshotId, recordedSnapshotId, + "New snapshot should have different ID"); + assertTrue(newSnapshotTimestamp > recordedSnapshotTimestamp, + "New snapshot timestamp should be later than recorded timestamp"); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_catalog = 'iceberg' AND table_schema = 'test_schema' AND table_name = 'test_freshtime_mv'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + Snapshot recordedSnapshot = baseTable.snapshot(recordedSnapshotId); + assertNotNull(recordedSnapshot, "Recorded snapshot should still exist"); + assertEquals(recordedSnapshot.timestampMillis(), recordedSnapshotTimestamp, + "Recorded snapshot timestamp should not have changed"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_freshtime_mv", 2); + + view = catalog.loadView(viewId); + String newStoredSnapshotId = view.properties().get("presto.materialized_view.base_snapshot.test_schema.test_freshtime_base"); + assertEquals(Long.parseLong(newStoredSnapshotId), newSnapshotId, + "After refresh, MV should store the new base table snapshot ID"); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_catalog = 'iceberg' AND table_schema = 'test_schema' AND table_name = 'test_freshtime_mv'", + "SELECT 'FULLY_MATERIALIZED'"); + } + finally { + catalog.close(); + } + + assertUpdate("DROP MATERIALIZED VIEW test_freshtime_mv"); + assertUpdate("DROP TABLE test_freshtime_base"); + } + + @Test + public void testStalenessPropertiesStoredInView() + throws Exception + { + assertUpdate("CREATE TABLE test_staleness_props_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_staleness_props_base VALUES (1, 100)", 1); + + // Create MV with staleness properties + assertUpdate("CREATE MATERIALIZED VIEW test_staleness_props_mv " + + "WITH (stale_read_behavior = 'FAIL', staleness_window = '1h') " + + "AS SELECT id, value FROM test_staleness_props_base"); + + RESTCatalog catalog = new RESTCatalog(); + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + try { + TableIdentifier viewId = TableIdentifier.of(Namespace.of("test_schema"), "test_staleness_props_mv"); + View view = catalog.loadView(viewId); + + String staleReadBehavior = view.properties().get("presto.materialized_view.stale_read_behavior"); + String stalenessWindow = view.properties().get("presto.materialized_view.staleness_window"); + + assertEquals(staleReadBehavior, "FAIL", + "stale_read_behavior should be stored in view properties"); + assertEquals(stalenessWindow, "1.00h", + "staleness_window should be stored in view properties"); + } + finally { + catalog.close(); + } + + assertUpdate("DROP MATERIALIZED VIEW test_staleness_props_mv"); + assertUpdate("DROP TABLE test_staleness_props_base"); + } + + @Test + public void testNoOrphanStorageTableOnValidationFailure() + throws Exception + { + try (RESTCatalog catalog = new RESTCatalog()) { + assertUpdate("CREATE TABLE test_orphan_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_orphan_base VALUES (1, 100)", 1); + + Session legacySession = Session.builder(getSession()) + .setSystemProperty("legacy_materialized_views", "true") + .build(); + + String mvName = "test_orphan_mv"; + String storageTableName = "__mv_storage__" + mvName; + + assertQueryFails( + legacySession, + "CREATE MATERIALIZED VIEW " + mvName + " AS SELECT id, value FROM test_orphan_base", + ".*Materialized view security mode is required.*"); + + assertQueryFails( + "SELECT COUNT(*) FROM \"" + storageTableName + "\"", + ".*(does not exist|not found).*"); + + Map catalogProps = new HashMap<>(); + catalogProps.put("uri", serverUri); + catalogProps.put("warehouse", warehouseLocation.getAbsolutePath()); + catalog.initialize("test_catalog", catalogProps); + + TableIdentifier storageTableId = TableIdentifier.of(Namespace.of("test_schema"), storageTableName); + boolean tableExists = catalog.tableExists(storageTableId); + assertFalse(tableExists, + "Storage table should not exist after failed MV creation. " + + "This would indicate validation happened after storage table creation."); + } + finally { + assertUpdate("DROP TABLE test_orphan_base"); + } + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewOptimizer.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewOptimizer.java new file mode 100644 index 0000000000000..8480b987817a1 --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViewOptimizer.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.assertj.core.util.Files; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Optional; + +import static com.facebook.presto.iceberg.CatalogType.REST; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.getRestServer; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.restConnectorProperties; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; + +/** + * Plan-level tests for MaterializedView optimizer rule. + * Verifies that the optimizer correctly decides when to use UNION stitching vs full recompute. + */ +@Test(singleThreaded = true) +public class TestIcebergMaterializedViewOptimizer + extends AbstractTestQueryFramework +{ + private File warehouseLocation; + private TestingHttpServer restServer; + + @BeforeClass + @Override + public void init() + throws Exception + { + warehouseLocation = Files.newTemporaryFolder(); + restServer = getRestServer(warehouseLocation.getAbsolutePath()); + restServer.start(); + super.init(); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + throws Exception + { + if (restServer != null) { + restServer.stop(); + } + if (warehouseLocation != null) { + deleteRecursively(warehouseLocation.toPath(), ALLOW_INSECURE); + } + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(REST) + .setExtraConnectorProperties(restConnectorProperties(restServer.getBaseUrl().toString())) + .setDataDirectory(Optional.of(warehouseLocation.toPath())) + .setSchemaName("test_schema") + .setCreateTpchTables(false) + .setExtraProperties(ImmutableMap.of("experimental.legacy-materialized-views", "false")) + .build().getQueryRunner(); + } + + @Test + public void testBasicOptimization() + { + assertUpdate("CREATE TABLE base_no_parts (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO base_no_parts VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_no_parts AS SELECT id, value FROM base_no_parts"); + getQueryRunner().execute("REFRESH MATERIALIZED VIEW mv_no_parts"); + + assertUpdate("INSERT INTO base_no_parts VALUES (3, 300)", 1); + + assertPlan("SELECT * FROM mv_no_parts", + anyTree(tableScan("base_no_parts"))); + + getQueryRunner().execute("REFRESH MATERIALIZED VIEW mv_no_parts"); + + assertPlan("SELECT * FROM mv_no_parts", + anyTree(tableScan("__mv_storage__mv_no_parts"))); + + assertUpdate("DROP MATERIALIZED VIEW mv_no_parts"); + assertUpdate("DROP TABLE base_no_parts"); + } + + @Test + public void testMultiTableStaleness() + { + // Create two partitioned base tables + assertUpdate("CREATE TABLE orders (order_id BIGINT, customer_id BIGINT, ds VARCHAR) " + + "WITH (partitioning = ARRAY['ds'])"); + assertUpdate("CREATE TABLE customers (customer_id BIGINT, name VARCHAR, reg_date VARCHAR) " + + "WITH (partitioning = ARRAY['reg_date'])"); + + assertUpdate("INSERT INTO orders VALUES (1, 100, '2024-01-01')", 1); + assertUpdate("INSERT INTO customers VALUES (100, 'Alice', '2024-01-01')", 1); + + // Create JOIN MV with partition columns in output + assertUpdate("CREATE MATERIALIZED VIEW mv_join AS " + + "SELECT o.order_id, c.name, o.ds, c.reg_date " + + "FROM orders o JOIN customers c ON o.customer_id = c.customer_id"); + getQueryRunner().execute("REFRESH MATERIALIZED VIEW mv_join"); + + // Make one table stale + assertUpdate("INSERT INTO orders VALUES (2, 200, '2024-01-02')", 1); + + assertPlan("SELECT * FROM mv_join", + anyTree( + anyTree( + join( + anyTree(tableScan("orders")), + anyTree(tableScan("customers")))))); + + getQueryRunner().execute("REFRESH MATERIALIZED VIEW mv_join"); + + // Make both tables stale + assertUpdate("INSERT INTO orders VALUES (2, 200, '2024-01-02')", 1); + assertUpdate("INSERT INTO customers VALUES (200, 'Bob', '2024-01-02')", 1); + + assertPlan("SELECT * FROM mv_join", + anyTree( + anyTree( + join( + anyTree(tableScan("orders")), + anyTree(tableScan("customers")))))); + assertUpdate("DROP MATERIALIZED VIEW mv_join"); + assertUpdate("DROP TABLE customers"); + assertUpdate("DROP TABLE orders"); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViews.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViews.java new file mode 100644 index 0000000000000..0f28075e33898 --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMaterializedViews.java @@ -0,0 +1,1748 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.assertj.core.util.Files; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Optional; + +import static com.facebook.presto.iceberg.CatalogType.REST; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.getRestServer; +import static com.facebook.presto.iceberg.rest.IcebergRestTestUtil.restConnectorProperties; +import static com.google.common.io.MoreFiles.deleteRecursively; +import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; + +@Test(singleThreaded = true) +public class TestIcebergMaterializedViews + extends AbstractTestQueryFramework +{ + private File warehouseLocation; + private TestingHttpServer restServer; + private String serverUri; + + @BeforeClass + @Override + public void init() + throws Exception + { + warehouseLocation = Files.newTemporaryFolder(); + + restServer = getRestServer(warehouseLocation.getAbsolutePath()); + restServer.start(); + + serverUri = restServer.getBaseUrl().toString(); + super.init(); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + throws Exception + { + if (restServer != null) { + restServer.stop(); + } + deleteRecursively(warehouseLocation.toPath(), ALLOW_INSECURE); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(REST) + .setExtraConnectorProperties(restConnectorProperties(serverUri)) + .setDataDirectory(Optional.of(warehouseLocation.toPath())) + .setSchemaName("test_schema") + .setCreateTpchTables(false) + .setExtraProperties(ImmutableMap.of("experimental.legacy-materialized-views", "false")) + .build().getQueryRunner(); + } + + @Test + public void testCreateMaterializedView() + { + assertUpdate("CREATE TABLE test_mv_base (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO test_mv_base VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_simple AS SELECT id, name, value FROM test_mv_base"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_simple\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_mv_simple", "SELECT 3"); + assertQuery("SELECT * FROM test_mv_simple ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_simple"); + assertUpdate("DROP TABLE test_mv_base"); + } + + @Test + public void testCreateMaterializedViewWithFilter() + { + assertUpdate("CREATE TABLE test_mv_filtered_base (id BIGINT, status VARCHAR, amount BIGINT)"); + assertUpdate("INSERT INTO test_mv_filtered_base VALUES (1, 'active', 100), (2, 'inactive', 200), (3, 'active', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_filtered AS SELECT id, amount FROM test_mv_filtered_base WHERE status = 'active'"); + + assertQuery("SELECT COUNT(*) FROM test_mv_filtered", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_filtered ORDER BY id", + "VALUES (1, 100), (3, 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_filtered"); + assertUpdate("DROP TABLE test_mv_filtered_base"); + } + + @Test + public void testCreateMaterializedViewWithAggregation() + { + assertUpdate("CREATE TABLE test_mv_sales (product_id BIGINT, category VARCHAR, revenue BIGINT)"); + assertUpdate("INSERT INTO test_mv_sales VALUES (1, 'Electronics', 1000), (2, 'Electronics', 1500), (3, 'Books', 500), (4, 'Books', 300)", 4); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_category_sales AS " + + "SELECT category, COUNT(*) as product_count, SUM(revenue) as total_revenue " + + "FROM test_mv_sales GROUP BY category"); + + assertQuery("SELECT COUNT(*) FROM test_mv_category_sales", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_category_sales ORDER BY category", + "VALUES ('Books', 2, 800), ('Electronics', 2, 2500)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_category_sales"); + assertUpdate("DROP TABLE test_mv_sales"); + } + + @Test + public void testMaterializedViewStaleness() + { + assertUpdate("CREATE TABLE test_mv_stale_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_mv_stale_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_stale AS SELECT id, value FROM test_mv_stale_base"); + + assertQuery("SELECT COUNT(*) FROM test_mv_stale", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_stale ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO test_mv_stale_base VALUES (3, 300)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_stale", "SELECT 3"); + assertQuery("SELECT * FROM test_mv_stale ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_stale", 3); + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_stale\"", "SELECT 3"); + + assertUpdate("TRUNCATE TABLE test_mv_stale_base"); + assertQuery("SELECT COUNT(*) FROM test_mv_stale_base", "SELECT 0"); + assertQuery("SELECT COUNT(*) FROM test_mv_stale", "SELECT 0"); + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_stale\"", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_stale"); + assertUpdate("DROP TABLE test_mv_stale_base"); + } + + @Test + public void testDropMaterializedView() + { + assertUpdate("CREATE TABLE test_mv_drop_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO test_mv_drop_base VALUES (1, 'test')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_drop AS SELECT id, value FROM test_mv_drop_base"); + + assertQuery("SELECT COUNT(*) FROM test_mv_drop", "SELECT 1"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_drop\"", "SELECT 0"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_drop"); + + assertQueryFails("SELECT * FROM \"__mv_storage__test_mv_drop\"", ".*does not exist.*"); + + assertQuery("SELECT COUNT(*) FROM test_mv_drop_base", "SELECT 1"); + + assertUpdate("DROP TABLE test_mv_drop_base"); + } + + @Test + public void testMaterializedViewMetadata() + { + assertUpdate("CREATE TABLE test_mv_metadata_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_mv_metadata_base VALUES (1, 'test')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_metadata AS SELECT id, name FROM test_mv_metadata_base WHERE id > 0"); + + assertQuery("SELECT table_name, table_type FROM information_schema.tables " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_mv_metadata'", + "VALUES ('test_mv_metadata', 'MATERIALIZED VIEW')"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_metadata"); + assertUpdate("DROP TABLE test_mv_metadata_base"); + } + + @DataProvider(name = "baseTableNames") + public Object[][] baseTableNamesProvider() + { + return new Object[][] { + {"tt1"}, + {"\"tt2\""}, + {"\"tt.3\""}, + {"\"tt,4.5\""}, + {"\"tt\"\"tt,123\"\".123\""} + }; + } + + @Test(dataProvider = "baseTableNames") + public void testMaterializedViewWithSpecialBaseTableName(String tableName) + { + assertUpdate("CREATE TABLE " + tableName + " (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_refresh AS SELECT id, value FROM " + tableName); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_refresh ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_refresh", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_refresh\" ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_refresh ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO " + tableName + " VALUES (3, 300)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 3"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_refresh", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_refresh\" ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_refresh"); + assertUpdate("DROP TABLE " + tableName); + } + + @Test + public void testRefreshMaterializedView() + { + assertUpdate("CREATE TABLE test_mv_refresh_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_mv_refresh_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_refresh AS SELECT id, value FROM test_mv_refresh_base"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_refresh ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_refresh", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_refresh\" ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 2"); + assertQuery("SELECT * FROM test_mv_refresh ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO test_mv_refresh_base VALUES (3, 300)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_refresh", "SELECT 3"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_refresh", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_refresh\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_refresh\" ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_refresh"); + assertUpdate("DROP TABLE test_mv_refresh_base"); + } + + @Test + public void testRefreshMaterializedViewWithAggregation() + { + assertUpdate("CREATE TABLE test_mv_agg_refresh_base (category VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO test_mv_agg_refresh_base VALUES ('A', 10), ('B', 20), ('A', 15)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_agg_refresh AS " + + "SELECT category, SUM(value) as total FROM test_mv_agg_refresh_base GROUP BY category"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_agg_refresh\"", "SELECT 0"); + + assertQuery("SELECT * FROM test_mv_agg_refresh ORDER BY category", + "VALUES ('A', 25), ('B', 20)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_agg_refresh", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_agg_refresh\"", "SELECT 2"); + + assertUpdate("INSERT INTO test_mv_agg_refresh_base VALUES ('A', 5), ('C', 30)", 2); + + assertQuery("SELECT * FROM test_mv_agg_refresh ORDER BY category", + "VALUES ('A', 30), ('B', 20), ('C', 30)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_agg_refresh", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_agg_refresh\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_agg_refresh\" ORDER BY category", + "VALUES ('A', 30), ('B', 20), ('C', 30)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_agg_refresh"); + assertUpdate("DROP TABLE test_mv_agg_refresh_base"); + } + + @Test + public void testPartitionedMaterializedViewWithStaleDataConstraints() + { + assertUpdate("CREATE TABLE test_mv_partitioned_base (" + + "id BIGINT, " + + "event_date DATE, " + + "value BIGINT) " + + "WITH (partitioning = ARRAY['event_date'])"); + + assertUpdate("INSERT INTO test_mv_partitioned_base VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-01', 200), " + + "(3, DATE '2024-01-02', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_partitioned AS " + + "SELECT id, event_date, value FROM test_mv_partitioned_base"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_partitioned\"", "SELECT 0"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_partitioned", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_partitioned\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_partitioned\" ORDER BY id", + "VALUES (1, DATE '2024-01-01', 100), (2, DATE '2024-01-01', 200), (3, DATE '2024-01-02', 300)"); + + assertQuery("SELECT COUNT(*) FROM test_mv_partitioned", "SELECT 3"); + assertQuery("SELECT * FROM test_mv_partitioned ORDER BY id", + "VALUES (1, DATE '2024-01-01', 100), (2, DATE '2024-01-01', 200), (3, DATE '2024-01-02', 300)"); + + assertUpdate("INSERT INTO test_mv_partitioned_base VALUES " + + "(4, DATE '2024-01-03', 400), " + + "(5, DATE '2024-01-03', 500)", 2); + + assertQuery("SELECT COUNT(*) FROM test_mv_partitioned", "SELECT 5"); + assertQuery("SELECT * FROM test_mv_partitioned ORDER BY id", + "VALUES (1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-01', 200), " + + "(3, DATE '2024-01-02', 300), " + + "(4, DATE '2024-01-03', 400), " + + "(5, DATE '2024-01-03', 500)"); + + assertUpdate("INSERT INTO test_mv_partitioned_base VALUES " + + "(6, DATE '2024-01-04', 600)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_partitioned", "SELECT 6"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_partitioned"); + assertUpdate("DROP TABLE test_mv_partitioned_base"); + } + + @Test + public void testMinimalRefresh() + { + assertUpdate("CREATE TABLE minimal_table (id BIGINT)"); + assertUpdate("INSERT INTO minimal_table VALUES (1)", 1); + assertUpdate("CREATE MATERIALIZED VIEW minimal_mv AS SELECT id FROM minimal_table"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__minimal_mv\"", "SELECT 0"); + + try { + assertUpdate("REFRESH MATERIALIZED VIEW minimal_mv", 1); + } + catch (Exception e) { + System.err.println("REFRESH failed with: " + e.getMessage()); + } + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__minimal_mv\"", "SELECT 1"); + assertQuery("SELECT * FROM \"__mv_storage__minimal_mv\"", "SELECT 1"); + + assertUpdate("DROP MATERIALIZED VIEW minimal_mv"); + assertUpdate("DROP TABLE minimal_table"); + } + + @Test + public void testJoinMaterializedViewLifecycle() + { + assertUpdate("CREATE TABLE test_mv_orders (order_id BIGINT, customer_id BIGINT, amount BIGINT)"); + assertUpdate("CREATE TABLE test_mv_customers (customer_id BIGINT, customer_name VARCHAR)"); + + assertUpdate("INSERT INTO test_mv_orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25)", 3); + assertUpdate("INSERT INTO test_mv_customers VALUES (100, 'Alice'), (200, 'Bob')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_order_details AS " + + "SELECT o.order_id, c.customer_name, o.amount " + + "FROM test_mv_orders o JOIN test_mv_customers c ON o.customer_id = c.customer_id"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_order_details\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_mv_order_details", "SELECT 3"); + assertQuery("SELECT * FROM test_mv_order_details ORDER BY order_id", + "VALUES (1, 'Alice', 50), (2, 'Bob', 75), (3, 'Alice', 25)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_order_details", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_order_details\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_order_details\" ORDER BY order_id", + "VALUES (1, 'Alice', 50), (2, 'Bob', 75), (3, 'Alice', 25)"); + + assertQuery("SELECT COUNT(*) FROM test_mv_order_details", "SELECT 3"); + + assertUpdate("INSERT INTO test_mv_orders VALUES (4, 200, 100)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_order_details", "SELECT 4"); + assertQuery("SELECT * FROM test_mv_order_details ORDER BY order_id", + "VALUES (1, 'Alice', 50), (2, 'Bob', 75), (3, 'Alice', 25), (4, 'Bob', 100)"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_order_details\"", "SELECT 3"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_order_details", 4); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_order_details\"", "SELECT 4"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_order_details\" ORDER BY order_id", + "VALUES (1, 'Alice', 50), (2, 'Bob', 75), (3, 'Alice', 25), (4, 'Bob', 100)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_order_details"); + assertUpdate("DROP TABLE test_mv_customers"); + assertUpdate("DROP TABLE test_mv_orders"); + } + + @Test + public void testPartitionedJoinMaterializedView() + { + assertUpdate("CREATE TABLE test_mv_part_orders (" + + "order_id BIGINT, " + + "customer_id BIGINT, " + + "order_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['order_date'])"); + + assertUpdate("CREATE TABLE test_mv_part_customers (customer_id BIGINT, customer_name VARCHAR)"); + + assertUpdate("INSERT INTO test_mv_part_orders VALUES " + + "(1, 100, DATE '2024-01-01', 50), " + + "(2, 200, DATE '2024-01-01', 75), " + + "(3, 100, DATE '2024-01-02', 25)", 3); + assertUpdate("INSERT INTO test_mv_part_customers VALUES (100, 'Alice'), (200, 'Bob')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_part_join AS " + + "SELECT o.order_id, c.customer_name, o.order_date, o.amount " + + "FROM test_mv_part_orders o JOIN test_mv_part_customers c ON o.customer_id = c.customer_id"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_part_join\"", "SELECT 0"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_part_join", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_part_join\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_mv_part_join\" ORDER BY order_id", + "VALUES (1, 'Alice', DATE '2024-01-01', 50), " + + "(2, 'Bob', DATE '2024-01-01', 75), " + + "(3, 'Alice', DATE '2024-01-02', 25)"); + + assertQuery("SELECT COUNT(*) FROM test_mv_part_join", "SELECT 3"); + assertQuery("SELECT * FROM test_mv_part_join ORDER BY order_id", + "VALUES (1, 'Alice', DATE '2024-01-01', 50), " + + "(2, 'Bob', DATE '2024-01-01', 75), " + + "(3, 'Alice', DATE '2024-01-02', 25)"); + + assertUpdate("INSERT INTO test_mv_part_orders VALUES (4, 200, DATE '2024-01-03', 100)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_part_join", "SELECT 4"); + assertQuery("SELECT * FROM test_mv_part_join ORDER BY order_id", + "VALUES (1, 'Alice', DATE '2024-01-01', 50), " + + "(2, 'Bob', DATE '2024-01-01', 75), " + + "(3, 'Alice', DATE '2024-01-02', 25), " + + "(4, 'Bob', DATE '2024-01-03', 100)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_part_join", 4); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_part_join\"", "SELECT 4"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_part_join"); + assertUpdate("DROP TABLE test_mv_part_customers"); + assertUpdate("DROP TABLE test_mv_part_orders"); + } + + @Test + public void testMultiTableStaleness_TwoTablesBothStale() + { + assertUpdate("CREATE TABLE test_mv_orders (" + + "order_id BIGINT, " + + "order_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['order_date'])"); + + assertUpdate("CREATE TABLE test_mv_customers (" + + "customer_id BIGINT, " + + "reg_date DATE, " + + "name VARCHAR) " + + "WITH (partitioning = ARRAY['reg_date'])"); + + assertUpdate("INSERT INTO test_mv_orders VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200)", 2); + assertUpdate("INSERT INTO test_mv_customers VALUES " + + "(1, DATE '2024-01-01', 'Alice'), " + + "(2, DATE '2024-01-02', 'Bob')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_multi_stale AS " + + "SELECT o.order_id, c.name, o.order_date, c.reg_date, o.amount " + + "FROM test_mv_orders o JOIN test_mv_customers c ON o.order_id = c.customer_id"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_multi_stale", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_multi_stale\"", "SELECT 2"); + + assertQuery("SELECT COUNT(*) FROM test_mv_multi_stale", "SELECT 2"); + + assertUpdate("INSERT INTO test_mv_orders VALUES (3, DATE '2024-01-03', 300)", 1); + assertUpdate("INSERT INTO test_mv_customers VALUES (3, DATE '2024-01-03', 'Charlie')", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_multi_stale", "SELECT 3"); + assertQuery("SELECT order_id, name, order_date, reg_date, amount FROM test_mv_multi_stale ORDER BY order_id", + "VALUES (1, 'Alice', DATE '2024-01-01', DATE '2024-01-01', 100), " + + "(2, 'Bob', DATE '2024-01-02', DATE '2024-01-02', 200), " + + "(3, 'Charlie', DATE '2024-01-03', DATE '2024-01-03', 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_multi_stale"); + assertUpdate("DROP TABLE test_mv_customers"); + assertUpdate("DROP TABLE test_mv_orders"); + } + + @Test + public void testMultiTableStaleness_ThreeTablesWithTwoStale() + { + assertUpdate("CREATE TABLE test_mv_t1 (" + + "id BIGINT, " + + "date1 DATE, " + + "value1 BIGINT) " + + "WITH (partitioning = ARRAY['date1'])"); + + assertUpdate("CREATE TABLE test_mv_t2 (" + + "id BIGINT, " + + "date2 DATE, " + + "value2 BIGINT) " + + "WITH (partitioning = ARRAY['date2'])"); + + assertUpdate("CREATE TABLE test_mv_t3 (" + + "id BIGINT, " + + "date3 DATE, " + + "value3 BIGINT) " + + "WITH (partitioning = ARRAY['date3'])"); + + assertUpdate("INSERT INTO test_mv_t1 VALUES (1, DATE '2024-01-01', 100)", 1); + assertUpdate("INSERT INTO test_mv_t2 VALUES (1, DATE '2024-01-01', 200)", 1); + assertUpdate("INSERT INTO test_mv_t3 VALUES (1, DATE '2024-01-01', 300)", 1); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_three_tables AS " + + "SELECT t1.id, t1.date1, t2.date2, t3.date3, " + + " t1.value1, t2.value2, t3.value3 " + + "FROM test_mv_t1 t1 " + + "JOIN test_mv_t2 t2 ON t1.id = t2.id " + + "JOIN test_mv_t3 t3 ON t1.id = t3.id"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_three_tables", 1); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_three_tables\"", "SELECT 1"); + + assertUpdate("INSERT INTO test_mv_t1 VALUES (2, DATE '2024-01-02', 150)", 1); + assertUpdate("INSERT INTO test_mv_t2 VALUES (2, DATE '2024-01-01', 250)", 1); + assertUpdate("INSERT INTO test_mv_t3 VALUES (2, DATE '2024-01-02', 350)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_three_tables", "SELECT 2"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_three_tables"); + assertUpdate("DROP TABLE test_mv_t3"); + assertUpdate("DROP TABLE test_mv_t2"); + assertUpdate("DROP TABLE test_mv_t1"); + } + + @Test + public void testMultiTableStaleness_DifferentPartitionCounts() + { + assertUpdate("CREATE TABLE test_mv_table_a (" + + "id BIGINT, " + + "date_a DATE, " + + "value BIGINT) " + + "WITH (partitioning = ARRAY['date_a'])"); + + assertUpdate("CREATE TABLE test_mv_table_b (" + + "id BIGINT, " + + "date_b DATE, " + + "status VARCHAR) " + + "WITH (partitioning = ARRAY['date_b'])"); + + assertUpdate("INSERT INTO test_mv_table_a VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200)", 2); + assertUpdate("INSERT INTO test_mv_table_b VALUES " + + "(1, DATE '2024-01-01', 'active'), " + + "(2, DATE '2024-01-02', 'inactive')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_diff_partitions AS " + + "SELECT a.id, a.date_a, b.date_b, a.value, b.status " + + "FROM test_mv_table_a a JOIN test_mv_table_b b ON a.id = b.id"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_diff_partitions", 2); + + assertUpdate("INSERT INTO test_mv_table_a VALUES " + + "(3, DATE '2024-01-03', 300), " + + "(4, DATE '2024-01-04', 400), " + + "(5, DATE '2024-01-05', 500)", 3); + + assertUpdate("INSERT INTO test_mv_table_b VALUES " + + "(3, DATE '2024-01-03', 'active'), " + + "(4, DATE '2024-01-04', 'active'), " + + "(5, DATE '2024-01-05', 'pending')", 3); + + assertQuery("SELECT COUNT(*) FROM test_mv_diff_partitions", "SELECT 5"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_diff_partitions"); + assertUpdate("DROP TABLE test_mv_table_b"); + assertUpdate("DROP TABLE test_mv_table_a"); + } + + @Test + public void testMultiTableStaleness_NonPartitionedAndPartitionedBothStale() + { + assertUpdate("CREATE TABLE test_mv_non_part (id BIGINT, category VARCHAR)"); + + assertUpdate("CREATE TABLE test_mv_part_sales (" + + "id BIGINT, " + + "sale_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['sale_date'])"); + + assertUpdate("INSERT INTO test_mv_non_part VALUES (1, 'Electronics'), (2, 'Books')", 2); + assertUpdate("INSERT INTO test_mv_part_sales VALUES " + + "(1, DATE '2024-01-01', 500), " + + "(2, DATE '2024-01-02', 300)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_mv_mixed_stale AS " + + "SELECT c.id, c.category, s.sale_date, s.amount " + + "FROM test_mv_non_part c JOIN test_mv_part_sales s ON c.id = s.id"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_mixed_stale\"", "SELECT 0"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_mixed_stale", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_mv_mixed_stale\"", "SELECT 2"); + assertQuery("SELECT id, category, sale_date, amount FROM \"__mv_storage__test_mv_mixed_stale\" ORDER BY id", + "VALUES (1, 'Electronics', DATE '2024-01-01', 500), (2, 'Books', DATE '2024-01-02', 300)"); + + assertUpdate("INSERT INTO test_mv_non_part VALUES (3, 'Toys')", 1); + assertUpdate("INSERT INTO test_mv_part_sales VALUES (3, DATE '2024-01-03', 700)", 1); + + assertQuery("SELECT COUNT(*) FROM test_mv_mixed_stale", "SELECT 3"); + assertQuery("SELECT id, category, sale_date, amount FROM test_mv_mixed_stale ORDER BY id", + "VALUES (1, 'Electronics', DATE '2024-01-01', 500), " + + "(2, 'Books', DATE '2024-01-02', 300), " + + "(3, 'Toys', DATE '2024-01-03', 700)"); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_mixed_stale"); + assertUpdate("DROP TABLE test_mv_part_sales"); + assertUpdate("DROP TABLE test_mv_non_part"); + } + + @Test + public void testPartitionAlignment_MatchingColumns() + { + assertUpdate("CREATE TABLE test_pa_matching_base (" + + "id BIGINT, " + + "event_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['event_date'])"); + + assertUpdate("INSERT INTO test_pa_matching_base VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200), " + + "(3, DATE '2024-01-03', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_pa_matching_mv AS " + + "SELECT id, event_date, amount FROM test_pa_matching_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_pa_matching_mv", 3); + + assertUpdate("INSERT INTO test_pa_matching_base VALUES (4, DATE '2024-01-04', 400)", 1); + + assertQuery("SELECT COUNT(*) FROM test_pa_matching_mv", "SELECT 4"); + assertQuery("SELECT id, event_date, amount FROM test_pa_matching_mv ORDER BY id", + "VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200), " + + "(3, DATE '2024-01-03', 300), " + + "(4, DATE '2024-01-04', 400)"); + + assertUpdate("DROP MATERIALIZED VIEW test_pa_matching_mv"); + assertUpdate("DROP TABLE test_pa_matching_base"); + } + + @Test + public void testPartitionAlignment_MissingConstraintColumn() + { + assertUpdate("CREATE TABLE test_pa_missing_base (" + + "id BIGINT, " + + "event_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['event_date'])"); + + assertUpdate("INSERT INTO test_pa_missing_base VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200), " + + "(3, DATE '2024-01-03', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_pa_missing_mv AS " + + "SELECT id, amount FROM test_pa_missing_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_pa_missing_mv", 3); + + assertUpdate("INSERT INTO test_pa_missing_base VALUES (4, DATE '2024-01-04', 400)", 1); + + assertQuery("SELECT COUNT(*) FROM test_pa_missing_mv", "SELECT 4"); + assertQuery("SELECT id, amount FROM test_pa_missing_mv ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300), (4, 400)"); + + assertUpdate("DROP MATERIALIZED VIEW test_pa_missing_mv"); + assertUpdate("DROP TABLE test_pa_missing_base"); + } + + @Test + public void testPartitionAlignment_OverSpecifiedStorage() + { + assertUpdate("CREATE TABLE test_pa_over_table_a (" + + "id BIGINT, " + + "event_date DATE, " + + "amount BIGINT) " + + "WITH (partitioning = ARRAY['event_date'])"); + + assertUpdate("CREATE TABLE test_pa_over_table_b (" + + "customer_id BIGINT, " + + "region VARCHAR, " + + "name VARCHAR) " + + "WITH (partitioning = ARRAY['region'])"); + + assertUpdate("INSERT INTO test_pa_over_table_a VALUES " + + "(1, DATE '2024-01-01', 100), " + + "(2, DATE '2024-01-02', 200), " + + "(3, DATE '2024-01-03', 300)", 3); + + assertUpdate("INSERT INTO test_pa_over_table_b VALUES " + + "(1, 'US', 'Alice'), " + + "(2, 'US', 'Bob'), " + + "(3, 'UK', 'Charlie')", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_pa_over_mv AS " + + "SELECT a.id, a.event_date, a.amount, b.region, b.name " + + "FROM test_pa_over_table_a a " + + "JOIN test_pa_over_table_b b ON a.id = b.customer_id"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_pa_over_mv", 3); + + assertUpdate("INSERT INTO test_pa_over_table_a VALUES (1, DATE '2024-01-04', 150)", 1); + + assertQuery("SELECT COUNT(*) FROM test_pa_over_mv", "SELECT 4"); + assertQuery("SELECT id, event_date, amount, region, name FROM test_pa_over_mv ORDER BY id, event_date", + "VALUES " + + "(1, DATE '2024-01-01', 100, 'US', 'Alice'), " + + "(1, DATE '2024-01-04', 150, 'US', 'Alice'), " + + "(2, DATE '2024-01-02', 200, 'US', 'Bob'), " + + "(3, DATE '2024-01-03', 300, 'UK', 'Charlie')"); + + assertUpdate("DROP MATERIALIZED VIEW test_pa_over_mv"); + assertUpdate("DROP TABLE test_pa_over_table_b"); + assertUpdate("DROP TABLE test_pa_over_table_a"); + } + + @Test + public void testAggregationMV_MisalignedPartitioning() + { + // Bug: When GROUP BY column differs from partition column and multiple partitions + // are stale, the current implementation creates partial aggregates per partition + // and GROUP BY treats them as distinct rows instead of re-aggregating. + assertUpdate("CREATE TABLE test_agg_misaligned (" + + "id BIGINT, " + + "partition_col VARCHAR, " + + "region VARCHAR, " + + "sales BIGINT) " + + "WITH (partitioning = ARRAY['partition_col'])"); + + assertUpdate("INSERT INTO test_agg_misaligned VALUES " + + "(1, 'A', 'US', 100), " + + "(2, 'A', 'EU', 50), " + + "(3, 'B', 'US', 200), " + + "(4, 'B', 'EU', 75)", 4); + + assertUpdate("CREATE MATERIALIZED VIEW test_agg_mv AS " + + "SELECT region, SUM(sales) as total_sales " + + "FROM test_agg_misaligned " + + "GROUP BY region"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_agg_mv", 2); + + assertQuery("SELECT * FROM test_agg_mv ORDER BY region", + "VALUES ('EU', 125), ('US', 300)"); + + assertUpdate("INSERT INTO test_agg_misaligned VALUES " + + "(5, 'A', 'US', 10), " + + "(6, 'B', 'US', 20)", 2); + + assertQuery("SELECT * FROM test_agg_mv ORDER BY region", + "VALUES ('EU', 125), ('US', 330)"); + + assertUpdate("DROP MATERIALIZED VIEW test_agg_mv"); + assertUpdate("DROP TABLE test_agg_misaligned"); + } + + @Test + public void testAggregationMV_MultiTableJoin_BothStale() + { + // Bug: When both tables are stale, creates partial aggregates for each branch + // which are treated as distinct rows instead of being re-aggregated. + assertUpdate("CREATE TABLE test_multi_orders (" + + "order_id BIGINT, " + + "product_id BIGINT, " + + "order_date DATE, " + + "quantity BIGINT) " + + "WITH (partitioning = ARRAY['order_date'])"); + + assertUpdate("CREATE TABLE test_multi_products (" + + "product_id BIGINT, " + + "product_category VARCHAR, " + + "price BIGINT) " + + "WITH (partitioning = ARRAY['product_category'])"); + + assertUpdate("INSERT INTO test_multi_orders VALUES " + + "(1, 100, DATE '2024-01-01', 5), " + + "(2, 200, DATE '2024-01-01', 3)", 2); + assertUpdate("INSERT INTO test_multi_products VALUES " + + "(100, 'Electronics', 50), " + + "(200, 'Books', 20)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_multi_agg_mv AS " + + "SELECT p.product_category, SUM(o.quantity * p.price) as total_revenue " + + "FROM test_multi_orders o " + + "JOIN test_multi_products p ON o.product_id = p.product_id " + + "GROUP BY p.product_category"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_multi_agg_mv", 2); + + assertQuery("SELECT * FROM test_multi_agg_mv ORDER BY product_category", + "VALUES ('Books', 60), ('Electronics', 250)"); + + assertUpdate("INSERT INTO test_multi_orders VALUES " + + "(3, 100, DATE '2024-01-02', 2), " + + "(4, 200, DATE '2024-01-02', 4)", 2); + + assertUpdate("INSERT INTO test_multi_products VALUES " + + "(300, 'Toys', 30)", 1); + + assertUpdate("INSERT INTO test_multi_orders VALUES " + + "(5, 300, DATE '2024-01-02', 1)", 1); + + String explainResult = (String) computeScalar("EXPLAIN SELECT * FROM test_multi_agg_mv ORDER BY product_category"); + System.out.println("=== EXPLAIN PLAN ==="); + System.out.println(explainResult); + System.out.println("==================="); + + assertQuery("SELECT * FROM test_multi_agg_mv ORDER BY product_category", + "VALUES ('Books', 140), ('Electronics', 350), ('Toys', 30)"); + + assertUpdate("DROP MATERIALIZED VIEW test_multi_agg_mv"); + assertUpdate("DROP TABLE test_multi_products"); + assertUpdate("DROP TABLE test_multi_orders"); + } + + @Test + public void testMaterializedViewWithCustomStorageTableName() + { + assertUpdate("CREATE TABLE test_custom_storage_base (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO test_custom_storage_base VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_custom_storage_mv " + + "WITH (storage_table = 'my_custom_storage_table') " + + "AS SELECT id, name, value FROM test_custom_storage_base"); + + assertQuery("SELECT COUNT(*) FROM my_custom_storage_table", "SELECT 0"); + + assertQueryFails("SELECT * FROM \"__mv_storage__test_custom_storage_mv\"", ".*does not exist.*"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_custom_storage_mv", 2); + + assertQuery("SELECT COUNT(*) FROM my_custom_storage_table", "SELECT 2"); + assertQuery("SELECT * FROM my_custom_storage_table ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertQuery("SELECT * FROM test_custom_storage_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertUpdate("INSERT INTO test_custom_storage_base VALUES (3, 'Charlie', 300)", 1); + assertUpdate("REFRESH MATERIALIZED VIEW test_custom_storage_mv", 3); + + assertQuery("SELECT COUNT(*) FROM my_custom_storage_table", "SELECT 3"); + assertQuery("SELECT * FROM my_custom_storage_table ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_custom_storage_mv"); + + assertQueryFails("SELECT * FROM my_custom_storage_table", ".*does not exist.*"); + + assertUpdate("DROP TABLE test_custom_storage_base"); + } + + @Test + public void testMaterializedViewWithCustomStorageSchema() + { + assertUpdate("CREATE SCHEMA IF NOT EXISTS test_storage_schema"); + + assertUpdate("CREATE TABLE test_custom_schema_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_custom_schema_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_custom_schema_mv " + + "WITH (storage_schema = 'test_storage_schema', " + + "storage_table = 'storage_table') " + + "AS SELECT id, value FROM test_schema.test_custom_schema_base"); + + assertQuery("SELECT COUNT(*) FROM test_storage_schema.storage_table", "SELECT 0"); + + assertQueryFails("SELECT * FROM test_schema.storage_table", ".*does not exist.*"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_schema.test_custom_schema_mv", 2); + + assertQuery("SELECT COUNT(*) FROM test_storage_schema.storage_table", "SELECT 2"); + assertQuery("SELECT * FROM test_storage_schema.storage_table ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertQuery("SELECT * FROM test_custom_schema_mv ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertUpdate("DROP MATERIALIZED VIEW test_schema.test_custom_schema_mv"); + assertQueryFails("SELECT * FROM test_storage_schema.storage_table", ".*does not exist.*"); + + assertUpdate("DROP TABLE test_custom_schema_base"); + assertUpdate("DROP SCHEMA test_storage_schema"); + } + + @Test + public void testMaterializedViewWithCustomPrefix() + { + assertUpdate("CREATE TABLE test_custom_prefix_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_custom_prefix_base VALUES (1, 'test')", 1); + + Session sessionWithCustomPrefix = Session.builder(getSession()) + .setCatalogSessionProperty("iceberg", "materialized_view_storage_prefix", "custom_prefix_") + .build(); + + assertUpdate(sessionWithCustomPrefix, "CREATE MATERIALIZED VIEW test_custom_prefix_mv " + + "AS SELECT id, name FROM test_custom_prefix_base"); + + assertQuery("SELECT COUNT(*) FROM custom_prefix_test_custom_prefix_mv", "SELECT 0"); + + assertQueryFails("SELECT * FROM \"__mv_storage__test_custom_prefix_mv\"", ".*does not exist.*"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_custom_prefix_mv", 1); + + assertQuery("SELECT COUNT(*) FROM custom_prefix_test_custom_prefix_mv", "SELECT 1"); + assertQuery("SELECT * FROM custom_prefix_test_custom_prefix_mv", "VALUES (1, 'test')"); + + assertQuery("SELECT * FROM test_custom_prefix_mv", "VALUES (1, 'test')"); + + assertUpdate("DROP MATERIALIZED VIEW test_custom_prefix_mv"); + assertQueryFails("SELECT * FROM custom_prefix_test_custom_prefix_mv", ".*does not exist.*"); + + assertUpdate("DROP TABLE test_custom_prefix_base"); + } + + @Test + public void testMaterializedViewWithValuesOnly() + { + assertUpdate("CREATE MATERIALIZED VIEW test_values_mv AS SELECT * FROM (VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)) AS t(id, name, value)"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_values_mv\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_values_mv", "SELECT 3"); + assertQuery("SELECT * FROM test_values_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_values_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_values_mv\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_values_mv\" ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertQuery("SELECT * FROM test_values_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("DROP MATERIALIZED VIEW test_values_mv"); + assertQueryFails("SELECT * FROM \"__mv_storage__test_values_mv\"", ".*does not exist.*"); + } + + @Test + public void testMaterializedViewWithBaseTableButNoColumnsSelected() + { + assertUpdate("CREATE TABLE test_no_cols_base (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO test_no_cols_base VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_no_cols_mv AS " + + "SELECT 'constant' as label, 42 as fixed_value FROM test_no_cols_base"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_no_cols_mv\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_no_cols_mv", "SELECT 3"); + assertQuery("SELECT * FROM test_no_cols_mv", + "VALUES ('constant', 42), ('constant', 42), ('constant', 42)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_no_cols_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_no_cols_mv\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_no_cols_mv\"", + "VALUES ('constant', 42), ('constant', 42), ('constant', 42)"); + + assertUpdate("INSERT INTO test_no_cols_base VALUES (4, 'Dave', 400)", 1); + + assertQuery("SELECT COUNT(*) FROM test_no_cols_mv", "SELECT 4"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_no_cols_mv", 4); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_no_cols_mv\"", "SELECT 4"); + + assertUpdate("DROP MATERIALIZED VIEW test_no_cols_mv"); + assertQueryFails("SELECT * FROM \"__mv_storage__test_no_cols_mv\"", ".*does not exist.*"); + + assertUpdate("DROP TABLE test_no_cols_base"); + } + + @Test + public void testMaterializedViewOnEmptyBaseTable() + { + assertUpdate("CREATE TABLE test_empty_base (id BIGINT, name VARCHAR, value BIGINT)"); + + assertUpdate("CREATE MATERIALIZED VIEW test_empty_mv AS SELECT id, name, value FROM test_empty_base"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_empty_mv\"", "SELECT 0"); + + assertQuery("SELECT COUNT(*) FROM test_empty_mv", "SELECT 0"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_empty_mv", 0); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_empty_mv\"", "SELECT 0"); + + assertUpdate("INSERT INTO test_empty_base VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + + assertQuery("SELECT COUNT(*) FROM test_empty_mv", "SELECT 2"); + assertQuery("SELECT * FROM test_empty_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_empty_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_empty_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_empty_mv\" ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertUpdate("DROP MATERIALIZED VIEW test_empty_mv"); + assertQueryFails("SELECT * FROM \"__mv_storage__test_empty_mv\"", ".*does not exist.*"); + + assertUpdate("DROP TABLE test_empty_base"); + } + + @Test + public void testRefreshFailurePreservesOldData() + { + assertUpdate("CREATE TABLE test_refresh_failure_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_refresh_failure_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_refresh_failure_mv AS " + + "SELECT id, value FROM test_refresh_failure_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_refresh_failure_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_refresh_failure_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_refresh_failure_mv\" ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertUpdate("DROP TABLE test_refresh_failure_base"); + + try { + getQueryRunner().execute("REFRESH MATERIALIZED VIEW test_refresh_failure_mv"); + throw new AssertionError("Expected REFRESH to fail when base table doesn't exist"); + } + catch (Exception e) { + if (!e.getMessage().contains("does not exist") && !e.getMessage().contains("not found")) { + throw new AssertionError("Expected 'does not exist' error, got: " + e.getMessage()); + } + } + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_refresh_failure_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_refresh_failure_mv\" ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertUpdate("DROP MATERIALIZED VIEW test_refresh_failure_mv"); + } + + @Test + public void testBaseTableDroppedAndRecreated() + { + assertUpdate("CREATE TABLE test_recreate_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_recreate_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_recreate_mv AS SELECT id, value FROM test_recreate_base"); + assertUpdate("REFRESH MATERIALIZED VIEW test_recreate_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_recreate_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_recreate_mv\" ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertUpdate("DROP TABLE test_recreate_base"); + + assertUpdate("CREATE TABLE test_recreate_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_recreate_base VALUES (3, 300), (4, 400), (5, 500)", 3); + + assertQuery("SELECT COUNT(*) FROM test_recreate_mv", "SELECT 3"); + assertQuery("SELECT * FROM test_recreate_mv ORDER BY id", + "VALUES (3, 300), (4, 400), (5, 500)"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_recreate_mv\"", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_recreate_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_recreate_mv\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_recreate_mv\" ORDER BY id", + "VALUES (3, 300), (4, 400), (5, 500)"); + + assertUpdate("DROP MATERIALIZED VIEW test_recreate_mv"); + assertUpdate("DROP TABLE test_recreate_base"); + } + + @Test + public void testStorageTableDroppedDirectly() + { + assertUpdate("CREATE TABLE test_storage_drop_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_storage_drop_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_storage_drop_mv AS SELECT id, value FROM test_storage_drop_base"); + assertUpdate("REFRESH MATERIALIZED VIEW test_storage_drop_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_storage_drop_mv\"", "SELECT 2"); + + assertUpdate("DROP TABLE \"__mv_storage__test_storage_drop_mv\""); + + assertQueryFails("SELECT * FROM \"__mv_storage__test_storage_drop_mv\"", ".*does not exist.*"); + + assertQueryFails("SELECT * FROM test_storage_drop_mv", ".*does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW test_storage_drop_mv"); + assertUpdate("DROP TABLE test_storage_drop_base"); + } + + @Test + public void testMaterializedViewWithRenamedColumns() + { + assertUpdate("CREATE TABLE test_renamed_base (id BIGINT, original_name VARCHAR, original_value BIGINT)"); + assertUpdate("INSERT INTO test_renamed_base VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_renamed_mv AS " + + "SELECT id AS person_id, original_name AS full_name, original_value AS amount " + + "FROM test_renamed_base"); + + assertQuery("SELECT COUNT(*) FROM test_renamed_mv", "SELECT 3"); + assertQuery("SELECT * FROM test_renamed_mv ORDER BY person_id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_renamed_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_renamed_mv\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_renamed_mv\" ORDER BY person_id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertQuery("SELECT * FROM test_renamed_mv ORDER BY person_id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertQuery("SELECT person_id, full_name FROM test_renamed_mv WHERE amount > 150 ORDER BY person_id", + "VALUES (2, 'Bob'), (3, 'Charlie')"); + + assertUpdate("INSERT INTO test_renamed_base VALUES (4, 'Dave', 400)", 1); + + assertQuery("SELECT COUNT(*) FROM test_renamed_mv", "SELECT 4"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_renamed_mv", 4); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_renamed_mv\"", "SELECT 4"); + assertQuery("SELECT * FROM \"__mv_storage__test_renamed_mv\" ORDER BY person_id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300), (4, 'Dave', 400)"); + + assertUpdate("DROP MATERIALIZED VIEW test_renamed_mv"); + assertUpdate("DROP TABLE test_renamed_base"); + } + + @Test + public void testMaterializedViewWithComputedColumns() + { + assertUpdate("CREATE TABLE test_computed_base (id BIGINT, quantity BIGINT, unit_price BIGINT)"); + assertUpdate("INSERT INTO test_computed_base VALUES (1, 5, 100), (2, 10, 50), (3, 3, 200)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_computed_mv AS " + + "SELECT id, " + + "quantity, " + + "unit_price, " + + "quantity * unit_price AS total_price, " + + "quantity * 2 AS double_quantity, " + + "'Order_' || CAST(id AS VARCHAR) AS order_label " + + "FROM test_computed_base"); + + assertQuery("SELECT COUNT(*) FROM test_computed_mv", "SELECT 3"); + assertQuery("SELECT id, quantity, unit_price, total_price, double_quantity, order_label FROM test_computed_mv ORDER BY id", + "VALUES (1, 5, 100, 500, 10, 'Order_1'), " + + "(2, 10, 50, 500, 20, 'Order_2'), " + + "(3, 3, 200, 600, 6, 'Order_3')"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_computed_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_computed_mv\"", "SELECT 3"); + assertQuery("SELECT id, quantity, unit_price, total_price, double_quantity, order_label FROM \"__mv_storage__test_computed_mv\" ORDER BY id", + "VALUES (1, 5, 100, 500, 10, 'Order_1'), " + + "(2, 10, 50, 500, 20, 'Order_2'), " + + "(3, 3, 200, 600, 6, 'Order_3')"); + + assertQuery("SELECT * FROM test_computed_mv WHERE total_price > 550 ORDER BY id", + "VALUES (3, 3, 200, 600, 6, 'Order_3')"); + + assertQuery("SELECT id, order_label FROM test_computed_mv WHERE double_quantity >= 10 ORDER BY id", + "VALUES (1, 'Order_1'), (2, 'Order_2')"); + + assertUpdate("INSERT INTO test_computed_base VALUES (4, 8, 75)", 1); + + assertQuery("SELECT COUNT(*) FROM test_computed_mv", "SELECT 4"); + assertQuery("SELECT id, total_price, order_label FROM test_computed_mv WHERE id = 4", + "VALUES (4, 600, 'Order_4')"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_computed_mv", 4); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_computed_mv\"", "SELECT 4"); + assertQuery("SELECT id, quantity, unit_price, total_price, order_label FROM \"__mv_storage__test_computed_mv\" WHERE id = 4", + "VALUES (4, 8, 75, 600, 'Order_4')"); + + assertUpdate("DROP MATERIALIZED VIEW test_computed_mv"); + assertUpdate("DROP TABLE test_computed_base"); + } + + @Test + public void testMaterializedViewWithCustomTableProperties() + { + assertUpdate("CREATE TABLE test_custom_props_base (id BIGINT, name VARCHAR, region VARCHAR)"); + assertUpdate("INSERT INTO test_custom_props_base VALUES (1, 'Alice', 'US'), (2, 'Bob', 'EU'), (3, 'Charlie', 'APAC')", 3); + + assertUpdate("CREATE MATERIALIZED VIEW test_custom_props_mv " + + "WITH (" + + " partitioning = ARRAY['region'], " + + " sorted_by = ARRAY['id'], " + + " \"write.format.default\" = 'ORC'" + + ") AS " + + "SELECT id, name, region FROM test_custom_props_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_custom_props_mv", 3); + + assertQuery("SELECT COUNT(*) FROM test_custom_props_mv", "SELECT 3"); + assertQuery("SELECT name FROM test_custom_props_mv WHERE region = 'US'", "VALUES ('Alice')"); + assertQuery("SELECT name FROM test_custom_props_mv WHERE region = 'EU'", "VALUES ('Bob')"); + + String storageTableName = "__mv_storage__test_custom_props_mv"; + assertQuery("SELECT COUNT(*) FROM \"" + storageTableName + "\"", "SELECT 3"); + + assertQuery("SELECT COUNT(*) FROM \"" + storageTableName + "\" WHERE region = 'APAC'", "SELECT 1"); + + assertUpdate("INSERT INTO test_custom_props_base VALUES (4, 'David', 'US')", 1); + assertUpdate("REFRESH MATERIALIZED VIEW test_custom_props_mv", 4); + + assertQuery("SELECT COUNT(*) FROM test_custom_props_mv WHERE region = 'US'", "SELECT 2"); + assertQuery("SELECT name FROM test_custom_props_mv WHERE region = 'US' ORDER BY id", + "VALUES ('Alice'), ('David')"); + + assertUpdate("DROP MATERIALIZED VIEW test_custom_props_mv"); + assertUpdate("DROP TABLE test_custom_props_base"); + } + + @Test + public void testMaterializedViewWithNestedTypes() + { + assertUpdate("CREATE TABLE test_nested_base (" + + "id BIGINT, " + + "tags ARRAY(VARCHAR), " + + "properties MAP(VARCHAR, VARCHAR), " + + "address ROW(street VARCHAR, city VARCHAR, zipcode VARCHAR))"); + + assertUpdate("INSERT INTO test_nested_base VALUES " + + "(1, ARRAY['tag1', 'tag2'], MAP(ARRAY['key1', 'key2'], ARRAY['value1', 'value2']), ROW('123 Main St', 'NYC', '10001')), " + + "(2, ARRAY['tag3'], MAP(ARRAY['key3'], ARRAY['value3']), ROW('456 Oak Ave', 'LA', '90001'))", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_nested_mv AS " + + "SELECT id, tags, properties, address FROM test_nested_base"); + + assertQuery("SELECT COUNT(*) FROM test_nested_mv", "SELECT 2"); + assertQuery("SELECT id, cardinality(tags) FROM test_nested_mv ORDER BY id", + "VALUES (1, 2), (2, 1)"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_nested_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_nested_mv\"", "SELECT 2"); + + assertQuery("SELECT id, cardinality(tags), address.city FROM test_nested_mv ORDER BY id", + "VALUES (1, 2, 'NYC'), (2, 1, 'LA')"); + + assertQuery("SELECT id FROM test_nested_mv WHERE element_at(properties, 'key1') = 'value1'", + "VALUES (1)"); + + assertUpdate("INSERT INTO test_nested_base VALUES " + + "(3, ARRAY['tag4', 'tag5', 'tag6'], MAP(ARRAY['key4'], ARRAY['value4']), ROW('789 Elm St', 'Chicago', '60601'))", 1); + + assertQuery("SELECT COUNT(*) FROM test_nested_mv", "SELECT 3"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_nested_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_nested_mv\"", "SELECT 3"); + assertQuery("SELECT id, address.zipcode FROM test_nested_mv WHERE id = 3", + "VALUES (3, '60601')"); + + assertUpdate("DROP MATERIALIZED VIEW test_nested_mv"); + assertUpdate("DROP TABLE test_nested_base"); + } + + @Test + public void testMaterializedViewAfterColumnAdded() + { + assertUpdate("CREATE TABLE test_evolve_add_base (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO test_evolve_add_base VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_evolve_add_mv AS " + + "SELECT id, name, value FROM test_evolve_add_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_add_mv", 2); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_evolve_add_mv\"", "SELECT 2"); + assertQuery("SELECT * FROM test_evolve_add_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertUpdate("ALTER TABLE test_evolve_add_base ADD COLUMN region VARCHAR"); + + assertUpdate("INSERT INTO test_evolve_add_base VALUES (3, 'Charlie', 300, 'US')", 1); + + assertQuery("SELECT COUNT(*) FROM test_evolve_add_mv", "SELECT 3"); + assertQuery("SELECT * FROM test_evolve_add_mv ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_evolve_add_mv\"", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_add_mv", 3); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_evolve_add_mv\"", "SELECT 3"); + assertQuery("SELECT * FROM \"__mv_storage__test_evolve_add_mv\" ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("CREATE MATERIALIZED VIEW test_evolve_add_mv2 AS " + + "SELECT id, name, value, region FROM test_evolve_add_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_add_mv2", 3); + + assertQuery("SELECT * FROM test_evolve_add_mv2 WHERE id = 3", + "VALUES (3, 'Charlie', 300, 'US')"); + assertQuery("SELECT id, region FROM test_evolve_add_mv2 WHERE id IN (1, 2) ORDER BY id", + "VALUES (1, NULL), (2, NULL)"); + + assertUpdate("DROP MATERIALIZED VIEW test_evolve_add_mv"); + assertUpdate("DROP MATERIALIZED VIEW test_evolve_add_mv2"); + assertUpdate("DROP TABLE test_evolve_add_base"); + } + + @Test + public void testMaterializedViewAfterColumnDropped() + { + assertUpdate("CREATE TABLE test_evolve_drop_base (id BIGINT, name VARCHAR, value BIGINT, status VARCHAR)"); + assertUpdate("INSERT INTO test_evolve_drop_base VALUES (1, 'Alice', 100, 'active'), (2, 'Bob', 200, 'inactive')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_evolve_drop_mv_all AS " + + "SELECT id, name, value, status FROM test_evolve_drop_base"); + + assertUpdate("CREATE MATERIALIZED VIEW test_evolve_drop_mv_subset AS " + + "SELECT id, name, value FROM test_evolve_drop_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_drop_mv_all", 2); + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_drop_mv_subset", 2); + + assertQuery("SELECT * FROM test_evolve_drop_mv_all ORDER BY id", + "VALUES (1, 'Alice', 100, 'active'), (2, 'Bob', 200, 'inactive')"); + assertQuery("SELECT * FROM test_evolve_drop_mv_subset ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200)"); + + assertUpdate("ALTER TABLE test_evolve_drop_base DROP COLUMN status"); + + assertUpdate("INSERT INTO test_evolve_drop_base VALUES (3, 'Charlie', 300)", 1); + + assertQuery("SELECT COUNT(*) FROM test_evolve_drop_mv_subset", "SELECT 3"); + assertQuery("SELECT * FROM test_evolve_drop_mv_subset ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertQueryFails("SELECT * FROM test_evolve_drop_mv_all", + ".*Column 'status' cannot be resolved.*"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_evolve_drop_mv_all\"", "SELECT 2"); + assertQuery("SELECT * FROM \"__mv_storage__test_evolve_drop_mv_all\" ORDER BY id", + "VALUES (1, 'Alice', 100, 'active'), (2, 'Bob', 200, 'inactive')"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_evolve_drop_mv_subset", 3); + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_evolve_drop_mv_subset\"", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_evolve_drop_mv_all"); + assertUpdate("DROP MATERIALIZED VIEW test_evolve_drop_mv_subset"); + assertUpdate("DROP TABLE test_evolve_drop_base"); + } + + @Test + public void testDropNonExistentMaterializedView() + { + assertQueryFails("DROP MATERIALIZED VIEW non_existent_mv", + ".*does not exist.*"); + } + + @Test + public void testCreateMaterializedViewWithSameNameAsExistingTable() + { + assertUpdate("CREATE TABLE existing_table_name (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO existing_table_name VALUES (1, 'test')", 1); + + assertQueryFails("CREATE MATERIALIZED VIEW existing_table_name AS SELECT id, value FROM existing_table_name", + ".*already exists.*"); + + assertQuery("SELECT COUNT(*) FROM existing_table_name", "SELECT 1"); + assertQuery("SELECT * FROM existing_table_name", "VALUES (1, 'test')"); + + assertUpdate("CREATE TABLE test_mv_base (id BIGINT, name VARCHAR)"); + assertUpdate("INSERT INTO test_mv_base VALUES (2, 'foo')", 1); + + assertQueryFails("CREATE MATERIALIZED VIEW existing_table_name AS SELECT id, name FROM test_mv_base", + ".*already exists.*"); + + assertUpdate("DROP TABLE existing_table_name"); + assertUpdate("DROP TABLE test_mv_base"); + } + + @Test + public void testInformationSchemaMaterializedViews() + { + assertUpdate("CREATE TABLE test_is_mv_base1 (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("CREATE TABLE test_is_mv_base2 (category VARCHAR, amount BIGINT)"); + + assertUpdate("INSERT INTO test_is_mv_base1 VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + assertUpdate("INSERT INTO test_is_mv_base2 VALUES ('A', 50), ('B', 75)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_is_mv1 AS SELECT id, name, value FROM test_is_mv_base1 WHERE id > 0"); + assertUpdate("CREATE MATERIALIZED VIEW test_is_mv2 AS SELECT category, SUM(amount) as total FROM test_is_mv_base2 GROUP BY category"); + + assertQuery( + "SELECT table_name FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name IN ('test_is_mv1', 'test_is_mv2') " + + "ORDER BY table_name", + "VALUES ('test_is_mv1'), ('test_is_mv2')"); + + assertQuery( + "SELECT table_catalog, table_schema, table_name, storage_schema, storage_table_name, base_tables " + + "FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv1'", + "SELECT 'iceberg', 'test_schema', 'test_is_mv1', 'test_schema', '__mv_storage__test_is_mv1', 'iceberg.test_schema.test_is_mv_base1'"); + + assertQuery( + "SELECT COUNT(*) FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv1' " + + "AND view_definition IS NOT NULL AND length(view_definition) > 0", + "SELECT 1"); + + assertQuery( + "SELECT table_name FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv2'", + "VALUES ('test_is_mv2')"); + + assertQuery( + "SELECT COUNT(*) FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv1' " + + "AND view_owner IS NOT NULL", + "SELECT 1"); + + assertQuery( + "SELECT COUNT(*) FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv1' " + + "AND view_security IS NOT NULL", + "SELECT 1"); + + assertQuery( + "SELECT base_tables FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv2'", + "VALUES ('iceberg.test_schema.test_is_mv_base2')"); + + assertUpdate("DROP MATERIALIZED VIEW test_is_mv1"); + assertUpdate("DROP MATERIALIZED VIEW test_is_mv2"); + assertUpdate("DROP TABLE test_is_mv_base1"); + assertUpdate("DROP TABLE test_is_mv_base2"); + + assertQuery( + "SELECT COUNT(*) FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name IN ('test_is_mv1', 'test_is_mv2')", + "VALUES 0"); + } + + @Test + public void testInformationSchemaTablesWithMaterializedViews() + { + assertUpdate("CREATE TABLE test_is_tables_base (id BIGINT, name VARCHAR)"); + assertUpdate("CREATE VIEW test_is_tables_view AS SELECT id, name FROM test_is_tables_base"); + assertUpdate("CREATE MATERIALIZED VIEW test_is_tables_mv AS SELECT id, name FROM test_is_tables_base"); + + assertQuery( + "SELECT table_name, table_type FROM information_schema.tables " + + "WHERE table_schema = 'test_schema' AND table_name IN ('test_is_tables_base', 'test_is_tables_view', 'test_is_tables_mv') " + + "ORDER BY table_name", + "VALUES ('test_is_tables_base', 'BASE TABLE'), ('test_is_tables_mv', 'MATERIALIZED VIEW'), ('test_is_tables_view', 'VIEW')"); + + assertQuery( + "SELECT table_name FROM information_schema.views " + + "WHERE table_schema = 'test_schema' AND table_name IN ('test_is_tables_view', 'test_is_tables_mv') " + + "ORDER BY table_name", + "VALUES ('test_is_tables_view')"); + + assertUpdate("DROP MATERIALIZED VIEW test_is_tables_mv"); + assertUpdate("DROP VIEW test_is_tables_view"); + assertUpdate("DROP TABLE test_is_tables_base"); + } + + @Test + public void testInformationSchemaMaterializedViewsAfterRefresh() + { + assertUpdate("CREATE TABLE test_is_mv_refresh_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_is_mv_refresh_base VALUES (1, 100), (2, 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW test_is_mv_refresh AS SELECT id, value FROM test_is_mv_refresh_base"); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'NOT_MATERIALIZED'"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_is_mv_refresh", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertUpdate("INSERT INTO test_is_mv_refresh_base VALUES (3, 300)", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertUpdate("UPDATE test_is_mv_refresh_base SET value = 250 WHERE id = 2", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertUpdate("DELETE FROM test_is_mv_refresh_base WHERE id = 1", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_is_mv_refresh", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertUpdate("DROP MATERIALIZED VIEW test_is_mv_refresh"); + assertUpdate("DROP TABLE test_is_mv_refresh_base"); + + assertQuery( + "SELECT COUNT(*) FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_is_mv_refresh'", + "VALUES 0"); + } + + @Test + public void testStaleReadBehaviorFail() + { + assertUpdate("CREATE TABLE test_stale_fail_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_stale_fail_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_stale_fail " + + "WITH (stale_read_behavior = 'FAIL', staleness_window = '0s') " + + "AS SELECT id, value FROM test_stale_fail_base"); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_fail'", + "SELECT 'NOT_MATERIALIZED'"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_stale_fail", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_fail'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_stale_fail", "SELECT 2"); + assertQuery("SELECT * FROM test_stale_fail ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO test_stale_fail_base VALUES (3, 300)", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_fail'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertQueryFails("SELECT * FROM test_stale_fail", + ".*Materialized view .* is stale.*"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_stale_fail", 3); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_fail'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_stale_fail", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_stale_fail"); + assertUpdate("DROP TABLE test_stale_fail_base"); + } + + @Test + public void testStaleReadBehaviorUseViewQuery() + { + assertUpdate("CREATE TABLE test_stale_use_query_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_stale_use_query_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_stale_use_query " + + "WITH (stale_read_behavior = 'USE_VIEW_QUERY', staleness_window = '0s') " + + "AS SELECT id, value FROM test_stale_use_query_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_stale_use_query", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_use_query'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_stale_use_query", "SELECT 2"); + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_stale_use_query\"", "SELECT 2"); + + assertUpdate("INSERT INTO test_stale_use_query_base VALUES (3, 300)", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_stale_use_query'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_stale_use_query", "SELECT 3"); + assertQuery("SELECT * FROM test_stale_use_query ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300)"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_stale_use_query\"", "SELECT 2"); + + assertUpdate("DROP MATERIALIZED VIEW test_stale_use_query"); + assertUpdate("DROP TABLE test_stale_use_query_base"); + } + + @Test + public void testMaterializedViewWithNoStaleReadBehavior() + { + assertUpdate("CREATE TABLE test_no_stale_config_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_no_stale_config_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_no_stale_config AS SELECT id, value FROM test_no_stale_config_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_no_stale_config", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_no_stale_config'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_no_stale_config", "SELECT 2"); + + assertUpdate("INSERT INTO test_no_stale_config_base VALUES (3, 300)", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_no_stale_config'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_no_stale_config", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW test_no_stale_config"); + assertUpdate("DROP TABLE test_no_stale_config_base"); + } + + @Test + public void testStalenessWindowAllowsStaleReads() + { + assertUpdate("CREATE TABLE test_staleness_window_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO test_staleness_window_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW test_staleness_window_mv " + + "WITH (stale_read_behavior = 'FAIL', staleness_window = '1h') " + + "AS SELECT id, value FROM test_staleness_window_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW test_staleness_window_mv", 2); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_staleness_window_mv'", + "SELECT 'FULLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_staleness_window_mv", "SELECT 2"); + assertQuery("SELECT * FROM test_staleness_window_mv ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO test_staleness_window_base VALUES (3, 300)", 1); + + assertQuery( + "SELECT freshness_state FROM information_schema.materialized_views " + + "WHERE table_schema = 'test_schema' AND table_name = 'test_staleness_window_mv'", + "SELECT 'PARTIALLY_MATERIALIZED'"); + + assertQuery("SELECT COUNT(*) FROM test_staleness_window_mv", "SELECT 2"); + + assertQuery("SELECT COUNT(*) FROM \"__mv_storage__test_staleness_window_mv\"", "SELECT 2"); + + assertUpdate("DROP MATERIALIZED VIEW test_staleness_window_mv"); + assertUpdate("DROP TABLE test_staleness_window_base"); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMetadataListing.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMetadataListing.java index 6f522c12f7aad..cc6504bab715f 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMetadataListing.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergMetadataListing.java @@ -119,7 +119,7 @@ public void testTableColumnListing() @Test public void testTableDescribing() { - assertQuery("DESCRIBE iceberg.test_schema.iceberg_table1", "VALUES ('_string', 'varchar', '', ''), ('_integer', 'integer', '', '')"); + assertQuery("DESCRIBE iceberg.test_schema.iceberg_table1", "VALUES ('_string', 'varchar', '', '', null, null, 2147483647L), ('_integer', 'integer', '', '', 10l, null, null)"); } /* @@ -154,28 +154,4 @@ public void testTableValidation() assertQuerySucceeds("SELECT * FROM iceberg.test_schema.iceberg_table1"); assertQueryFails("SELECT * FROM iceberg.test_schema.hive_table", "Not an Iceberg table: test_schema.hive_table"); } - - @Test - public void testRenameView() - { - assertQuerySucceeds("CREATE SCHEMA iceberg.test_rename_view_schema"); - assertQuerySucceeds("CREATE TABLE iceberg.test_rename_view_schema.iceberg_test_table (_string VARCHAR, _integer INTEGER)"); - assertUpdate("CREATE VIEW iceberg.test_rename_view_schema.test_view_to_be_renamed AS SELECT * FROM iceberg.test_rename_view_schema.iceberg_test_table"); - assertUpdate("ALTER VIEW IF EXISTS iceberg.test_rename_view_schema.test_view_to_be_renamed RENAME TO iceberg.test_rename_view_schema.test_view_renamed"); - assertUpdate("CREATE VIEW iceberg.test_rename_view_schema.test_view2_to_be_renamed AS SELECT * FROM iceberg.test_rename_view_schema.iceberg_test_table"); - assertUpdate("ALTER VIEW iceberg.test_rename_view_schema.test_view2_to_be_renamed RENAME TO iceberg.test_rename_view_schema.test_view2_renamed"); - assertQuerySucceeds("SELECT * FROM iceberg.test_rename_view_schema.test_view_renamed"); - assertQuerySucceeds("SELECT * FROM iceberg.test_rename_view_schema.test_view2_renamed"); - assertUpdate("DROP VIEW iceberg.test_rename_view_schema.test_view_renamed"); - assertUpdate("DROP VIEW iceberg.test_rename_view_schema.test_view2_renamed"); - assertUpdate("DROP TABLE iceberg.test_rename_view_schema.iceberg_test_table"); - assertQuerySucceeds("DROP SCHEMA IF EXISTS iceberg.test_rename_view_schema"); - } - - @Test - public void testRenameViewIfNotExists() - { - assertQueryFails("ALTER VIEW iceberg.test_schema.test_rename_view_not_exist RENAME TO iceberg.test_schema.test_renamed_view_not_exist", "line 1:1: View 'iceberg.test_schema.test_rename_view_not_exist' does not exist"); - assertQuerySucceeds("ALTER VIEW IF EXISTS iceberg.test_schema.test_rename_view_not_exist RENAME TO iceberg.test_schema.test_renamed_view_not_exist"); - } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergScalarFunctions.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergScalarFunctions.java new file mode 100644 index 0000000000000..c6253afcb9638 --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergScalarFunctions.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.iceberg.function.IcebergBucketFunction; +import com.facebook.presto.metadata.FunctionExtractor; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.type.DateOperators; +import com.facebook.presto.type.TimestampOperators; +import com.facebook.presto.type.TimestampWithTimeZoneOperators; +import org.testcontainers.shaded.com.google.common.collect.ImmutableList; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.Decimals.encodeScaledValue; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.Bucket.bucketLongDecimal; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.Bucket.bucketShortDecimal; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketDate; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketInteger; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketTimestamp; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketTimestampWithTimeZone; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketVarbinary; +import static com.facebook.presto.iceberg.function.IcebergBucketFunction.bucketVarchar; +import static io.airlift.slice.Slices.utf8Slice; + +public class TestIcebergScalarFunctions + extends AbstractTestFunctions +{ + public TestIcebergScalarFunctions() + { + super(TEST_SESSION, new FeaturesConfig(), new FunctionsConfig(), false); + } + + @BeforeClass + public void registerFunction() + { + ImmutableList.Builder> functions = ImmutableList.builder(); + functions.add(IcebergBucketFunction.class) + .add(IcebergBucketFunction.Bucket.class); + functionAssertions.addConnectorFunctions(FunctionExtractor.extractFunctions(functions.build(), + new CatalogSchemaName("iceberg", "system")), "iceberg"); + } + + @Test + public void testBucketFunction() + { + String catalogSchema = "iceberg.system"; + functionAssertions.assertFunction(catalogSchema + ".bucket(cast(10 as tinyint), 3)", BIGINT, bucketInteger(10, 3)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast(1950 as smallint), 4)", BIGINT, bucketInteger(1950, 4)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast(2375645 as int), 5)", BIGINT, bucketInteger(2375645, 5)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast(2779099983928392323 as bigint), 6)", BIGINT, bucketInteger(2779099983928392323L, 6)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast(456.43 as DECIMAL(5,2)), 12)", BIGINT, bucketShortDecimal(5, 2, 45643, 12)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast('12345678901234567890.1234567890' as DECIMAL(30,10)), 12)", BIGINT, bucketLongDecimal(30, 10, encodeScaledValue(new BigDecimal("12345678901234567890.1234567890")), 12)); + + functionAssertions.assertFunction(catalogSchema + ".bucket(cast('nasdbsdnsdms' as varchar), 7)", BIGINT, bucketVarchar(utf8Slice("nasdbsdnsdms"), 7)); + functionAssertions.assertFunction(catalogSchema + ".bucket(cast('nasdbsdnsdms' as varbinary), 8)", BIGINT, bucketVarbinary(utf8Slice("nasdbsdnsdms"), 8)); + + functionAssertions.assertFunction(catalogSchema + ".bucket(cast('2018-04-06' as date), 9)", BIGINT, bucketDate(DateOperators.castFromSlice(utf8Slice("2018-04-06")), 9)); + functionAssertions.assertFunction(catalogSchema + ".bucket(CAST('2018-04-06 04:35:00.000' AS TIMESTAMP),10)", BIGINT, bucketTimestamp(TimestampOperators.castFromSlice(TEST_SESSION.getSqlFunctionProperties(), utf8Slice("2018-04-06 04:35:00.000")), 10)); + functionAssertions.assertFunction(catalogSchema + ".bucket(CAST('2018-04-06 04:35:00.000 GMT' AS TIMESTAMP WITH TIME ZONE), 11)", BIGINT, bucketTimestampWithTimeZone(TimestampWithTimeZoneOperators.castFromSlice(TEST_SESSION.getSqlFunctionProperties(), utf8Slice("2018-04-06 04:35:00.000 GMT")), 11)); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergSystemTables.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergSystemTables.java index 2147d3894c3b8..ad73de00ae762 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergSystemTables.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergSystemTables.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.iceberg; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.spi.security.AllowAllAccessControl; @@ -23,7 +24,6 @@ import com.facebook.presto.tests.DistributedQueryRunner; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -112,11 +112,11 @@ public void testPartitionTable() { assertQuery("SELECT count(*) FROM test_schema.test_table", "VALUES 6"); assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$partitions\"", - "VALUES ('_date', 'date', '', '')," + - "('row_count', 'bigint', '', '')," + - "('file_count', 'bigint', '', '')," + - "('total_size', 'bigint', '', '')," + - "('_bigint', 'row(\"min\" bigint, \"max\" bigint, \"null_count\" bigint)', '', '')"); + "VALUES ('_date', 'date', '', '', null, null, null)," + + "('row_count', 'bigint', '', '', 19L, null, null)," + + "('file_count', 'bigint', '', '', 19L, null, null)," + + "('total_size', 'bigint', '', '', 19L, null, null)," + + "('_bigint', 'row(\"min\" bigint, \"max\" bigint, \"null_count\" bigint)', '', '', null, null, null)"); MaterializedResult result = computeActual("SELECT * from test_schema.\"test_table$partitions\""); assertEquals(result.getRowCount(), 3); @@ -145,10 +145,10 @@ public void testPartitionTable() public void testHistoryTable() { assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$history\"", - "VALUES ('made_current_at', 'timestamp with time zone', '', '')," + - "('snapshot_id', 'bigint', '', '')," + - "('parent_id', 'bigint', '', '')," + - "('is_current_ancestor', 'boolean', '', '')"); + "VALUES ('made_current_at', 'timestamp with time zone', '', '', null, null, null)," + + "('snapshot_id', 'bigint', '', '', 19l, null, null)," + + "('parent_id', 'bigint', '', '', 19l, null, null)," + + "('is_current_ancestor', 'boolean', '', '',null , null, null)"); // Test the number of history entries assertQuery("SELECT count(*) FROM test_schema.\"test_table$history\"", "VALUES 2"); @@ -158,12 +158,12 @@ public void testHistoryTable() public void testSnapshotsTable() { assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$snapshots\"", - "VALUES ('committed_at', 'timestamp with time zone', '', '')," + - "('snapshot_id', 'bigint', '', '')," + - "('parent_id', 'bigint', '', '')," + - "('operation', 'varchar', '', '')," + - "('manifest_list', 'varchar', '', '')," + - "('summary', 'map(varchar, varchar)', '', '')"); + "VALUES ('committed_at', 'timestamp with time zone', '', '', null, null, null)," + + "('snapshot_id', 'bigint', '', '', 19L, null, null)," + + "('parent_id', 'bigint', '', '', 19L, null, null)," + + "('operation', 'varchar', '', '', null, null, 2147483647L)," + + "('manifest_list', 'varchar', '', '', null, null, 2147483647L)," + + "('summary', 'map(varchar, varchar)', '', '', null, null, null)"); assertQuery("SELECT operation FROM test_schema.\"test_table$snapshots\"", "VALUES 'append', 'append'"); assertQuery("SELECT summary['total-records'] FROM test_schema.\"test_table$snapshots\"", "VALUES '3', '6'"); @@ -173,14 +173,14 @@ public void testSnapshotsTable() public void testManifestsTable() { assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$manifests\"", - "VALUES ('path', 'varchar', '', '')," + - "('length', 'bigint', '', '')," + - "('partition_spec_id', 'integer', '', '')," + - "('added_snapshot_id', 'bigint', '', '')," + - "('added_data_files_count', 'integer', '', '')," + - "('existing_data_files_count', 'integer', '', '')," + - "('deleted_data_files_count', 'integer', '', '')," + - "('partitions', 'array(row(\"contains_null\" boolean, \"lower_bound\" varchar, \"upper_bound\" varchar))', '', '')"); + "VALUES ('path', 'varchar', '', '', null, null, 2147483647L)," + + "('length', 'bigint', '', '', 19L, null, null)," + + "('partition_spec_id', 'integer', '', '', 10L, null, null)," + + "('added_snapshot_id', 'bigint', '', '', 19L, null, null)," + + "('added_data_files_count', 'integer', '', '', 10L, null, null)," + + "('existing_data_files_count', 'integer', '', '', 10L, null, null)," + + "('deleted_data_files_count', 'integer', '', '', 10L, null, null)," + + "('partitions', 'array(row(\"contains_null\" boolean, \"lower_bound\" varchar, \"upper_bound\" varchar))', '', '', null, null, null)"); assertQuerySucceeds("SELECT * FROM test_schema.\"test_table$manifests\""); assertQuerySucceeds("SELECT * FROM test_schema.\"test_table_multilevel_partitions$manifests\""); @@ -190,20 +190,20 @@ public void testManifestsTable() public void testFilesTable() { assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$files\"", - "VALUES ('content', 'integer', '', '')," + - "('file_path', 'varchar', '', '')," + - "('file_format', 'varchar', '', '')," + - "('record_count', 'bigint', '', '')," + - "('file_size_in_bytes', 'bigint', '', '')," + - "('column_sizes', 'map(integer, bigint)', '', '')," + - "('value_counts', 'map(integer, bigint)', '', '')," + - "('null_value_counts', 'map(integer, bigint)', '', '')," + - "('nan_value_counts', 'map(integer, bigint)', '', '')," + - "('lower_bounds', 'map(integer, varchar)', '', '')," + - "('upper_bounds', 'map(integer, varchar)', '', '')," + - "('key_metadata', 'varbinary', '', '')," + - "('split_offsets', 'array(bigint)', '', '')," + - "('equality_ids', 'array(integer)', '', '')"); + "VALUES ('content', 'integer', '', '', 10L, null, null)," + + "('file_path', 'varchar', '', '', null, null, 2147483647L)," + + "('file_format', 'varchar', '', '', null, null, 2147483647L)," + + "('record_count', 'bigint', '', '', 19L, null, null)," + + "('file_size_in_bytes', 'bigint', '', '', 19L, null, null)," + + "('column_sizes', 'map(integer, bigint)', '', '', null, null, null)," + + "('value_counts', 'map(integer, bigint)', '', '', null, null, null)," + + "('null_value_counts', 'map(integer, bigint)', '', '', null, null, null)," + + "('nan_value_counts', 'map(integer, bigint)', '', '', null, null, null)," + + "('lower_bounds', 'map(integer, varchar)', '', '', null, null, null)," + + "('upper_bounds', 'map(integer, varchar)', '', '', null, null, null)," + + "('key_metadata', 'varbinary', '', '', null, null, null)," + + "('split_offsets', 'array(bigint)', '', '', null, null, null)," + + "('equality_ids', 'array(integer)', '', '', null, null, null)"); assertQuerySucceeds("SELECT * FROM test_schema.\"test_table$files\""); } @@ -211,12 +211,12 @@ public void testFilesTable() public void testRefsTable() { assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$refs\"", - "VALUES ('name', 'varchar', '', '')," + - "('type', 'varchar', '', '')," + - "('snapshot_id', 'bigint', '', '')," + - "('max_reference_age_in_ms', 'bigint', '', '')," + - "('min_snapshots_to_keep', 'bigint', '', '')," + - "('max_snapshot_age_in_ms', 'bigint', '', '')"); + "VALUES ('name', 'varchar', '', '', null, null, 2147483647L)," + + "('type', 'varchar', '', '', null, null, 2147483647L)," + + "('snapshot_id', 'bigint', '', '', 19L, null, null)," + + "('max_reference_age_in_ms', 'bigint', '', '', 19L, null, null)," + + "('min_snapshots_to_keep', 'bigint', '', '', 19L, null, null)," + + "('max_snapshot_age_in_ms', 'bigint', '', '', 19L, null, null)"); assertQuerySucceeds("SELECT * FROM test_schema.\"test_table$refs\""); // Check main branch entry @@ -226,6 +226,20 @@ public void testRefsTable() assertQuerySucceeds("SELECT * FROM test_schema.\"test_table_multilevel_partitions$refs\""); } + @Test + public void testMetadataLogTable() + { + assertQuery("SHOW COLUMNS FROM test_schema.\"test_table$metadata_log_entries\"", + "VALUES ('timestamp', 'timestamp with time zone', '', '', null, null, null)," + + "('file', 'varchar', '', '', null, null, 2147483647)," + + "('latest_snapshot_id', 'bigint', '', '', 19, null, null)," + + "('latest_schema_id', 'integer', '', '', 10, null, null)," + + "('latest_sequence_number', 'bigint', '', '', 19, null, null)"); + assertQuerySucceeds("SELECT * FROM test_schema.\"test_table$metadata_log_entries\""); + + assertQuerySucceeds("SELECT * FROM test_schema.\"test_table_multilevel_partitions$metadata_log_entries\""); + } + @Test public void testSessionPropertiesInManuallyStartedTransaction() { @@ -235,7 +249,7 @@ public void testSessionPropertiesInManuallyStartedTransaction() MaterializedResult materializedRows = getQueryRunner().execute("select * from test_schema.\"test_session_properties_table$properties\""); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", "merge-on-read"))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", "merge-on-read", "true"))); // Simulate `set session iceberg.merge_on_read_enabled=false` to disable merge on read mode for iceberg tables in session level Session session = Session.builder(getQueryRunner().getDefaultSession()) @@ -268,8 +282,12 @@ protected void checkTableProperties(String tableName, String deleteMode) protected void checkTableProperties(String schemaName, String tableName, String deleteMode, int propertiesCount, Map additionalValidateProperties) { - assertQuery(String.format("SHOW COLUMNS FROM %s.\"%s$properties\"", schemaName, tableName), - "VALUES ('key', 'varchar', '', '')," + "('value', 'varchar', '', '')"); + assertQuery( + String.format("SHOW COLUMNS FROM %s.\"%s$properties\"", schemaName, tableName), + "VALUES " + + "('key', 'varchar', '', '', null, null, 2147483647)," + + "('value', 'varchar', '', '', null, null, 2147483647)," + + "('is_supported_by_presto', 'varchar', '', '', null, null, 2147483647)"); assertQuery(String.format("SELECT COUNT(*) FROM %s.\"%s$properties\"", schemaName, tableName), "VALUES " + propertiesCount); List materializedRows = computeActual(getSession(), String.format("SELECT * FROM %s.\"%s$properties\"", schemaName, tableName)).getMaterializedRows(); @@ -277,34 +295,35 @@ protected void checkTableProperties(String schemaName, String tableName, String assertThat(materializedRows).hasSize(propertiesCount); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.update.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.update.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "PARQUET"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "PARQUET", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "GZIP"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "ZSTD", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes())))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes()), "true"))); additionalValidateProperties.entrySet().stream() .forEach(entry -> assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, entry.getKey(), entry.getValue())))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, entry.getKey(), entry.getValue(), "true")))); } protected void checkORCFormatTableProperties(String tableName, String deleteMode) { assertQuery(String.format("SHOW COLUMNS FROM test_schema.\"%s$properties\"", tableName), - "VALUES ('key', 'varchar', '', '')," + "('value', 'varchar', '', '')"); + "VALUES ('key', 'varchar', '', '', null, null, 2147483647L)," + "('value', 'varchar', '', '', null, null, 2147483647L)," + + "('is_supported_by_presto', 'varchar', '', '', null, null, 2147483647L)"); assertQuery(String.format("SELECT COUNT(*) FROM test_schema.\"%s$properties\"", tableName), "VALUES 10"); List materializedRows = computeActual(getSession(), String.format("SELECT * FROM test_schema.\"%s$properties\"", tableName)).getMaterializedRows(); @@ -312,25 +331,25 @@ protected void checkORCFormatTableProperties(String tableName, String deleteMode assertThat(materializedRows).hasSize(10); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.update.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.update.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "ORC"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "ORC", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.orc.compression-codec", "ZLIB"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.orc.compression-codec", "ZSTD", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "zstd"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "zstd", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes())))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes()), "true"))); } @Test @@ -366,9 +385,9 @@ public void testMetadataVersionsMaintainingProperties() MaterializedResult materializedRows = getQueryRunner().execute("select * from test_schema.\"test_metadata_versions_maintain$properties\""); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "1"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "1", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "true"))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "true", "true"))); } @Test @@ -377,7 +396,7 @@ public void testMetricsMaxInferredColumnProperties() MaterializedResult materializedRows = getQueryRunner().execute("select * from test_schema.\"test_metrics_max_inferred_column$properties\""); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "16"))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "16", "true"))); } @AfterClass(alwaysRun = true) diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableChangelog.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableChangelog.java index 4c45b70080488..b6eb6a7c362ec 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableChangelog.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableChangelog.java @@ -65,10 +65,10 @@ public void testSchema() { assertQuery(String.format("SHOW COLUMNS FROM \"ctas_orders@%d$changelog\"", snapshots[0]), "VALUES" + - "('operation', 'varchar', '', '')," + - "('ordinal', 'bigint', '', '')," + - "('snapshotid', 'bigint', '', '')," + - "('rowdata', 'row(\"orderkey\" bigint, \"custkey\" bigint, \"orderstatus\" varchar, \"totalprice\" double, \"orderdate\" date, \"orderpriority\" varchar, \"clerk\" varchar, \"shippriority\" integer, \"comment\" varchar)', '', '')"); + "('operation', 'varchar', '', '', null ,null, 2147483647)," + + "('ordinal', 'bigint', '', '', 19, null, null)," + + "('snapshotid', 'bigint', '', '', 19, null, null)," + + "('rowdata', 'row(\"orderkey\" bigint, \"custkey\" bigint, \"orderstatus\" varchar, \"totalprice\" double, \"orderdate\" date, \"orderpriority\" varchar, \"clerk\" varchar, \"shippriority\" integer, \"comment\" varchar)', '', '', null, null, null)"); } @Test @@ -308,4 +308,28 @@ private long getSnapshot(int idx, String tableName) .mapToLong(Long.class::cast) .skip(idx).findFirst().getAsLong(); } + + @Test + public void testApplyChangelogFunctionInSystemNamespace() + { + assertQuery( + "SELECT iceberg.system.apply_changelog(1, 'INSERT', 'test_value') IS NOT NULL", + "SELECT true"); + } + + @Test + public void testApplyChangelogFunctionNotInGlobalNamespace() + { + assertQueryFails( + "SELECT apply_changelog(1, 'INSERT', 'test_value')", + "line 1:8: Function apply_changelog not registered"); + } + + @Test + public void testApplyChangelogFunctionNotInPrestoDefaultNamespace() + { + assertQueryFails( + "SELECT presto.default.apply_changelog(1, 'INSERT', 'test_value')", + "line 1:8: Function apply_changelog not registered"); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableVersion.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableVersion.java index c8f497db7272b..d0e1c8b5e1ffd 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableVersion.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergTableVersion.java @@ -349,15 +349,15 @@ public void testTableVersionErrors() assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF id", ".* cannot be resolved"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF (SELECT CURRENT_TIMESTAMP)", ".* Constant expression cannot contain a subquery"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF NULL", "Table version AS OF/BEFORE expression cannot be NULL for .*"); - assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF TIMESTAMP " + "'" + tab2Timestamp1 + "' - INTERVAL '1' MONTH", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab2\""); - assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CAST ('2023-01-01' AS TIMESTAMP WITH TIME ZONE)", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab2\""); - assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CAST ('2023-01-01' AS TIMESTAMP)", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab2\""); + assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF TIMESTAMP " + "'" + tab2Timestamp1 + "' - INTERVAL '1' MONTH", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab2"); + assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CAST ('2023-01-01' AS TIMESTAMP WITH TIME ZONE)", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab2"); + assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CAST ('2023-01-01' AS TIMESTAMP)", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab2"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CAST ('2023-01-01' AS DATE)", ".* Type date is invalid. Supported table version AS OF/BEFORE expression type is Timestamp or Timestamp with Time Zone."); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF CURRENT_DATE", ".* Type date is invalid. Supported table version AS OF/BEFORE expression type is Timestamp or Timestamp with Time Zone."); - assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF TIMESTAMP '2023-01-01 00:00:00.000'", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab2\""); + assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP AS OF TIMESTAMP '2023-01-01 00:00:00.000'", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab2"); - assertQueryFails("SELECT desc FROM " + tableName1 + " FOR VERSION BEFORE " + tab1VersionId1 + " ORDER BY 1", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab1\""); - assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP BEFORE TIMESTAMP " + "'" + tab2Timestamp1 + "' - INTERVAL '1' MONTH", "No history found based on timestamp for table \"test_tt_schema\".\"test_table_version_tab2\""); + assertQueryFails("SELECT desc FROM " + tableName1 + " FOR VERSION BEFORE " + tab1VersionId1 + " ORDER BY 1", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab1"); + assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP BEFORE TIMESTAMP " + "'" + tab2Timestamp1 + "' - INTERVAL '1' MONTH", "No history found based on timestamp for table iceberg.test_tt_schema.test_table_version_tab2"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR VERSION BEFORE 100", ".* Type integer is invalid. Supported table version AS OF/BEFORE expression type is BIGINT or VARCHAR"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR VERSION BEFORE " + tab2VersionId1 + " - " + tab2VersionId1, "Iceberg snapshot ID does not exists: 0"); assertQueryFails("SELECT desc FROM " + tableName2 + " FOR TIMESTAMP BEFORE 'bad'", ".* Type varchar\\(3\\) is invalid. Supported table version AS OF/BEFORE expression type is Timestamp or Timestamp with Time Zone."); diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergUtil.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergUtil.java index 47c06c962a9e4..132ca9661c20a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergUtil.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestIcebergUtil.java @@ -14,9 +14,17 @@ package com.facebook.presto.iceberg; import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.hive.HiveCompressionCodec; +import com.facebook.presto.hive.HiveStorageFormat; +import com.facebook.presto.hive.HiveType; +import com.facebook.presto.hive.metastore.Column; +import com.google.common.collect.ImmutableList; +import org.apache.iceberg.types.Types; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.math.BigDecimal; +import java.util.List; import java.util.Optional; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -377,4 +385,62 @@ public void testGetTargetSplitSize() assertEquals(1024, getTargetSplitSize(1024, 512).toBytes()); assertEquals(512, getTargetSplitSize(0, 512).toBytes()); } + + @DataProvider + public Object[][] compressionCodecMatrix() + { + return new Object[][] { + // format, codec, expectedSupport + {HiveStorageFormat.PARQUET, HiveCompressionCodec.NONE, true}, + {HiveStorageFormat.PARQUET, HiveCompressionCodec.SNAPPY, true}, + {HiveStorageFormat.PARQUET, HiveCompressionCodec.GZIP, true}, + {HiveStorageFormat.PARQUET, HiveCompressionCodec.LZ4, false}, + {HiveStorageFormat.PARQUET, HiveCompressionCodec.ZSTD, true}, + {HiveStorageFormat.ORC, HiveCompressionCodec.NONE, true}, + {HiveStorageFormat.ORC, HiveCompressionCodec.SNAPPY, true}, + {HiveStorageFormat.ORC, HiveCompressionCodec.GZIP, true}, + {HiveStorageFormat.ORC, HiveCompressionCodec.ZSTD, true}, + {HiveStorageFormat.ORC, HiveCompressionCodec.LZ4, true}, + }; + } + + @Test(dataProvider = "compressionCodecMatrix") + public void testCompressionCodecSupport(HiveStorageFormat format, HiveCompressionCodec codec, boolean expectedSupport) + { + assertThat(codec.isSupportedStorageFormat(format)) + .as("Codec %s support for %s format", codec, format) + .isEqualTo(expectedSupport); + } + + @Test + public void testParquetCompressionCodecAvailability() + { + assertThat(HiveCompressionCodec.NONE.getParquetCompressionCodec()).isNotNull(); + assertThat(HiveCompressionCodec.SNAPPY.getParquetCompressionCodec()).isNotNull(); + assertThat(HiveCompressionCodec.GZIP.getParquetCompressionCodec()).isNotNull(); + + assertThat(HiveCompressionCodec.LZ4.getParquetCompressionCodec()).isNotNull(); + assertThat(HiveCompressionCodec.ZSTD.getParquetCompressionCodec()).isNotNull(); + } + + @Test + public void testToHiveColumnsWithTimeType() + { + List icebergColumns = ImmutableList.of( + Types.NestedField.required(1, "id", Types.LongType.get()), + Types.NestedField.optional(2, "time", Types.TimeType.get()), + Types.NestedField.optional(3, "name", Types.StringType.get())); + + List hiveColumns = IcebergUtil.toHiveColumns(icebergColumns); + assertThat(hiveColumns).hasSize(3); + + assertThat(hiveColumns.get(0).getName()).isEqualTo("id"); + assertThat(hiveColumns.get(0).getType()).isEqualTo(HiveType.HIVE_LONG); + + assertThat(hiveColumns.get(1).getName()).isEqualTo("time"); + assertThat(hiveColumns.get(1).getType()).isEqualTo(HiveType.HIVE_LONG); + + assertThat(hiveColumns.get(2).getName()).isEqualTo("name"); + assertThat(hiveColumns.get(2).getType()).isEqualTo(HiveType.HIVE_STRING); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestOutputColumnTypes.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestOutputColumnTypes.java index 8f247835e7418..89551f5d388a7 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestOutputColumnTypes.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestOutputColumnTypes.java @@ -15,10 +15,12 @@ package com.facebook.presto.iceberg; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.SourceColumn; import com.facebook.presto.spi.Plugin; -import com.facebook.presto.spi.eventlistener.Column; import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.spi.eventlistener.EventListenerFactory; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.eventlistener.QueryCompletedEvent; import com.facebook.presto.spi.eventlistener.QueryCreatedEvent; import com.facebook.presto.spi.eventlistener.SplitCompletedEvent; @@ -26,6 +28,7 @@ import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.intellij.lang.annotations.Language; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -108,34 +111,214 @@ public void testOutputColumnsForInsertAsSelect() assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) .containsExactly( - new Column("clerk", "varchar"), - new Column("orderkey", "bigint"), - new Column("totalprice", "double")); + new OutputColumnMetadata("clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForInsertAsSelectAllWithAliasedRelation() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE create_insert_output1 AS SELECT clerk AS test_clerk, orderkey AS test_orderkey, totalprice AS test_totalprice FROM orders", 2); + runQueryAndWaitForEvents("INSERT INTO create_insert_output1(test_clerk,test_orderkey,test_totalprice) SELECT clerk AS test_clerk, orderkey AS test_orderkey, totalprice AS test_totalprice FROM (SELECT * from orders) orders_a", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_insert_output1"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("INSERT"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("test_clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("test_totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForInsertAsSelectColumnAliasInAliasedRelation() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE create_insert_output2 AS SELECT clerk AS test_clerk, orderkey AS test_orderkey, totalprice AS test_totalprice FROM orders", 2); + runQueryAndWaitForEvents("INSERT INTO create_insert_output2(test_clerk,test_orderkey,test_totalprice) SELECT aliased_clerk AS test_clerk, aliased_orderkey AS test_orderkey, aliased_totalprice AS test_totalprice FROM (SELECT clerk, orderkey, totalprice from orders) orders_a(aliased_clerk, aliased_orderkey, aliased_totalprice)", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_insert_output2"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("INSERT"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("test_clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("test_totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); } @Test public void testOutputColumnsForCreateTableAS() throws Exception { - runQueryAndWaitForEvents("CREATE TABLE create_update_table AS SELECT * FROM orders ", 2); + runQueryAndWaitForEvents("CREATE TABLE create_update_table2 AS SELECT * FROM orders ", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_update_table2"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("CREATE TABLE"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("custkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "custkey"))), + new OutputColumnMetadata("orderstatus", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderstatus"))), + new OutputColumnMetadata("totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice"))), + new OutputColumnMetadata("orderdate", "date", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderdate"))), + new OutputColumnMetadata("orderpriority", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderpriority"))), + new OutputColumnMetadata("clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("shippriority", "integer", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "shippriority"))), + new OutputColumnMetadata("comment", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "comment")))); + } + + @Test + public void testOutputColumnsForCreateTableAsSelectWithColumns() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE create_update_table3 AS SELECT clerk, orderkey, totalprice FROM orders", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_update_table3"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("CREATE TABLE"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForCreateTableAsSelectWithAlias() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE create_update_table4 AS SELECT clerk AS clerk_name, orderkey, totalprice AS annual_totalprice FROM orders", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_update_table4"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("CREATE TABLE"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("clerk_name", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("annual_totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForCreateTableAsSelectAllWithAliasedRelation() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE table_alias1 AS SELECT clerk AS test_clerk, orderkey AS test_orderkey, totalprice AS test_totalprice FROM (SELECT * from orders) orders_a", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); + assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("table_alias1"); + assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("CREATE TABLE"); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("test_clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("test_totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForCreateTableAsSelectColumnAliasInAliasedRelation() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE table_alias2 AS SELECT aliased_clerk AS test_clerk, aliased_orderkey AS test_orderkey, aliased_totalprice AS test_totalprice FROM (SELECT clerk,orderkey,totalprice from orders) orders_a(aliased_clerk,aliased_orderkey,aliased_totalprice)", 2); QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); assertThat(event.getIoMetadata().getOutput().get().getCatalogName()).isEqualTo("iceberg"); assertThat(event.getIoMetadata().getOutput().get().getSchema()).isEqualTo("tpch"); - assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("create_update_table"); + assertThat(event.getIoMetadata().getOutput().get().getTable()).isEqualTo("table_alias2"); assertThat(event.getMetadata().getUpdateQueryType().get()).isEqualTo("CREATE TABLE"); assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) .containsExactly( - new Column("orderkey", "bigint"), - new Column("custkey", "bigint"), - new Column("orderstatus", "varchar"), - new Column("totalprice", "double"), - new Column("orderdate", "date"), - new Column("orderpriority", "varchar"), - new Column("clerk", "varchar"), - new Column("shippriority", "integer"), - new Column("comment", "varchar")); + new OutputColumnMetadata("test_clerk", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"))), + new OutputColumnMetadata("test_totalprice", "double", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "totalprice")))); + } + + @Test + public void testOutputColumnsForSetOperationUnion() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE table_alias3 AS SELECT orderpriority AS test_orderpriority, orderkey AS test_orderkey FROM orders UNION SELECT clerk, custkey FROM orders", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("test_orderpriority", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderpriority"), + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"), + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "custkey")))); + } + + @Test + public void testOutputColumnsForSetOperationUnionAll() + throws Exception + { + runQueryAndWaitForEvents("CREATE TABLE table_alias4 AS SELECT orderpriority AS test_orderpriority, orderkey AS test_orderkey FROM orders UNION ALL SELECT clerk, custkey FROM orders", 2); + QueryCompletedEvent event = getOnlyElement(generatedEvents.getQueryCompletedEvents()); + + assertThat(event.getIoMetadata().getOutput().get().getColumns().get()) + .containsExactly( + new OutputColumnMetadata("test_orderpriority", "varchar", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderpriority"), + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "clerk"))), + new OutputColumnMetadata("test_orderkey", "bigint", ImmutableSet.of( + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "orderkey"), + new SourceColumn(new QualifiedObjectName("iceberg", "tpch", "orders"), "custkey")))); } static class TestingEventListenerPlugin diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionData.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionData.java new file mode 100644 index 0000000000000..4befbedfb625a --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionData.java @@ -0,0 +1,203 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import org.apache.iceberg.types.Types; +import org.testng.annotations.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestPartitionData +{ + private static final JsonNodeFactory JSON = JsonNodeFactory.instance; + + @Test + public void testGetValueWithNull() + { + JsonNode nullNode = JSON.nullNode(); + assertNull(PartitionData.getValue(nullNode, Types.IntegerType.get())); + assertNull(PartitionData.getValue(nullNode, Types.LongType.get())); + assertNull(PartitionData.getValue(nullNode, Types.StringType.get())); + assertNull(PartitionData.getValue(nullNode, Types.DecimalType.of(10, 2))); + } + + @Test + public void testGetValueWithDecimalFromLong() + { + Types.DecimalType decimalType = Types.DecimalType.of(10, 2); + JsonNode longNode = JSON.numberNode(12345L); + + BigDecimal result = (BigDecimal) PartitionData.getValue(longNode, decimalType); + assertEquals(result, new BigDecimal("123.45")); + assertEquals(result.scale(), 2); + assertEquals(result.unscaledValue(), BigInteger.valueOf(12345L)); + } + + @Test + public void testGetValueWithDecimalFromInt() + { + Types.DecimalType decimalType = Types.DecimalType.of(5, 2); + JsonNode intNode = JSON.numberNode(999); + + BigDecimal result = (BigDecimal) PartitionData.getValue(intNode, decimalType); + assertEquals(result, new BigDecimal("9.99")); + assertEquals(result.scale(), 2); + assertEquals(result.unscaledValue(), BigInteger.valueOf(999)); + } + + @Test + public void testGetValueWithDecimalFromBigInteger() + { + Types.DecimalType decimalType = Types.DecimalType.of(20, 3); + BigInteger bigInt = new BigInteger("123456789012345"); + JsonNode bigIntNode = JSON.numberNode(bigInt); + + BigDecimal result = (BigDecimal) PartitionData.getValue(bigIntNode, decimalType); + assertEquals(result.scale(), 3); + assertEquals(result.unscaledValue(), bigInt); + assertEquals(result, new BigDecimal(bigInt, 3)); + } + + @Test + public void testGetValueWithDecimalFromDecimal() + { + Types.DecimalType decimalType = Types.DecimalType.of(10, 4); + JsonNode decimalNode = JSON.numberNode(new BigDecimal("123.456")); + + BigDecimal result = (BigDecimal) PartitionData.getValue(decimalNode, decimalType); + assertEquals(result, new BigDecimal("123.4560")); + assertEquals(result.scale(), 4); + } + + @Test + public void testGetValueWithDecimalZeroScale() + { + Types.DecimalType decimalType = Types.DecimalType.of(10, 0); + JsonNode longNode = JSON.numberNode(12345L); + + BigDecimal result = (BigDecimal) PartitionData.getValue(longNode, decimalType); + + assertEquals(result, new BigDecimal("12345")); + assertEquals(result.scale(), 0); + } + + @Test + public void testGetValueWithDecimalLargeScale() + { + Types.DecimalType decimalType = Types.DecimalType.of(15, 10); + JsonNode intNode = JSON.numberNode(123); + + BigDecimal result = (BigDecimal) PartitionData.getValue(intNode, decimalType); + assertEquals(result.scale(), 10); + assertEquals(result.unscaledValue(), BigInteger.valueOf(123)); + } + + @Test + public void testGetValueWithDecimalNegativeValue() + { + Types.DecimalType decimalType = Types.DecimalType.of(10, 2); + JsonNode longNode = JSON.numberNode(-12345L); + + BigDecimal result = (BigDecimal) PartitionData.getValue(longNode, decimalType); + + assertEquals(result, new BigDecimal("-123.45")); + assertEquals(result.scale(), 2); + } + + @Test + public void testGetValueWithDecimalVeryLargeNumber() + { + Types.DecimalType decimalType = Types.DecimalType.of(38, 5); + BigInteger veryLarge = new BigInteger("12345678901234567890123456789012"); + JsonNode bigIntNode = JSON.numberNode(veryLarge); + + BigDecimal result = (BigDecimal) PartitionData.getValue(bigIntNode, decimalType); + + assertEquals(result.scale(), 5); + assertEquals(result.unscaledValue(), veryLarge); + } + + @Test + public void testJsonRoundTripWithDecimals() + { + // This tests all new code paths: isLong(), isInt(), isBigInteger(), and fallback + org.apache.iceberg.types.Type[] types = new org.apache.iceberg.types.Type[] { + Types.DecimalType.of(15, 2), // Will deserialize from long (12345L in JSON) + Types.DecimalType.of(10, 3), // Will deserialize from long (9876543210L in JSON) + Types.DecimalType.of(5, 2), // Will deserialize from int (999 in JSON) + Types.DecimalType.of(20, 5), // Will deserialize from decimal + Types.DecimalType.of(38, 10) // Will deserialize from BigInteger + }; + + Object[] values = new Object[] { + 12345L, + 9876543210L, + 999, + new BigDecimal("123456.78901"), + new BigDecimal("1234567890123456789012345678.0123456789") + }; + + PartitionData original = new PartitionData(values); + String json = original.toJson(); + + PartitionData deserialized = PartitionData.fromJson(json, types); + assertEquals(deserialized.get(0, BigDecimal.class), new BigDecimal("123.45")); + assertEquals(deserialized.get(0, BigDecimal.class).scale(), 2); + assertEquals(deserialized.get(1, BigDecimal.class), new BigDecimal("9876543.210")); + assertEquals(deserialized.get(1, BigDecimal.class).scale(), 3); + assertEquals(deserialized.get(2, BigDecimal.class), new BigDecimal("9.99")); + assertEquals(deserialized.get(2, BigDecimal.class).scale(), 2); + assertEquals(deserialized.get(3, BigDecimal.class).compareTo(new BigDecimal("123456.78901")), 0); + assertEquals(deserialized.get(3, BigDecimal.class).scale(), 5); + assertEquals(deserialized.get(4, BigDecimal.class).compareTo(new BigDecimal("1234567890123456789012345678.0123456789")), 0); + assertEquals(deserialized.get(4, BigDecimal.class).scale(), 10); + } + + @Test + public void testJsonRoundTripWithMixedTypes() + { + org.apache.iceberg.types.Type[] types = new org.apache.iceberg.types.Type[] { + Types.IntegerType.get(), + Types.LongType.get(), + Types.DecimalType.of(10, 2), + Types.StringType.get(), + Types.DecimalType.of(5, 3) + }; + + Object[] values = new Object[] { + 42, + 9876543210L, + new BigDecimal("999.99"), + "test_partition", + new BigDecimal("12.345") + }; + + PartitionData original = new PartitionData(values); + String json = original.toJson(); + + PartitionData deserialized = PartitionData.fromJson(json, types); + + assertEquals(deserialized.get(0, Integer.class), Integer.valueOf(42)); + assertEquals(deserialized.get(1, Long.class), Long.valueOf(9876543210L)); + assertEquals(deserialized.get(2, BigDecimal.class).compareTo(new BigDecimal("999.99")), 0); + assertEquals(deserialized.get(3, String.class), "test_partition"); + assertEquals(deserialized.get(4, BigDecimal.class).compareTo(new BigDecimal("12.345")), 0); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionFields.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionFields.java index 1f3a859e2a342..29f9705d40408 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionFields.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionFields.java @@ -26,10 +26,8 @@ import java.util.function.Consumer; -import static com.facebook.presto.iceberg.PartitionFields.parsePartitionField; -import static com.facebook.presto.iceberg.PartitionFields.toPartitionFields; +import static com.facebook.presto.iceberg.PartitionFields.buildPartitionField; import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static com.google.common.collect.Iterables.getOnlyElement; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; public class TestPartitionFields @@ -60,7 +58,6 @@ private static void assertParse(String value, PartitionSpec expected) { assertEquals(expected.fields().size(), 1); assertEquals(parseField(value), expected); - assertEquals(getOnlyElement(toPartitionFields(expected)), value); } private static void assertInvalid(String value, String message) @@ -75,7 +72,7 @@ private static void assertInvalid(String value, String message) private static PartitionSpec parseField(String value) { - return partitionSpec(builder -> parsePartitionField(builder, value)); + return partitionSpec(builder -> buildPartitionField(builder, value)); } private static PartitionSpec partitionSpec(Consumer consumer) diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionSpecConverter.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionSpecConverter.java index 41b3556956820..938819231b4f4 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionSpecConverter.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestPartitionSpecConverter.java @@ -20,13 +20,13 @@ import java.util.ArrayList; import java.util.List; +import java.util.OptionalInt; import static com.facebook.presto.iceberg.PartitionSpecConverter.toIcebergPartitionSpec; import static com.facebook.presto.iceberg.PartitionSpecConverter.toPrestoPartitionSpec; import static com.facebook.presto.iceberg.TestSchemaConverter.prestoIcebergSchema; import static com.facebook.presto.iceberg.TestSchemaConverter.schema; import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; -import static java.lang.String.format; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; @@ -46,26 +46,6 @@ public static Object[][] testAllTransforms() }; } - @Test(dataProvider = "allTransforms") - public void testToPrestoPartitionSpec(String transform, String name) - { - // Create a test TypeManager - TypeManager typeManager = createTestFunctionAndTypeManager(); - - // Create a mock PartitionSpec - PartitionSpec partitionSpec = partitionSpec(transform, name); - - PrestoIcebergPartitionSpec expectedPrestoPartitionSpec = prestoIcebergPartitionSpec(transform, name, typeManager); - - // Convert Iceberg PartitionSpec to Presto Iceberg Partition Spec - PrestoIcebergPartitionSpec prestoIcebergPartitionSpec = toPrestoPartitionSpec(partitionSpec, typeManager); - - // Check that the result is not null - assertNotNull(prestoIcebergPartitionSpec); - - assertEquals(prestoIcebergPartitionSpec, expectedPrestoPartitionSpec); - } - @Test(dataProvider = "allTransforms") public void testToIcebergPartitionSpec(String transform, String name) { @@ -110,23 +90,35 @@ public void validateConversion(String transform, String name) private static PrestoIcebergPartitionSpec prestoIcebergPartitionSpec(String transform, String name, TypeManager typeManager) { - List fields = new ArrayList<>(); - + List fields = new ArrayList<>(); + IcebergPartitionField.Builder builder = IcebergPartitionField.builder(); + builder.setName(name); switch (transform) { case "identity": - fields.add(name); + builder.setTransform(PartitionTransformType.IDENTITY); break; case "year": + builder.setTransform(PartitionTransformType.YEAR); + break; case "month": + builder.setTransform(PartitionTransformType.MONTH); + break; case "day": - fields.add(format("%s(%s)", transform, name)); + builder.setTransform(PartitionTransformType.DAY); + break; + case "hour": + builder.setTransform(PartitionTransformType.HOUR); break; case "bucket": + builder.setTransform(PartitionTransformType.BUCKET) + .setParameter(OptionalInt.of(3)); + break; case "truncate": - fields.add(format("%s(%s, 3)", transform, name)); + builder.setTransform(PartitionTransformType.TRUNCATE) + .setParameter(OptionalInt.of(3)); break; } - + fields.add(builder.build()); return new PrestoIcebergPartitionSpec(0, prestoIcebergSchema(typeManager), fields); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java new file mode 100644 index 0000000000000..fb79f69618fca --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestRewriteDataFilesProcedure.java @@ -0,0 +1,508 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg; + +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.CatalogUtil; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.hadoop.HadoopCatalog; +import org.apache.iceberg.io.CloseableIterator; +import org.testng.annotations.Test; + +import java.io.File; +import java.nio.file.Path; +import java.util.Map; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.facebook.presto.iceberg.CatalogType.HADOOP; +import static com.facebook.presto.iceberg.FileFormat.PARQUET; +import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static com.facebook.presto.iceberg.IcebergQueryRunner.getIcebergDataDirectoryPath; +import static java.lang.String.format; +import static org.apache.iceberg.SnapshotSummary.TOTAL_DATA_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.TOTAL_DELETE_FILES_PROP; +import static org.apache.iceberg.expressions.Expressions.alwaysTrue; +import static org.testng.Assert.assertEquals; + +public class TestRewriteDataFilesProcedure + extends AbstractTestQueryFramework +{ + public static final String TEST_SCHEMA = "tpch"; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(HADOOP) + .setFormat(PARQUET) + .setNodeCount(OptionalInt.of(1)) + .setCreateTpchTables(false) + .setAddJmxPlugin(false) + .build().getQueryRunner(); + } + + public void dropTable(String tableName) + { + assertQuerySucceeds("DROP TABLE IF EXISTS " + tableName); + } + + @Test + public void testRewriteDataFilesInEmptyTable() + { + String tableName = "default_empty_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (id integer, value integer)"); + assertUpdate(format("CALL system.rewrite_data_files('%s', '%s')", TEST_SCHEMA, tableName), 0); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesOnPartitionTable() + { + String tableName = "example_partition_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c2'])"); + + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 10,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 10, 0); + + assertUpdate("DELETE from " + tableName + " WHERE c1 = 7", 1); + assertUpdate("DELETE from " + tableName + " WHERE c1 in (8, 10)", 2); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 10,and the number of delete files is 3 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 3); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(9, 'foo')"); + + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 7); + + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 2,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 2); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .filter(alwaysTrue()) + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 2, 0); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(9, 'foo')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesOnNonPartitionTable() + { + String tableName = "example_non_partition_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + assertUpdate("DELETE from " + tableName + " WHERE c1 = 7", 1); + assertUpdate("DELETE from " + tableName + " WHERE c1 in (9, 10)", 2); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 5,and the number of delete files is 2 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 2); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(8, 'bar')"); + + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 7); + + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .filter(alwaysTrue()) + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 1, 0); + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(8, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithFilter() + { + String tableName = "example_partition_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar) with (partitioning = ARRAY['c2'])"); + + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 10,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 10); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 10, 0); + + // do not support rewrite files filtered by non-identity columns + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, TEST_SCHEMA), ".*"); + + // select 5 files to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c2 = ''bar''')", tableName, TEST_SCHEMA), 5); + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 6,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 6); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 6, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeterministicTrueFilter() + { + String tableName = "example_non_partition_true_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + // do not support rewrite files filtered by non-identity columns + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c1 > 3')", tableName, TEST_SCHEMA), ".*"); + + // the filter is `true` means select all files to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1 = 1')", tableName, TEST_SCHEMA), 10); + + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 1, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeterministicFalseFilter() + { + String tableName = "example_non_partition_false_filter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (c1 integer, c2 varchar)"); + + // create 5 files + assertUpdate("INSERT INTO " + tableName + " values(1, 'foo'), (2, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(3, 'foo'), (4, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(5, 'foo'), (6, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(7, 'foo'), (8, 'bar')", 2); + assertUpdate("INSERT INTO " + tableName + " values(9, 'foo'), (10, 'bar')", 2); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 5); + //The number of data files is 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + // the filter is `false` means select no file to rewrite + assertUpdate(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1 = 0')", tableName, TEST_SCHEMA), 0); + + table.refresh(); + assertHasSize(table.snapshots(), 5); + //The number of data files is still 5,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 5); + assertHasDeleteFiles(table.currentSnapshot(), 0); + fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 5, 0); + + assertQuery("select * from " + tableName, + "values(1, 'foo'), (2, 'bar'), " + + "(3, 'foo'), (4, 'bar'), " + + "(5, 'foo'), (6, 'bar'), " + + "(7, 'foo'), (8, 'bar'), " + + "(9, 'foo'), (10, 'bar')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteDataFilesWithDeleteAndPartitionEvolution() + { + String tableName = "example_partition_evolution_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (a int, b varchar)"); + assertUpdate("INSERT INTO " + tableName + " values(1, '1001'), (2, '1002')", 2); + assertUpdate("DELETE FROM " + tableName + " WHERE a = 1", 1); + assertQuery("select * from " + tableName, "values(2, '1002')"); + + Table table = loadTable(tableName); + assertHasSize(table.snapshots(), 2); + //The number of data files is 1,and the number of delete files is 1 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 1); + + assertUpdate("alter table " + tableName + " add column c int with (partitioning = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " values(5, '1005', 5), (6, '1006', 6), (7, '1007', 7)", 3); + assertUpdate("DELETE FROM " + tableName + " WHERE b = '1006'", 1); + assertQuery("select * from " + tableName, "values(2, '1002', NULL), (5, '1005', 5), (7, '1007', 7)"); + + table.refresh(); + assertHasSize(table.snapshots(), 4); + //The number of data files is 4,and the number of delete files is 2 + assertHasDataFiles(table.currentSnapshot(), 4); + assertHasDeleteFiles(table.currentSnapshot(), 2); + + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'a > 3')", tableName, TEST_SCHEMA), ".*"); + assertQueryFails(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c > 3')", tableName, TEST_SCHEMA), ".*"); + + assertUpdate(format("call system.rewrite_data_files(table_name => '%s', schema => '%s')", tableName, TEST_SCHEMA), 3); + table.refresh(); + assertHasSize(table.snapshots(), 5); + //The number of data files is 3,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 3); + assertHasDeleteFiles(table.currentSnapshot(), 0); + CloseableIterator fileScanTasks = table.newScan() + .useSnapshot(table.currentSnapshot().snapshotId()) + .planFiles().iterator(); + assertFilesPlan(fileScanTasks, 3, 0); + assertQuery("select * from " + tableName, "values(2, '1002', NULL), (5, '1005', 5), (7, '1007', 7)"); + + assertUpdate("delete from " + tableName + " where b = '1002'", 1); + table.refresh(); + assertHasSize(table.snapshots(), 6); + //The number of data files is 3,and the number of delete files is 1 + assertHasDataFiles(table.currentSnapshot(), 3); + assertHasDeleteFiles(table.currentSnapshot(), 1); + assertUpdate(format("call system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'c is null')", tableName, TEST_SCHEMA), 0); + + table.refresh(); + assertHasSize(table.snapshots(), 7); + //The number of data files is 2,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 2); + assertHasDeleteFiles(table.currentSnapshot(), 0); + assertQuery("select * from " + tableName, "values(5, '1005', 5), (7, '1007', 7)"); + + // This is a metadata delete + assertUpdate("delete from " + tableName + " where c = 7", 1); + table.refresh(); + assertHasSize(table.snapshots(), 8); + //The number of data files is 1,and the number of delete files is 0 + assertHasDataFiles(table.currentSnapshot(), 1); + assertHasDeleteFiles(table.currentSnapshot(), 0); + assertQuery("select * from " + tableName, "values(5, '1005', 5)"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testInvalidParameterCases() + { + String tableName = "invalid_parameter_table"; + try { + assertUpdate("CREATE TABLE " + tableName + " (a int, b varchar, c int)"); + assertQueryFails("CALL system.rewrite_data_files('n', table_name => 't')", ".*Named and positional arguments cannot be mixed"); + assertQueryFails("CALL custom.rewrite_data_files('n', 't')", "Procedure not registered: custom.rewrite_data_files"); + assertQueryFails("CALL system.rewrite_data_files()", ".*Required procedure argument 'schema' is missing"); + assertQueryFails("CALL system.rewrite_data_files('s', 'n')", "Schema s does not exist"); + assertQueryFails("CALL system.rewrite_data_files('', '')", "Table name is empty"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '''hello''')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type varchar\\(5\\)"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => '1001')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type integer"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'a')", tableName, TEST_SCHEMA), ".*WHERE clause must evaluate to a boolean: actual type integer"); + assertQueryFails(format("CALL system.rewrite_data_files(table_name => '%s', schema => '%s', filter => 'n')", tableName, TEST_SCHEMA), ".*Column 'n' cannot be resolved"); + } + finally { + dropTable(tableName); + } + } + + private Table loadTable(String tableName) + { + Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); + return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); + } + + private Map getProperties() + { + File metastoreDir = getCatalogDirectory(); + return ImmutableMap.of("warehouse", metastoreDir.toString()); + } + + private File getCatalogDirectory() + { + Path dataDirectory = getDistributedQueryRunner().getCoordinator().getDataDirectory(); + Path catalogDirectory = getIcebergDataDirectoryPath(dataDirectory, HADOOP.name(), new IcebergConfig().getFileFormat(), false); + return catalogDirectory.toFile(); + } + + private void assertHasSize(Iterable iterable, int size) + { + AtomicInteger count = new AtomicInteger(0); + iterable.forEach(obj -> count.incrementAndGet()); + assertEquals(count.get(), size); + } + + private void assertHasDataFiles(Snapshot snapshot, int dataFilesCount) + { + Map map = snapshot.summary(); + int totalDataFiles = Integer.valueOf(map.get(TOTAL_DATA_FILES_PROP)); + assertEquals(totalDataFiles, dataFilesCount); + } + + private void assertHasDeleteFiles(Snapshot snapshot, int deleteFilesCount) + { + Map map = snapshot.summary(); + int totalDeleteFiles = Integer.valueOf(map.get(TOTAL_DELETE_FILES_PROP)); + assertEquals(totalDeleteFiles, deleteFilesCount); + } + + private void assertFilesPlan(CloseableIterator iterator, int dataFileCount, int deleteFileCount) + { + AtomicInteger dataCount = new AtomicInteger(0); + AtomicInteger deleteCount = new AtomicInteger(0); + while (iterator.hasNext()) { + FileScanTask fileScanTask = iterator.next(); + dataCount.incrementAndGet(); + deleteCount.addAndGet(fileScanTask.deletes().size()); + } + assertEquals(dataCount.get(), dataFileCount); + assertEquals(deleteCount.get(), deleteFileCount); + + try { + iterator.close(); + iterator = CloseableIterator.empty(); + } + catch (Exception e) { + // do nothing + } + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestStatisticsUtil.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestStatisticsUtil.java index 6d894ddb934c2..99d4e728e56ce 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestStatisticsUtil.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/TestStatisticsUtil.java @@ -287,7 +287,7 @@ public void testGenerateStatisticColumnSets() .setDataColumns(ImmutableList.of()) .setPredicateColumns(ImmutableMap.of()) .setRequestedColumns(Optional.empty()) - .setTable(new IcebergTableHandle("test", IcebergTableName.from("test"), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), ImmutableList.of())) + .setTable(new IcebergTableHandle("test", IcebergTableName.from("test"), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), ImmutableList.of(), Optional.empty())) .setDomainPredicate(TupleDomain.all()); // verify all selected columns are included List includedColumns = combineSelectedAndPredicateColumns( diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/container/IcebergMinIODataLake.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/container/IcebergMinIODataLake.java index ff626fb4e1ea9..e851c38e71003 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/container/IcebergMinIODataLake.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/container/IcebergMinIODataLake.java @@ -13,18 +13,22 @@ */ package com.facebook.presto.iceberg.container; -import com.amazonaws.auth.AWSStaticCredentialsProvider; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.client.builder.AwsClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; import com.facebook.presto.testing.containers.MinIOContainer; import com.facebook.presto.util.AutoCloseableCloser; import com.google.common.collect.ImmutableMap; import org.testcontainers.containers.Network; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import java.io.Closeable; import java.io.IOException; +import java.net.URI; import java.util.concurrent.atomic.AtomicBoolean; import static java.util.Objects.requireNonNull; @@ -39,15 +43,19 @@ public class IcebergMinIODataLake private final String bucketName; private final String warehouseDir; private final MinIOContainer minIOContainer; - private final AtomicBoolean isStarted = new AtomicBoolean(false); private final AutoCloseableCloser closer = AutoCloseableCloser.create(); public IcebergMinIODataLake(String bucketName, String warehouseDir) + { + this(bucketName, warehouseDir, newNetwork()); + } + + public IcebergMinIODataLake(String bucketName, String warehouseDir, Network network) { this.bucketName = requireNonNull(bucketName, "bucketName is null"); this.warehouseDir = requireNonNull(warehouseDir, "warehouseDir is null"); - Network network = closer.register(newNetwork()); + closer.register(network); this.minIOContainer = closer.register( MinIOContainer.builder() .withNetwork(network) @@ -63,19 +71,39 @@ public void start() if (isStarted()) { return; } + try { this.minIOContainer.start(); - AmazonS3 s3Client = AmazonS3ClientBuilder - .standard() - .withEndpointConfiguration(new AwsClientBuilder.EndpointConfiguration( - "http://localhost:" + minIOContainer.getMinioApiEndpoint().getPort(), - "us-east-1")) - .withPathStyleAccessEnabled(true) - .withCredentials(new AWSStaticCredentialsProvider( - new BasicAWSCredentials(ACCESS_KEY, SECRET_KEY))) + + S3Client s3Client = S3Client.builder() + .endpointOverride(URI.create("http://localhost:" + minIOContainer.getMinioApiEndpoint().getPort())) + .region(Region.US_EAST_1) + .forcePathStyle(true) + .serviceConfiguration(S3Configuration.builder() + // Disable checksum validation and chunked encoding for MinIO compatibility + // MinIO checksum handling differs from AWS S3 + // Prevents chunked transfer encoding issues with MinIO + .checksumValidationEnabled(false) + .chunkedEncodingEnabled(false) + .build()) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create(ACCESS_KEY, SECRET_KEY))) .build(); - s3Client.createBucket(this.bucketName); - s3Client.putObject(this.bucketName, this.warehouseDir, ""); + + s3Client.createBucket(CreateBucketRequest.builder() + .bucket(this.bucketName) + .build()); + String objectKey = this.warehouseDir.endsWith("/") + ? this.warehouseDir + ".keep" + : this.warehouseDir + "/.keep"; + + s3Client.putObject( + PutObjectRequest.builder() + .bucket(this.bucketName) + .key(objectKey) + .build(), + RequestBody.fromString("placeholder")); + closer.register(s3Client); } finally { isStarted.set(true); diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergDistributedOnS3Hadoop.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergDistributedOnS3Hadoop.java index 5e1eda7e5afd3..30abb698f929f 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergDistributedOnS3Hadoop.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergDistributedOnS3Hadoop.java @@ -37,6 +37,7 @@ import java.net.URI; import static com.facebook.presto.iceberg.CatalogType.HADOOP; +import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; import static com.facebook.presto.iceberg.container.IcebergMinIODataLake.ACCESS_KEY; import static com.facebook.presto.iceberg.container.IcebergMinIODataLake.SECRET_KEY; import static com.facebook.presto.testing.TestingConnectorSession.SESSION; @@ -130,7 +131,7 @@ protected HdfsEnvironment getHdfsEnvironment() protected Table loadTable(String tableName) { Configuration configuration = getHdfsEnvironment().getConfiguration(new HdfsContext(SESSION), getCatalogDataDirectory()); - Catalog catalog = CatalogUtil.loadCatalog(HADOOP.getCatalogImpl(), "test-hive", getProperties(), configuration); + Catalog catalog = CatalogUtil.loadCatalog(HADOOP.getCatalogImpl(), ICEBERG_CATALOG, getProperties(), configuration); return catalog.loadTable(TableIdentifier.of("tpch", tableName)); } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergSmokeOnS3Hadoop.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergSmokeOnS3Hadoop.java index 29633808a3113..0c3319e5ee4cd 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergSmokeOnS3Hadoop.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hadoop/TestIcebergSmokeOnS3Hadoop.java @@ -330,6 +330,112 @@ public void testTableComments() dropTable(session, "test_table_comments"); } + @Test + public void testAddColumnWithMultiplePartitionTransforms() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_multiple_partition_column(a int)"); + assertUpdate("insert into add_multiple_partition_column values 1", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'", + "write.data.path", "'" + getPathBasedOnDataDirectory(session.getSchema().get() + "/add_multiple_partition_column") + "'"))); + + // Add a varchar column with partition transforms `ARRAY['bucket(4)', 'truncate(2)', 'identity']` + assertQuerySucceeds("alter table add_multiple_partition_column add column b varchar with(partitioning = ARRAY['bucket(4)', 'truncate(2)', 'identity'])"); + assertUpdate("insert into add_multiple_partition_column values(2, '1002')", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of( + columnDefinition("a", "integer"), + columnDefinition("b", "varchar")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'", + "partitioning", "ARRAY['bucket(b, 4)','truncate(b, 2)','b']", + "write.data.path", "'" + getPathBasedOnDataDirectory(session.getSchema().get() + "/add_multiple_partition_column") + "'"))); + + // Add a date column with partition transforms `ARRAY['year', 'bucket(8)', 'identity']` + assertQuerySucceeds("alter table add_multiple_partition_column add column c date with(partitioning = ARRAY['year', 'bucket(8)', 'identity'])"); + assertUpdate("insert into add_multiple_partition_column values(3, '1003', date '1984-12-08')", 1); + + validateShowCreateTable(catalog, schema, "add_multiple_partition_column", + ImmutableList.of( + columnDefinition("a", "integer"), + columnDefinition("b", "varchar"), + columnDefinition("c", "date")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_multiple_partition_column") + "'", + "partitioning", "ARRAY['bucket(b, 4)','truncate(b, 2)','b','year(c)','bucket(c, 8)','c']", + "write.data.path", "'" + getPathBasedOnDataDirectory(session.getSchema().get() + "/add_multiple_partition_column") + "'"))); + + assertQuery("select * from add_multiple_partition_column", + "values(1, null, null), (2, '1002', null), (3, '1003', date '1984-12-08')"); + dropTable(getSession(), "add_multiple_partition_column"); + } + + @Test + public void testAddColumnWithRedundantOrDuplicatedPartitionTransforms() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_redundant_partition_column(a int)"); + + // Specify duplicated transforms would fail + assertQueryFails("alter table add_redundant_partition_column add column b varchar with(partitioning = ARRAY['bucket(4)', 'truncate(2)', 'bucket(4)'])", + "Cannot add duplicate partition field: .*"); + assertQueryFails("alter table add_redundant_partition_column add column b varchar with(partitioning = ARRAY['identity', 'identity'])", + "Cannot add duplicate partition field: .*"); + + // Specify redundant transforms would fail + assertQueryFails("alter table add_redundant_partition_column add column c date with(partitioning = ARRAY['year', 'month'])", + "Cannot add redundant partition field: .*"); + assertQueryFails("alter table add_redundant_partition_column add column c timestamp with(partitioning = ARRAY['day', 'hour'])", + "Cannot add redundant partition field: .*"); + + validateShowCreateTable(catalog, schema, "add_redundant_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_redundant_partition_column") + "'", + "write.data.path", "'" + getPathBasedOnDataDirectory(session.getSchema().get() + "/add_redundant_partition_column") + "'"))); + + dropTable(getSession(), "add_redundant_partition_column"); + } + + @Test + public void testAddColumnWithUnsupportedPropertyValueTypes() + { + Session session = getSession(); + String catalog = session.getCatalog().get(); + String schema = format("\"%s\"", session.getSchema().get()); + + assertQuerySucceeds("create table add_invalid_partition_column(a int)"); + + assertQueryFails("alter table add_invalid_partition_column add column b varchar with(partitioning = 123)", + "Invalid value for column property 'partitioning': Cannot convert '123' to array\\(varchar\\) or any of \\[varchar]"); + assertQueryFails("alter table add_invalid_partition_column add column b varchar with(partitioning = ARRAY[123, 234])", + "Invalid value for column property 'partitioning': Cannot convert 'ARRAY\\[123,234]' to array\\(varchar\\) or any of \\[varchar]"); + + validateShowCreateTable(catalog, schema, "add_invalid_partition_column", + ImmutableList.of(columnDefinition("a", "integer")), + null, + getCustomizedTableProperties(ImmutableMap.of( + "location", "'" + getLocation(session.getSchema().get(), "add_invalid_partition_column") + "'", + "write.data.path", "'" + getPathBasedOnDataDirectory(session.getSchema().get() + "/add_invalid_partition_column") + "'"))); + + dropTable(getSession(), "add_invalid_partition_column"); + } + @Override protected String getLocation(String schema, String table) { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestHiveTableOperationsConfig.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestHiveTableOperationsConfig.java index db1b2d8d885f2..42265f3fa5470 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestHiveTableOperationsConfig.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestHiveTableOperationsConfig.java @@ -22,7 +22,7 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.Duration.succinctDuration; +import static com.facebook.airlift.units.Duration.succinctDuration; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -37,7 +37,8 @@ public void testDefaults() .setTableRefreshBackoffMaxSleepTime(succinctDuration(5, SECONDS)) .setTableRefreshMaxRetryTime(succinctDuration(1, MINUTES)) .setTableRefreshBackoffScaleFactor(4.0) - .setTableRefreshRetries(20)); + .setTableRefreshRetries(20) + .setLockingEnabled(true)); } @Test @@ -49,6 +50,7 @@ public void testExplicitPropertyMappings() .put("iceberg.hive.table-refresh.max-retry-time", "30s") .put("iceberg.hive.table-refresh.retries", "42") .put("iceberg.hive.table-refresh.backoff-scale-factor", "2.0") + .put("iceberg.engine.hive.lock-enabled", "false") .build(); IcebergHiveTableOperationsConfig expected = new IcebergHiveTableOperationsConfig() @@ -56,7 +58,8 @@ public void testExplicitPropertyMappings() .setTableRefreshBackoffMaxSleepTime(succinctDuration(20, SECONDS)) .setTableRefreshMaxRetryTime(succinctDuration(30, SECONDS)) .setTableRefreshBackoffScaleFactor(2.0) - .setTableRefreshRetries(42); + .setTableRefreshRetries(42) + .setLockingEnabled(false); assertFullMapping(properties, expected); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java index 720621f1892e0..49a01ed705ff0 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergDistributedHive.java @@ -16,7 +16,12 @@ import com.facebook.presto.Session; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.hive.HdfsContext; +import com.facebook.presto.hive.HiveColumnConverterProvider; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; +import com.facebook.presto.hive.metastore.MetastoreContext; +import com.facebook.presto.iceberg.HiveTableOperations; +import com.facebook.presto.iceberg.IcebergCatalogName; import com.facebook.presto.iceberg.IcebergDistributedTestBase; import com.facebook.presto.iceberg.IcebergHiveMetadata; import com.facebook.presto.iceberg.IcebergHiveTableOperationsConfig; @@ -25,7 +30,9 @@ import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.CatalogMetadata; import com.facebook.presto.metadata.MetadataUtil; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.classloader.ClassLoaderSafeConnectorMetadata; @@ -33,20 +40,34 @@ import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheStats; import com.google.common.collect.ImmutableMap; +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Table; +import org.apache.iceberg.TableMetadata; +import org.apache.iceberg.Transaction; import org.testng.annotations.Test; import java.lang.reflect.Field; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore.memoizeMetastore; +import static com.facebook.presto.hive.metastore.MetastoreUtil.getMetastoreHeaders; +import static com.facebook.presto.hive.metastore.MetastoreUtil.isUserDefinedTypeEncodingEnabled; import static com.facebook.presto.iceberg.CatalogType.HIVE; +import static com.facebook.presto.iceberg.IcebergAbstractMetadata.toIcebergSchema; import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES; import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES; +import static com.google.common.io.Files.createTempDir; import static java.lang.String.format; +import static org.apache.iceberg.TableMetadata.newTableMetadata; +import static org.apache.iceberg.TableProperties.HIVE_LOCK_ENABLED; +import static org.apache.iceberg.Transactions.createTableTransaction; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -55,6 +76,15 @@ public class TestIcebergDistributedHive extends IcebergDistributedTestBase { + public TestIcebergDistributedHive(Map extraConnectorProperties) + { + super(HIVE, ImmutableMap.builder() + .put("iceberg.hive-statistics-merge-strategy", Joiner.on(",").join( + NUMBER_OF_DISTINCT_VALUES.name(), + TOTAL_SIZE_IN_BYTES.name())) + .putAll(extraConnectorProperties) + .build()); + } public TestIcebergDistributedHive() { super(HIVE, ImmutableMap.of("iceberg.hive-statistics-merge-strategy", Joiner.on(",").join(NUMBER_OF_DISTINCT_VALUES.name(), TOTAL_SIZE_IN_BYTES.name()))); @@ -86,6 +116,21 @@ public void testStatisticsFileCache() // so this test won't complete successfully. } + @Test + public void testCreateAlterTableWithHiveLocksDisabled() + { + assertQuerySucceeds("CREATE TABLE test_table(i int) WITH (\"engine.hive.lock-enabled\" = false)"); + assertEquals(getQueryRunner().execute("SELECT value FROM \"test_table$properties\" WHERE key = 'engine.hive.lock-enabled'").getOnlyValue(), + "false"); + assertQuerySucceeds("CREATE TABLE sample_table(i int)"); + assertEquals(getQueryRunner().execute("SELECT value FROM \"sample_table$properties\" WHERE key = 'engine.hive.lock-enabled'").getRowCount(), + 0); + assertUpdate("ALTER TABLE sample_table SET PROPERTIES(\"engine.hive.lock-enabled\" = false)"); + + assertEquals(getQueryRunner().execute("SELECT value FROM \"sample_table$properties\" WHERE key = 'engine.hive.lock-enabled'").getOnlyValue(), + "false"); + } + @Test public void testManifestFileCaching() throws Exception @@ -186,9 +231,26 @@ public void testManifestFileCachingDisabled() assertQuerySucceeds(session, "DROP SCHEMA default"); } + @Test + public void testCommitTableMetadataForNoLock() + { + createTable("iceberg-test-table", createTempDir().toURI().toString(), ImmutableMap.of("engine.hive.lock-enabled", "false"), 2); + BaseTable table = (BaseTable) loadTable("iceberg-test-table"); + assertEquals(table.properties().get(HIVE_LOCK_ENABLED), "false"); + HiveTableOperations operations = (HiveTableOperations) table.operations(); + TableMetadata currentMetadata = operations.current(); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.putAll(currentMetadata.properties()); + builder.put("test_property_new", "test_value_new"); + operations.commit(currentMetadata, TableMetadata.buildFrom(currentMetadata).setProperties(builder.build()).build()); + assertEquals(operations.current().properties(), builder.build()); + } + @Override protected Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); CatalogManager catalogManager = getDistributedQueryRunner().getCoordinator().getCatalogManager(); ConnectorId connectorId = catalogManager.getCatalog(ICEBERG_CATALOG).get().getConnectorId(); @@ -197,6 +259,7 @@ protected Table loadTable(String tableName) new IcebergHiveTableOperationsConfig(), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024 * 1024), getQueryRunner().getDefaultSession().toConnectorSession(connectorId), + new IcebergCatalogName(ICEBERG_CATALOG), SchemaTableName.valueOf("tpch." + tableName)); } @@ -207,4 +270,35 @@ protected ExtendedHiveMetastore getFileHiveMetastore() "test"); return memoizeMetastore(fileHiveMetastore, false, 1000, 0); } + + protected Table createTable(String tableName, String targetPath, Map tableProperties, int columns) + { + CatalogManager catalogManager = getDistributedQueryRunner().getCoordinator().getCatalogManager(); + ConnectorId connectorId = catalogManager.getCatalog(getDistributedQueryRunner().getDefaultSession().getCatalog().get()).get().getConnectorId(); + ConnectorSession session = getQueryRunner().getDefaultSession().toConnectorSession(connectorId); + MetastoreContext context = new MetastoreContext(session.getIdentity(), session.getQueryId(), session.getClientInfo(), session.getClientTags(), session.getSource(), getMetastoreHeaders(session), isUserDefinedTypeEncodingEnabled(session), HiveColumnConverterProvider.DEFAULT_COLUMN_CONVERTER_PROVIDER, session.getWarningCollector(), session.getRuntimeStats()); + HdfsContext hdfsContext = new HdfsContext(session, "tpch", tableName); + HiveTableOperations operations = new HiveTableOperations( + getFileHiveMetastore(), + context, + getHdfsEnvironment(), + hdfsContext, + new IcebergHiveTableOperationsConfig().setLockingEnabled(false), + new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), + "tpch", + tableName, + session.getUser(), + targetPath); + List columnMetadataList = new ArrayList<>(); + for (int i = 0; i < columns; i++) { + columnMetadataList.add(ColumnMetadata.builder().setName("column" + i).setType(INTEGER).build()); + } + TableMetadata metadata = newTableMetadata( + toIcebergSchema(columnMetadataList), + PartitionSpec.unpartitioned(), targetPath, + tableProperties); + Transaction transaction = createTableTransaction(tableName, operations, metadata); + transaction.commitTransaction(); + return transaction.table(); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergHiveStatistics.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergHiveStatistics.java index 13465aed27149..ec29e4aae83e2 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergHiveStatistics.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergHiveStatistics.java @@ -31,6 +31,7 @@ import com.facebook.presto.hive.authentication.NoHdfsAuthentication; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.iceberg.CatalogType; +import com.facebook.presto.iceberg.IcebergCatalogName; import com.facebook.presto.iceberg.IcebergColumnHandle; import com.facebook.presto.iceberg.IcebergHiveTableOperationsConfig; import com.facebook.presto.iceberg.IcebergMetadataColumn; @@ -53,6 +54,7 @@ import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.google.common.cache.CacheBuilder; @@ -465,6 +467,32 @@ public void testStatisticsCachePartialEviction() } } + @Test + public void testShowStatsWithTimestampWithTimeZone() + { + assertQuerySucceeds("CREATE TABLE test_timestamp_tz(id BIGINT, ts TIMESTAMP WITH TIME ZONE)"); + assertUpdate("INSERT INTO test_timestamp_tz VALUES " + + "(1, TIMESTAMP '2024-01-01 12:00:00 UTC'), " + + "(2, TIMESTAMP '2024-01-02 18:30:00 UTC'), " + + "(3, TIMESTAMP '2024-01-03 00:00:00 America/New_York')", 3); + + MaterializedResult stats = getQueryRunner().execute("SHOW STATS FOR test_timestamp_tz"); + + assertStatValue(StatsSchema.LOW_VALUE, stats, ImmutableSet.of("ts"), null, true); + assertStatValue(StatsSchema.HIGH_VALUE, stats, ImmutableSet.of("ts"), null, true); + + Optional tsRow = stats.getMaterializedRows().stream() + .filter(row -> row.getField(StatsSchema.COLUMN_NAME.ordinal()) != null) + .filter(row -> row.getField(StatsSchema.COLUMN_NAME.ordinal()).equals("ts")) + .findFirst(); + assertTrue(tsRow.isPresent(), "Statistics for column 'ts' not found"); + MaterializedRow row = tsRow.get(); + assertEquals((String) row.getField(StatsSchema.LOW_VALUE.ordinal()), "2024-01-01 12:00:00.000 UTC"); + assertEquals((String) row.getField(StatsSchema.HIGH_VALUE.ordinal()), "2024-01-03 05:00:00.000 UTC"); + + assertQuerySucceeds("DROP TABLE test_timestamp_tz"); + } + private TableStatistics getScanStatsEstimate(Session session, @Language("SQL") String sql) { Plan plan = plan(sql, session); @@ -592,6 +620,7 @@ private void deleteTableStatistics(String tableName) private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); CatalogManager catalogManager = getDistributedQueryRunner().getCoordinator().getCatalogManager(); ConnectorId connectorId = catalogManager.getCatalog(ICEBERG_CATALOG).get().getConnectorId(); @@ -600,6 +629,7 @@ private Table loadTable(String tableName) new IcebergHiveTableOperationsConfig(), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), getQueryRunner().getDefaultSession().toConnectorSession(connectorId), + new IcebergCatalogName(ICEBERG_CATALOG), SchemaTableName.valueOf("tpch." + tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergSmokeHive.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergSmokeHive.java index c261f4be3fba5..3e42fb1b1a814 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergSmokeHive.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestIcebergSmokeHive.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.iceberg.hive; +import com.facebook.presto.FullConnectorSession; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; +import com.facebook.presto.iceberg.IcebergCatalogName; import com.facebook.presto.iceberg.IcebergConfig; import com.facebook.presto.iceberg.IcebergDistributedSmokeTestBase; import com.facebook.presto.iceberg.IcebergHiveTableOperationsConfig; @@ -64,11 +66,13 @@ protected ExtendedHiveMetastore getFileHiveMetastore() @Override protected Table getIcebergTable(ConnectorSession session, String schema, String tableName) { + String defaultCatalog = ((FullConnectorSession) session).getSession().getCatalog().get(); return IcebergUtil.getHiveIcebergTable(getFileHiveMetastore(), getHdfsEnvironment(), new IcebergHiveTableOperationsConfig(), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), session, + new IcebergCatalogName(defaultCatalog), SchemaTableName.valueOf(schema + "." + tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java index d89d8dcb2601c..4fe2febcddba9 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/hive/TestRenameTableOnFragileFileSystem.java @@ -36,6 +36,7 @@ import com.facebook.presto.hive.metastore.file.FileHiveMetastoreConfig; import com.facebook.presto.hive.metastore.file.TableMetadata; import com.facebook.presto.iceberg.CommitTaskData; +import com.facebook.presto.iceberg.IcebergCatalogName; import com.facebook.presto.iceberg.IcebergConfig; import com.facebook.presto.iceberg.IcebergHiveMetadata; import com.facebook.presto.iceberg.IcebergHiveMetadataFactory; @@ -47,6 +48,7 @@ import com.facebook.presto.iceberg.IcebergTableType; import com.facebook.presto.iceberg.ManifestFileCache; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.ConnectorSession; @@ -75,6 +77,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.reflect.TypeToken; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FSDataOutputStream; @@ -207,7 +210,8 @@ public String formatRowExpression(ConnectorSession session, RowExpression expres Optional.empty(), Optional.empty(), ImmutableList.of(), - ImmutableList.of()); + ImmutableList.of(), + Optional.empty()); @Test public void testRenameTableSucceed() @@ -405,18 +409,23 @@ private ConnectorMetadata getIcebergHiveMetadata(ExtendedHiveMetastore metastore { HdfsEnvironment hdfsEnvironment = new TestingHdfsEnvironment(); IcebergHiveMetadataFactory icebergHiveMetadataFactory = new IcebergHiveMetadataFactory( + new IcebergCatalogName("unimportant"), metastore, hdfsEnvironment, FUNCTION_AND_TYPE_MANAGER, + new BuiltInProcedureRegistry(METADATA.getFunctionAndTypeManager()), FUNCTION_RESOLUTION, ROW_EXPRESSION_SERVICE, jsonCodec(CommitTaskData.class), + jsonCodec(new TypeToken<>() {}), + jsonCodec(new TypeToken<>() {}), new NodeVersion("test_node_v1"), FILTER_STATS_CALCULATOR_SERVICE, new IcebergHiveTableOperationsConfig(), new StatisticsFileCache(CacheBuilder.newBuilder().build()), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), - new IcebergTableProperties(new IcebergConfig())); + new IcebergTableProperties(new IcebergConfig()), + () -> false); return icebergHiveMetadataFactory.create(); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergNessieRestCatalogDistributedQueries.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergNessieRestCatalogDistributedQueries.java new file mode 100644 index 0000000000000..95c6048730f3e --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergNessieRestCatalogDistributedQueries.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.nessie; + +import com.facebook.presto.iceberg.IcebergQueryRunner; +import com.facebook.presto.iceberg.container.IcebergMinIODataLake; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.containers.MinIOContainer; +import com.facebook.presto.testing.containers.NessieContainer; +import com.google.common.collect.ImmutableMap; +import com.google.common.net.HostAndPort; +import org.testcontainers.containers.Network; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import software.amazon.awssdk.regions.Region; + +import java.net.URI; +import java.util.Map; + +import static com.facebook.presto.iceberg.CatalogType.REST; +import static com.facebook.presto.iceberg.container.IcebergMinIODataLake.ACCESS_KEY; +import static com.facebook.presto.iceberg.container.IcebergMinIODataLake.SECRET_KEY; +import static com.facebook.presto.tests.sql.TestTable.randomTableSuffix; +import static java.lang.String.format; + +public class TestIcebergNessieRestCatalogDistributedQueries + extends TestIcebergNessieCatalogDistributedQueries +{ + static final String WAREHOUSE_DATA_DIR = "warehouse_data/"; + private NessieContainer nessieContainer; + private IcebergMinIODataLake dockerizedS3DataLake; + private String bucketName; + HostAndPort minioApiEndpoint; + + @BeforeClass + @Override + public void init() + throws Exception + { + Network network = Network.newNetwork(); + + bucketName = "fornessie-" + randomTableSuffix(); + dockerizedS3DataLake = new IcebergMinIODataLake(bucketName, WAREHOUSE_DATA_DIR, network); + dockerizedS3DataLake.start(); + minioApiEndpoint = dockerizedS3DataLake.getMinio().getMinioApiEndpoint(); + + Map envVars = ImmutableMap.builder() + .putAll(NessieContainer.DEFAULT_ENV_VARS) + .put("NESSIE_CATALOG_SERVICE_S3_DEFAULT-OPTIONS_ACCESS-KEY", "urn:nessie-secret:quarkus:nessie.catalog.secrets.access-key") + .put("NESSIE_CATALOG_SECRETS_ACCESS-KEY_NAME", ACCESS_KEY) + .put("NESSIE_CATALOG_SECRETS_ACCESS-KEY_SECRET", SECRET_KEY) + .put("NESSIE_CATALOG_SERVICE_S3_DEFAULT-OPTIONS_PATH_STYLE_ACCESS", "true") + .put("NESSIE_CATALOG_SERVICE_S3_DEFAULT-OPTIONS_ENDPOINT", format("http://%s:%s", MinIOContainer.DEFAULT_HOST_NAME, MinIOContainer.MINIO_API_PORT)) + .put("NESSIE_CATALOG_SERVICE_S3_DEFAULT-OPTIONS_REGION", Region.US_EAST_1.toString()) + .put("NESSIE_CATALOG_WAREHOUSES_WAREHOUSE_LOCATION", getCatalogDataDirectory().toString()) + .put("NESSIE_CATALOG_DEFAULT-WAREHOUSE", "warehouse") + .put("NESSIE_CATALOG_SERVICE_S3.DEFAULT-OPTIONS_EXTERNAL-ENDPOINT", format("http://%s:%s", minioApiEndpoint.getHost(), minioApiEndpoint.getPort())) + .buildOrThrow(); + nessieContainer = NessieContainer.builder().withEnvVars(envVars).withNetwork(network).build(); + nessieContainer.start(); + + super.init(); + } + + @AfterClass(alwaysRun = true) + @Override + public void tearDown() + { + super.tearDown(); + if (nessieContainer != null) { + nessieContainer.stop(); + } + if (dockerizedS3DataLake != null) { + dockerizedS3DataLake.stop(); + } + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Map icebergProperties = ImmutableMap.builder() + .put("iceberg.rest.uri", nessieContainer.getIcebergRestUri()) + .put("iceberg.catalog.warehouse", getCatalogDataDirectory().toString()) + .put("hive.s3.aws-access-key", ACCESS_KEY) + .put("hive.s3.aws-secret-key", SECRET_KEY) + .put("hive.s3.endpoint", format("http://%s:%s", minioApiEndpoint.getHost(), minioApiEndpoint.getPort())) + .put("hive.s3.path-style-access", "true") + .build(); + + return IcebergQueryRunner.builder() + .setCatalogType(REST) + .setExtraConnectorProperties(icebergProperties) + .build().getQueryRunner(); + } + + protected org.apache.hadoop.fs.Path getCatalogDataDirectory() + { + return new org.apache.hadoop.fs.Path(URI.create(format("s3a://%s/%s", bucketName, WAREHOUSE_DATA_DIR))); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSmokeNessie.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSmokeNessie.java index 064679730c9e2..9c4faeb7546d4 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSmokeNessie.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSmokeNessie.java @@ -82,7 +82,7 @@ protected String getLocation(String schema, String table) Path dataDirectory = ((DistributedQueryRunner) queryRunner).getCoordinator().getDataDirectory(); Path icebergDataDirectory = getIcebergDataDirectoryPath(dataDirectory, NESSIE.name(), new IcebergConfig().getFileFormat(), false); Optional tempTableLocation = Arrays.stream(requireNonNull(icebergDataDirectory.resolve(schema).toFile().listFiles())) - .filter(file -> file.toURI().toString().contains(table)).findFirst(); + .filter(file -> endsWithTableUUID(table, file.toURI().toString())).findFirst(); String dataLocation = icebergDataDirectory.toFile().toURI().toString(); String relativeTableLocation = tempTableLocation.get().toURI().toString().replace(dataLocation, ""); @@ -119,4 +119,9 @@ protected Table getIcebergTable(ConnectorSession session, String schema, String session, SchemaTableName.valueOf(schema + "." + tableName)); } + + private static boolean endsWithTableUUID(String tableName, String tablePath) + { + return tablePath.matches(format(".*%s_[-a-fA-F0-9]{36}/$", tableName)); + } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessie.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessie.java index 4c2e169cffc98..f45d3e096c10f 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessie.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessie.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.iceberg.nessie; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.iceberg.IcebergConfig; import com.facebook.presto.iceberg.IcebergPlugin; @@ -24,7 +25,6 @@ import com.facebook.presto.testing.containers.NessieContainer; import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -93,7 +93,10 @@ protected QueryRunner createQueryRunner() protected void checkTableProperties(String tableName, String deleteMode) { assertQuery(String.format("SHOW COLUMNS FROM test_schema.\"%s$properties\"", tableName), - "VALUES ('key', 'varchar', '', '')," + "('value', 'varchar', '', '')"); + "VALUES " + + "('key', 'varchar', '', '', null, null, 2147483647)," + + "('value', 'varchar', '', '', null, null, 2147483647)," + + "('is_supported_by_presto', 'varchar', '', '', null, null, 2147483647)"); assertQuery(String.format("SELECT COUNT(*) FROM test_schema.\"%s$properties\"", tableName), "VALUES 11"); List materializedRows = computeActual(getSession(), String.format("SELECT * FROM test_schema.\"%s$properties\"", tableName)).getMaterializedRows(); @@ -101,30 +104,34 @@ protected void checkTableProperties(String tableName, String deleteMode) assertThat(materializedRows).hasSize(11); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "PARQUET"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "PARQUET", "true"))) .anySatisfy(row -> assertThat(row.getField(0)).isEqualTo("nessie.commit.id")) - .anySatisfy(row -> assertThat(row).isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "gc.enabled", "false"))) + .anySatisfy(row -> assertThat(row).isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "gc.enabled", "false", "false"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "GZIP"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "ZSTD", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes())))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, IcebergTableProperties.TARGET_SPLIT_SIZE, Long.toString(DataSize.valueOf("128MB").toBytes()), "true"))); } @Override protected void checkORCFormatTableProperties(String tableName, String deleteMode) { assertQuery(String.format("SHOW COLUMNS FROM test_schema.\"%s$properties\"", tableName), - "VALUES ('key', 'varchar', '', '')," + "('value', 'varchar', '', '')"); + "VALUES " + + "('key', 'varchar', '', '', null, null, 2147483647)," + + "('value', 'varchar', '', '', null, null, 2147483647)," + + "('is_supported_by_presto', 'varchar', '', '', null, null, 2147483647)"); + assertQuery(String.format("SELECT COUNT(*) FROM test_schema.\"%s$properties\"", tableName), "VALUES 12"); List materializedRows = computeActual(getSession(), String.format("SELECT * FROM test_schema.\"%s$properties\"", tableName)).getMaterializedRows(); @@ -132,22 +139,22 @@ protected void checkORCFormatTableProperties(String tableName, String deleteMode assertThat(materializedRows).hasSize(12); assertThat(materializedRows) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.delete.mode", deleteMode, "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "ORC"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.format.default", "ORC", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.orc.compression-codec", "ZLIB"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.orc.compression-codec", "ZSTD", "true"))) .anySatisfy(row -> assertThat(row.getField(0)).isEqualTo("nessie.commit.id")) - .anySatisfy(row -> assertThat(row).isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "gc.enabled", "false"))) + .anySatisfy(row -> assertThat(row).isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "gc.enabled", "false", "false"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "zstd"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.parquet.compression-codec", "zstd", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.delete-after-commit.enabled", "false", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "commit.retry.num-retries", "4", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100"))) + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.previous-versions-max", "100", "true"))) .anySatisfy(row -> assertThat(row) - .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100"))); + .isEqualTo(new MaterializedRow(MaterializedResult.DEFAULT_PRECISION, "write.metadata.metrics.max-inferred-column-defaults", "100", "true"))); } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessieWithBearerAuth.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessieWithBearerAuth.java new file mode 100644 index 0000000000000..35c831c35a047 --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestIcebergSystemTablesNessieWithBearerAuth.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.nessie; + +import com.facebook.presto.Session; +import com.facebook.presto.iceberg.IcebergConfig; +import com.facebook.presto.iceberg.IcebergPlugin; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.containers.KeycloakContainer; +import com.facebook.presto.testing.containers.NessieContainer; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testcontainers.containers.Network; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.nio.file.Path; +import java.util.Map; + +import static com.facebook.presto.iceberg.CatalogType.NESSIE; +import static com.facebook.presto.iceberg.IcebergQueryRunner.ICEBERG_CATALOG; +import static com.facebook.presto.iceberg.IcebergQueryRunner.getIcebergDataDirectoryPath; +import static com.facebook.presto.iceberg.nessie.NessieTestUtil.nessieConnectorProperties; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestIcebergSystemTablesNessieWithBearerAuth + extends TestIcebergSystemTablesNessie +{ + private NessieContainer nessieContainer; + private KeycloakContainer keycloakContainer; + + @BeforeClass + @Override + public void init() + throws Exception + { + Map envVars = ImmutableMap.builder() + .putAll(NessieContainer.DEFAULT_ENV_VARS) + .put("QUARKUS_OIDC_AUTH_SERVER_URL", KeycloakContainer.SERVER_URL + "/realms/" + KeycloakContainer.MASTER_REALM) + .put("QUARKUS_OIDC_CLIENT_ID", "nessie") + .put("NESSIE_SERVER_AUTHENTICATION_ENABLED", "true") + .buildOrThrow(); + + Network network = Network.newNetwork(); + + nessieContainer = NessieContainer.builder().withEnvVars(envVars).withNetwork(network).build(); + nessieContainer.start(); + keycloakContainer = KeycloakContainer.builder().withNetwork(network).build(); + keycloakContainer.start(); + + super.init(); + } + + @AfterClass(alwaysRun = true) + @Override + public void tearDown() + { + super.tearDown(); + if (nessieContainer != null) { + nessieContainer.stop(); + } + if (keycloakContainer != null) { + keycloakContainer.stop(); + } + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog(ICEBERG_CATALOG) + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); + + Path dataDirectory = queryRunner.getCoordinator().getDataDirectory(); + Path catalogDirectory = getIcebergDataDirectoryPath(dataDirectory, "NESSIE", new IcebergConfig().getFileFormat(), false); + + queryRunner.installPlugin(new IcebergPlugin()); + Map icebergProperties = ImmutableMap.builder() + .put("iceberg.catalog.type", String.valueOf(NESSIE)) + .putAll(nessieConnectorProperties(nessieContainer.getRestApiUri())) + .put("iceberg.catalog.warehouse", catalogDirectory.getParent().toFile().toURI().toString()) + .put("iceberg.nessie.auth.type", "BEARER") + .put("iceberg.nessie.auth.bearer.token", keycloakContainer.getAccessToken()) + .build(); + + queryRunner.createCatalog(ICEBERG_CATALOG, "iceberg", icebergProperties); + + icebergProperties = ImmutableMap.builder() + .put("iceberg.catalog.type", String.valueOf(NESSIE)) + .putAll(nessieConnectorProperties(nessieContainer.getRestApiUri())) + .put("iceberg.catalog.warehouse", catalogDirectory.getParent().toFile().toURI().toString()) + .put("iceberg.nessie.auth.type", "BEARER") + .put("iceberg.nessie.auth.bearer.token", "invalid_token") + .build(); + + queryRunner.createCatalog("iceberg_invalid_credentials", "iceberg", icebergProperties); + + return queryRunner; + } + + @Test + public void testInvalidBearerToken() + { + assertQueryFails("CREATE SCHEMA iceberg_invalid_credentials.test_schema", "Unauthorized \\(HTTP/401\\).*", true); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestNessieMultiBranching.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestNessieMultiBranching.java index 1cde1e1549f97..f9a66646c9483 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestNessieMultiBranching.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/nessie/TestNessieMultiBranching.java @@ -19,8 +19,8 @@ import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.testing.containers.NessieContainer; import com.facebook.presto.tests.AbstractTestQueryFramework; +import org.projectnessie.client.NessieClientBuilder; import org.projectnessie.client.api.NessieApiV1; -import org.projectnessie.client.http.HttpClientBuilder; import org.projectnessie.error.NessieConflictException; import org.projectnessie.error.NessieNotFoundException; import org.projectnessie.model.Branch; @@ -53,7 +53,7 @@ public void init() { nessieContainer = NessieContainer.builder().build(); nessieContainer.start(); - nessieApiV1 = HttpClientBuilder.builder().withUri(nessieContainer.getRestApiUri()).build(NessieApiV1.class); + nessieApiV1 = NessieClientBuilder.createClientBuilder(null, null).withUri(nessieContainer.getRestApiUri()).build(NessieApiV1.class); super.init(); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestExpireSnapshotProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestExpireSnapshotProcedure.java index cb5389106148a..5ecb275cfd10b 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestExpireSnapshotProcedure.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestExpireSnapshotProcedure.java @@ -248,6 +248,7 @@ private String getTimestampString(long timeMillsUtc, String zoneId) private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestFastForwardBranchProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestFastForwardBranchProcedure.java index 708b053a4a42e..fa685cb30d163 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestFastForwardBranchProcedure.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestFastForwardBranchProcedure.java @@ -262,6 +262,7 @@ public void testFastForwardNonExistingBranch() private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureBase.java index 8807ab0307885..31bd60609ee2f 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureBase.java @@ -56,6 +56,7 @@ import static com.google.common.io.Files.createTempDir; import static java.lang.String.format; import static java.util.Objects.requireNonNull; +import static org.apache.iceberg.TableProperties.DELETE_MODE; import static org.apache.iceberg.TableProperties.WRITE_DATA_LOCATION; import static org.apache.iceberg.TableProperties.WRITE_METADATA_LOCATION; import static org.testng.Assert.assertEquals; @@ -131,6 +132,8 @@ public void testRemoveOrphanFilesInMetadataAndDataFolder(String zoneId, boolean assertUpdate(session, format("create table %s (a int, b varchar)", tableName)); assertUpdate(session, format("insert into %s values(1, '1001'), (2, '1002')", tableName), 2); assertUpdate(session, format("insert into %s values(3, '1003'), (4, '1004')", tableName), 2); + assertUpdate(session, format("insert into %s values(5, '1005'), (6, '1006')", tableName), 2); + assertUpdate(session, format("delete from %s where a between 5 and 6", tableName), 2); assertQuery(session, "select * from " + tableName, "values(1, '1001'), (2, '1002'), (3, '1003'), (4, '1004')"); Table table = loadTable(tableName); @@ -185,12 +188,15 @@ public void testRemoveOrphanFilesWithNonDefaultMetadataPath(String zoneId, boole // Create an iceberg table using specified table properties Table table = createTable(tempTableName, tableTargetPath, - ImmutableMap.of(WRITE_METADATA_LOCATION, specifiedMetadataPath)); + ImmutableMap.of(WRITE_METADATA_LOCATION, specifiedMetadataPath, + DELETE_MODE, "merge-on-read")); assertNotNull(table.properties().get(WRITE_METADATA_LOCATION)); assertEquals(table.properties().get(WRITE_METADATA_LOCATION), specifiedMetadataPath); assertUpdate(session, format("CALL system.register_table('%s', '%s', '%s')", TEST_SCHEMA, tableName, metadataLocation(table))); assertUpdate(session, "insert into " + tableName + " values(1, '1001'), (2, '1002')", 2); + assertUpdate(session, "insert into " + tableName + " values(3, '1003'), (4, '1004')", 2); + assertUpdate(session, "delete from " + tableName + " where a between 3 and 4", 2); assertQuery(session, "select * from " + tableName, "values(1, '1001'), (2, '1002')"); int metadataFilesCountBefore = allMetadataFilesCount(session, table); @@ -236,12 +242,15 @@ public void testRemoveOrphanFilesWithNonDefaultDataPath(String zoneId, boolean l // Create an iceberg table using specified table properties Table table = createTable(tempTableName, tableTargetPath, - ImmutableMap.of(WRITE_DATA_LOCATION, specifiedDataPath)); + ImmutableMap.of(WRITE_DATA_LOCATION, specifiedDataPath, + DELETE_MODE, "merge-on-read")); assertNotNull(table.properties().get(WRITE_DATA_LOCATION)); assertEquals(table.properties().get(WRITE_DATA_LOCATION), specifiedDataPath); assertUpdate(session, format("CALL system.register_table('%s', '%s', '%s')", TEST_SCHEMA, tableName, metadataLocation(table))); assertUpdate(session, "insert into " + tableName + " values(1, '1001'), (2, '1002')", 2); + assertUpdate(session, "insert into " + tableName + " values(3, '1003'), (4, '1004')", 2); + assertUpdate(session, "delete from " + tableName + " where a between 3 and 4", 2); assertQuery(session, "select * from " + tableName, "values(1, '1001'), (2, '1002')"); int metadataFilesCountBefore = allMetadataFilesCount(session, table); diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHadoop.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHadoop.java index d65f134583042..ba62a81799786 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHadoop.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHadoop.java @@ -74,6 +74,7 @@ Table createTable(String tableName, String targetPath, Map table @Override Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HADOOP.getCatalogImpl(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHive.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHive.java index 7eda1fd2b720e..b186ef87cf663 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHive.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRemoveOrphanFilesProcedureHive.java @@ -18,6 +18,7 @@ import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.MetastoreContext; import com.facebook.presto.iceberg.HiveTableOperations; +import com.facebook.presto.iceberg.IcebergCatalogName; import com.facebook.presto.iceberg.IcebergHiveTableOperationsConfig; import com.facebook.presto.iceberg.IcebergUtil; import com.facebook.presto.iceberg.ManifestFileCache; @@ -89,6 +90,7 @@ Table createTable(String tableName, String targetPath, Map table @Override Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); CatalogManager catalogManager = getDistributedQueryRunner().getCoordinator().getCatalogManager(); ConnectorId connectorId = catalogManager.getCatalog(ICEBERG_CATALOG).get().getConnectorId(); @@ -97,6 +99,7 @@ Table loadTable(String tableName) new IcebergHiveTableOperationsConfig(), new ManifestFileCache(CacheBuilder.newBuilder().build(), false, 0, 1024), getQueryRunner().getDefaultSession().toConnectorSession(connectorId), + new IcebergCatalogName(ICEBERG_CATALOG), SchemaTableName.valueOf("tpch." + tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRewriteManifestsProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRewriteManifestsProcedure.java new file mode 100644 index 0000000000000..08142f52d6e32 --- /dev/null +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRewriteManifestsProcedure.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.iceberg.procedure; + +import com.facebook.presto.iceberg.IcebergQueryRunner; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import org.testng.annotations.Test; + +import static com.facebook.presto.iceberg.CatalogType.HADOOP; +import static java.lang.String.format; +import static org.apache.iceberg.TableProperties.MANIFEST_MERGE_ENABLED; +import static org.testng.Assert.assertEquals; + +public class TestRewriteManifestsProcedure + extends AbstractTestQueryFramework +{ + public static final String TEST_SCHEMA = "tpch"; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return IcebergQueryRunner.builder() + .setCatalogType(HADOOP) + .build() + .getQueryRunner(); + } + + private void createTable(String tableName) + { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + assertUpdate("CREATE TABLE " + tableName + " (id INTEGER, value VARCHAR)"); + } + + private void dropTable(String tableName) + { + assertQuerySucceeds("DROP TABLE IF EXISTS " + TEST_SCHEMA + "." + tableName); + } + + @Test + public void testRewriteManifestsUsingPositionalArgs() + { + String tableName = "rewrite_manifests_positional"; + createTable(tableName); + try { + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'b')", 1); + + long replaceBefore = (long) computeScalar(format("SELECT count(*) FROM %s.\"%s$snapshots\" WHERE operation = 'replace'", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0"); + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s')", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0"); + long replaceAfter = (long) computeScalar(format("SELECT count(*) FROM %s.\"%s$snapshots\" WHERE operation = 'replace'", TEST_SCHEMA, tableName)); + + assertQuery(format("SELECT * FROM %s.%s ORDER BY id", TEST_SCHEMA, tableName), "VALUES (1, 'a'), (2, 'b')"); + assertEquals(replaceAfter, replaceBefore + 1); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteManifestsUsingNamedArgs() + { + String tableName = "rewrite_manifests_named"; + createTable(tableName); + try { + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'b')", 1); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0"); + assertUpdate(format("CALL system.rewrite_manifests(schema => '%s', table_name => '%s')", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteManifestsWithValidSpecId() + { + String tableName = "rewrite_manifests_spec"; + createTable(tableName); + try { + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'b')", 1); + // default tables have spec_id = 0 + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0"); + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s', 0)", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteManifestsWithInvalidSpecIdFails() + { + String tableName = "rewrite_manifests_invalid_spec"; + createTable(tableName); + try { + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + assertQueryFails(format("CALL system.rewrite_manifests('%s', '%s', 999)", TEST_SCHEMA, tableName), "Given spec id does not exist: 999"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteManifestsIsIdempotent() + { + String tableName = "rewrite_manifests_idempotent"; + createTable(tableName); + try { + assertUpdate(format("INSERT INTO %s.%s VALUES (1, 'a')", TEST_SCHEMA, tableName), 1); + assertUpdate(format("INSERT INTO %s.%s VALUES (2, 'b')", TEST_SCHEMA, tableName), 1); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0"); + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s')", TEST_SCHEMA, tableName)); + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s')", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0"); + assertQuery(format("SELECT * FROM %s.%s", TEST_SCHEMA, tableName), "VALUES (1, 'a'), (2, 'b')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testInvalidRewriteManifestsCalls() + { + assertQueryFails("CALL system.rewrite_manifests('test_table', 1)", "line 1:45: Cannot cast type integer to varchar"); + assertQueryFails("CALL system.rewrite_manifests(table_name => 'test_table', spec_id=> 1)", "line 1:1: Required procedure argument 'schema' is missing"); + assertQueryFails("CALL system.rewrite_manifests(schema => 'tpch', table_name => 'test', 1)", "line 1:1: Named and positional arguments cannot be mixed"); + assertQueryFails("CALL custom.rewrite_manifests('tpch', 'test')", "Procedure not registered: custom.rewrite_manifests"); + } + + @Test + public void testRewriteManifestsWithPartitionEvolution() + { + String tableName = "rewrite_manifests_spec_evolution"; + createTable(tableName); + assertUpdate(format("call system.set_table_property('%s', '%s', '%s', '%s')", TEST_SCHEMA, tableName, MANIFEST_MERGE_ENABLED, false)); + try { + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 'a')", 1); + assertUpdate("INSERT INTO " + tableName + " VALUES (2, 'b')", 1); + // default tables have spec_id = 0 + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0"); + + assertUpdate("ALTER TABLE " + tableName + " ADD COLUMN col_new VARCHAR WITH (PARTITIONING = 'identity')"); + assertUpdate("INSERT INTO " + tableName + " VALUES (3, 'c', 'val_3')", 1); + + // current spec_id = 1 + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0, 1"); + assertQuery("SELECT * from " + tableName, "VALUES(1, 'a', null), (2, 'b', null), (3, 'c', 'val_3')"); + + // default rewrite manifest files with current spec_id = 1 + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s')", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 0, 1"); + + // rewrite manifest files with specified spec_id = 0 + assertUpdate(format("CALL system.rewrite_manifests('%s', '%s', 0)", TEST_SCHEMA, tableName)); + assertQuery(format("SELECT partition_spec_id from %s.\"%s$manifests\"", TEST_SCHEMA, tableName), "VALUES 0, 1"); + assertQuery("SELECT * from " + tableName, "VALUES(1, 'a', null), (2, 'b', null), (3, 'c', 'val_3')"); + } + finally { + dropTable(tableName); + } + } + + @Test + public void testRewriteManifestsOnNonExistingTableFails() + { + assertQueryFails("CALL system.rewrite_manifests('tpch', 'non_existing_table')", "Table does not exist: tpch.non_existing_table"); + } +} diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRollbackToTimestampProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRollbackToTimestampProcedure.java index 2f5c1ba3eb6af..11397d91c53e8 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRollbackToTimestampProcedure.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestRollbackToTimestampProcedure.java @@ -251,6 +251,7 @@ private static String getTimestampString(long timeMillsUtc, String zoneId) private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HADOOP.getCatalogImpl(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetCurrentSnapshotProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetCurrentSnapshotProcedure.java index 7cb88bc9d4c19..fa50afa142727 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetCurrentSnapshotProcedure.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetCurrentSnapshotProcedure.java @@ -174,6 +174,7 @@ public void testSetCurrentSnapshotToRef() private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetTablePropertyProcedure.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetTablePropertyProcedure.java index 6f98befeca4c3..0f227ab770152 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetTablePropertyProcedure.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/procedure/TestSetTablePropertyProcedure.java @@ -160,6 +160,7 @@ public void testInvalidSetTablePropertyProcedureCases() private Table loadTable(String tableName) { + tableName = normalizeIdentifier(tableName, ICEBERG_CATALOG); Catalog catalog = CatalogUtil.loadCatalog(HadoopCatalog.class.getName(), ICEBERG_CATALOG, getProperties(), new Configuration()); return catalog.loadTable(TableIdentifier.of(TEST_SCHEMA, tableName)); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/IcebergRestTestUtil.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/IcebergRestTestUtil.java index a874a42168c2b..34a88837ce696 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/IcebergRestTestUtil.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/IcebergRestTestUtil.java @@ -14,6 +14,7 @@ package com.facebook.presto.iceberg.rest; import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.http.server.HttpServerConfig; import com.facebook.airlift.http.server.TheServlet; import com.facebook.airlift.http.server.testing.TestingHttpServer; import com.facebook.airlift.http.server.testing.TestingHttpServerModule; @@ -43,6 +44,8 @@ import java.util.Map; import java.util.concurrent.ThreadLocalRandom; +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.http.server.UriCompliance.LEGACY; import static com.facebook.presto.iceberg.IcebergDistributedTestBase.getHdfsEnvironment; import static java.util.Objects.requireNonNull; import static org.apache.iceberg.CatalogProperties.URI; @@ -124,8 +127,13 @@ private class RestHttpServerModule @Override public void configure(Binder binder) { + configBinder(binder) + .bindConfigDefaults(HttpServerConfig.class, config -> { + // This is required to support nested namespace URI paths + config.setUriComplianceMode(LEGACY); + }); binder.bind(new TypeLiteral>() {}).annotatedWith(TheServlet.class).toInstance(ImmutableMap.of()); - binder.bind(javax.servlet.Servlet.class).annotatedWith(TheServlet.class).toInstance(new IcebergRestCatalogServlet(adapter)); + binder.bind(jakarta.servlet.Servlet.class).annotatedWith(TheServlet.class).toInstance(new IcebergRestCatalogServlet(adapter)); binder.bind(NodeInfo.class).toInstance(new NodeInfo("test")); } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergDistributedRest.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergDistributedRest.java index baa105bc79f0b..53e076fcdef2c 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergDistributedRest.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergDistributedRest.java @@ -36,7 +36,6 @@ import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; -import static org.assertj.core.api.Assertions.assertThatThrownBy; @Test public class TestIcebergDistributedRest @@ -88,6 +87,10 @@ protected QueryRunner createQueryRunner() Map connectorProperties = ImmutableMap.builder() .putAll(restConnectorProperties(serverUri)) .put("iceberg.rest.session.type", SessionType.USER.name()) + // Enable OAuth2 authentication to trigger token exchange flow + // The credential is required to initialize the OAuth2Manager + .put("iceberg.rest.auth.type", "OAUTH2") + .put("iceberg.rest.auth.oauth2.credential", "client:secret") .build(); return IcebergQueryRunner.builder() @@ -98,15 +101,6 @@ protected QueryRunner createQueryRunner() .getQueryRunner(); } - @Test - public void testDeleteOnV1Table() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(super::testDeleteOnV1Table) - .isInstanceOf(RuntimeException.class) - .hasMessageMatching("Cannot downgrade v2 table to v1"); - } - @Test public void testRestUserSessionAuthorization() { diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java index 73b6d0a21e1b7..42e0df564402a 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRest.java @@ -47,7 +47,6 @@ import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static java.lang.String.format; import static org.apache.iceberg.rest.auth.OAuth2Properties.OAUTH2_SERVER_URI; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; @Test @@ -128,69 +127,6 @@ protected Table getIcebergTable(ConnectorSession session, String schema, String SchemaTableName.valueOf(schema + "." + tableName)); } - @Test - public void testDeleteOnPartitionedV1Table() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(super::testDeleteOnPartitionedV1Table) - .isInstanceOf(RuntimeException.class) - .hasMessageMatching("Cannot downgrade v2 table to v1"); - } - - @Test - public void testCreateTableWithFormatVersion() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(() -> super.testMetadataDeleteOnNonIdentityPartitionColumn("1", "copy-on-write")) - .isInstanceOf(RuntimeException.class) - .hasMessageMatching("Cannot downgrade v2 table to v1"); - - // v2 succeeds - super.testCreateTableWithFormatVersion("2", "merge-on-read"); - } - - @Test(dataProvider = "version_and_mode") - public void testMetadataDeleteOnNonIdentityPartitionColumn(String version, String mode) - { - if (version.equals("1")) { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(() -> super.testMetadataDeleteOnNonIdentityPartitionColumn(version, mode)) - .isInstanceOf(RuntimeException.class); - } - else { - // v2 succeeds - super.testMetadataDeleteOnNonIdentityPartitionColumn(version, mode); - } - } - - @Test(dataProvider = "version_and_mode") - public void testMetadataDeleteOnTableWithUnsupportedSpecsIncludingNoData(String version, String mode) - { - if (version.equals("1")) { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(() -> super.testMetadataDeleteOnTableWithUnsupportedSpecsIncludingNoData(version, mode)) - .isInstanceOf(RuntimeException.class); - } - else { - // v2 succeeds - super.testMetadataDeleteOnTableWithUnsupportedSpecsIncludingNoData(version, mode); - } - } - - @Test(dataProvider = "version_and_mode") - public void testMetadataDeleteOnTableWithUnsupportedSpecsWhoseDataAllDeleted(String version, String mode) - { - if (version.equals("1")) { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(() -> super.testMetadataDeleteOnTableWithUnsupportedSpecsWhoseDataAllDeleted(version, mode)) - .isInstanceOf(RuntimeException.class); - } - else { - // v2 succeeds - super.testMetadataDeleteOnTableWithUnsupportedSpecsWhoseDataAllDeleted(version, mode); - } - } - @Test public void testSetOauth2ServerUriPropertyI() { @@ -206,10 +142,4 @@ public void testSetOauth2ServerUriPropertyI() assertEquals(catalog.properties().get(OAUTH2_SERVER_URI), authEndpoint); } - - @Override - public void testDeprecatedTablePropertiesCreateTable() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - } } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRestNestedNamespace.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRestNestedNamespace.java index 0a8ce9888b04f..300dc687817db 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRestNestedNamespace.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/rest/TestIcebergSmokeRestNestedNamespace.java @@ -53,7 +53,6 @@ import static java.nio.file.Files.createTempDirectory; import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; @Test public class TestIcebergSmokeRestNestedNamespace @@ -263,19 +262,6 @@ protected void testCreatePartitionedTableAs(Session session, FileFormat fileForm dropTable(session, "test_create_partitioned_table_as_" + fileFormatString); } - @Test - @Override - public void testCreateTableWithFormatVersion() - { - // v1 table create fails due to Iceberg REST catalog bug (see: https://github.com/apache/iceberg/issues/8756) - assertThatThrownBy(() -> testCreateTableWithFormatVersion("1", "copy-on-write")) - .hasCauseInstanceOf(RuntimeException.class) - .hasStackTraceContaining("Cannot downgrade v2 table to v1"); - - // v2 succeeds - testCreateTableWithFormatVersion("2", "merge-on-read"); - } - @Override // override due to double quotes around nested namespace protected void testCreateTableWithFormatVersion(String formatVersion, String defaultDeleteMode) { diff --git a/presto-iceberg/src/test/java/org/apache/iceberg/rest/IcebergRestCatalogServlet.java b/presto-iceberg/src/test/java/org/apache/iceberg/rest/IcebergRestCatalogServlet.java index b783ab2d7dcf9..8b48eaf7232dc 100644 --- a/presto-iceberg/src/test/java/org/apache/iceberg/rest/IcebergRestCatalogServlet.java +++ b/presto-iceberg/src/test/java/org/apache/iceberg/rest/IcebergRestCatalogServlet.java @@ -17,20 +17,20 @@ import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.MalformedJwtException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.HttpHeaders; import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.io.CharStreams; -import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.HTTPRequest.HTTPMethod; import org.apache.iceberg.rest.RESTCatalogAdapter.Route; import org.apache.iceberg.rest.responses.ErrorResponse; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; import org.apache.iceberg.util.Pair; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import java.io.IOException; import java.io.InputStreamReader; import java.io.Reader; @@ -42,8 +42,8 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; import static java.lang.String.format; -import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; /** * The IcebergRestCatalogServlet provides a servlet implementation used in combination with a @@ -54,6 +54,11 @@ public class IcebergRestCatalogServlet { private static final Logger LOG = Logger.get(IcebergRestCatalogServlet.class); + private static final String SUBJECT_TOKEN = "subject_token"; + private static final String GRANT_TYPE = "grant_type"; + private static final String TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"; + private static final String TOKEN_EXCHANGE_PREFIX = "token-exchange-token:sub="; + private final RESTCatalogAdapter restCatalogAdapter; private final Map responseHeaders = ImmutableMap.of(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.getMimeType()); @@ -106,36 +111,51 @@ protected void execute(ServletRequestContext context, HttpServletResponse respon } if (context.error().isPresent()) { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST); - RESTObjectMapper.mapper().writeValue(response.getWriter(), context.error().get()); + ErrorResponse error = context.error().get(); + response.setStatus(error.code()); + RESTObjectMapper.mapper().writeValue(response.getWriter(), error); return; } + // Handle token exchange requests specially to preserve user identity + if (context.route() == Route.TOKENS && context.body() instanceof Map) { + @SuppressWarnings("unchecked") + Map tokenRequest = (Map) context.body(); + String grantType = tokenRequest.get(GRANT_TYPE); + String subjectToken = tokenRequest.get(SUBJECT_TOKEN); + + if (TOKEN_EXCHANGE_GRANT_TYPE.equals(grantType) && subjectToken != null) { + // Return the subject token prefixed so that authorization check can extract the original JWT + String responseToken = TOKEN_EXCHANGE_PREFIX + subjectToken; + OAuthTokenResponse oauthResponse = OAuthTokenResponse.builder() + .withToken(responseToken) + .withTokenType("Bearer") + .withIssuedTokenType("urn:ietf:params:oauth:token-type:access_token") + .build(); + RESTObjectMapper.mapper().writeValue(response.getWriter(), oauthResponse); + return; + } + } + try { + HTTPRequest request = restCatalogAdapter.buildRequest( + context.method(), + context.path(), + context.queryParams(), + context.headers(), + context.body()); Object responseBody = restCatalogAdapter.execute( - context.method(), - context.path(), - context.queryParams(), - context.body(), + request, context.route().responseClass(), - context.headers(), - handle(response)); + handleResponseError(response), + handleResponseHeader(response)); if (responseBody != null) { RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); } } catch (RESTException e) { - if ((context.route() == Route.LOAD_TABLE && e.getLocalizedMessage().contains("NoSuchTableException")) || - (context.route() == Route.LOAD_VIEW && e.getLocalizedMessage().contains("NoSuchViewException"))) { - // Suppress stack trace for load_table requests, most of which occur immediately - // preceding a create_table request - LOG.warn("Table at endpoint %s does not exist", context.path()); - } - else { - LOG.error(e, "Error processing REST request at endpoint %s", context.path()); - } response.setStatus(SC_INTERNAL_SERVER_ERROR); } catch (Exception e) { @@ -144,7 +164,14 @@ protected void execute(ServletRequestContext context, HttpServletResponse respon } } - protected Consumer handle(HttpServletResponse response) + private Consumer> handleResponseHeader(HttpServletResponse response) + { + return (responseHeaders) -> { + throw new RuntimeException("Unexpected response header: " + responseHeaders); + }; + } + + protected Consumer handleResponseError(HttpServletResponse response) { return (errorResponse) -> { response.setStatus(errorResponse.code()); diff --git a/presto-iceberg/src/test/resources/com/facebook/presto/iceberg/security.json b/presto-iceberg/src/test/resources/com/facebook/presto/iceberg/security.json new file mode 100644 index 0000000000000..c30b2b4d0bf53 --- /dev/null +++ b/presto-iceberg/src/test/resources/com/facebook/presto/iceberg/security.json @@ -0,0 +1,66 @@ +{ + "tables": [ + { + "user": "iceberg", + "privileges": [ + "SELECT", "INSERT", "DELETE", "OWNERSHIP" + ] + }, + { + "user": "alice", + "privileges": [ + "SELECT", "INSERT" + ] + }, + { + "user": "bob", + "privileges": [ + "SELECT" + ] + } + ], + "schemas": [ + { + "user": "iceberg", + "owner": true + } + ], + "sessionProperties": [ + { + "user": ".*", + "property": ".*", + "allow": true + } + ], + "procedures": [ + { + "user": "iceberg", + "schema": ".*", + "privileges": ["EXECUTE"] + }, + { + "user": "alice", + "schema": "system", + "privileges": ["EXECUTE"] + }, + { + "user": "bob", + "schema": "system", + "procedure": "invalidate_statistics_file_cache", + "privileges": ["EXECUTE"] + }, + { + "user": "bob", + "schema": "system", + "procedure": "rewrite_data_files", + "privileges": ["EXECUTE"] + }, + { + "user": "joe", + "schema": "system", + "procedure": "other_procedure", + "privileges": ["EXECUTE"] + } + ] +} + diff --git a/presto-jdbc/pom.xml b/presto-jdbc/pom.xml index aaa53ad254121..406aaf3854e27 100644 --- a/presto-jdbc/pom.xml +++ b/presto-jdbc/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-jdbc @@ -14,6 +14,7 @@ ${project.parent.basedir} com.facebook.presto.jdbc.internal + true @@ -62,7 +63,7 @@ - io.airlift + com.facebook.airlift units @@ -110,18 +111,48 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true + + com.facebook.airlift + concurrent + test + + + + com.facebook.airlift + configuration + test + + + + com.facebook.airlift + http-server + test + + + + com.facebook.airlift + jaxrs + test + + com.facebook.presto presto-testng-services test + + com.facebook.presto + presto-memory + test + + com.facebook.presto presto-main-base @@ -170,6 +201,18 @@ test + + jakarta.servlet + jakarta.servlet-api + test + + + + jakarta.ws.rs + jakarta.ws.rs-api + test + + org.testng testng @@ -194,12 +237,6 @@ test - - com.facebook.airlift - concurrent - test - - com.google.inject guice @@ -217,11 +254,13 @@ jjwt-api test + io.jsonwebtoken jjwt-impl test + io.jsonwebtoken jjwt-jackson @@ -247,7 +286,15 @@ - + + org.apache.maven.plugins + maven-dependency-plugin + + + com.facebook.airlift:security + + + org.apache.maven.plugins maven-shade-plugin @@ -335,7 +382,10 @@ com.google.thirdparty ${shadeBase}.guava - + + com.google.errorprone + ${shadeBase}.com.google.errorprone + io.airlift ${shadeBase}.io.airlift @@ -344,13 +394,21 @@ com.facebook.airlift ${shadeBase}.com.facebook.airlift + + jakarta.annotation + ${shadeBase}.jakarta.annotation + javax.annotation ${shadeBase}.javax.annotation javax.inject - ${shadeBase}.inject + ${shadeBase}.javax.inject + + + jakarta.inject + ${shadeBase}.jakarta.inject org.openjdk.jol @@ -380,6 +438,10 @@ com.google.j2objc ${shadeBase}.j2objc + + com.google.code + ${shadeBase}.com.google.code + org.apache.commons ${shadeBase}.apache.commons @@ -401,6 +463,7 @@ META-INF/services/com.fasterxml.** META-INF.versions.9.module-info LICENSE + ValidationMessages.properties diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java index c6c8e180820d7..931761723d9e5 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/ConnectionProperties.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.jdbc; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.client.auth.external.ExternalRedirectStrategy; +import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.net.HostAndPort; @@ -26,12 +29,14 @@ import java.util.Properties; import java.util.Set; import java.util.function.Predicate; +import java.util.stream.StreamSupport; import static com.facebook.presto.jdbc.AbstractConnectionProperty.ClassListConverter.CLASS_LIST_CONVERTER; import static com.facebook.presto.jdbc.AbstractConnectionProperty.HttpProtocolConverter.HTTP_PROTOCOL_CONVERTER; import static com.facebook.presto.jdbc.AbstractConnectionProperty.ListValidateConvertor.LIST_VALIDATE_CONVERTOR; import static com.facebook.presto.jdbc.AbstractConnectionProperty.StringMapConverter.STRING_MAP_CONVERTER; import static com.facebook.presto.jdbc.AbstractConnectionProperty.checkedPredicate; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Collections.unmodifiableMap; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; @@ -68,6 +73,11 @@ final class ConnectionProperties public static final ConnectionProperty FOLLOW_REDIRECTS = new FollowRedirects(); public static final ConnectionProperty SSL_KEY_STORE_TYPE = new SSLKeyStoreType(); public static final ConnectionProperty SSL_TRUST_STORE_TYPE = new SSLTrustStoreType(); + public static final ConnectionProperty EXTERNAL_AUTHENTICATION = new ExternalAuthentication(); + public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TIMEOUT = new ExternalAuthenticationTimeout(); + public static final ConnectionProperty EXTERNAL_AUTHENTICATION_TOKEN_CACHE = new ExternalAuthenticationTokenCache(); + public static final ConnectionProperty> EXTERNAL_AUTHENTICATION_REDIRECT_HANDLERS = new ExternalAuthenticationRedirectHandlers(); + private static final Set> ALL_PROPERTIES = ImmutableSet.>builder() .add(USER) .add(PASSWORD) @@ -98,6 +108,10 @@ final class ConnectionProperties .add(QUERY_INTERCEPTORS) .add(VALIDATE_NEXTURI_SOURCE) .add(FOLLOW_REDIRECTS) + .add(EXTERNAL_AUTHENTICATION) + .add(EXTERNAL_AUTHENTICATION_TIMEOUT) + .add(EXTERNAL_AUTHENTICATION_TOKEN_CACHE) + .add(EXTERNAL_AUTHENTICATION_REDIRECT_HANDLERS) .build(); private static final Map> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream() @@ -411,4 +425,54 @@ public SSLKeyStoreType() super("SSLKeyStoreType", Optional.of(KeyStore.getDefaultType()), NOT_REQUIRED, ALLOWED, STRING_CONVERTER); } } + + private static Predicate isExternalAuthEnabled() + { + return checkedPredicate(properties -> EXTERNAL_AUTHENTICATION.getValue(properties).isPresent()); + } + + private static class ExternalAuthentication + extends AbstractConnectionProperty + { + public ExternalAuthentication() + { + super("externalAuthentication", Optional.of("false"), NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER); + } + } + + private static class ExternalAuthenticationTimeout + extends AbstractConnectionProperty + { + public ExternalAuthenticationTimeout() + { + super("externalAuthenticationTimeout", NOT_REQUIRED, isExternalAuthEnabled(), Duration::valueOf); + } + } + + private static class ExternalAuthenticationTokenCache + extends AbstractConnectionProperty + { + public ExternalAuthenticationTokenCache() + { + super("externalAuthenticationTokenCache", Optional.of(KnownTokenCache.NONE.name()), NOT_REQUIRED, ALLOWED, KnownTokenCache::valueOf); + } + } + + private static class ExternalAuthenticationRedirectHandlers + extends AbstractConnectionProperty> + { + private static final Splitter ENUM_SPLITTER = Splitter.on(',').trimResults().omitEmptyStrings(); + + public ExternalAuthenticationRedirectHandlers() + { + super("externalAuthenticationRedirectHandlers", Optional.of("OPEN"), NOT_REQUIRED, ALLOWED, ExternalAuthenticationRedirectHandlers::parse); + } + + public static List parse(String value) + { + return StreamSupport.stream(ENUM_SPLITTER.split(value).spliterator(), false) + .map(ExternalRedirectStrategy::valueOf) + .collect(toImmutableList()); + } + } } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/KnownTokenCache.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/KnownTokenCache.java new file mode 100644 index 0000000000000..8792075be56d1 --- /dev/null +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/KnownTokenCache.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.facebook.presto.client.auth.external.KnownToken; + +public enum KnownTokenCache +{ + NONE { + @Override + KnownToken create() + { + return KnownToken.local(); + } + }; + + abstract KnownToken create(); +} diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java index 1f3f760631850..d27be99e448fe 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoConnection.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.jdbc; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ClientSession; import com.facebook.presto.client.ServerInfo; import com.facebook.presto.client.StatementClient; @@ -23,7 +24,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Ints; -import io.airlift.units.Duration; import java.net.URI; import java.nio.charset.CharsetEncoder; diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java index a49db1ea6e40a..73f00a19c1833 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java @@ -1144,8 +1144,7 @@ public boolean insertsAreDetected(int type) public boolean supportsBatchUpdates() throws SQLException { - // TODO: support batch updates - return false; + return true; } @Override diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java index e355b58f32331..1dcfc47b152d8 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDriverUri.java @@ -15,6 +15,12 @@ import com.facebook.presto.client.ClientException; import com.facebook.presto.client.OkHttpUtil; +import com.facebook.presto.client.auth.external.CompositeRedirectHandler; +import com.facebook.presto.client.auth.external.ExternalAuthenticator; +import com.facebook.presto.client.auth.external.HttpTokenPoller; +import com.facebook.presto.client.auth.external.RedirectHandler; +import com.facebook.presto.client.auth.external.TokenPoller; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -27,6 +33,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.sql.SQLException; +import java.time.Duration; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -35,6 +42,7 @@ import java.util.Optional; import java.util.Properties; import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.client.GCSOAuthInterceptor.GCS_CREDENTIALS_PATH_KEY; import static com.facebook.presto.client.GCSOAuthInterceptor.GCS_OAUTH_SCOPES_KEY; @@ -51,6 +59,10 @@ import static com.facebook.presto.jdbc.ConnectionProperties.CLIENT_TAGS; import static com.facebook.presto.jdbc.ConnectionProperties.CUSTOM_HEADERS; import static com.facebook.presto.jdbc.ConnectionProperties.DISABLE_COMPRESSION; +import static com.facebook.presto.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION; +import static com.facebook.presto.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_REDIRECT_HANDLERS; +import static com.facebook.presto.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TIMEOUT; +import static com.facebook.presto.jdbc.ConnectionProperties.EXTERNAL_AUTHENTICATION_TOKEN_CACHE; import static com.facebook.presto.jdbc.ConnectionProperties.EXTRA_CREDENTIALS; import static com.facebook.presto.jdbc.ConnectionProperties.FOLLOW_REDIRECTS; import static com.facebook.presto.jdbc.ConnectionProperties.HTTP_PROTOCOLS; @@ -88,7 +100,7 @@ final class PrestoDriverUri private static final Splitter QUERY_SPLITTER = Splitter.on('&').omitEmptyStrings(); private static final Splitter ARG_SPLITTER = Splitter.on('=').limit(2); - + private static final AtomicReference REDIRECT_HANDLER = new AtomicReference<>(null); private final HostAndPort address; private final URI uri; @@ -282,6 +294,30 @@ public void setupClient(OkHttpClient.Builder builder) } builder.addInterceptor(tokenAuth(ACCESS_TOKEN.getValue(properties).get())); } + + if (EXTERNAL_AUTHENTICATION.getValue(properties).orElse(false)) { + if (!useSecureConnection) { + throw new SQLException("Authentication using external authorization requires SSL to be enabled"); + } + + // create HTTP client that shares the same settings, but without the external authenticator + TokenPoller poller = new HttpTokenPoller(builder.build()); + + Duration timeout = EXTERNAL_AUTHENTICATION_TIMEOUT.getValue(properties) + .map(value -> Duration.ofMillis(value.toMillis())) + .orElse(Duration.ofMinutes(2)); + + KnownTokenCache knownTokenCache = EXTERNAL_AUTHENTICATION_TOKEN_CACHE.getValue(properties).get(); + Optional configuredHandler = EXTERNAL_AUTHENTICATION_REDIRECT_HANDLERS.getValue(properties) + .map(CompositeRedirectHandler::new) + .map(RedirectHandler.class::cast); + RedirectHandler redirectHandler = Optional.ofNullable(REDIRECT_HANDLER.get()) + .orElseGet(() -> configuredHandler.orElseThrow(() -> new RuntimeException("External authentication redirect handler is not configured"))); + ExternalAuthenticator authenticator = new ExternalAuthenticator(redirectHandler, poller, knownTokenCache.create(), timeout); + + builder.authenticator(authenticator); + builder.addInterceptor(authenticator); + } } catch (ClientException e) { throw new SQLException(e.getMessage(), e); @@ -419,4 +455,10 @@ private static void validateConnectionProperties(Properties connectionProperties property.validate(connectionProperties); } } + + @VisibleForTesting + static void setRedirectHandler(RedirectHandler handler) + { + REDIRECT_HANDLER.set(requireNonNull(handler, "handler is null")); + } } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java index 9720e6b167e0f..bd1af7ea2233b 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java @@ -23,6 +23,7 @@ import java.math.BigDecimal; import java.net.URL; import java.sql.Array; +import java.sql.BatchUpdateException; import java.sql.Blob; import java.sql.Clob; import java.sql.Date; @@ -42,6 +43,7 @@ import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.HashMap; import java.util.List; @@ -72,9 +74,11 @@ public class PrestoPreparedStatement implements PreparedStatement { private final Map parameters = new HashMap<>(); + private final List> batchValues = new ArrayList<>(); private final String statementName; private final String originalSql; private boolean isClosed; + private boolean isBatch; PrestoPreparedStatement(PrestoConnection connection, String statementName, String sql) throws SQLException @@ -101,7 +105,8 @@ public void close() public ResultSet executeQuery() throws SQLException { - if (!super.execute(getExecuteSql())) { + requireNonBatchStatement(); + if (!super.execute(getExecuteSql(statementName, toValues(parameters)))) { throw new SQLException("Prepared SQL statement is not a query: " + originalSql); } return getResultSet(); @@ -111,6 +116,7 @@ public ResultSet executeQuery() public int executeUpdate() throws SQLException { + requireNonBatchStatement(); return Ints.saturatedCast(executeLargeUpdate()); } @@ -118,7 +124,8 @@ public int executeUpdate() public long executeLargeUpdate() throws SQLException { - if (super.execute(getExecuteSql())) { + requireNonBatchStatement(); + if (super.execute(getExecuteSql(statementName, toValues(parameters)))) { throw new SQLException("Prepared SQL is not an update statement: " + originalSql); } return getLargeUpdateCount(); @@ -128,7 +135,8 @@ public long executeLargeUpdate() public boolean execute() throws SQLException { - return super.execute(getExecuteSql()); + requireNonBatchStatement(); + return super.execute(getExecuteSql(statementName, toValues(parameters))); } @Override @@ -430,7 +438,41 @@ else if (x instanceof Timestamp) { public void addBatch() throws SQLException { - throw new NotImplementedException("PreparedStatement", "addBatch"); + checkOpen(); + batchValues.add(toValues(parameters)); + isBatch = true; + } + + @Override + public void clearBatch() + throws SQLException + { + checkOpen(); + batchValues.clear(); + isBatch = false; + } + + @Override + public int[] executeBatch() + throws SQLException + { + try { + int[] batchUpdateCounts = new int[batchValues.size()]; + for (int i = 0; i < batchValues.size(); i++) { + try { + super.execute(getExecuteSql(statementName, batchValues.get(i))); + batchUpdateCounts[i] = getUpdateCount(); + } + catch (SQLException e) { + long[] updateCounts = Arrays.stream(batchUpdateCounts).mapToLong(j -> j).toArray(); + throw new BatchUpdateException(e.getMessage(), e.getSQLState(), e.getErrorCode(), updateCounts, e.getCause()); + } + } + return batchUpdateCounts; + } + finally { + clearBatch(); + } } @Override @@ -759,27 +801,34 @@ private void setParameter(int parameterIndex, String value) parameters.put(parameterIndex - 1, value); } - private void formatParametersTo(StringBuilder builder) + private static List toValues(Map parameters) throws SQLException { - List values = new ArrayList<>(); + ImmutableList.Builder values = ImmutableList.builder(); for (int index = 0; index < parameters.size(); index++) { if (!parameters.containsKey(index)) { throw new SQLException("No value specified for parameter " + (index + 1)); } values.add(parameters.get(index)); } - Joiner.on(", ").appendTo(builder, values); + return values.build(); } - private String getExecuteSql() + private void requireNonBatchStatement() throws SQLException + { + if (isBatch) { + throw new SQLException("Batch prepared statement must be executed using executeBatch method"); + } + } + + private static String getExecuteSql(String statementName, List values) { StringBuilder sql = new StringBuilder(); sql.append("EXECUTE ").append(statementName); - if (!parameters.isEmpty()) { + if (!values.isEmpty()) { sql.append(" USING "); - formatParametersTo(sql); + Joiner.on(", ").appendTo(sql, values); } return sql.toString(); } diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/WarningsManager.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/WarningsManager.java index 2c2998a541073..bd7c3c5387b0f 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/WarningsManager.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/WarningsManager.java @@ -15,9 +15,8 @@ import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.WarningCode; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.sql.SQLWarning; import java.util.HashSet; diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcExternalAuthentication.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcExternalAuthentication.java new file mode 100644 index 0000000000000..ede9a534cf2a3 --- /dev/null +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcExternalAuthentication.java @@ -0,0 +1,524 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.http.server.BasicPrincipal; +import com.facebook.airlift.log.Logging; +import com.facebook.presto.client.ClientException; +import com.facebook.presto.client.auth.external.DesktopBrowserRedirectHandler; +import com.facebook.presto.client.auth.external.RedirectException; +import com.facebook.presto.client.auth.external.RedirectHandler; +import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.sql.parser.SqlParserOptions; +import com.facebook.presto.tpch.TpchPlugin; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Scopes; +import com.google.inject.multibindings.Multibinder; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ConcurrentModificationException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.IntSupplier; + +import static com.facebook.airlift.configuration.ConditionalModule.installModuleIf; +import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; +import static com.facebook.airlift.testing.Closeables.closeAll; +import static com.facebook.presto.jdbc.PrestoDriverUri.setRedirectHandler; +import static com.facebook.presto.jdbc.TestPrestoDriver.waitForNodeRefresh; +import static com.google.common.io.Resources.getResource; +import static com.google.inject.Scopes.SINGLETON; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static com.google.inject.util.Modules.combine; +import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; +import static java.lang.String.format; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@Test(singleThreaded = true) +public class TestJdbcExternalAuthentication +{ + private static final String TEST_CATALOG = "test_catalog"; + private TestingPrestoServer server; + + @BeforeClass + public void setup() + throws Exception + { + Logging.initialize(); + + Map properties = ImmutableMap.builder() + .put("http-server.authentication.type", "TEST_EXTERNAL") + .put("http-server.https.enabled", "true") + .put("http-server.https.keystore.path", new File(getResource("localhost.keystore").toURI()).getPath()) + .put("http-server.https.keystore.key", "changeit") + .build(); + List additionalModules = ImmutableList.builder() + .add(new DummyExternalAuthModule(() -> server.getAddress().getPort())) + .build(); + + server = new TestingPrestoServer(true, properties, null, null, new SqlParserOptions(), additionalModules); + server.installPlugin(new TpchPlugin()); + server.createCatalog(TEST_CATALOG, "tpch"); + waitForNodeRefresh(server); + } + + @AfterClass(alwaysRun = true) + public void teardown() + throws Exception + { + closeAll(server); + server = null; + } + + @BeforeMethod(alwaysRun = true) + public void clearUpLoggingSessions() + { + invalidateAllTokens(); + } + + @Test + public void testSuccessfulAuthenticationWithHttpGetOnlyRedirectHandler() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler()); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThat(statement.execute("SELECT 123")).isTrue(); + } + } + + /** + * Ignored due to lack of ui environment with web-browser on CI servers. + * Still this test is useful for local environments. + */ + @Test(enabled = false) + public void testSuccessfulAuthenticationWithDefaultBrowserRedirect() + throws Exception + { + try (Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThat(statement.execute("SELECT 123")).isTrue(); + } + } + + @Test + public void testAuthenticationFailsAfterUnfinishedRedirect() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new NoOpRedirectHandler()); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThatThrownBy(() -> statement.execute("SELECT 123")) + .isInstanceOf(SQLException.class); + } + } + + @Test + public void testAuthenticationFailsAfterRedirectException() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new FailingRedirectHandler()); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThatThrownBy(() -> statement.execute("SELECT 123")) + .isInstanceOf(SQLException.class) + .hasCauseExactlyInstanceOf(RedirectException.class); + } + } + + @Test + public void testAuthenticationFailsAfterServerAuthenticationFailure() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler()); + AutoCloseable ignore2 = TokenPollingErrorFixture.withPollingError("error occurred during token polling"); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThatThrownBy(() -> statement.execute("SELECT 123")) + .isInstanceOf(SQLException.class) + .hasMessage("error occurred during token polling"); + } + } + + @Test + public void testAuthenticationFailsAfterReceivingMalformedHeaderFromServer() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler()); + AutoCloseable ignored = WwwAuthenticateHeaderFixture.withWwwAuthenticate("Bearer no-valid-fields"); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + assertThatThrownBy(() -> statement.execute("SELECT 123")) + .isInstanceOf(SQLException.class) + .hasCauseInstanceOf(ClientException.class) + .hasMessage("Authentication failed: Authentication required"); + } + } + + @Test + public void testAuthenticationReusesObtainedTokenPerConnection() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler()); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + statement.execute("SELECT 123"); + statement.execute("SELECT 123"); + statement.execute("SELECT 123"); + + assertThat(countIssuedTokens()).isEqualTo(1); + } + } + + @Test + public void testAuthenticationAfterInitialTokenHasBeenInvalidated() + throws Exception + { + try (RedirectHandlerFixture ignore = RedirectHandlerFixture.withHandler(new HttpGetOnlyRedirectHandler()); + Connection connection = createConnection(); + Statement statement = connection.createStatement()) { + statement.execute("SELECT 123"); + + invalidateAllTokens(); + assertThat(countIssuedTokens()).isEqualTo(0); + + assertThat(statement.execute("SELECT 123")).isTrue(); + } + } + + private Connection createConnection() + throws Exception + { + String url = format("jdbc:presto://localhost:%s", server.getHttpsAddress().getPort()); + Properties properties = new Properties(); + properties.setProperty("SSL", "true"); + properties.setProperty("SSLTrustStorePath", new File(getResource("localhost.truststore").toURI()).getPath()); + properties.setProperty("SSLTrustStorePassword", "changeit"); + properties.setProperty("externalAuthentication", "true"); + properties.setProperty("externalAuthenticationTimeout", "2s"); + properties.setProperty("user", "test"); + return DriverManager.getConnection(url, properties); + } + + private static Multibinder authenticatorBinder(Binder binder) + { + return newSetBinder(binder, Authenticator.class); + } + + public static Module authenticatorModule(Class clazz, Module module) + { + Module authModule = binder -> authenticatorBinder(binder).addBinding().to(clazz).in(Scopes.SINGLETON); + return installModuleIf( + SecurityConfig.class, + config -> true, + combine(module, authModule)); + } + + private static class DummyExternalAuthModule + extends AbstractConfigurationAwareModule + { + private final IntSupplier port; + + public DummyExternalAuthModule(IntSupplier port) + { + this.port = requireNonNull(port, "port is null"); + } + + @Override + protected void setup(Binder ignored) + { + Module test = authenticatorModule(DummyAuthenticator.class, binder -> { + binder.bind(Authentications.class).in(SINGLETON); + binder.bind(IntSupplier.class).toInstance(port); + jaxrsBinder(binder).bind(DummyExternalAuthResources.class); + }); + + install(test); + } + } + + private static class Authentications + { + private final Map logginSessions = new ConcurrentHashMap<>(); + private final Set validTokens = ConcurrentHashMap.newKeySet(); + + public String startAuthentication() + { + String sessionId = UUID.randomUUID().toString(); + logginSessions.put(sessionId, ""); + return sessionId; + } + + public void logIn(String sessionId) + { + String token = sessionId + "_token"; + validTokens.add(token); + logginSessions.put(sessionId, token); + } + + public Optional getToken(String sessionId) + throws IllegalArgumentException + { + return Optional.ofNullable(logginSessions.get(sessionId)) + .filter(s -> !s.isEmpty()); + } + + public boolean verifyToken(String token) + { + return validTokens.contains(token); + } + + public void invalidateAllTokens() + { + validTokens.clear(); + } + + public int countValidTokens() + { + return validTokens.size(); + } + } + + private void invalidateAllTokens() + { + Authentications authentications = server.getInstance(Key.get(Authentications.class)); + authentications.invalidateAllTokens(); + } + + private int countIssuedTokens() + { + Authentications authentications = server.getInstance(Key.get(Authentications.class)); + return authentications.countValidTokens(); + } + + public static class DummyAuthenticator + implements Authenticator + { + private final IntSupplier port; + private final Authentications authentications; + + @Inject + public DummyAuthenticator(IntSupplier port, Authentications authentications) + { + this.port = requireNonNull(port, "port is null"); + this.authentications = requireNonNull(authentications, "authentications is null"); + } + + @Override + public Principal authenticate(HttpServletRequest request) + throws AuthenticationException + { + Optional authHeader = Optional.ofNullable(request.getHeader(AUTHORIZATION)); + List bearerHeaders = authHeader.isPresent() ? ImmutableList.of(authHeader.get()) : ImmutableList.of(); + if (bearerHeaders.stream() + .filter(header -> header.startsWith("Bearer ")) + .anyMatch(header -> authentications.verifyToken(header.substring("Bearer ".length())))) { + return new BasicPrincipal("user"); + } + + String sessionId = authentications.startAuthentication(); + + throw Optional.ofNullable(WwwAuthenticateHeaderFixture.HEADER.get()) + .map(header -> new AuthenticationException("Authentication required", header)) + .orElseGet(() -> new AuthenticationException( + "Authentication required", + format("Bearer x_redirect_server=\"http://localhost:%s/v1/authentications/dummy/logins/%s\", " + + "x_token_server=\"http://localhost:%s/v1/authentications/dummy/%s\"", + port.getAsInt(), sessionId, port.getAsInt(), sessionId))); + } + } + + @Path("/v1/authentications/dummy") + public static class DummyExternalAuthResources + { + private final Authentications authentications; + + @Inject + public DummyExternalAuthResources(Authentications authentications) + { + this.authentications = authentications; + } + + @GET + @Produces(TEXT_PLAIN) + @Path("logins/{sessionId}") + public String logInUser(@PathParam("sessionId") String sessionId) + { + authentications.logIn(sessionId); + return "User has been successfully logged in during " + sessionId + " session"; + } + + @GET + @Path("{sessionId}") + public Response getToken(@PathParam("sessionId") String sessionId, @Context HttpServletRequest request) + { + try { + return Optional.ofNullable(TokenPollingErrorFixture.ERROR.get()) + .map(error -> Response.ok(format("{ \"error\" : \"%s\"}", error), APPLICATION_JSON_TYPE).build()) + .orElseGet(() -> authentications.getToken(sessionId) + .map(token -> Response.ok(format("{ \"token\" : \"%s\"}", token), APPLICATION_JSON_TYPE).build()) + .orElseGet(() -> Response.ok(format("{ \"nextUri\" : \"%s\" }", request.getRequestURI()), APPLICATION_JSON_TYPE).build())); + } + catch (IllegalArgumentException ex) { + return Response.status(NOT_FOUND).build(); + } + } + } + + public static class HttpGetOnlyRedirectHandler + implements RedirectHandler + { + @Override + public void redirectTo(URI uri) + throws RedirectException + { + OkHttpClient client = new OkHttpClient(); + + Request request = new Request.Builder() + .url(HttpUrl.get(uri.toString())) + .build(); + + try (okhttp3.Response response = client.newCall(request).execute()) { + if (response.code() != HTTP_OK) { + throw new RedirectException("HTTP GET failed with status " + response.code()); + } + } + catch (IOException e) { + throw new RedirectException("Redirection failed", e); + } + } + } + + public static class NoOpRedirectHandler + implements RedirectHandler + { + @Override + public void redirectTo(URI uri) + throws RedirectException + {} + } + + public static class FailingRedirectHandler + implements RedirectHandler + { + @Override + public void redirectTo(URI uri) + throws RedirectException + { + throw new RedirectException("Redirect to uri has failed " + uri); + } + } + + static class RedirectHandlerFixture + implements AutoCloseable + { + private static final RedirectHandlerFixture INSTANCE = new RedirectHandlerFixture(); + + private RedirectHandlerFixture() {} + + public static RedirectHandlerFixture withHandler(RedirectHandler handler) + { + setRedirectHandler(handler); + return INSTANCE; + } + + @Override + public void close() + { + setRedirectHandler(new DesktopBrowserRedirectHandler()); + } + } + + static class TokenPollingErrorFixture + implements AutoCloseable + { + private static final AtomicReference ERROR = new AtomicReference<>(null); + + public static AutoCloseable withPollingError(String error) + { + if (ERROR.compareAndSet(null, error)) { + return new TokenPollingErrorFixture(); + } + throw new ConcurrentModificationException("polling errors can't be invoked in parallel"); + } + + @Override + public void close() + { + ERROR.set(null); + } + } + + static class WwwAuthenticateHeaderFixture + implements AutoCloseable + { + private static final AtomicReference HEADER = new AtomicReference<>(null); + + public static AutoCloseable withWwwAuthenticate(String header) + { + if (HEADER.compareAndSet(null, header)) { + return new WwwAuthenticateHeaderFixture(); + } + throw new ConcurrentModificationException("with WWW-Authenticate header can't be invoked in parallel"); + } + + @Override + public void close() + { + HEADER.set(null); + } + } +} diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java index 0cfa4668d81bf..6d46f36475e61 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logging; import com.facebook.presto.plugin.blackhole.BlackHolePlugin; +import com.facebook.presto.plugin.memory.MemoryPlugin; import com.facebook.presto.server.testing.TestingPrestoServer; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -40,10 +41,13 @@ import static com.facebook.presto.jdbc.TestPrestoDriver.closeQuietly; import static com.facebook.presto.jdbc.TestPrestoDriver.waitForNodeRefresh; +import static com.facebook.presto.jdbc.TestingJdbcUtils.list; +import static com.facebook.presto.jdbc.TestingJdbcUtils.readRows; import static com.google.common.base.Strings.repeat; import static com.google.common.primitives.Ints.asList; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -61,7 +65,9 @@ public void setup() Logging.initialize(); server = new TestingPrestoServer(); server.installPlugin(new BlackHolePlugin()); + server.installPlugin(new MemoryPlugin()); server.createCatalog("blackhole", "blackhole"); + server.createCatalog("memory", "memory"); waitForNodeRefresh(server); try (Connection connection = createConnection(); @@ -636,6 +642,88 @@ public void testInvalidConversions() assertInvalidConversion((ps, i) -> ps.setObject(i, "abc", Types.SMALLINT), "Cannot convert instance of java.lang.String to SQL type " + Types.SMALLINT); } + @Test + public void testExecuteBatch() + throws Exception + { + try (Connection connection = createConnection("memory", "default")) { + try (Statement statement = connection.createStatement()) { + statement.execute("CREATE TABLE test_execute_batch(c_int integer)"); + } + + try (PreparedStatement preparedStatement = connection.prepareStatement( + "INSERT INTO test_execute_batch VALUES (?)")) { + // Run executeBatch before addBatch + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + for (int i = 0; i < 3; i++) { + preparedStatement.setInt(1, i); + preparedStatement.addBatch(); + } + assertEquals(preparedStatement.executeBatch(), new int[] {1, 1, 1}); + + try (Statement statement = connection.createStatement()) { + ResultSet resultSet = statement.executeQuery("SELECT c_int FROM test_execute_batch"); + assertThat(readRows(resultSet)) + .containsExactlyInAnyOrder( + list(0), + list(1), + list(2)); + } + + // Make sure the above executeBatch cleared existing batch + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + // clearBatch removes added batch and cancel batch mode + preparedStatement.setBoolean(1, true); + preparedStatement.clearBatch(); + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + preparedStatement.setInt(1, 1); + assertEquals(preparedStatement.executeUpdate(), 1); + } + + try (Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE test_execute_batch"); + } + } + } + + @Test + public void testInvalidExecuteBatch() + throws Exception + { + try (Connection connection = createConnection("blackhole", "blackhole")) { + try (Statement statement = connection.createStatement()) { + statement.execute("CREATE TABLE test_invalid_execute_batch(c_int integer)"); + } + + try (PreparedStatement statement = connection.prepareStatement( + "INSERT INTO test_invalid_execute_batch VALUES (?)")) { + statement.setInt(1, 1); + statement.addBatch(); + + String message = "Batch prepared statement must be executed using executeBatch method"; + assertThatThrownBy(statement::executeQuery) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::executeUpdate) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::executeLargeUpdate) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::execute) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + + try (Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE test_invalid_execute_batch"); + } + } + } + private void assertInvalidConversion(Binder binder, String message) { assertThatThrownBy(() -> assertParameter(null, Types.NULL, binder)) diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcResultSet.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcResultSet.java index 3990213ac3963..ef2effac22da8 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcResultSet.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcResultSet.java @@ -214,11 +214,11 @@ public void testObjectTypes() }); checkRepresentation("TIMESTAMP '1970-01-01 00:14:15.227 Europe/Warsaw'", Types.TIMESTAMP /* TODO TIMESTAMP_WITH_TIMEZONE */, (rs, column) -> { - assertEquals(rs.getObject(column), Timestamp.valueOf(LocalDateTime.of(1969, 12, 31, 15, 14, 15, 227_000_000))); // TODO this should represent TIMESTAMP '1970-01-01 00:14:15.227 Europe/Warsaw' + assertEquals(rs.getObject(column), Timestamp.valueOf(LocalDateTime.of(1969, 12, 31, 16, 14, 15, 227_000_000))); // TODO this should represent TIMESTAMP '1970-01-01 00:14:15.227 Europe/Warsaw' assertThrows(() -> rs.getDate(column)); assertThrows(() -> rs.getTime(column)); // TODO this should fail, as there no java.sql.Timestamp representation for TIMESTAMP '1970-01-01 00:14:15.227ó' in America/Bahia_Banderas - assertEquals(rs.getTimestamp(column), Timestamp.valueOf(LocalDateTime.of(1969, 12, 31, 15, 14, 15, 227_000_000))); + assertEquals(rs.getTimestamp(column), Timestamp.valueOf(LocalDateTime.of(1969, 12, 31, 16, 14, 15, 227_000_000))); }); statement.execute("CREATE TYPE cat.sch.dist AS integer"); diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java index 1d91cc3bf05d7..667023ff947e9 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestPrestoDriver.java @@ -14,6 +14,7 @@ package com.facebook.presto.jdbc; import com.facebook.airlift.log.Logging; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; @@ -42,7 +43,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.testng.annotations.AfterClass; @@ -81,6 +81,7 @@ import static com.facebook.airlift.testing.Assertions.assertContains; import static com.facebook.airlift.testing.Assertions.assertInstanceOf; import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.common.type.CharType.createCharType; import static com.facebook.presto.common.type.DecimalType.createDecimalType; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; @@ -90,7 +91,6 @@ import static com.facebook.presto.testing.TestingSession.TESTING_CATALOG; import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog; import static com.facebook.presto.tests.AbstractTestQueries.TEST_CATALOG_PROPERTIES; -import static io.airlift.units.Duration.nanosSince; import static java.lang.Float.POSITIVE_INFINITY; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java index 71e6c5d042941..c3b37334781a7 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestQueryExecutor.java @@ -14,8 +14,8 @@ package com.facebook.presto.jdbc; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.ServerInfo; -import io.airlift.units.Duration; import okhttp3.OkHttpClient; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java new file mode 100644 index 0000000000000..26ad29b209e50 --- /dev/null +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.google.common.collect.ImmutableList; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Arrays.asList; + +public class TestingJdbcUtils +{ + private TestingJdbcUtils() {} + + public static List> readRows(ResultSet rs) + throws SQLException + { + ImmutableList.Builder> rows = ImmutableList.builder(); + int columnCount = rs.getMetaData().getColumnCount(); + while (rs.next()) { + List row = new ArrayList<>(); + for (int i = 1; i <= columnCount; i++) { + row.add(rs.getObject(i)); + } + rows.add(row); + } + return rows.build(); + } + + @SafeVarargs + public static List list(T... elements) + { + return asList(elements); + } +} diff --git a/presto-jmx/pom.xml b/presto-jmx/pom.xml index 291bceea12690..f20dad633d3d3 100644 --- a/presto-jmx/pom.xml +++ b/presto-jmx/pom.xml @@ -4,15 +4,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-jmx + presto-jmx Presto - JMX Connector presto-plugin ${project.parent.basedir} + true @@ -47,13 +49,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true @@ -63,8 +65,13 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api + + + + jakarta.inject + jakarta.inject-api @@ -86,7 +93,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -98,7 +105,7 @@ - io.airlift + com.facebook.airlift units provided @@ -152,4 +159,19 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.fasterxml.jackson.core:jackson-databind + javax.inject:javax.inject + + + + + diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxColumnHandle.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxColumnHandle.java index ff3883e022f84..d30bac91441b4 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxColumnHandle.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxColumnHandle.java @@ -87,4 +87,12 @@ public ColumnMetadata getColumnMetadata() .setType(columnType) .build(); } + + public ColumnMetadata getColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(columnType) + .build(); + } } diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnector.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnector.java index 23b2594970ba3..01cba8e6b43a9 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnector.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnector.java @@ -17,8 +17,7 @@ import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnectorConfig.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnectorConfig.java index d220c67c71484..7a109004d2b13 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnectorConfig.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxConnectorConfig.java @@ -14,14 +14,13 @@ package com.facebook.presto.connector.jmx; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.Set; import java.util.regex.Pattern; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxHistoricalData.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxHistoricalData.java index 6ba56912c2fe8..cd4133214f95e 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxHistoricalData.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxHistoricalData.java @@ -15,8 +15,7 @@ import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxMetadata.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxMetadata.java index d67b76c046a64..a9de60e9f0699 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxMetadata.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxMetadata.java @@ -33,8 +33,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; import com.google.common.collect.Streams; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.management.JMException; import javax.management.MBeanAttributeInfo; import javax.management.MBeanInfo; @@ -175,7 +175,22 @@ private Stream getColumnHandles(MBeanInfo mbeanInfo) @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { - return ((JmxTableHandle) tableHandle).getTableMetadata(); + JmxTableHandle jmxTableHandle = (JmxTableHandle) tableHandle; + return getTableMetadata(session, getTableName(jmxTableHandle), jmxTableHandle.getColumnHandles()); + } + + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName schemaTableName, List columnHandles) + { + List columns = columnHandles.stream() + .map(column -> column.getColumnMetadata(normalizeIdentifier(session, column.getColumnName()))) + .collect(toImmutableList()); + + return new ConnectorTableMetadata(schemaTableName, columns); + } + + private static SchemaTableName getTableName(ConnectorTableHandle tableHandle) + { + return ((JmxTableHandle) tableHandle).getTableName(); } @Override @@ -216,7 +231,7 @@ private List listJmxTables() public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { JmxTableHandle jmxTableHandle = (JmxTableHandle) tableHandle; - return ImmutableMap.copyOf(Maps.uniqueIndex(jmxTableHandle.getColumnHandles(), column -> column.getColumnName().toLowerCase(ENGLISH))); + return ImmutableMap.copyOf(Maps.uniqueIndex(jmxTableHandle.getColumnHandles(), column -> normalizeIdentifier(session, column.getColumnName()))); } @Override @@ -248,11 +263,15 @@ public Map> listTableColumns(ConnectorSess } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { JmxTableHandle handle = (JmxTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new JmxTableLayoutHandle(handle, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxPeriodicSampler.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxPeriodicSampler.java index a92b6e2604da6..7fe5cf29c8483 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxPeriodicSampler.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxPeriodicSampler.java @@ -16,9 +16,8 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.SchemaTableName; import com.google.common.collect.ImmutableList; - -import javax.annotation.PostConstruct; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxRecordSetProvider.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxRecordSetProvider.java index 9c820636012a7..551758b57c654 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxRecordSetProvider.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxRecordSetProvider.java @@ -25,8 +25,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.management.Attribute; import javax.management.JMException; import javax.management.MBeanServer; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxSplitManager.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxSplitManager.java index f5c3a62830cdf..51311dcaa81ec 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxSplitManager.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/JmxSplitManager.java @@ -26,8 +26,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/util/RebindSafeMBeanServer.java b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/util/RebindSafeMBeanServer.java index 7579c1f0c9dbe..73dc0340a4343 100644 --- a/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/util/RebindSafeMBeanServer.java +++ b/presto-jmx/src/main/java/com/facebook/presto/connector/jmx/util/RebindSafeMBeanServer.java @@ -14,8 +14,8 @@ package com.facebook.presto.connector.jmx.util; import com.facebook.airlift.log.Logger; +import com.google.errorprone.annotations.ThreadSafe; -import javax.annotation.concurrent.ThreadSafe; import javax.management.Attribute; import javax.management.AttributeList; import javax.management.AttributeNotFoundException; diff --git a/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxConnectorConfig.java b/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxConnectorConfig.java index 950c39de138af..c5fca4f2720e2 100644 --- a/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxConnectorConfig.java +++ b/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxConnectorConfig.java @@ -14,9 +14,9 @@ package com.facebook.presto.connector.jmx; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxSplitManager.java b/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxSplitManager.java index b35b8663a918c..16ba589ef282f 100644 --- a/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxSplitManager.java +++ b/presto-jmx/src/test/java/com/facebook/presto/connector/jmx/TestJmxSplitManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.connector.jmx; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.common.predicate.NullableValue; import com.facebook.presto.common.predicate.TupleDomain; @@ -36,7 +37,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; diff --git a/presto-kafka/pom.xml b/presto-kafka/pom.xml index b0cc635db0fbd..4b6d0087e24c7 100644 --- a/presto-kafka/pom.xml +++ b/presto-kafka/pom.xml @@ -5,18 +5,29 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-kafka + presto-kafka Presto - Kafka Connector presto-plugin ${project.parent.basedir} - 2.12.2 + true + + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + 2.16.2 + + + + com.facebook.airlift @@ -54,8 +65,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -81,6 +92,11 @@ + + jakarta.inject + jakarta.inject-api + + javax.inject javax.inject @@ -120,7 +136,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -132,7 +148,7 @@ - io.airlift + com.facebook.airlift units provided @@ -156,11 +172,10 @@ runtime - com.101tec zkclient - 0.10 + 0.11 runtime @@ -184,7 +199,7 @@ org.apache.kafka - kafka_2.12 + kafka_2.13 ${dep.kafka.version} test test @@ -198,7 +213,7 @@ org.apache.kafka - kafka_2.12 + kafka_2.13 ${dep.kafka.version} test @@ -209,6 +224,12 @@ + + org.apache.kafka + kafka-metadata + test + + org.apache.kafka kafka-clients @@ -302,7 +323,6 @@ org.scala-lang scala-library - ${scala.version} test @@ -316,7 +336,7 @@ org.apache.maven.surefire surefire-testng - 3.0.0-M7 + 3.5.4 diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnector.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnector.java index a835ec0559c4a..d9e7b2c7fe0c0 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnector.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnector.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorConfig.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorConfig.java index 877535076369e..1ab44680a22d8 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorConfig.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorConfig.java @@ -15,14 +15,13 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.facebook.presto.kafka.schema.file.FileTableDescriptionSupplier; import com.facebook.presto.kafka.server.file.FileKafkaClusterMetadataSupplier; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.List; @@ -71,6 +70,8 @@ public class KafkaConnectorConfig */ private List resourceConfigFiles = ImmutableList.of(); + private boolean caseSensitiveNameMatching; + @NotNull public String getDefaultSchema() { @@ -174,4 +175,18 @@ public KafkaConnectorConfig setResourceConfigFiles(String files) .collect(toImmutableList()); return this; } + + public boolean isCaseSensitiveNameMatching() + { + return caseSensitiveNameMatching; + } + + @Config("case-sensitive-name-matching") + @ConfigDescription("Enable case-sensitive matching of schema, table names across the connector. " + + "When disabled, names are matched case-insensitively using lowercase normalization.") + public KafkaConnectorConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatchingEnabled) + { + this.caseSensitiveNameMatching = caseSensitiveNameMatchingEnabled; + return this; + } } diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorModule.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorModule.java index c3cdce91edeeb..e4166096d49ea 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorModule.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaConnectorModule.java @@ -27,8 +27,7 @@ import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Scopes; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.function.Function; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaInternalFieldDescription.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaInternalFieldDescription.java index 9de72206eb06c..8d4a42510d63f 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaInternalFieldDescription.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaInternalFieldDescription.java @@ -131,10 +131,10 @@ KafkaColumnHandle getColumnHandle(String connectorId, int index, boolean hidden) true); } - ColumnMetadata getColumnMetadata(boolean hidden) + ColumnMetadata getColumnMetadata(boolean hidden, String name) { return ColumnMetadata.builder() - .setName(columnName) + .setName(name) .setType(type) .setComment(comment) .setHidden(hidden) diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaMetadata.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaMetadata.java index 468b3a29068b9..806f156f7f11d 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaMetadata.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaMetadata.java @@ -37,8 +37,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -51,6 +50,7 @@ import static com.facebook.presto.kafka.KafkaHandleResolver.convertTableHandle; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.String.format; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; /** @@ -64,6 +64,7 @@ public class KafkaMetadata private final String connectorId; private final boolean hideInternalColumns; private final TableDescriptionSupplier tableDescriptionSupplier; + private final boolean caseSensitiveNameMatching; @Inject public KafkaMetadata( @@ -76,6 +77,7 @@ public KafkaMetadata( requireNonNull(kafkaConnectorConfig, "kafkaConfig is null"); this.hideInternalColumns = kafkaConnectorConfig.isHideInternalColumns(); this.tableDescriptionSupplier = requireNonNull(tableDescriptionSupplier, "tableDescriptionSupplier is null"); + this.caseSensitiveNameMatching = kafkaConnectorConfig.isCaseSensitiveNameMatching(); } @Override @@ -113,7 +115,7 @@ private static String getDataFormat(Optional fieldGroup) @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { - return getTableMetadata(convertTableHandle(tableHandle).toSchemaTableName()); + return getTableMetadata(session, convertTableHandle(tableHandle).toSchemaTableName()); } @Override @@ -183,7 +185,7 @@ public Map> listTableColumns(ConnectorSess for (SchemaTableName tableName : tableNames) { try { - columns.put(tableName, getTableMetadata(tableName).getColumns()); + columns.put(tableName, getTableMetadata(session, tableName).getColumns()); } catch (TableNotFoundException e) { // Normally it would mean the table disappeared during listing operation @@ -201,7 +203,11 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { KafkaTableHandle handle = convertTableHandle(table); long startTimestamp = 0; @@ -225,7 +231,7 @@ public List getTableLayouts(ConnectorSession session } ConnectorTableLayout layout = new ConnectorTableLayout(new KafkaTableLayoutHandle(handle, startTimestamp, endTimestamp)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -234,7 +240,7 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa return new ConnectorTableLayout(handle); } - private ConnectorTableMetadata getTableMetadata(SchemaTableName schemaTableName) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, SchemaTableName schemaTableName) { KafkaTopicDescription table = getRequiredTopicDescription(schemaTableName); @@ -244,7 +250,7 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName schemaTableName) List fields = key.getFields(); if (fields != null) { for (KafkaTopicFieldDescription fieldDescription : fields) { - builder.add(fieldDescription.getColumnMetadata()); + builder.add(fieldDescription.getColumnMetadata(normalizeIdentifier(session, fieldDescription.getName()))); } } }); @@ -253,13 +259,13 @@ private ConnectorTableMetadata getTableMetadata(SchemaTableName schemaTableName) List fields = message.getFields(); if (fields != null) { for (KafkaTopicFieldDescription fieldDescription : fields) { - builder.add(fieldDescription.getColumnMetadata()); + builder.add(fieldDescription.getColumnMetadata(normalizeIdentifier(session, fieldDescription.getName()))); } } }); for (KafkaInternalFieldDescription fieldDescription : KafkaInternalFieldDescription.values()) { - builder.add(fieldDescription.getColumnMetadata(hideInternalColumns)); + builder.add(fieldDescription.getColumnMetadata(hideInternalColumns, normalizeIdentifier(session, fieldDescription.getColumnName()))); } return new ConnectorTableMetadata(schemaTableName, builder.build()); @@ -302,4 +308,9 @@ private Optional getTopicDescription(SchemaTableName sche { return tableDescriptionSupplier.getTopicDescription(schemaTableName); } + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatching ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaPageSinkProvider.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaPageSinkProvider.java index 738e7a4c521b8..a3fdc41935ae1 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaPageSinkProvider.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaPageSinkProvider.java @@ -26,8 +26,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.nio.file.Files; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaRecordSetProvider.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaRecordSetProvider.java index 0f0e57ff52ff5..7b4cd735a68eb 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaRecordSetProvider.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaRecordSetProvider.java @@ -23,8 +23,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSecurityConfig.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSecurityConfig.java index f2f57e8502482..c34e9ded30669 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSecurityConfig.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSecurityConfig.java @@ -15,10 +15,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import jakarta.validation.constraints.AssertTrue; import org.apache.kafka.common.security.auth.SecurityProtocol; -import javax.validation.constraints.AssertTrue; - import java.util.Optional; import static org.apache.kafka.common.security.auth.SecurityProtocol.SASL_PLAINTEXT; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSplitManager.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSplitManager.java index aaa5493d8179c..22265a7b277d4 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSplitManager.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaSplitManager.java @@ -27,14 +27,13 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; +import jakarta.inject.Inject; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.consumer.OffsetAndTimestamp; import org.apache.kafka.common.Node; import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; -import javax.inject.Inject; - import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaTopicFieldDescription.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaTopicFieldDescription.java index 73751a2091155..8ceb909bb5906 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaTopicFieldDescription.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/KafkaTopicFieldDescription.java @@ -114,10 +114,10 @@ KafkaColumnHandle getColumnHandle(String connectorId, boolean keyCodec, int inde false); } - ColumnMetadata getColumnMetadata() + ColumnMetadata getColumnMetadata(String name) { return ColumnMetadata.builder() - .setName(getName()) + .setName(name) .setType(getType()) .setComment(getComment()) .setHidden(isHidden()) diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaConsumerManager.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaConsumerManager.java index fcd12173fa569..8b4857b7a66dc 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaConsumerManager.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaConsumerManager.java @@ -14,10 +14,9 @@ package com.facebook.presto.kafka; import com.facebook.presto.spi.HostAddress; +import jakarta.inject.Inject; import org.apache.kafka.common.serialization.ByteBufferDeserializer; -import javax.inject.Inject; - import java.util.Properties; import static java.util.Objects.requireNonNull; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaProducerFactory.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaProducerFactory.java index 4e778c96febc6..7868bf9017dfd 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaProducerFactory.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/PlainTextKafkaProducerFactory.java @@ -14,8 +14,7 @@ package com.facebook.presto.kafka; import com.facebook.presto.spi.HostAddress; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.util.HashMap; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/SaslKafkaConsumerManager.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/SaslKafkaConsumerManager.java index 9d8fc8884b5ce..a5fa29e0c882a 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/SaslKafkaConsumerManager.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/SaslKafkaConsumerManager.java @@ -16,8 +16,7 @@ import com.facebook.presto.kafka.security.ForKafkaSasl; import com.facebook.presto.kafka.security.KafkaSaslConfig; import com.facebook.presto.spi.HostAddress; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Properties; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/DispatchingRowEncoderFactory.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/DispatchingRowEncoderFactory.java index 87a638bba0cf2..3ae76f7cd2dcd 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/DispatchingRowEncoderFactory.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/DispatchingRowEncoderFactory.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.ConnectorSession; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/json/JsonRowEncoderFactory.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/json/JsonRowEncoderFactory.java index 3c3f6833ec859..5403562816121 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/json/JsonRowEncoderFactory.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/encoder/json/JsonRowEncoderFactory.java @@ -18,8 +18,7 @@ import com.facebook.presto.kafka.encoder.RowEncoderFactory; import com.facebook.presto.spi.ConnectorSession; import com.fasterxml.jackson.databind.ObjectMapper; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplier.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplier.java index 6f1779555d5fc..07216e82fd895 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplier.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplier.java @@ -25,8 +25,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; -import javax.inject.Inject; import javax.inject.Provider; import java.io.File; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplierConfig.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplierConfig.java index 123d1c5bd5a37..8041febb7bc4e 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplierConfig.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/schema/file/FileTableDescriptionSupplierConfig.java @@ -16,8 +16,7 @@ import com.facebook.airlift.configuration.Config; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.Set; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/security/KafkaSaslConfig.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/security/KafkaSaslConfig.java index 65519d8523437..f0a1943b20ea1 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/security/KafkaSaslConfig.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/security/KafkaSaslConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; import com.google.common.collect.ImmutableMap; - -import javax.validation.constraints.AssertTrue; +import jakarta.validation.constraints.AssertTrue; import java.util.Map; import java.util.Optional; diff --git a/presto-kafka/src/main/java/com/facebook/presto/kafka/server/file/FileKafkaClusterMetadataSupplier.java b/presto-kafka/src/main/java/com/facebook/presto/kafka/server/file/FileKafkaClusterMetadataSupplier.java index a6e979d14428e..afe1563b51a2b 100644 --- a/presto-kafka/src/main/java/com/facebook/presto/kafka/server/file/FileKafkaClusterMetadataSupplier.java +++ b/presto-kafka/src/main/java/com/facebook/presto/kafka/server/file/FileKafkaClusterMetadataSupplier.java @@ -15,8 +15,7 @@ import com.facebook.presto.kafka.server.KafkaClusterMetadataSupplier; import com.facebook.presto.spi.HostAddress; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/KafkaQueryRunner.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/KafkaQueryRunner.java index 5e067e6dfdb83..22898e4ced000 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/KafkaQueryRunner.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/KafkaQueryRunner.java @@ -36,12 +36,12 @@ import java.util.Optional; import static com.facebook.airlift.testing.Closeables.closeAllSuppress; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.kafka.util.TestUtils.installKafkaPlugin; import static com.facebook.presto.kafka.util.TestUtils.loadTpchTopicDescription; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static com.google.common.io.ByteStreams.toByteArray; -import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.concurrent.TimeUnit.SECONDS; @@ -58,10 +58,11 @@ private KafkaQueryRunner() public static DistributedQueryRunner createKafkaQueryRunner(EmbeddedKafka embeddedKafka, TpchTable... tables) throws Exception { - return createKafkaQueryRunner(embeddedKafka, ImmutableList.copyOf(tables)); + List extraTables = ImmutableList.of(); + return createKafkaQueryRunner(embeddedKafka, ImmutableList.copyOf(tables), ImmutableMap.of(), extraTables); } - public static DistributedQueryRunner createKafkaQueryRunner(EmbeddedKafka embeddedKafka, Iterable> tables) + public static DistributedQueryRunner createKafkaQueryRunner(EmbeddedKafka embeddedKafka, Iterable> tables, Map connectorProperties, List extraTables) throws Exception { DistributedQueryRunner queryRunner = null; @@ -69,7 +70,7 @@ public static DistributedQueryRunner createKafkaQueryRunner(EmbeddedKafka embedd queryRunner = new DistributedQueryRunner(createSession(), 2); queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); + queryRunner.createCatalog("tpch", "tpch", connectorProperties); embeddedKafka.start(); @@ -77,9 +78,9 @@ public static DistributedQueryRunner createKafkaQueryRunner(EmbeddedKafka embedd embeddedKafka.createTopics(kafkaTopicName(table)); } - Map topicDescriptions = createTpchTopicDescriptions(queryRunner.getCoordinator().getMetadata(), tables, embeddedKafka); + Map topicDescriptions = createTpchTopicDescriptions(queryRunner.getCoordinator().getMetadata(), tables, embeddedKafka, extraTables); - installKafkaPlugin(embeddedKafka, queryRunner, topicDescriptions); + installKafkaPlugin(embeddedKafka, queryRunner, topicDescriptions, connectorProperties); TestingPrestoClient prestoClient = queryRunner.getRandomClient(); @@ -111,7 +112,7 @@ private static String kafkaTopicName(TpchTable table) return TPCH_SCHEMA + "." + table.getTableName().toLowerCase(ENGLISH); } - private static Map createTpchTopicDescriptions(Metadata metadata, Iterable> tables, EmbeddedKafka embeddedKafka) + private static Map createTpchTopicDescriptions(Metadata metadata, Iterable> tables, EmbeddedKafka embeddedKafka, List extraTables) throws Exception { JsonCodec topicDescriptionJsonCodec = new CodecSupplier<>(KafkaTopicDescription.class, metadata).get(); @@ -121,7 +122,15 @@ private static Map createTpchTopicDescri String tableName = table.getTableName(); SchemaTableName tpchTable = new SchemaTableName(TPCH_SCHEMA, tableName); - topicDescriptions.put(loadTpchTopicDescription(topicDescriptionJsonCodec, tpchTable.toString(), tpchTable)); + topicDescriptions.put(loadTpchTopicDescription(topicDescriptionJsonCodec, tpchTable.toString(), tpchTable, table.getTableName())); + } + + for (SchemaTableName extra : extraTables) { + topicDescriptions.put(loadTpchTopicDescription( + topicDescriptionJsonCodec, + extra.getTableName(), + extra, + extra.getTableName().toLowerCase() + "_upper")); } List tableNames = new ArrayList<>(4); @@ -173,7 +182,7 @@ public static void main(String[] args) throws Exception { Logging.initialize(); - DistributedQueryRunner queryRunner = createKafkaQueryRunner(EmbeddedKafka.createEmbeddedKafka(), TpchTable.getTables()); + DistributedQueryRunner queryRunner = createKafkaQueryRunner(EmbeddedKafka.createEmbeddedKafka(), TpchTable.getTables(), ImmutableMap.of(), ImmutableList.of()); Thread.sleep(10); Logger log = Logger.get(KafkaQueryRunner.class); log.info("======== SERVER STARTED ========"); diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaConnectorConfig.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaConnectorConfig.java index a83ba9d67564e..c72247eaabf49 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaConnectorConfig.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaConnectorConfig.java @@ -33,7 +33,8 @@ public void testDefaults() .setHideInternalColumns(true) .setMaxPartitionFetchBytes(1048576) .setMaxPollRecords(500) - .setResourceConfigFiles("")); + .setResourceConfigFiles("") + .setCaseSensitiveNameMatching(false)); } @Test @@ -51,6 +52,7 @@ public void testExplicitPropertyMappings() .put("kafka.max-partition-fetch-bytes", "1024") .put("kafka.max-poll-records", "1000") .put("kafka.config.resources", tempFile1 + "," + tempFile2) + .put("case-sensitive-name-matching", "true") .build(); KafkaConnectorConfig expected = new KafkaConnectorConfig() @@ -61,7 +63,8 @@ public void testExplicitPropertyMappings() .setHideInternalColumns(false) .setMaxPartitionFetchBytes(1024) .setMaxPollRecords(1000) - .setResourceConfigFiles(tempFile1 + "," + tempFile2); + .setResourceConfigFiles(tempFile1 + "," + tempFile2) + .setCaseSensitiveNameMatching(true); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaDistributed.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaDistributed.java index ef3679980ad8a..1329c3934006c 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaDistributed.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaDistributed.java @@ -16,6 +16,8 @@ import com.facebook.presto.kafka.util.EmbeddedKafka; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueries; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -36,7 +38,7 @@ protected QueryRunner createQueryRunner() throws Exception { this.embeddedKafka = createEmbeddedKafka(); - return createKafkaQueryRunner(embeddedKafka, TpchTable.getTables()); + return createKafkaQueryRunner(embeddedKafka, TpchTable.getTables(), ImmutableMap.of(), ImmutableList.of()); } @AfterClass(alwaysRun = true) diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaIntegrationMixedCase.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaIntegrationMixedCase.java new file mode 100644 index 0000000000000..3a2274e28e4df --- /dev/null +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaIntegrationMixedCase.java @@ -0,0 +1,233 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.kafka; + +import com.facebook.presto.Session; +import com.facebook.presto.kafka.util.EmbeddedKafka; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.util.List; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.kafka.KafkaQueryRunner.createKafkaQueryRunner; +import static com.facebook.presto.kafka.util.EmbeddedKafka.createEmbeddedKafka; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.tests.QueryAssertions.assertContains; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test +public class TestKafkaIntegrationMixedCase + extends AbstractTestQueryFramework +{ + private EmbeddedKafka embeddedKafka; + private Session session; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + List extraTables = ImmutableList.of( + new SchemaTableName("TPCH", "ORDERS")); + embeddedKafka = createEmbeddedKafka(); + + return createKafkaQueryRunner(embeddedKafka, ImmutableList.of(TpchTable.ORDERS), + ImmutableMap.of("case-sensitive-name-matching", "true"), extraTables); + } + + @BeforeClass(alwaysRun = true) + public final void setUp() + { + session = testSessionBuilder() + .setCatalog("kafka") + .setSchema("tpch") + .build(); + } + + @AfterClass(alwaysRun = true) + public void destroy() + throws IOException + { + if (embeddedKafka != null) { + embeddedKafka.close(); + } + } + + @Test + public void testTableExists() + { + assertTrue(getQueryRunner().tableExists(session, "orders")); + assertFalse(getQueryRunner().tableExists(session, "Orders")); + assertFalse(getQueryRunner().tableExists(session, "oRdErS")); + + assertFalse(getQueryRunner().tableExists(session, "nonexistent")); + assertFalse(getQueryRunner().tableExists(session, "NONEXISTENT")); + } + + @Test + public void testSelect() + { + // Should work with exact case + assertQuerySucceeds(session, "SELECT count(*) FROM orders"); + assertQuerySucceeds(session, "SELECT count(*) FROM tpch.orders"); + assertQuerySucceeds(session, "SELECT count(*) FROM TPCH.ORDERS"); + assertQuerySucceeds(session, "SELECT CASE WHEN (SELECT count(*) FROM tpch.orders) = (SELECT count(*) FROM TPCH.\"ORDERS\") THEN 1 ELSE 0 END"); + assertQuerySucceeds(session, "SELECT CASE WHEN (SELECT min(orderkey) FROM tpch.orders) = (SELECT min(ORDERKEY) FROM TPCH.\"ORDERS\") THEN 1 ELSE 0 END"); + assertQuerySucceeds(session, "SELECT CASE WHEN (SELECT max(totalprice) FROM tpch.orders) = (SELECT max(TOTALPRICE) FROM TPCH.\"ORDERS\") THEN 1 ELSE 0 END"); + + // Should fail with wrong case when case-sensitive is enabled + assertQueryFails(session, "SELECT count(*) FROM Orders", "Table kafka.tpch.Orders does not exist"); + assertQueryFails(session, "SELECT count(*) FROM oRdErS", "Table kafka.tpch.oRdErS does not exist"); + assertQueryFails(session, "SELECT count(*) FROM TPCH.orders", "Table kafka.TPCH.orders does not exist"); + } + + @Test + public void testDescribeTable() + { + assertQuerySucceeds(session, "DESCRIBE orders"); + assertQuerySucceeds(session, "DESCRIBE tpch.orders"); + assertQuerySucceeds(session, "DESCRIBE TPCH.ORDERS"); + + assertQueryFails(session, "DESCRIBE Orders", "line 1:1: Table 'kafka.tpch.Orders' does not exist"); + assertQueryFails(session, "DESCRIBE oRdErS", "line 1:1: Table 'kafka.tpch.oRdErS' does not exist"); + assertQueryFails(session, "DESCRIBE TPCH.orders", "line 1:1: Table 'kafka.TPCH.orders' does not exist"); + + // Validate full column metadata for lowercase + assertQuery( + session, + "SELECT column_name, data_type FROM information_schema.columns " + + "WHERE table_catalog = 'kafka' AND table_schema = 'tpch' AND table_name = 'orders' " + + "ORDER BY column_name", + "SELECT * FROM (VALUES " + + "('clerk','varchar(15)')," + + "('comment','varchar(79)')," + + "('custkey','bigint')," + + "('orderdate','date')," + + "('orderkey','bigint')," + + "('orderpriority','varchar(15)')," + + "('orderstatus','varchar(1)')," + + "('shippriority','integer')," + + "('totalprice','double')) AS t(column_name, data_type)"); + // Validate full column metadata for uppercase + assertQuery( + session, + "SELECT column_name, data_type FROM information_schema.columns " + + "WHERE table_catalog = 'kafka' AND table_schema = 'TPCH' AND table_name = 'ORDERS' " + + "ORDER BY column_name", + "SELECT * FROM (VALUES " + + "('CLERK','varchar(15)')," + + "('COMMENT','varchar(79)')," + + "('CUSTKEY','bigint')," + + "('ORDERDATE','date')," + + "('ORDERKEY','bigint')," + + "('ORDERPRIORITY','varchar(15)')," + + "('ORDERSTATUS','varchar(1)')," + + "('OrderStatus','varchar(1)')," + + "('SHIPPRIORITY','integer')," + + "('TOTALPRICE','double')) AS t(column_name, data_type)"); + } + + @Test + public void testShowTables() + { + // Both lowercase and uppercase tables are registered + assertQuery(session, "SHOW TABLES FROM tpch", "VALUES ('orders')"); + assertQuery(session, "SHOW TABLES FROM TPCH", "VALUES ('ORDERS')"); + } + + @Test + public void testInformationSchema() + { + assertQuery(session, "SELECT table_name FROM information_schema.tables WHERE table_name = 'orders'", "VALUES ('orders')"); + + assertQuerySucceeds(session, "SELECT table_name FROM information_schema.tables WHERE table_schema = 'tpch'"); + assertQuerySucceeds(session, "SELECT table_name FROM information_schema.tables WHERE table_schema = 'TPCH'"); + + assertQuery(session, "SELECT table_name FROM information_schema.tables WHERE table_name = 'Orders'", "SELECT 'empty' WHERE false"); + assertQuery(session, "SELECT table_name FROM information_schema.tables WHERE table_schema = 'Tpch'", "SELECT 'empty' WHERE false"); + } + + @Test + public void testMixedCaseQueries() + { + assertQuerySucceeds(session, "SELECT count(*) FROM orders WHERE orderkey > 100"); + assertQuerySucceeds(session, "SELECT o.orderkey FROM orders o LIMIT 1"); + + assertQuerySucceeds(session, "SELECT count(*) FROM TPCH.ORDERS WHERE ORDERKEY > 100"); + assertQuerySucceeds(session, "SELECT o.ORDERKEY FROM TPCH.ORDERS o LIMIT 1"); + + assertQueryFails(session, "SELECT COUNT(*) FROM Orders WHERE OrderKey > 100", "Table kafka.tpch.Orders does not exist"); + assertQueryFails(session, "SELECT * FROM TPCH.Orders", "Table kafka.TPCH.Orders does not exist"); + } + + @Test + public void testJoinsWithCaseSensitivity() + { + assertQuerySucceeds(session, "SELECT count(*) FROM orders o1 JOIN orders o2 ON o1.orderkey = o2.orderkey LIMIT 10"); + + assertQueryFails(session, "SELECT count(*) FROM Orders o1 JOIN orders o2 ON o1.orderkey = o2.orderkey", "Table kafka.tpch.Orders does not exist"); + assertQueryFails(session, "SELECT count(*) FROM orders o1 JOIN Orders o2 ON o1.orderkey = o2.orderkey", "Table kafka.tpch.Orders does not exist"); + } + + @Test + public void testSchemaCasing() + { + assertQuerySucceeds(session, "SHOW TABLES FROM tpch"); + assertQuerySucceeds("SHOW TABLES FROM TPCH"); + assertQueryFails(session, "SHOW TABLES FROM Tpch", "line 1:1: Schema 'Tpch' does not exist"); + + assertQuery(session, + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'tpch'", + "VALUES ('orders')"); + + assertQuery( + session, + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'TPCH'", + "VALUES ('ORDERS')"); + + assertQuery( + session, + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'Tpch'", + "SELECT 'empty' WHERE false"); + } + + @Test + public void testShowColumns() + { + MaterializedResult actual = computeActual("SHOW COLUMNS FROM tpch.orders"); + + MaterializedResult expected = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), createUnboundedVarcharType()) + .row("orderkey", "bigint", "", "", Long.valueOf(19), null, null) + .row("custkey", "bigint", "", "", Long.valueOf(19), null, null) + .row("orderstatus", "varchar(1)", "", "", null, null, Long.valueOf(1)) + .row("totalprice", "double", "", "", Long.valueOf(53), null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar(15)", "", "", null, null, Long.valueOf(15)) + .row("clerk", "varchar(15)", "", "", null, null, Long.valueOf(15)) + .row("shippriority", "integer", "", "", Long.valueOf(10), null, null) + .row("comment", "varchar(79)", "", "", null, null, Long.valueOf(79)) + .build(); + assertContains(actual, expected); + } +} diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaSecurityConfig.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaSecurityConfig.java index fe4b08390a444..fba61d823f4b1 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaSecurityConfig.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/TestKafkaSecurityConfig.java @@ -14,10 +14,9 @@ package com.facebook.presto.kafka; import com.google.common.collect.ImmutableMap; +import jakarta.validation.constraints.AssertTrue; import org.testng.annotations.Test; -import javax.validation.constraints.AssertTrue; - import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/security/TestKafkaSaslConfig.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/security/TestKafkaSaslConfig.java index eb3a7e16c8afa..e603cd834fc79 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/security/TestKafkaSaslConfig.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/security/TestKafkaSaslConfig.java @@ -15,10 +15,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import jakarta.validation.constraints.AssertTrue; import org.testng.annotations.Test; -import javax.validation.constraints.AssertTrue; - import java.io.FileWriter; import java.io.IOException; import java.nio.file.Files; diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/util/KafkaLoader.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/util/KafkaLoader.java index 901baa5b23465..ddb31c6bd19cb 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/util/KafkaLoader.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/util/KafkaLoader.java @@ -120,7 +120,7 @@ public void addResults(QueryStatusInfo statusInfo, QueryData data) } @Override - public Void build(Map setSessionProperties, Set resetSessionProperties) + public Void build(Map setSessionProperties, Set resetSessionProperties, String startTransactionId, boolean clearTransactionId) { return null; } diff --git a/presto-kafka/src/test/java/com/facebook/presto/kafka/util/TestUtils.java b/presto-kafka/src/test/java/com/facebook/presto/kafka/util/TestUtils.java index 96c49541f7e1b..38b03513f3f05 100644 --- a/presto-kafka/src/test/java/com/facebook/presto/kafka/util/TestUtils.java +++ b/presto-kafka/src/test/java/com/facebook/presto/kafka/util/TestUtils.java @@ -64,7 +64,7 @@ public static Properties toProperties(Map map) return properties; } - public static void installKafkaPlugin(EmbeddedKafka embeddedKafka, QueryRunner queryRunner, Map topicDescriptions) + public static void installKafkaPlugin(EmbeddedKafka embeddedKafka, QueryRunner queryRunner, Map topicDescriptions, Map connectorProperties) { FileKafkaClusterMetadataSupplierConfig clusterMetadataSupplierConfig = new FileKafkaClusterMetadataSupplierConfig(); clusterMetadataSupplierConfig.setNodes(embeddedKafka.getConnectString()); @@ -82,12 +82,20 @@ public static void installKafkaPlugin(EmbeddedKafka embeddedKafka, QueryRunner q queryRunner.installPlugin(kafkaPlugin); - Map kafkaConfig = ImmutableMap.of( - "kafka.cluster-metadata-supplier", TEST, - "kafka.table-description-supplier", TEST, - "kafka.connect-timeout", "120s", - "kafka.default-schema", "default"); - queryRunner.createCatalog("kafka", "kafka", kafkaConfig); + ImmutableMap.Builder kafkaConfigBuilder = ImmutableMap.builder(); + kafkaConfigBuilder.put("kafka.cluster-metadata-supplier", TEST); + kafkaConfigBuilder.put("kafka.table-description-supplier", TEST); + kafkaConfigBuilder.put("kafka.connect-timeout", "120s"); + kafkaConfigBuilder.put("kafka.default-schema", "default"); + + kafkaConfigBuilder.putAll(connectorProperties); + + queryRunner.createCatalog("kafka", "kafka", kafkaConfigBuilder.build()); + } + + public static void installKafkaPlugin(EmbeddedKafka embeddedKafka, QueryRunner queryRunner, Map topicDescriptions) + { + installKafkaPlugin(embeddedKafka, queryRunner, topicDescriptions, ImmutableMap.of()); } public static void loadTpchTopic(EmbeddedKafka embeddedKafka, TestingPrestoClient prestoClient, String topicName, QualifiedObjectName tpchTableName) @@ -98,16 +106,15 @@ public static void loadTpchTopic(EmbeddedKafka embeddedKafka, TestingPrestoClien } } - public static Map.Entry loadTpchTopicDescription(JsonCodec topicDescriptionJsonCodec, String topicName, SchemaTableName schemaTableName) + public static Map.Entry loadTpchTopicDescription(JsonCodec topicDescriptionJsonCodec, String topicName, SchemaTableName schemaTableName, String fileName) throws IOException { - KafkaTopicDescription tpchTemplate = topicDescriptionJsonCodec.fromJson(ByteStreams.toByteArray(TestUtils.class.getResourceAsStream(format("/tpch/%s.json", schemaTableName.getTableName())))); + KafkaTopicDescription tpchTemplate = topicDescriptionJsonCodec.fromJson(ByteStreams.toByteArray(TestUtils.class.getResourceAsStream(format("/tpch/%s.json", fileName)))); return new AbstractMap.SimpleImmutableEntry<>( schemaTableName, new KafkaTopicDescription(schemaTableName.getTableName(), Optional.of(schemaTableName.getSchemaName()), topicName, tpchTemplate.getKey(), tpchTemplate.getMessage())); } - public static Map.Entry createEmptyTopicDescription(String topicName, SchemaTableName schemaTableName) { return new AbstractMap.SimpleImmutableEntry<>( diff --git a/presto-kafka/src/test/resources/tpch/orders_upper.json b/presto-kafka/src/test/resources/tpch/orders_upper.json new file mode 100644 index 0000000000000..2c2839651b0db --- /dev/null +++ b/presto-kafka/src/test/resources/tpch/orders_upper.json @@ -0,0 +1,72 @@ +{ + "tableName": "ORDERS", + "schemaName": "TPCH", + "topicName": "TPCH.ORDERS", + "key": { + "dataFormat": "raw", + "fields": [ + { + "name": "kafka_key", + "dataFormat": "LONG", + "type": "BIGINT", + "hidden": "true" + } + ] + }, + "message": { + "dataFormat": "json", + "fields": [ + { + "name": "ORDERKEY", + "mapping": "ORDERKEY", + "type": "BIGINT" + }, + { + "name": "CUSTKEY", + "mapping": "CUSTKEY", + "type": "BIGINT" + }, + { + "name": "ORDERSTATUS", + "mapping": "ORDERSTATUS", + "type": "VARCHAR(1)" + }, + { + "name": "OrderStatus", + "mapping": "OrderStatus", + "type": "VARCHAR(1)" + }, + { + "name": "TOTALPRICE", + "mapping": "TOTALPRICE", + "type": "DOUBLE" + }, + { + "name": "ORDERDATE", + "mapping": "ORDERDATE", + "type": "DATE", + "dataFormat": "iso8601" + }, + { + "name": "ORDERPRIORITY", + "mapping": "ORDERPRIORITY", + "type": "VARCHAR(15)" + }, + { + "name": "CLERK", + "mapping": "CLERK", + "type": "VARCHAR(15)" + }, + { + "name": "SHIPPRIORITY", + "mapping": "SHIPPRIORITY", + "type": "INTEGER" + }, + { + "name": "COMMENT", + "mapping": "COMMENT", + "type": "VARCHAR(79)" + } + ] + } +} diff --git a/presto-kudu/pom.xml b/presto-kudu/pom.xml index ccf763b019051..d9fae0dd02c36 100644 --- a/presto-kudu/pom.xml +++ b/presto-kudu/pom.xml @@ -4,16 +4,19 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-kudu + presto-kudu Presto - Kudu Connector presto-plugin ${project.parent.basedir} 1.12.0 + true + 0.15.1 @@ -31,7 +34,7 @@ com.facebook.presto.hadoop - hadoop-apache2 + hadoop-apache @@ -75,13 +78,13 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -98,7 +101,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -110,7 +113,7 @@ - io.airlift + com.facebook.airlift units provided @@ -125,7 +128,7 @@ org.apache.yetus audience-annotations - 0.8.0 + ${yetus.audience-annotations.version} provided @@ -218,7 +221,7 @@ org.apache.yetus audience-annotations - 0.8.0 + ${yetus.audience-annotations.version} diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduClientConfig.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduClientConfig.java index 887a18705c8ca..d8b660aa150d2 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduClientConfig.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduClientConfig.java @@ -14,14 +14,13 @@ package com.facebook.presto.kudu; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MaxDuration; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; -import io.airlift.units.MaxDuration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.NotNull; -import javax.validation.constraints.Size; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; import java.util.List; import java.util.concurrent.TimeUnit; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduConnector.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduConnector.java index 7cdb99f0932c2..013c4b727cf21 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduConnector.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduConnector.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Set; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduMetadata.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduMetadata.java index fa9744d560681..b5a52a5898c06 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduMetadata.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduMetadata.java @@ -38,12 +38,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; +import jakarta.inject.Inject; import org.apache.kudu.ColumnSchema; import org.apache.kudu.Schema; import org.apache.kudu.client.KuduTable; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -100,14 +99,14 @@ public Map> listTableColumns(ConnectorSess for (SchemaTableName tableName : tables) { KuduTableHandle tableHandle = getTableHandle(session, tableName); if (tableHandle != null) { - ConnectorTableMetadata tableMetadata = getTableMetadata(tableHandle); + ConnectorTableMetadata tableMetadata = getTableMetadata(session, tableHandle); columns.put(tableName, tableMetadata.getColumns()); } } return columns.build(); } - private ColumnMetadata getColumnMetadata(ColumnSchema column) + private ColumnMetadata getColumnMetadata(ConnectorSession session, ColumnSchema column) { Map properties = new LinkedHashMap<>(); StringBuilder extra = new StringBuilder(); @@ -135,7 +134,7 @@ private ColumnMetadata getColumnMetadata(ColumnSchema column) Type prestoType = TypeHelper.fromKuduColumn(column); return ColumnMetadata.builder() - .setName(column.getName()) + .setName(normalizeIdentifier(session, column.getName())) .setType(prestoType) .setExtraInfo(extra.toString()) .setHidden(false) @@ -143,14 +142,14 @@ private ColumnMetadata getColumnMetadata(ColumnSchema column) .build(); } - private ConnectorTableMetadata getTableMetadata(KuduTableHandle tableHandle) + private ConnectorTableMetadata getTableMetadata(ConnectorSession session, KuduTableHandle tableHandle) { KuduTable table = tableHandle.getTable(clientSession); Schema schema = table.getSchema(); List columnsMetaList = schema.getColumns().stream() .filter(column -> !column.isKey() || !column.getName().equals(KuduColumnHandle.ROW_ID)) - .map(this::getColumnMetadata) + .map(column -> getColumnMetadata(session, column)) .collect(toImmutableList()); Map properties = clientSession.getTableProperties(tableHandle); @@ -204,7 +203,8 @@ public KuduTableHandle getTableHandle(ConnectorSession session, SchemaTableName } @Override - public List getTableLayouts(ConnectorSession session, + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, ConnectorTableHandle tableHandle, Constraint constraint, Optional> desiredColumns) @@ -212,7 +212,7 @@ public List getTableLayouts(ConnectorSession session KuduTableHandle handle = (KuduTableHandle) tableHandle; ConnectorTableLayout layout = new ConnectorTableLayout( new KuduTableLayoutHandle(handle, constraint.getSummary(), desiredColumns)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -225,7 +225,7 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { KuduTableHandle kuduTableHandle = (KuduTableHandle) tableHandle; - return getTableMetadata(kuduTableHandle); + return getTableMetadata(session, kuduTableHandle); } @Override @@ -325,7 +325,7 @@ public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, Map columnProperties = new HashMap<>(); columnProperties.put(KuduTableProperties.PRIMARY_KEY, true); copy.add(0, ColumnMetadata.builder() - .setName(rowId) + .setName(normalizeIdentifier(session, rowId)) .setType(VARCHAR) .setComment("key=true") .setHidden(true) @@ -368,9 +368,9 @@ public Optional finishCreateTable(ConnectorSession sess } @Override - public ColumnHandle getDeleteRowIdColumnHandle(ConnectorSession session, ConnectorTableHandle tableHandle) + public Optional getDeleteRowIdColumn(ConnectorSession session, ConnectorTableHandle tableHandle) { - return KuduColumnHandle.ROW_ID_HANDLE; + return Optional.of(KuduColumnHandle.ROW_ID_HANDLE); } @Override @@ -380,8 +380,9 @@ public ConnectorDeleteTableHandle beginDelete(ConnectorSession session, Connecto } @Override - public void finishDelete(ConnectorSession session, ConnectorDeleteTableHandle tableHandle, Collection fragments) + public Optional finishDeleteWithOutput(ConnectorSession session, ConnectorDeleteTableHandle tableHandle, Collection fragments) { + return Optional.empty(); } @Override diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduModule.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduModule.java index bb91d80faa93f..64d93e26cd608 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduModule.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduModule.java @@ -29,10 +29,9 @@ import com.google.inject.Scopes; import com.google.inject.multibindings.Multibinder; import com.google.inject.multibindings.ProvidesIntoSet; +import jakarta.inject.Singleton; import org.apache.kudu.client.KuduClient; -import javax.inject.Singleton; - import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static java.util.Objects.requireNonNull; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduPageSinkProvider.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduPageSinkProvider.java index fbc740d6c8534..e5be921da2632 100644 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduPageSinkProvider.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduPageSinkProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduRecordSetProvider.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduRecordSetProvider.java index 56b25a9346dcb..44d6d0a9a4b6f 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduRecordSetProvider.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduRecordSetProvider.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduSplitManager.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduSplitManager.java index bce8b6bdb898f..f488506705b0a 100755 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduSplitManager.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/KuduSplitManager.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java index c9638aa803ea3..4fa52348f09fb 100644 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/procedures/RangePartitionProcedures.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.lang.invoke.MethodHandle; diff --git a/presto-kudu/src/main/java/com/facebook/presto/kudu/properties/KuduTableProperties.java b/presto-kudu/src/main/java/com/facebook/presto/kudu/properties/KuduTableProperties.java index 90cc240590238..f9ee6f1ab6676 100644 --- a/presto-kudu/src/main/java/com/facebook/presto/kudu/properties/KuduTableProperties.java +++ b/presto-kudu/src/main/java/com/facebook/presto/kudu/properties/KuduTableProperties.java @@ -18,6 +18,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.apache.kudu.ColumnSchema; import org.apache.kudu.Schema; import org.apache.kudu.Type; @@ -30,8 +31,6 @@ import org.apache.kudu.shaded.com.google.common.base.Predicates; import org.apache.kudu.shaded.com.google.common.collect.Iterators; -import javax.inject.Inject; - import java.io.IOException; import java.math.BigDecimal; import java.time.Instant; diff --git a/presto-lark-sheets/pom.xml b/presto-lark-sheets/pom.xml index dbb68352328f3..83261338e35e4 100644 --- a/presto-lark-sheets/pom.xml +++ b/presto-lark-sheets/pom.xml @@ -3,16 +3,18 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT 4.0.0 presto-lark-sheets + presto-lark-sheets Presto - Lark Sheets Connector presto-plugin ${project.parent.basedir} + true @@ -52,16 +54,21 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true + + jakarta.annotation + jakarta.annotation-api + + com.larksuite.oapi larksuite-oapi @@ -94,7 +101,7 @@ - io.airlift + com.facebook.airlift units provided @@ -154,4 +161,19 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.facebook.airlift:log + com.facebook.airlift:log-manager + + + + + diff --git a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsColumnHandle.java b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsColumnHandle.java index e62306dcf2494..35dc5d0bd0186 100644 --- a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsColumnHandle.java +++ b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsColumnHandle.java @@ -61,6 +61,11 @@ public ColumnMetadata toColumnMetadata() return ColumnMetadata.builder().setName(name).setType(type).build(); } + public ColumnMetadata toColumnMetadata(String name) + { + return ColumnMetadata.builder().setName(name).setType(type).build(); + } + @Override public boolean equals(Object o) { diff --git a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsConnector.java b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsConnector.java index 61e2b12e54613..19ccd72c201ee 100644 --- a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsConnector.java +++ b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsConnector.java @@ -25,8 +25,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.function.Supplier; diff --git a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsMetadata.java b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsMetadata.java index 33d48235df28e..0e1ba5d8c0d6c 100644 --- a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsMetadata.java +++ b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsMetadata.java @@ -34,8 +34,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.LinkedHashMap; import java.util.List; @@ -128,14 +127,15 @@ public Optional getSystemTable(ConnectorSession session, SchemaTabl } @Override - public List getTableLayouts(ConnectorSession session, + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { LarkSheetsTableHandle tableHandle = (LarkSheetsTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new LarkSheetsTableLayoutHandle(tableHandle)); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -149,7 +149,7 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect { LarkSheetsTableHandle sheetsTable = (LarkSheetsTableHandle) table; List sheetsColumns = getColumns(sheetsTable); - List columnMetadatas = toColumnMetadatas(sheetsColumns); + List columnMetadatas = toColumnMetadatas(session, sheetsColumns); return new ConnectorTableMetadata(sheetsTable.getSchemaTableName(), columnMetadatas); } @@ -197,7 +197,7 @@ public Map> listTableColumns(ConnectorSess for (SheetInfo sheet : metaInfo.getSheets()) { SchemaTableName tableName = new SchemaTableName(schemaName, sheet.getTitle()); List columnHandles = getColumns(toSheetsTableHandle(sheet)); - List columnMetadatas = toColumnMetadatas(columnHandles); + List columnMetadatas = toColumnMetadatas(session, columnHandles); builder.put(tableName, columnMetadatas); } } @@ -209,7 +209,7 @@ public Map> listTableColumns(ConnectorSess // in order to make queries like `DESC "@sheetId"` or `DESC "#1"` work SchemaTableName tableName = new SchemaTableName(schemaName, prefixTableName); List columnHandles = getColumns(toSheetsTableHandle(sheet)); - List columnMetadatas = toColumnMetadatas(columnHandles); + List columnMetadatas = toColumnMetadatas(session, columnHandles); builder.put(tableName, columnMetadatas); } } @@ -337,8 +337,10 @@ private static LarkSheetsTableHandle toSheetsTableHandle(SheetInfo sheet) return new LarkSheetsTableHandle(sheet.getToken(), sheet.getSheetId(), sheet.getTitle(), sheet.getIndex(), sheet.getColumnCount(), sheet.getRowCount()); } - private static List toColumnMetadatas(List columnHandles) + private List toColumnMetadatas(ConnectorSession session, List columnHandles) { - return columnHandles.stream().map(LarkSheetsColumnHandle::toColumnMetadata).collect(toImmutableList()); + return columnHandles.stream() + .map(column -> column.toColumnMetadata(normalizeIdentifier(session, column.getName()))) + .collect(toImmutableList()); } } diff --git a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsRecordSetProvider.java b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsRecordSetProvider.java index 6cb0baf96e896..29daa9d45c46b 100644 --- a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsRecordSetProvider.java +++ b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/LarkSheetsRecordSetProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.function.Supplier; diff --git a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/api/LarkSheetsApiFactory.java b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/api/LarkSheetsApiFactory.java index 7fcdcc9fc7cda..2f98505a1e652 100644 --- a/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/api/LarkSheetsApiFactory.java +++ b/presto-lark-sheets/src/main/java/com/facebook/presto/lark/sheets/api/LarkSheetsApiFactory.java @@ -14,6 +14,7 @@ package com.facebook.presto.lark.sheets.api; import com.facebook.presto.lark.sheets.LarkSheetsConfig; +import com.google.errorprone.annotations.ThreadSafe; import com.larksuite.oapi.core.AppSettings; import com.larksuite.oapi.core.AppType; import com.larksuite.oapi.core.Config; @@ -21,9 +22,7 @@ import com.larksuite.oapi.core.Domain; import com.larksuite.oapi.service.drive_permission.v2.DrivePermissionService; import com.larksuite.oapi.service.sheets.v2.SheetsService; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.function.Supplier; diff --git a/presto-local-file/pom.xml b/presto-local-file/pom.xml index e0a4c50993ddc..9994a7c9c5620 100644 --- a/presto-local-file/pom.xml +++ b/presto-local-file/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-local-file + presto-local-file Presto - Local File Connector presto-plugin ${project.parent.basedir} + true @@ -32,11 +34,6 @@ configuration - - com.facebook.airlift - json - - com.google.guava guava @@ -48,13 +45,13 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api - com.fasterxml.jackson.core - jackson-databind + javax.inject + javax.inject @@ -71,7 +68,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -83,7 +80,7 @@ - io.airlift + com.facebook.airlift units provided @@ -107,6 +104,18 @@ test + + com.facebook.airlift + json + test + + + + com.fasterxml.jackson.core + jackson-databind + test + + org.testng testng @@ -119,4 +128,18 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + javax.inject:javax.inject + + + + + diff --git a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileConnector.java b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileConnector.java index cf81f1549f2b3..0c416861b24e0 100644 --- a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileConnector.java +++ b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileConnector.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.spi.transaction.IsolationLevel.READ_COMMITTED; import static com.facebook.presto.spi.transaction.IsolationLevel.checkConnectorSupports; diff --git a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileMetadata.java b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileMetadata.java index 247b741e709d3..c4224fbc7c761 100644 --- a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileMetadata.java +++ b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileMetadata.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -85,11 +84,15 @@ public List listTables(ConnectorSession session, String schemaN } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { LocalFileTableHandle tableHandle = (LocalFileTableHandle) table; ConnectorTableLayout layout = new ConnectorTableLayout(new LocalFileTableLayoutHandle(tableHandle, constraint.getSummary())); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override diff --git a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileRecordSetProvider.java b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileRecordSetProvider.java index 75befa6235554..b40f06763306a 100644 --- a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileRecordSetProvider.java +++ b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileRecordSetProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileSplitManager.java b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileSplitManager.java index ec004022efd42..e18e7e2829a9f 100644 --- a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileSplitManager.java +++ b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileSplitManager.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.stream.Collectors; diff --git a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileTables.java b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileTables.java index 438a2ce6f8fff..b429dc176dbb9 100644 --- a/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileTables.java +++ b/presto-local-file/src/main/java/com/facebook/presto/localfile/LocalFileTables.java @@ -22,8 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.UncheckedExecutionException; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.util.List; diff --git a/presto-main-base/pom.xml b/presto-main-base/pom.xml index b1f972ce7a10f..dd79f85355825 100644 --- a/presto-main-base/pom.xml +++ b/presto-main-base/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-main-base @@ -13,6 +13,9 @@ ${project.parent.basedir} + 8 + true + 8.11.3 @@ -82,8 +85,9 @@ - com.google.code.findbugs - jsr305 + com.facebook.presto + presto-main-tests + test @@ -129,6 +133,7 @@ com.facebook.airlift log-manager + runtime @@ -142,7 +147,7 @@ - io.airlift + com.facebook.airlift units @@ -152,43 +157,50 @@ - com.facebook.drift + com.facebook.airlift.drift drift-server + test - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty - com.facebook.drift + io.netty + netty-buffer + + + + com.facebook.airlift.drift drift-transport-spi - com.facebook.drift + com.facebook.airlift.drift drift-client - com.facebook.drift + com.facebook.airlift.drift drift-protocol - com.facebook.drift + com.facebook.airlift.drift drift-codec - com.facebook.drift + com.facebook.airlift.drift drift-api - com.facebook.drift + com.facebook.airlift.drift drift-codec-utils + runtime @@ -202,18 +214,13 @@ - javax.servlet - javax.servlet-api - - - - javax.annotation - javax.annotation-api + com.google.errorprone + error_prone_annotations - javax.ws.rs - javax.ws.rs-api + jakarta.annotation + jakarta.annotation-api @@ -247,8 +254,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api + + + + jakarta.inject + jakarta.inject-api @@ -295,6 +307,7 @@ org.apache.lucene lucene-analyzers-common + ${dep.lucene.version} @@ -320,6 +333,7 @@ com.facebook.presto presto-function-namespace-managers-common + test @@ -455,6 +469,20 @@ io.netty netty-transport + + + com.facebook.presto + presto-built-in-worker-function-tools + ${project.version} + test + + + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test + @@ -519,6 +547,9 @@ com.facebook.presto.testing.assertions + + com.facebook.presto.server.MockHttpServletRequest + @@ -528,6 +559,12 @@ com.facebook.presto:presto-ui:jar + + + com.facebook.airlift.drift:drift-transport-spi + + io.netty:netty-buffer + diff --git a/presto-main-base/src/main/java/com/facebook/presto/ClientRequestFilterManager.java b/presto-main-base/src/main/java/com/facebook/presto/ClientRequestFilterManager.java index 79fb3a359f5c5..234f8d92710a1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ClientRequestFilterManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ClientRequestFilterManager.java @@ -18,9 +18,8 @@ import com.facebook.presto.spi.ClientRequestFilterFactory; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededCpuLimitException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededCpuLimitException.java index 7e2a097422168..8da204f884668 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededCpuLimitException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededCpuLimitException.java @@ -13,8 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.Duration; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_CPU_LIMIT; import static java.lang.String.format; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededIntermediateWrittenBytesException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededIntermediateWrittenBytesException.java index d2c8599cb2383..e0f648c3e042e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededIntermediateWrittenBytesException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededIntermediateWrittenBytesException.java @@ -13,8 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.DataSize; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_WRITTEN_INTERMEDIATE_BYTES_LIMIT; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededMemoryLimitException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededMemoryLimitException.java index 8d5225e5be307..84643433614e1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededMemoryLimitException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededMemoryLimitException.java @@ -13,10 +13,10 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.ErrorCause; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.StandardErrorCode; -import io.airlift.units.DataSize; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededOutputSizeLimitException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededOutputSizeLimitException.java index 305922cdd9cf5..0c3af0a089449 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededOutputSizeLimitException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededOutputSizeLimitException.java @@ -13,8 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.DataSize; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_OUTPUT_SIZE_LIMIT; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededScanLimitException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededScanLimitException.java index 6557f7ede8954..3a95fa4c66aff 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededScanLimitException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededScanLimitException.java @@ -13,8 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.DataSize; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_SCAN_RAW_BYTES_READ_LIMIT; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ExceededSpillLimitException.java b/presto-main-base/src/main/java/com/facebook/presto/ExceededSpillLimitException.java index eeb9c22ef5eca..99ef446e47d26 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ExceededSpillLimitException.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ExceededSpillLimitException.java @@ -13,8 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.PrestoException; -import io.airlift.units.DataSize; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_SPILL_LIMIT; import static java.lang.String.format; diff --git a/presto-main-base/src/main/java/com/facebook/presto/FullConnectorSession.java b/presto-main-base/src/main/java/com/facebook/presto/FullConnectorSession.java index 1d8ee24735d8f..44ba295923010 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/FullConnectorSession.java +++ b/presto-main-base/src/main/java/com/facebook/presto/FullConnectorSession.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.function.SqlFunctionProperties; +import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.spi.ConnectorId; @@ -48,17 +49,11 @@ public class FullConnectorSession private final SessionPropertyManager sessionPropertyManager; private final SqlFunctionProperties sqlFunctionProperties; private final Map sessionFunctions; + private final RuntimeStats runtimeStats; public FullConnectorSession(Session session, ConnectorIdentity identity) { - this.session = requireNonNull(session, "session is null"); - this.identity = requireNonNull(identity, "identity is null"); - this.properties = null; - this.connectorId = null; - this.catalog = null; - this.sessionPropertyManager = null; - this.sqlFunctionProperties = session.getSqlFunctionProperties(); - this.sessionFunctions = ImmutableMap.copyOf(session.getSessionFunctions()); + this(builder(session, identity, null, null, null, null)); } public FullConnectorSession( @@ -69,14 +64,123 @@ public FullConnectorSession( String catalog, SessionPropertyManager sessionPropertyManager) { - this.session = requireNonNull(session, "session is null"); - this.identity = requireNonNull(identity, "identity is null"); - this.properties = ImmutableMap.copyOf(requireNonNull(properties, "properties is null")); - this.connectorId = requireNonNull(connectorId, "connectorId is null"); - this.catalog = requireNonNull(catalog, "catalog is null"); - this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); - this.sqlFunctionProperties = session.getSqlFunctionProperties(); - this.sessionFunctions = ImmutableMap.copyOf(session.getSessionFunctions()); + this(builder(session, identity, properties, connectorId, catalog, sessionPropertyManager)); + } + + private FullConnectorSession(Builder builder) + { + this.session = builder.getSession(); + this.identity = builder.getIdentity(); + this.properties = builder.getProperties(); + this.connectorId = builder.getConnectorId(); + this.catalog = builder.getCatalog(); + this.sessionPropertyManager = builder.getSessionPropertyManager(); + this.sqlFunctionProperties = builder.getSqlFunctionProperties() != null ? builder.getSqlFunctionProperties() : builder.getSession().getSqlFunctionProperties(); + this.sessionFunctions = builder.getSessionFunctions() != null ? builder.getSessionFunctions() : ImmutableMap.copyOf(builder.getSession().getSessionFunctions()); + this.runtimeStats = builder.getRuntimeStats() != null ? builder.getRuntimeStats() : builder.getSession().getRuntimeStats(); + } + + public static Builder builder( + Session session, + ConnectorIdentity identity, + Map properties, + ConnectorId connectorId, + String catalog, + SessionPropertyManager sessionPropertyManager) + { + return new Builder(session, identity, properties, connectorId, catalog, sessionPropertyManager); + } + + public static class Builder + { + private final Session session; + private final ConnectorIdentity identity; + private final Map properties; + private final ConnectorId connectorId; + private final String catalog; + private final SessionPropertyManager sessionPropertyManager; + + private SqlFunctionProperties sqlFunctionProperties; + private Map sessionFunctions; + private RuntimeStats runtimeStats; + + private Builder(Session session, ConnectorIdentity identity, Map properties, ConnectorId connectorId, String catalog, SessionPropertyManager sessionPropertyManager) + { + this.session = requireNonNull(session, "session is null"); + this.identity = requireNonNull(identity, "identity is null"); + this.properties = properties; + this.connectorId = connectorId; + this.catalog = catalog; + this.sessionPropertyManager = sessionPropertyManager; + } + + public Session getSession() + { + return session; + } + + public ConnectorIdentity getIdentity() + { + return identity; + } + + public Map getProperties() + { + return properties; + } + + public ConnectorId getConnectorId() + { + return connectorId; + } + + public String getCatalog() + { + return catalog; + } + + public SessionPropertyManager getSessionPropertyManager() + { + return sessionPropertyManager; + } + + public SqlFunctionProperties getSqlFunctionProperties() + { + return sqlFunctionProperties; + } + + public Builder setSqlFunctionProperties(SqlFunctionProperties sqlFunctionProperties) + { + this.sqlFunctionProperties = sqlFunctionProperties; + return this; + } + + public Map getSessionFunctions() + { + return sessionFunctions; + } + + public Builder setSessionFunctions(Map sessionFunctions) + { + this.sessionFunctions = sessionFunctions; + return this; + } + + public RuntimeStats getRuntimeStats() + { + return runtimeStats; + } + + public Builder setRuntimeStats(RuntimeStats runtimeStats) + { + this.runtimeStats = runtimeStats; + return this; + } + + public FullConnectorSession build() + { + return new FullConnectorSession(this); + } } public Session getSession() @@ -166,6 +270,12 @@ public Optional getSchema() return session.getSchema(); } + @Override + public Optional getConnectorId() + { + return Optional.ofNullable(connectorId); + } + @Override public boolean isReadConstraints() { @@ -197,7 +307,13 @@ public WarningCollector getWarningCollector() @Override public RuntimeStats getRuntimeStats() { - return session.getRuntimeStats(); + return runtimeStats; + } + + @Override + public Optional getQueryType() + { + return session.getQueryType(); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/GroupByHashPageIndexerFactory.java b/presto-main-base/src/main/java/com/facebook/presto/GroupByHashPageIndexerFactory.java index e135188ba3fe2..3a6c045fb1442 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/GroupByHashPageIndexerFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/GroupByHashPageIndexerFactory.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.PageIndexer; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.sql.gen.JoinCompiler; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/PagesIndexPageSorter.java b/presto-main-base/src/main/java/com/facebook/presto/PagesIndexPageSorter.java index 7190d545e7fc1..4c01c2c1a95fb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/PagesIndexPageSorter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/PagesIndexPageSorter.java @@ -19,8 +19,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.spi.PageSorter; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/Session.java b/presto-main-base/src/main/java/com/facebook/presto/Session.java index 393b0b1e3e428..0492f7d1bd300 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/Session.java +++ b/presto-main-base/src/main/java/com/facebook/presto/Session.java @@ -13,6 +13,8 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.resourceGroups.QueryType; @@ -40,11 +42,10 @@ import com.facebook.presto.sql.planner.optimizations.OptimizerInformationCollector; import com.facebook.presto.sql.planner.optimizations.OptimizerResultCollector; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.security.Principal; import java.util.HashMap; @@ -445,7 +446,8 @@ public Session beginTransactionId(TransactionId transactionId, boolean enableRol identity.getExtraCredentials(), identity.getExtraAuthenticators(), identity.getSelectedUser(), - identity.getReasonForSelect()), + identity.getReasonForSelect(), + identity.getCertificates()), source, catalog, schema, @@ -470,6 +472,50 @@ public Session beginTransactionId(TransactionId transactionId, boolean enableRol queryType); } + @VisibleForTesting + public Session clearTransaction(TransactionManager transactionManager, AccessControl accessControl) + { + checkArgument(this.transactionId.isPresent(), "Session does not have an active transaction"); + requireNonNull(transactionManager, "transactionManager is null"); + requireNonNull(accessControl, "accessControl is null"); + + for (Entry property : systemProperties.entrySet()) { + // verify permissions + accessControl.checkCanSetSystemSessionProperty(identity, context, property.getKey()); + + // validate session property value + sessionPropertyManager.validateSystemSessionProperty(property.getKey(), property.getValue()); + } + + return new Session( + queryId, + Optional.empty(), + clientTransactionSupport, + identity, + source, + catalog, + schema, + traceToken, + timeZoneKey, + locale, + remoteUserAddress, + userAgent, + clientInfo, + clientTags, + resourceEstimates, + startTime, + systemProperties, + connectorProperties, + unprocessedCatalogProperties, + sessionPropertyManager, + preparedStatements, + sessionFunctions, + tracer, + warningCollector, + runtimeStats, + queryType); + } + public ConnectorSession toConnectorSession() { return new FullConnectorSession(this, identity.toConnectorIdentity()); @@ -495,17 +541,26 @@ public SqlFunctionProperties getSqlFunctionProperties() .build(); } - public ConnectorSession toConnectorSession(ConnectorId connectorId) + public ConnectorSession toConnectorSession(ConnectorId connectorId, RuntimeStats runtimeStats) { requireNonNull(connectorId, "connectorId is null"); - return new FullConnectorSession( - this, - identity.toConnectorIdentity(connectorId.getCatalogName()), - connectorProperties.getOrDefault(connectorId, ImmutableMap.of()), - connectorId, - connectorId.getCatalogName(), - sessionPropertyManager); + FullConnectorSession.Builder connectorSessionBuilder = FullConnectorSession + .builder( + this, + identity.toConnectorIdentity(connectorId.getCatalogName()), + connectorProperties.getOrDefault(connectorId, ImmutableMap.of()), + connectorId, + connectorId.getCatalogName(), + sessionPropertyManager) + .setRuntimeStats(runtimeStats); + + return connectorSessionBuilder.build(); + } + + public ConnectorSession toConnectorSession(ConnectorId connectorId) + { + return toConnectorSession(connectorId, runtimeStats); } public SessionRepresentation toSessionRepresentation() diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index f1e94ae0debd3..6991f439e2755 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -13,8 +13,11 @@ */ package com.facebook.presto; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.WarningHandlingLevel; import com.facebook.presto.common.plan.PlanCanonicalizationStrategy; +import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.cost.HistoryBasedOptimizationConfig; import com.facebook.presto.execution.QueryManagerConfig; import com.facebook.presto.execution.QueryManagerConfig.ExchangeMaterializationStrategy; @@ -24,8 +27,10 @@ import com.facebook.presto.execution.warnings.WarningCollectorConfig; import com.facebook.presto.memory.MemoryManagerConfig; import com.facebook.presto.memory.NodeMemoryConfig; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.eventlistener.CTEInformation; +import com.facebook.presto.spi.security.ViewSecurity; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spiller.NodeSpillConfig; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -40,24 +45,23 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartitioningPrecisionStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PushDownFilterThroughCrossJoinStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeNullSourceKeyInSemiJoinStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.ShardedJoinStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig.ShuffleForTableScanStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.SingleStreamSpillerChoice; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.planner.CompilerConfig; -import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.tracing.TracingConfig; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Comparator; import java.util.List; import java.util.Optional; import java.util.OptionalInt; +import java.util.Set; import java.util.stream.Stream; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -77,8 +81,9 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.ALWAYS; import static com.facebook.presto.sql.analyzer.FeaturesConfig.PartialAggregationStrategy.NEVER; -import static com.google.common.base.Preconditions.checkArgument; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.parseQueryTypesFromString; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.Boolean.TRUE; import static java.lang.Math.min; import static java.lang.String.format; @@ -87,6 +92,7 @@ public final class SystemSessionProperties { + public static final String MAX_PREFIXES_COUNT = "max_prefixes_count"; public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation"; public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type"; public static final String JOIN_MAX_BROADCAST_TABLE_SIZE = "join_max_broadcast_table_size"; @@ -115,6 +121,7 @@ public final class SystemSessionProperties public static final String QUERY_MAX_BROADCAST_MEMORY = "query_max_broadcast_memory"; public static final String QUERY_MAX_TOTAL_MEMORY = "query_max_total_memory"; public static final String QUERY_MAX_TOTAL_MEMORY_PER_NODE = "query_max_total_memory_per_node"; + public static final String QUERY_MAX_QUEUED_TIME = "query_max_queued_time"; public static final String QUERY_MAX_EXECUTION_TIME = "query_max_execution_time"; public static final String QUERY_MAX_RUN_TIME = "query_max_run_time"; public static final String RESOURCE_OVERCOMMIT = "resource_overcommit"; @@ -203,6 +210,7 @@ public final class SystemSessionProperties public static final String OPTIMIZED_REPARTITIONING_ENABLED = "optimized_repartitioning"; public static final String AGGREGATION_PARTITIONING_MERGING_STRATEGY = "aggregation_partitioning_merging_strategy"; public static final String LIST_BUILT_IN_FUNCTIONS_ONLY = "list_built_in_functions_only"; + public static final String NON_BUILT_IN_FUNCTION_NAMESPACES_TO_LIST_FUNCTIONS = "non_built_in_function_namespaces_to_list_functions"; public static final String PARTITIONING_PRECISION_STRATEGY = "partitioning_precision_strategy"; public static final String EXPERIMENTAL_FUNCTIONS_ENABLED = "experimental_functions_enabled"; public static final String OPTIMIZE_COMMON_SUB_EXPRESSIONS = "optimize_common_sub_expressions"; @@ -241,9 +249,13 @@ public final class SystemSessionProperties public static final String MATERIALIZED_VIEW_DATA_CONSISTENCY_ENABLED = "materialized_view_data_consistency_enabled"; public static final String CONSIDER_QUERY_FILTERS_FOR_MATERIALIZED_VIEW_PARTITIONS = "consider-query-filters-for-materialized-view-partitions"; public static final String QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED = "query_optimization_with_materialized_view_enabled"; + public static final String LEGACY_MATERIALIZED_VIEWS = "legacy_materialized_views"; + public static final String MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED = "materialized_view_allow_full_refresh_enabled"; + public static final String MATERIALIZED_VIEW_STALE_READ_BEHAVIOR = "materialized_view_stale_read_behavior"; public static final String AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY = "aggregation_if_to_filter_rewrite_strategy"; public static final String JOINS_NOT_NULL_INFERENCE_STRATEGY = "joins_not_null_inference_strategy"; public static final String RESOURCE_AWARE_SCHEDULING_STRATEGY = "resource_aware_scheduling_strategy"; + public static final String SCHEDULE_SPLITS_BASED_ON_TASK_LOAD = "schedule_splits_based_on_task_load"; public static final String HEAP_DUMP_ON_EXCEEDED_MEMORY_LIMIT_ENABLED = "heap_dump_on_exceeded_memory_limit_enabled"; public static final String EXCEEDED_MEMORY_LIMIT_HEAP_DUMP_FILE_DIRECTORY = "exceeded_memory_limit_heap_dump_file_directory"; public static final String DISTRIBUTED_TRACING_MODE = "distributed_tracing_mode"; @@ -255,12 +267,15 @@ public final class SystemSessionProperties public static final String MAX_STAGE_COUNT_FOR_EAGER_SCHEDULING = "max_stage_count_for_eager_scheduling"; public static final String HYPERLOGLOG_STANDARD_ERROR_WARNING_THRESHOLD = "hyperloglog_standard_error_warning_threshold"; public static final String PREFER_MERGE_JOIN_FOR_SORTED_INPUTS = "prefer_merge_join_for_sorted_inputs"; + public static final String PREFER_SORT_MERGE_JOIN = "prefer_sort_merge_join"; + public static final String SORTED_EXCHANGE_ENABLED = "sorted_exchange_enabled"; public static final String SEGMENTED_AGGREGATION_ENABLED = "segmented_aggregation_enabled"; public static final String USE_HISTORY_BASED_PLAN_STATISTICS = "use_history_based_plan_statistics"; public static final String TRACK_HISTORY_BASED_PLAN_STATISTICS = "track_history_based_plan_statistics"; public static final String TRACK_HISTORY_STATS_FROM_FAILED_QUERIES = "track_history_stats_from_failed_queries"; public static final String USE_PERFECTLY_CONSISTENT_HISTORIES = "use_perfectly_consistent_histories"; public static final String HISTORY_CANONICAL_PLAN_NODE_LIMIT = "history_canonical_plan_node_limit"; + public static final String HISTORY_BASED_OPTIMIZER_ESTIMATE_SIZE_USING_VARIABLES = "history_based_optimizer_estimate_size_using_variables"; public static final String HISTORY_BASED_OPTIMIZER_TIMEOUT_LIMIT = "history_based_optimizer_timeout_limit"; public static final String RESTRICT_HISTORY_BASED_OPTIMIZATION_TO_COMPLEX_QUERY = "restrict_history_based_optimization_to_complex_query"; public static final String HISTORY_INPUT_TABLE_STATISTICS_MATCHING_THRESHOLD = "history_input_table_statistics_matching_threshold"; @@ -268,13 +283,16 @@ public final class SystemSessionProperties public static final String ENABLE_VERBOSE_HISTORY_BASED_OPTIMIZER_RUNTIME_STATS = "enable_verbose_history_based_optimizer_runtime_stats"; public static final String LOG_QUERY_PLANS_USED_IN_HISTORY_BASED_OPTIMIZER = "log_query_plans_used_in_history_based_optimizer"; public static final String ENFORCE_HISTORY_BASED_OPTIMIZER_REGISTRATION_TIMEOUT = "enforce_history_based_optimizer_register_timeout"; + public static final String QUERY_TYPES_ENABLED_FOR_HISTORY_BASED_OPTIMIZATION = "query_types_enabled_for_history_based_optimization"; public static final String MAX_LEAF_NODES_IN_PLAN = "max_leaf_nodes_in_plan"; public static final String LEAF_NODE_LIMIT_ENABLED = "leaf_node_limit_enabled"; public static final String PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID = "push_remote_exchange_through_group_id"; public static final String OPTIMIZE_MULTIPLE_APPROX_PERCENTILE_ON_SAME_FIELD = "optimize_multiple_approx_percentile_on_same_field"; + public static final String OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE = "optimize_multiple_approx_distinct_on_same_type"; public static final String RANDOMIZE_OUTER_JOIN_NULL_KEY = "randomize_outer_join_null_key"; public static final String RANDOMIZE_OUTER_JOIN_NULL_KEY_STRATEGY = "randomize_outer_join_null_key_strategy"; public static final String RANDOMIZE_OUTER_JOIN_NULL_KEY_NULL_RATIO_THRESHOLD = "randomize_outer_join_null_key_null_ratio_threshold"; + public static final String RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY = "randomize_null_source_key_in_semi_join_strategy"; public static final String SHARDED_JOINS_STRATEGY = "sharded_joins_strategy"; public static final String JOIN_SHARD_COUNT = "join_shard_count"; public static final String IN_PREDICATES_AS_INNER_JOINS_ENABLED = "in_predicates_as_inner_joins_enabled"; @@ -320,6 +338,7 @@ public final class SystemSessionProperties public static final String REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION = "rewrite_expression_with_constant_expression"; public static final String PRINT_ESTIMATED_STATS_FROM_CACHE = "print_estimated_stats_from_cache"; public static final String REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT = "remove_cross_join_with_constant_single_row_input"; + public static final String OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT = "optimize_conditional_constant_approximate_distinct"; public static final String EAGER_PLAN_VALIDATION_ENABLED = "eager_plan_validation_enabled"; public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode"; public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side"; @@ -327,13 +346,28 @@ public final class SystemSessionProperties public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns"; public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values"; public static final String INCLUDE_VALUES_NODE_IN_CONNECTOR_OPTIMIZER = "include_values_node_in_connector_optimizer"; + public static final String ENABLE_EMPTY_CONNECTOR_OPTIMIZER = "enable_empty_connector_optimizer"; public static final String SINGLE_NODE_EXECUTION_ENABLED = "single_node_execution_enabled"; + public static final String BROADCAST_SEMI_JOIN_FOR_DELETE = "broadcast_semi_join_for_delete"; public static final String EXPRESSION_OPTIMIZER_NAME = "expression_optimizer_name"; public static final String ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID = "add_exchange_below_partial_aggregation_over_group_id"; public static final String QUERY_CLIENT_TIMEOUT = "query_client_timeout"; + public static final String REWRITE_MIN_MAX_BY_TO_TOP_N = "rewrite_min_max_by_to_top_n"; + public static final String ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD = "add_distinct_below_semi_join_build"; + public static final String UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING = "utilize_unique_property_in_query_planning"; + public static final String PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS = "pushdown_subfields_for_map_functions"; + public static final String PUSHDOWN_SUBFIELDS_FOR_CARDINALITY = "pushdown_subfields_for_cardinality"; + public static final String MAX_SERIALIZABLE_OBJECT_SIZE = "max_serializable_object_size"; + public static final String EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE = "expression_optimizer_in_row_expression_rewrite"; + public static final String TABLE_SCAN_SHUFFLE_PARALLELISM_THRESHOLD = "table_scan_shuffle_parallelism_threshold"; + public static final String TABLE_SCAN_SHUFFLE_STRATEGY = "table_scan_shuffle_strategy"; + public static final String SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION = "skip_pushdown_through_exchange_for_remote_projection"; + public static final String REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM = "remote_function_names_for_fixed_parallelism"; + public static final String REMOTE_FUNCTION_FIXED_PARALLELISM_TASK_COUNT = "remote_function_fixed_parallelism_task_count"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; + public static final String NATIVE_MAX_SPLIT_PRELOAD_PER_DRIVER = "native_max_split_preload_per_driver"; public static final String NATIVE_EXECUTION_ENABLED = "native_execution_enabled"; private static final String NATIVE_EXECUTION_EXECUTABLE_PATH = "native_execution_executable_path"; private static final String NATIVE_EXECUTION_PROGRAM_ARGUMENTS = "native_execution_program_arguments"; @@ -343,7 +377,6 @@ public final class SystemSessionProperties public static final String NATIVE_MIN_COLUMNAR_ENCODING_CHANNELS_TO_PREFER_ROW_WISE_ENCODING = "native_min_columnar_encoding_channels_to_prefer_row_wise_encoding"; public static final String NATIVE_ENFORCE_JOIN_BUILD_INPUT_PARTITION = "native_enforce_join_build_input_partition"; public static final String NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED = "native_execution_scale_writer_threads_enabled"; - private static final String NATIVE_EXECUTION_TYPE_REWRITE_ENABLED = "native_execution_type_rewrite_enabled"; private final List> sessionProperties; @@ -380,6 +413,11 @@ public SystemSessionProperties( HistoryBasedOptimizationConfig historyBasedOptimizationConfig) { sessionProperties = ImmutableList.of( + integerProperty( + MAX_PREFIXES_COUNT, + "Maximum number of prefixes (catalog/schema/table scopes used to narrow metadata lookups) that Presto generates when querying information_schema.", + featuresConfig.getMaxPrefixesCount(), + false), stringProperty( EXECUTION_POLICY, "Policy used for scheduling query tasks", @@ -546,7 +584,7 @@ public SystemSessionProperties( Integer.class, taskManagerConfig.getTaskConcurrency(), false, - value -> validateValueIsPowerOfTwo(requireNonNull(value, "value is null"), TASK_CONCURRENCY), + featuresConfig.isNativeExecutionEnabled() ? value -> validateIntegerValue(value, TASK_CONCURRENCY, 1, false) : value -> validateValueIsPowerOfTwo(requireNonNull(value, "value is null"), TASK_CONCURRENCY), value -> value), booleanProperty( TASK_SHARE_INDEX_LOADING, @@ -562,6 +600,15 @@ public SystemSessionProperties( false, value -> Duration.valueOf((String) value), Duration::toString), + new PropertyMetadata<>( + QUERY_MAX_QUEUED_TIME, + "Maximum Queued time of a query", + VARCHAR, + Duration.class, + queryManagerConfig.getQueryMaxQueuedTime(), + false, + value -> Duration.valueOf((String) value), + Duration::toString), new PropertyMetadata<>( QUERY_MAX_EXECUTION_TIME, "Maximum execution time of a query", @@ -1110,6 +1157,11 @@ public SystemSessionProperties( "Only List built-in functions in SHOW FUNCTIONS", featuresConfig.isListBuiltInFunctionsOnly(), false), + stringProperty( + NON_BUILT_IN_FUNCTION_NAMESPACES_TO_LIST_FUNCTIONS, + "Comma-separated list of function namespace names from which to list non-built-in functions. Only takes effect when LIST_BUILT_IN_FUNCTIONS_ONLY is false. If empty, functions from all available function namespaces will be listed.", + "", + false), new PropertyMetadata<>( PARTITIONING_PRECISION_STRATEGY, format("The strategy to use to pick when to repartition. Options are %s", @@ -1329,6 +1381,41 @@ public SystemSessionProperties( "Enable query optimization with materialized view", featuresConfig.isQueryOptimizationWithMaterializedViewEnabled(), true), + new PropertyMetadata<>( + LEGACY_MATERIALIZED_VIEWS, + "Experimental: Use legacy materialized views. This feature is under active development and may change " + + "or be removed at any time. Do not disable in production environments. " + + "To allow toggling this property via session, set experimental.allow-legacy-materialized-views-toggle=true in config.", + BOOLEAN, + Boolean.class, + featuresConfig.isLegacyMaterializedViews(), + true, + value -> { + if (!featuresConfig.isAllowLegacyMaterializedViewsToggle()) { + throw new PrestoException(INVALID_SESSION_PROPERTY, + "Cannot toggle legacy_materialized_views session property. " + + "Set experimental.allow-legacy-materialized-views-toggle=true in config to allow changing this setting."); + } + return (Boolean) value; + }, + object -> object), + booleanProperty( + MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED, + "Allow full refresh of MV when it's empty - potentially high cost.", + featuresConfig.isMaterializedViewAllowFullRefreshEnabled(), + true), + new PropertyMetadata<>( + MATERIALIZED_VIEW_STALE_READ_BEHAVIOR, + format("Default behavior when reading from a stale materialized view. Valid values: %s", + Stream.of(MaterializedViewStaleReadBehavior.values()) + .map(MaterializedViewStaleReadBehavior::name) + .collect(joining(","))), + VARCHAR, + MaterializedViewStaleReadBehavior.class, + featuresConfig.getMaterializedViewStaleReadBehavior(), + false, + value -> MaterializedViewStaleReadBehavior.valueOf(((String) value).toUpperCase()), + MaterializedViewStaleReadBehavior::name), stringProperty( DISTRIBUTED_TRACING_MODE, "Mode for distributed tracing. NO_TRACE, ALWAYS_TRACE, or SAMPLE_BASED", @@ -1370,6 +1457,16 @@ public SystemSessionProperties( "To make it work, the connector needs to guarantee and expose the data properties of the underlying table.", featuresConfig.isPreferMergeJoinForSortedInputs(), true), + booleanProperty( + PREFER_SORT_MERGE_JOIN, + "Prefer sort merge join for all joins. A SortNode is added if input is not already sorted.", + featuresConfig.isPreferSortMergeJoin(), + true), + booleanProperty( + SORTED_EXCHANGE_ENABLED, + "(Experimental) Enable pushing sort operations down to exchange nodes for distributed queries", + featuresConfig.isSortedExchangeEnabled(), + false), booleanProperty( SEGMENTED_AGGREGATION_ENABLED, "Enable segmented aggregation.", @@ -1399,6 +1496,11 @@ public SystemSessionProperties( false, value -> ResourceAwareSchedulingStrategy.valueOf(((String) value).toUpperCase()), ResourceAwareSchedulingStrategy::name), + booleanProperty( + SCHEDULE_SPLITS_BASED_ON_TASK_LOAD, + "Schedule splits based on task load, rather than on the node load.", + nodeSchedulerConfig.isScheduleSplitsBasedOnTaskLoad(), + false), stringProperty( ANALYZER_TYPE, "Analyzer type to use.", @@ -1488,6 +1590,11 @@ public SystemSessionProperties( "Enable history based optimization only for complex queries, i.e. queries with join and aggregation", true, false), + booleanProperty( + HISTORY_BASED_OPTIMIZER_ESTIMATE_SIZE_USING_VARIABLES, + "Estimate the size of the plan node output with variable statistics for HBO", + featuresConfig.isHistoryBasedOptimizerEstimateSizeUsingVariables(), + false), new PropertyMetadata<>( HISTORY_INPUT_TABLE_STATISTICS_MATCHING_THRESHOLD, "When the size difference between current table and history table exceed this threshold, do not match history statistics", @@ -1520,6 +1627,18 @@ public SystemSessionProperties( "Enforce timeout for query registration in HBO optimizer", featuresConfig.isEnforceTimeoutForHBOQueryRegistration(), false), + new PropertyMetadata<>( + QUERY_TYPES_ENABLED_FOR_HISTORY_BASED_OPTIMIZATION, + format("Query types which are enabled for history based optimization. Specify as a comma-separated string of QueryType values. Allowed options: %s", + Stream.of(QueryType.values()) + .map(QueryType::name) + .collect(joining(","))), + VARCHAR, + (Class>) (Class) List.class, + featuresConfig.getQueryTypesEnabledForHbo(), + false, + value -> parseQueryTypesFromString((String) value), + queryTypes -> ((List) queryTypes).stream().map(QueryType::name).collect(joining(","))), new PropertyMetadata<>( MAX_LEAF_NODES_IN_PLAN, "Maximum number of leaf nodes in the logical plan of SQL statement", @@ -1544,6 +1663,11 @@ public SystemSessionProperties( "Combine individual approx_percentile calls on individual field to evaluation on an array", featuresConfig.isOptimizeMultipleApproxPercentileOnSameFieldEnabled(), false), + booleanProperty( + OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, + "Combine individual approx_distinct calls on expressions of the same type using set_agg", + featuresConfig.isOptimizeMultipleApproxDistinctOnSameTypeEnabled(), + false), booleanProperty( NATIVE_AGGREGATION_SPILL_ALL, "Native Execution only. If true and spilling has been triggered during the input " + @@ -1552,6 +1676,11 @@ public SystemSessionProperties( "output processing stage.", true, false), + integerProperty( + NATIVE_MAX_SPLIT_PRELOAD_PER_DRIVER, + "Native Execution only. Maximum number of splits to preload per driver. Set to 0 to disable preloading.", + 0, + false), booleanProperty( NATIVE_EXECUTION_ENABLED, "Enable execution on native engine", @@ -1589,6 +1718,18 @@ public SystemSessionProperties( "Enable randomizing null join key for outer join when ratio of null join keys exceeds the threshold", 0.02, false), + new PropertyMetadata<>( + RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, + format("When to apply randomization to source join key in semi joins to mitigate null skew. Value must be one of: %s", + Stream.of(RandomizeNullSourceKeyInSemiJoinStrategy.values()) + .map(RandomizeNullSourceKeyInSemiJoinStrategy::name) + .collect(joining(","))), + VARCHAR, + RandomizeNullSourceKeyInSemiJoinStrategy.class, + featuresConfig.getRandomizeNullSourceKeyInSemiJoinStrategy(), + false, + value -> RandomizeNullSourceKeyInSemiJoinStrategy.valueOf(((String) value).toUpperCase()), + RandomizeNullSourceKeyInSemiJoinStrategy::name), new PropertyMetadata<>( SHARDED_JOINS_STRATEGY, format("When to shard joins to mitigate skew. Value must be one of: %s", @@ -1825,15 +1966,15 @@ public SystemSessionProperties( new PropertyMetadata<>( DEFAULT_VIEW_SECURITY_MODE, format("Set default view security mode. Options are: %s", - Stream.of(CreateView.Security.values()) - .map(CreateView.Security::name) + Stream.of(ViewSecurity.values()) + .map(ViewSecurity::name) .collect(joining(","))), VARCHAR, - CreateView.Security.class, + ViewSecurity.class, featuresConfig.getDefaultViewSecurityMode(), false, - value -> CreateView.Security.valueOf(((String) value).toUpperCase()), - CreateView.Security::name), + value -> ViewSecurity.valueOf(((String) value).toUpperCase()), + ViewSecurity::name), booleanProperty( JOIN_PREFILTER_BUILD_SIDE, "Prefiltering the build/inner side of a join with keys from the other side", @@ -1855,6 +1996,10 @@ public SystemSessionProperties( "Include values node for connector optimizer", featuresConfig.isIncludeValuesNodeInConnectorOptimizer(), false), + booleanProperty(ENABLE_EMPTY_CONNECTOR_OPTIMIZER, + "Run optimizers which optimize queries with values node", + false, + false), booleanProperty( INNER_JOIN_PUSHDOWN_ENABLED, "Enable Join Predicate Pushdown", @@ -1864,7 +2009,7 @@ public SystemSessionProperties( INEQUALITY_JOIN_PUSHDOWN_ENABLED, "Enable Join Pushdown for Inequality Predicates", featuresConfig.isInEqualityJoinPushdownEnabled(), - false), + false), integerProperty( NATIVE_MIN_COLUMNAR_ENCODING_CHANNELS_TO_PREFER_ROW_WISE_ENCODING, "Minimum number of columnar encoding channels to consider row wise encoding for partitioned exchange. Native execution only", @@ -1875,23 +2020,82 @@ public SystemSessionProperties( "Enable single node execution", featuresConfig.isSingleNodeExecutionEnabled(), false), + booleanProperty( + REWRITE_MIN_MAX_BY_TO_TOP_N, + "rewrite min_by/max_by to top n", + featuresConfig.isRewriteMinMaxByToTopNEnabled(), + false), booleanProperty(NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED, "Enable automatic scaling of writer threads", featuresConfig.isNativeExecutionScaleWritersThreadsEnabled(), !featuresConfig.isNativeExecutionEnabled()), - booleanProperty(NATIVE_EXECUTION_TYPE_REWRITE_ENABLED, - "Enable type rewrite for native execution", - featuresConfig.isNativeExecutionTypeRewriteEnabled(), - !featuresConfig.isNativeExecutionEnabled()), stringProperty( EXPRESSION_OPTIMIZER_NAME, "Configure which expression optimizer to use", featuresConfig.getExpressionOptimizerName(), false), + stringProperty( + EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, + "Expression optimizer used in row expression rewrite, empty means no rewrite", + featuresConfig.getExpressionOptimizerUsedInRowExpressionRewrite(), + false), + booleanProperty(BROADCAST_SEMI_JOIN_FOR_DELETE, + "Enforce broadcast join for semi join in delete", + featuresConfig.isBroadcastSemiJoinForDelete(), + false), booleanProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, "Enable adding an exchange below partial aggregation over a GroupId node to improve partial aggregation performance", featuresConfig.getAddExchangeBelowPartialAggregationOverGroupId(), false), + booleanProperty( + OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, + "Optimize out APPROX_DISTINCT operations over constant conditionals", + featuresConfig.isOptimizeConditionalApproxDistinct(), + false), + booleanProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, + "Enable subfield pruning for map functions, currently include map_subset and map_filter", + featuresConfig.isPushdownSubfieldForMapFunctions(), + false), + booleanProperty(PUSHDOWN_SUBFIELDS_FOR_CARDINALITY, + "Enable subfield pruning for cardinality() function to skip reading keys and values", + featuresConfig.isPushdownSubfieldForCardinality(), + false), + longProperty(MAX_SERIALIZABLE_OBJECT_SIZE, + "Configure the maximum byte size of a serializable object in expression interpreters", + featuresConfig.getMaxSerializableObjectSize(), + false), + doubleProperty( + TABLE_SCAN_SHUFFLE_PARALLELISM_THRESHOLD, + "Parallelism threshold for adding a shuffle above table scan. When the table's parallelism factor is below this threshold (0.0-1.0) and TABLE_SCAN_SHUFFLE_STRATEGY is COST_BASED, a round-robin shuffle exchange is added above the table scan to redistribute data", + featuresConfig.getTableScanShuffleParallelismThreshold(), + false), + new PropertyMetadata<>( + TABLE_SCAN_SHUFFLE_STRATEGY, + format("Strategy for adding shuffle above table scan to redistribute data. Options are %s", + Stream.of(ShuffleForTableScanStrategy.values()) + .map(ShuffleForTableScanStrategy::name) + .collect(joining(","))), + VARCHAR, + ShuffleForTableScanStrategy.class, + featuresConfig.getTableScanShuffleStrategy(), + false, + value -> ShuffleForTableScanStrategy.valueOf(((String) value).toUpperCase()), + ShuffleForTableScanStrategy::name), + booleanProperty( + SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, + "Skip pushing down remote projection through exchange", + featuresConfig.isSkipPushdownThroughExchangeForRemoteProjection(), + false), + stringProperty( + REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, + "Regex pattern to match remote function names that should use fixed parallelism", + featuresConfig.getRemoteFunctionNamesForFixedParallelism(), + false), + integerProperty( + REMOTE_FUNCTION_FIXED_PARALLELISM_TASK_COUNT, + "Number of tasks to use for remote functions matching the fixed parallelism pattern. If not set, the default hash partition count will be used.", + featuresConfig.getRemoteFunctionFixedParallelismTaskCount(), + false), new PropertyMetadata<>( QUERY_CLIENT_TIMEOUT, "Configures how long the query runs without contact from the client application, such as the CLI, before it's abandoned", @@ -1900,7 +2104,20 @@ public SystemSessionProperties( queryManagerConfig.getClientTimeout(), false, value -> Duration.valueOf((String) value), - Duration::toString)); + Duration::toString), + booleanProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, + "Utilize the unique property of input columns in query planning", + featuresConfig.isUtilizeUniquePropertyInQueryPlanning(), + false), + booleanProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, + "Add distinct aggregation below semi join build", + featuresConfig.isAddDistinctBelowSemiJoinBuild(), + false)); + } + + public static int getMaxPrefixesCount(Session session) + { + return session.getSystemProperty(MAX_PREFIXES_COUNT, Integer.class); } public static boolean isSpoolingOutputBufferEnabled(Session session) @@ -2150,6 +2367,11 @@ public static Duration getQueryMaxRunTime(Session session) return session.getSystemProperty(QUERY_MAX_RUN_TIME, Duration.class); } + public static Duration getQueryMaxQueuedTime(Session session) + { + return session.getSystemProperty(QUERY_MAX_QUEUED_TIME, Duration.class); + } + public static Duration getQueryMaxExecutionTime(Session session) { return session.getSystemProperty(QUERY_MAX_EXECUTION_TIME, Duration.class); @@ -2219,7 +2441,11 @@ public static OptionalInt getConcurrentLifespansPerNode(Session session) return OptionalInt.empty(); } else { - checkArgument(result > 0, "Concurrent lifespans per node must be positive if set to non-zero"); + if (result < 0) { + throw new PrestoException( + INVALID_SESSION_PROPERTY, + format("Concurrent lifespans per node must be positive if set to non-zero. Found: %s", result)); + } return OptionalInt.of(result); } } @@ -2232,7 +2458,11 @@ public static int getInitialSplitsPerNode(Session session) public static int getQueryPriority(Session session) { Integer priority = session.getSystemProperty(QUERY_PRIORITY, Integer.class); - checkArgument(priority > 0, "Query priority must be positive"); + if (priority <= 0) { + throw new PrestoException( + INVALID_SESSION_PROPERTY, + format("Query priority must be greater than zero. Found: %s", priority)); + } return priority; } @@ -2352,6 +2582,11 @@ public static boolean isSingleNodeExecutionEnabled(Session session) return session.getSystemProperty(SINGLE_NODE_EXECUTION_ENABLED, Boolean.class); } + public static boolean isRewriteMinMaxByToTopNEnabled(Session session) + { + return session.getSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, Boolean.class); + } + public static boolean isPushAggregationThroughJoin(Session session) { return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class); @@ -2610,6 +2845,11 @@ public static boolean isListBuiltInFunctionsOnly(Session session) return session.getSystemProperty(LIST_BUILT_IN_FUNCTIONS_ONLY, Boolean.class); } + public static Set getNonBuiltInFunctionNamespacesToListFunctions(Session session) + { + return Splitter.on(",").trimResults().splitToList(session.getSystemProperty(NON_BUILT_IN_FUNCTION_NAMESPACES_TO_LIST_FUNCTIONS, String.class)).stream().filter(x -> !x.isEmpty()).collect(toImmutableSet()); + } + public static boolean isExactPartitioningPreferred(Session session) { return session.getSystemProperty(PARTITIONING_PRECISION_STRATEGY, PartitioningPrecisionStrategy.class) @@ -2779,6 +3019,21 @@ public static boolean isQueryOptimizationWithMaterializedViewEnabled(Session ses return session.getSystemProperty(QUERY_OPTIMIZATION_WITH_MATERIALIZED_VIEW_ENABLED, Boolean.class); } + public static boolean isLegacyMaterializedViews(Session session) + { + return session.getSystemProperty(LEGACY_MATERIALIZED_VIEWS, Boolean.class); + } + + public static boolean isMaterializedViewAllowFullRefreshEnabled(Session session) + { + return session.getSystemProperty(MATERIALIZED_VIEW_ALLOW_FULL_REFRESH_ENABLED, Boolean.class); + } + + public static MaterializedViewStaleReadBehavior getMaterializedViewStaleReadBehavior(Session session) + { + return session.getSystemProperty(MATERIALIZED_VIEW_STALE_READ_BEHAVIOR, MaterializedViewStaleReadBehavior.class); + } + public static boolean isVerboseRuntimeStatsEnabled(Session session) { return session.getSystemProperty(VERBOSE_RUNTIME_STATS_ENABLED, Boolean.class); @@ -2824,6 +3079,16 @@ public static boolean preferMergeJoinForSortedInputs(Session session) return session.getSystemProperty(PREFER_MERGE_JOIN_FOR_SORTED_INPUTS, Boolean.class); } + public static boolean preferSortMergeJoin(Session session) + { + return session.getSystemProperty(PREFER_SORT_MERGE_JOIN, Boolean.class); + } + + public static boolean isSortedExchangeEnabled(Session session) + { + return session.getSystemProperty(SORTED_EXCHANGE_ENABLED, Boolean.class); + } + public static boolean isSegmentedAggregationEnabled(Session session) { return session.getSystemProperty(SEGMENTED_AGGREGATION_ENABLED, Boolean.class); @@ -2834,6 +3099,11 @@ public static boolean isCombineApproxPercentileEnabled(Session session) return session.getSystemProperty(OPTIMIZE_MULTIPLE_APPROX_PERCENTILE_ON_SAME_FIELD, Boolean.class); } + public static boolean isCombineApproxDistinctEnabled(Session session) + { + return session.getSystemProperty(OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, Boolean.class); + } + public static AggregationIfToFilterRewriteStrategy getAggregationIfToFilterRewriteStrategy(Session session) { return session.getSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, AggregationIfToFilterRewriteStrategy.class); @@ -2844,6 +3114,11 @@ public static ResourceAwareSchedulingStrategy getResourceAwareSchedulingStrategy return session.getSystemProperty(RESOURCE_AWARE_SCHEDULING_STRATEGY, ResourceAwareSchedulingStrategy.class); } + public static Boolean isScheduleSplitsBasedOnTaskLoad(Session session) + { + return session.getSystemProperty(SCHEDULE_SPLITS_BASED_ON_TASK_LOAD, Boolean.class); + } + public static String getAnalyzerType(Session session) { return session.getSystemProperty(ANALYZER_TYPE, String.class); @@ -2914,6 +3189,11 @@ public static boolean restrictHistoryBasedOptimizationToComplexQuery(Session ses return session.getSystemProperty(RESTRICT_HISTORY_BASED_OPTIMIZATION_TO_COMPLEX_QUERY, Boolean.class); } + public static boolean estimateSizeUsingVariablesForHBO(Session session) + { + return session.getSystemProperty(HISTORY_BASED_OPTIMIZER_ESTIMATE_SIZE_USING_VARIABLES, Boolean.class); + } + public static double getHistoryInputTableStatisticsMatchingThreshold(Session session) { return session.getSystemProperty(HISTORY_INPUT_TABLE_STATISTICS_MATCHING_THRESHOLD, Double.class); @@ -2933,6 +3213,11 @@ public static List getHistoryOptimizationPlanCanon return strategyList; } + public static List getQueryTypesEnabledForHBO(Session session) + { + return (List) session.getSystemProperty(QUERY_TYPES_ENABLED_FOR_HISTORY_BASED_OPTIMIZATION, List.class); + } + public static boolean enableVerboseHistoryBasedOptimizerRuntimeStats(Session session) { return session.getSystemProperty(ENABLE_VERBOSE_HISTORY_BASED_OPTIMIZER_RUNTIME_STATS, Boolean.class); @@ -2977,6 +3262,11 @@ public static double getRandomizeOuterJoinNullKeyNullRatioThreshold(Session sess return session.getSystemProperty(RANDOMIZE_OUTER_JOIN_NULL_KEY_NULL_RATIO_THRESHOLD, Double.class); } + public static RandomizeNullSourceKeyInSemiJoinStrategy getRandomizeNullSourceKeyInSemiJoinStrategy(Session session) + { + return session.getSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, RandomizeNullSourceKeyInSemiJoinStrategy.class); + } + public static ShardedJoinStrategy getShardedJoinStrategy(Session session) { return session.getSystemProperty(SHARDED_JOINS_STRATEGY, ShardedJoinStrategy.class); @@ -3092,7 +3382,7 @@ public static boolean isPullExpressionFromLambdaEnabled(Session session) return session.getSystemProperty(PULL_EXPRESSION_FROM_LAMBDA_ENABLED, Boolean.class); } - public static boolean isRwriteConstantArrayContainsToInExpressionEnabled(Session session) + public static boolean isRewriteConstantArrayContainsToInExpressionEnabled(Session session) { return session.getSystemProperty(REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION, Boolean.class); } @@ -3152,9 +3442,9 @@ public static boolean isEagerPlanValidationEnabled(Session session) return session.getSystemProperty(EAGER_PLAN_VALIDATION_ENABLED, Boolean.class); } - public static CreateView.Security getDefaultViewSecurityMode(Session session) + public static ViewSecurity getDefaultViewSecurityMode(Session session) { - return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, CreateView.Security.class); + return session.getSystemProperty(DEFAULT_VIEW_SECURITY_MODE, ViewSecurity.class); } public static boolean isJoinPrefilterEnabled(Session session) @@ -3192,6 +3482,11 @@ public static boolean isIncludeValuesNodeInConnectorOptimizer(Session session) return session.getSystemProperty(INCLUDE_VALUES_NODE_IN_CONNECTOR_OPTIMIZER, Boolean.class); } + public static boolean isEmptyConnectorOptimizerEnabled(Session session) + { + return session.getSystemProperty(ENABLE_EMPTY_CONNECTOR_OPTIMIZER, Boolean.class); + } + public static Boolean isInnerJoinPushdownEnabled(Session session) { return session.getSystemProperty(INNER_JOIN_PUSHDOWN_ENABLED, Boolean.class); @@ -3212,9 +3507,9 @@ public static boolean isNativeExecutionScaleWritersThreadsEnabled(Session sessio return session.getSystemProperty(NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED, Boolean.class); } - public static boolean isNativeExecutionTypeRewriteEnabled(Session session) + public static int getMaxSplitPreloadPerDriver(Session session) { - return session.getSystemProperty(NATIVE_EXECUTION_TYPE_REWRITE_ENABLED, Boolean.class); + return session.getSystemProperty(NATIVE_MAX_SPLIT_PRELOAD_PER_DRIVER, Integer.class); } public static String getExpressionOptimizerName(Session session) @@ -3222,11 +3517,41 @@ public static String getExpressionOptimizerName(Session session) return session.getSystemProperty(EXPRESSION_OPTIMIZER_NAME, String.class); } + public static String getExpressionOptimizerInRowExpressionRewrite(Session session) + { + return session.getSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, String.class); + } + + public static boolean isBroadcastSemiJoinForDeleteEnabled(Session session) + { + return session.getSystemProperty(BROADCAST_SEMI_JOIN_FOR_DELETE, Boolean.class); + } + public static boolean isEnabledAddExchangeBelowGroupId(Session session) { return session.getSystemProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, Boolean.class); } + public static boolean isPushSubfieldsForMapFunctionsEnabled(Session session) + { + return session.getSystemProperty(PUSHDOWN_SUBFIELDS_FOR_MAP_FUNCTIONS, Boolean.class); + } + + public static boolean isPushSubfieldsForCardinalityEnabled(Session session) + { + return session.getSystemProperty(PUSHDOWN_SUBFIELDS_FOR_CARDINALITY, Boolean.class); + } + + public static boolean isUtilizeUniquePropertyInQueryPlanningEnabled(Session session) + { + return session.getSystemProperty(UTILIZE_UNIQUE_PROPERTY_IN_QUERY_PLANNING, Boolean.class); + } + + public static boolean isAddDistinctBelowSemiJoinBuildEnabled(Session session) + { + return session.getSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, Boolean.class); + } + public static boolean isCanonicalizedJsonExtract(Session session) { return session.getSystemProperty(CANONICALIZED_JSON_EXTRACT, Boolean.class); @@ -3236,4 +3561,39 @@ public static Duration getQueryClientTimeout(Session session) { return session.getSystemProperty(QUERY_CLIENT_TIMEOUT, Duration.class); } + + public static boolean isOptimizeConditionalApproxDistinctEnabled(Session session) + { + return session.getSystemProperty(OPTIMIZE_CONDITIONAL_CONSTANT_APPROXIMATE_DISTINCT, Boolean.class); + } + + public static long getMaxSerializableObjectSize(Session session) + { + return session.getSystemProperty(MAX_SERIALIZABLE_OBJECT_SIZE, Long.class); + } + + public static double getTableScanShuffleParallelismThreshold(Session session) + { + return session.getSystemProperty(TABLE_SCAN_SHUFFLE_PARALLELISM_THRESHOLD, Double.class); + } + + public static ShuffleForTableScanStrategy getTableScanShuffleStrategy(Session session) + { + return session.getSystemProperty(TABLE_SCAN_SHUFFLE_STRATEGY, ShuffleForTableScanStrategy.class); + } + + public static boolean isSkipPushdownThroughExchangeForRemoteProjection(Session session) + { + return session.getSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, Boolean.class); + } + + public static String getRemoteFunctionNamesForFixedParallelism(Session session) + { + return session.getSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, String.class); + } + + public static int getRemoteFunctionFixedParallelismTaskCount(Session session) + { + return session.getSystemProperty(REMOTE_FUNCTION_FIXED_PARALLELISM_TASK_COUNT, Integer.class); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/block/BlockJsonSerde.java b/presto-main-base/src/main/java/com/facebook/presto/block/BlockJsonSerde.java index 1f86629fad6da..b0e38237c213c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/block/BlockJsonSerde.java +++ b/presto-main-base/src/main/java/com/facebook/presto/block/BlockJsonSerde.java @@ -26,8 +26,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/CatalogServer.java b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/CatalogServer.java index b4d5ae5759ac8..0dd4f8137ce82 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/CatalogServer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/CatalogServer.java @@ -29,8 +29,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RandomCatalogServerAddressSelector.java b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RandomCatalogServerAddressSelector.java index f98d7a0a84e54..9dd1f0ae757e1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RandomCatalogServerAddressSelector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RandomCatalogServerAddressSelector.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.HostAddress; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HostAndPort; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RemoteMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RemoteMetadataManager.java index b888ea547c736..6a7d166597061 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RemoteMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/catalogserver/RemoteMetadataManager.java @@ -33,8 +33,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorCodecManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorCodecManager.java new file mode 100644 index 0000000000000..37eb7220007bd --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorCodecManager.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector; + +import com.facebook.drift.codec.ThriftCodecManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.thrift.RemoteCodecProvider; +import com.google.inject.Provider; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; +import static java.util.Objects.requireNonNull; + +public class ConnectorCodecManager +{ + private final Map connectorCodecProviders = new ConcurrentHashMap<>(); + + @Inject + public ConnectorCodecManager(Provider thriftCodecManagerProvider) + { + requireNonNull(thriftCodecManagerProvider, "thriftCodecManager is null"); + + connectorCodecProviders.put(REMOTE_CONNECTOR_ID.toString(), new RemoteCodecProvider(thriftCodecManagerProvider)); + } + + public void addConnectorCodecProvider(ConnectorId connectorId, ConnectorCodecProvider connectorCodecProvider) + { + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(connectorCodecProvider, "connectorThriftCodecProvider is null"); + connectorCodecProviders.put(connectorId.getCatalogName(), connectorCodecProvider); + } + + public Optional> getConnectorSplitCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorSplitCodec); + } + + public Optional> getTransactionHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorTransactionHandleCodec); + } + + public Optional> getOutputTableHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorOutputTableHandleCodec); + } + + public Optional> getInsertTableHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorInsertTableHandleCodec); + } + + public Optional> getDeleteTableHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorDeleteTableHandleCodec); + } + + public Optional> getMergeTableHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorMergeTableHandleCodec); + } + + public Optional> getTableLayoutHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorTableLayoutHandleCodec); + } + + public Optional> getTableHandleCodec(String connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + return Optional.ofNullable(connectorCodecProviders.get(connectorId)).flatMap(ConnectorCodecProvider::getConnectorTableHandleCodec); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java index c36ff87b0460b..8182e56359c9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorContextInstance.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.FilterStatsCalculatorService; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpressionService; import static java.util.Objects.requireNonNull; @@ -32,6 +33,7 @@ public class ConnectorContextInstance { private final NodeManager nodeManager; private final TypeManager typeManager; + private final ProcedureRegistry procedureRegistry; private final FunctionMetadataManager functionMetadataManager; private final StandardFunctionResolution functionResolution; private final PageSorter pageSorter; @@ -44,6 +46,7 @@ public class ConnectorContextInstance public ConnectorContextInstance( NodeManager nodeManager, TypeManager typeManager, + ProcedureRegistry procedureRegistry, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution functionResolution, PageSorter pageSorter, @@ -55,6 +58,7 @@ public ConnectorContextInstance( { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.pageSorter = requireNonNull(pageSorter, "pageSorter is null"); @@ -77,6 +81,12 @@ public TypeManager getTypeManager() return typeManager; } + @Override + public ProcedureRegistry getProcedureRegistry() + { + return procedureRegistry; + } + @Override public FunctionMetadataManager getFunctionMetadataManager() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java index 56f3e15d72aa6..64997f954ca75 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.block.BlockEncodingSerde; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.connector.informationSchema.InformationSchemaConnector; @@ -28,7 +29,6 @@ import com.facebook.presto.index.IndexManager; import com.facebook.presto.metadata.Catalog; import com.facebook.presto.metadata.CatalogManager; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.MetadataManager; @@ -41,18 +41,23 @@ import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorAccessControl; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; import com.facebook.presto.spi.connector.ConnectorIndexProvider; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; @@ -71,11 +76,10 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -84,11 +88,14 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -105,13 +112,12 @@ public class ConnectorManager private final IndexManager indexManager; private final PartitioningProviderManager partitioningProviderManager; private final ConnectorPlanOptimizerManager connectorPlanOptimizerManager; - private final ConnectorMetadataUpdaterManager connectorMetadataUpdaterManager; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; private final PageSinkManager pageSinkManager; private final HandleResolver handleResolver; private final InternalNodeManager nodeManager; private final TypeManager typeManager; + private final ProcedureRegistry procedureRegistry; private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final NodeInfo nodeInfo; @@ -123,6 +129,7 @@ public class ConnectorManager private final FilterStatsCalculator filterStatsCalculator; private final BlockEncodingSerde blockEncodingSerde; private final ConnectorSystemConfig connectorSystemConfig; + private final ConnectorCodecManager connectorCodecManager; @GuardedBy("this") private final ConcurrentMap connectorFactories = new ConcurrentHashMap<>(); @@ -142,13 +149,12 @@ public ConnectorManager( IndexManager indexManager, PartitioningProviderManager partitioningProviderManager, ConnectorPlanOptimizerManager connectorPlanOptimizerManager, - ConnectorMetadataUpdaterManager connectorMetadataUpdaterManager, - ConnectorTypeSerdeManager connectorTypeSerdeManager, PageSinkManager pageSinkManager, HandleResolver handleResolver, InternalNodeManager nodeManager, NodeInfo nodeInfo, TypeManager typeManager, + ProcedureRegistry procedureRegistry, PageSorter pageSorter, PageIndexerFactory pageIndexerFactory, TransactionManager transactionManager, @@ -158,7 +164,8 @@ public ConnectorManager( DeterminismEvaluator determinismEvaluator, FilterStatsCalculator filterStatsCalculator, BlockEncodingSerde blockEncodingSerde, - FeaturesConfig featuresConfig) + FeaturesConfig featuresConfig, + ConnectorCodecManager connectorCodecManager) { this.metadataManager = requireNonNull(metadataManager, "metadataManager is null"); this.catalogManager = requireNonNull(catalogManager, "catalogManager is null"); @@ -168,12 +175,11 @@ public ConnectorManager( this.indexManager = requireNonNull(indexManager, "indexManager is null"); this.partitioningProviderManager = requireNonNull(partitioningProviderManager, "partitioningProviderManager is null"); this.connectorPlanOptimizerManager = requireNonNull(connectorPlanOptimizerManager, "connectorPlanOptimizerManager is null"); - this.connectorMetadataUpdaterManager = requireNonNull(connectorMetadataUpdaterManager, "connectorMetadataUpdaterManager is null"); - this.connectorTypeSerdeManager = requireNonNull(connectorTypeSerdeManager, "connectorMetadataUpdateHandleSerdeManager is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.pageSorter = requireNonNull(pageSorter, "pageSorter is null"); this.pageIndexerFactory = requireNonNull(pageIndexerFactory, "pageIndexerFactory is null"); this.nodeInfo = requireNonNull(nodeInfo, "nodeInfo is null"); @@ -185,6 +191,7 @@ public ConnectorManager( this.filterStatsCalculator = requireNonNull(filterStatsCalculator, "filterStatsCalculator is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.connectorSystemConfig = () -> featuresConfig.isNativeExecutionEnabled(); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); } @PreDestroy @@ -211,6 +218,12 @@ public synchronized void addConnectorFactory(ConnectorFactory connectorFactory) ConnectorFactory existingConnectorFactory = connectorFactories.putIfAbsent(connectorFactory.getName(), connectorFactory); checkArgument(existingConnectorFactory == null, "Connector %s is already registered", connectorFactory.getName()); handleResolver.addConnectorName(connectorFactory.getName(), connectorFactory.getHandleResolver()); + connectorFactory.getTableFunctionHandleResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionNamespace(connectorFactory.getName(), resolver); + }); + connectorFactory.getTableFunctionSplitResolver().ifPresent(resolver -> { + handleResolver.addTableFunctionSplitNamespace(connectorFactory.getName(), resolver); + }); } public synchronized ConnectorId createConnection(String catalogName, String connectorName, Map properties) @@ -218,10 +231,10 @@ public synchronized ConnectorId createConnection(String catalogName, String conn requireNonNull(connectorName, "connectorName is null"); ConnectorFactory connectorFactory = connectorFactories.get(connectorName); checkArgument(connectorFactory != null, "No factory for connector %s", connectorName); - return createConnection(catalogName, connectorFactory, properties); + return createConnection(catalogName, connectorFactory, properties, connectorName); } - private synchronized ConnectorId createConnection(String catalogName, ConnectorFactory connectorFactory, Map properties) + private synchronized ConnectorId createConnection(String catalogName, ConnectorFactory connectorFactory, Map properties, String connectorName) { checkState(!stopped.get(), "ConnectorManager is stopped"); requireNonNull(catalogName, "catalogName is null"); @@ -232,12 +245,12 @@ private synchronized ConnectorId createConnection(String catalogName, ConnectorF ConnectorId connectorId = new ConnectorId(catalogName); checkState(!connectors.containsKey(connectorId), "A connector %s already exists", connectorId); - addCatalogConnector(catalogName, connectorId, connectorFactory, properties); + addCatalogConnector(catalogName, connectorId, connectorFactory, properties, connectorName); return connectorId; } - private synchronized void addCatalogConnector(String catalogName, ConnectorId connectorId, ConnectorFactory factory, Map properties) + private synchronized void addCatalogConnector(String catalogName, ConnectorId connectorId, ConnectorFactory factory, Map properties, String connectorName) { // create all connectors before adding, so a broken connector does not leave the system half updated MaterializedConnector connector = new MaterializedConnector(connectorId, createConnector(connectorId, factory, properties)); @@ -272,7 +285,8 @@ private synchronized void addCatalogConnector(String catalogName, ConnectorId co informationSchemaConnector.getConnectorId(), informationSchemaConnector.getConnector(), systemConnector.getConnectorId(), - systemConnector.getConnector()); + systemConnector.getConnector(), + connectorName); try { addConnectorInternal(connector); @@ -312,25 +326,25 @@ private synchronized void addConnectorInternal(MaterializedConnector connector) connector.getPlanOptimizerProvider() .ifPresent(planOptimizerProvider -> connectorPlanOptimizerManager.addPlanOptimizerProvider(connectorId, planOptimizerProvider)); } - - connector.getMetadataUpdaterProvider() - .ifPresent(metadataUpdaterProvider -> connectorMetadataUpdaterManager.addMetadataUpdaterProvider(connectorId, metadataUpdaterProvider)); - - connector.getConnectorTypeSerdeProvider() - .ifPresent( - connectorTypeSerdeProvider -> - connectorTypeSerdeManager.addConnectorTypeSerdeProvider(connectorId, connectorTypeSerdeProvider)); - - metadataManager.getProcedureRegistry().addProcedures(connectorId, connector.getProcedures()); + connector.getConnectorCodecProvider().ifPresent(connectorCodecProvider -> connectorCodecManager.addConnectorCodecProvider(connectorId, connectorCodecProvider)); + metadataManager.getProcedureRegistry().addProcedures(connectorId, + connector.getProcedures()); + Set> systemFunctions = connector.getSystemFunctions(); + if (!systemFunctions.isEmpty()) { + metadataManager.registerConnectorFunctions(connectorId.getCatalogName(), extractFunctions(systemFunctions, new CatalogSchemaName(connectorId.getCatalogName(), "system"))); + } connector.getAccessControl() .ifPresent(accessControl -> accessControlManager.addCatalogAccessControl(connectorId, accessControl)); metadataManager.getTablePropertyManager().addProperties(connectorId, connector.getTableProperties()); + metadataManager.getMaterializedViewPropertyManager().addProperties(connectorId, connector.getMaterializedViewProperties()); metadataManager.getColumnPropertyManager().addProperties(connectorId, connector.getColumnProperties()); metadataManager.getSchemaPropertyManager().addProperties(connectorId, connector.getSchemaProperties()); metadataManager.getAnalyzePropertyManager().addProperties(connectorId, connector.getAnalyzeProperties()); metadataManager.getSessionPropertyManager().addConnectorSessionProperties(connectorId, connector.getSessionProperties()); + metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions(connectorId, connector.getTableFunctions()); + metadataManager.getFunctionAndTypeManager().addTableFunctionProcessorProvider(connectorId, connector.getTableFunctionProcessorProvider()); } public synchronized void dropConnection(String catalogName) @@ -342,6 +356,8 @@ public synchronized void dropConnection(String catalogName) removeConnectorInternal(connectorId); removeConnectorInternal(createInformationSchemaConnectorId(connectorId)); removeConnectorInternal(createSystemTablesConnectorId(connectorId)); + metadataManager.getFunctionAndTypeManager().getTableFunctionRegistry().removeTableFunctions(connectorId); + metadataManager.getFunctionAndTypeManager().removeTableFunctionProcessorProvider(connectorId); }); } @@ -355,12 +371,12 @@ private synchronized void removeConnectorInternal(ConnectorId connectorId) metadataManager.getProcedureRegistry().removeProcedures(connectorId); accessControlManager.removeCatalogAccessControl(connectorId); metadataManager.getTablePropertyManager().removeProperties(connectorId); + metadataManager.getMaterializedViewPropertyManager().removeProperties(connectorId); metadataManager.getColumnPropertyManager().removeProperties(connectorId); metadataManager.getSchemaPropertyManager().removeProperties(connectorId); metadataManager.getAnalyzePropertyManager().removeProperties(connectorId); metadataManager.getSessionPropertyManager().removeConnectorSessionProperties(connectorId); connectorPlanOptimizerManager.removePlanOptimizerProvider(connectorId); - connectorMetadataUpdaterManager.removeMetadataUpdaterProvider(connectorId); MaterializedConnector materializedConnector = connectors.remove(connectorId); if (materializedConnector != null) { @@ -379,6 +395,7 @@ private Connector createConnector(ConnectorId connectorId, ConnectorFactory fact ConnectorContext context = new ConnectorContextInstance( new ConnectorAwareNodeManager(nodeManager, nodeInfo.getEnvironment(), connectorId), typeManager, + procedureRegistry, metadataManager.getFunctionAndTypeManager(), new FunctionResolution(metadataManager.getFunctionAndTypeManager().getFunctionAndTypeResolver()), pageSorter, @@ -398,23 +415,37 @@ private Connector createConnector(ConnectorId connectorId, ConnectorFactory fact } } + public Optional getConnectorCodecProvider(ConnectorId connectorId) + { + requireNonNull(connectorId, "connectorId is null"); + MaterializedConnector materializedConnector = connectors.get(connectorId); + if (materializedConnector == null) { + return Optional.empty(); + } + return materializedConnector.getConnectorCodecProvider(); + } + private static class MaterializedConnector { private final ConnectorId connectorId; private final Connector connector; private final ConnectorSplitManager splitManager; private final Set systemTables; - private final Set procedures; + private final Set> procedures; + + private final Set> functions; + private final Set connectorTableFunctions; + private final Function connectorTableFunctionProcessorProvider; private final ConnectorPageSourceProvider pageSourceProvider; private final Optional pageSinkProvider; private final Optional indexProvider; private final Optional partitioningProvider; private final Optional planOptimizerProvider; - private final Optional metadataUpdaterProvider; - private final Optional connectorTypeSerdeProvider; + private final Optional connectorCodecProvider; private final Optional accessControl; private final List> sessionProperties; private final List> tableProperties; + private final List> materializedViewProperties; private final List> schemaProperties; private final List> columnProperties; private final List> analyzeProperties; @@ -431,9 +462,19 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) requireNonNull(systemTables, "Connector %s returned a null system tables set"); this.systemTables = ImmutableSet.copyOf(systemTables); + ImmutableSet.Builder> proceduresBuilder = ImmutableSet.builder(); Set procedures = connector.getProcedures(); requireNonNull(procedures, "Connector %s returned a null procedures set"); - this.procedures = ImmutableSet.copyOf(procedures); + proceduresBuilder.addAll(procedures); + Set distributedProcedures = connector.getDistributedProcedures(); + requireNonNull(distributedProcedures, "Connector %s returned a null distributedProcedures set"); + proceduresBuilder.addAll(distributedProcedures); + this.procedures = ImmutableSet.copyOf(proceduresBuilder.build()); + + Set connectorTableFunctions = connector.getTableFunctions(); + requireNonNull(connectorTableFunctions, format("Connector '%s' returned a null table functions set", connectorId)); + this.connectorTableFunctions = ImmutableSet.copyOf(connectorTableFunctions); + this.connectorTableFunctionProcessorProvider = connector.getTableFunctionProcessorProvider(); ConnectorPageSourceProvider connectorPageSourceProvider = null; try { @@ -492,23 +533,14 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) } this.planOptimizerProvider = Optional.ofNullable(planOptimizerProvider); - ConnectorMetadataUpdaterProvider metadataUpdaterProvider = null; + ConnectorCodecProvider connectorCodecProvider = null; try { - metadataUpdaterProvider = connector.getConnectorMetadataUpdaterProvider(); - requireNonNull(metadataUpdaterProvider, format("Connector %s returned null metadata updater provider", connectorId)); + connectorCodecProvider = connector.getConnectorCodecProvider(); + requireNonNull(connectorCodecProvider, format("Connector %s returned null connector specific codec provider", connectorId)); } catch (UnsupportedOperationException ignored) { } - this.metadataUpdaterProvider = Optional.ofNullable(metadataUpdaterProvider); - - ConnectorTypeSerdeProvider connectorTypeSerdeProvider = null; - try { - connectorTypeSerdeProvider = connector.getConnectorTypeSerdeProvider(); - requireNonNull(connectorTypeSerdeProvider, format("Connector %s returned null connector type serde provider", connectorId)); - } - catch (UnsupportedOperationException ignored) { - } - this.connectorTypeSerdeProvider = Optional.ofNullable(connectorTypeSerdeProvider); + this.connectorCodecProvider = Optional.ofNullable(connectorCodecProvider); ConnectorAccessControl accessControl = null; try { @@ -526,6 +558,10 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) requireNonNull(tableProperties, "Connector %s returned a null table properties set"); this.tableProperties = ImmutableList.copyOf(tableProperties); + List> materializedViewProperties = connector.getMaterializedViewProperties(); + requireNonNull(materializedViewProperties, "Connector %s returned a null materialized view properties set"); + this.materializedViewProperties = ImmutableList.copyOf(materializedViewProperties); + List> schemaProperties = connector.getSchemaProperties(); requireNonNull(schemaProperties, "Connector %s returned a null schema properties set"); this.schemaProperties = ImmutableList.copyOf(schemaProperties); @@ -537,6 +573,10 @@ public MaterializedConnector(ConnectorId connectorId, Connector connector) List> analyzeProperties = connector.getAnalyzeProperties(); requireNonNull(analyzeProperties, "Connector %s returned a null analyze properties set"); this.analyzeProperties = ImmutableList.copyOf(analyzeProperties); + + Set> systemFunctions = connector.getSystemFunctions(); + requireNonNull(systemFunctions, "Connector %s returned a null system function set"); + this.functions = ImmutableSet.copyOf(systemFunctions); } public ConnectorId getConnectorId() @@ -559,11 +599,23 @@ public Set getSystemTables() return systemTables; } - public Set getProcedures() + public > Set getProcedures(Class targetClz) + { + return procedures.stream().filter(targetClz::isInstance) + .map(targetClz::cast) + .collect(toImmutableSet()); + } + + public Set> getProcedures() { return procedures; } + public Set> getSystemFunctions() + { + return functions; + } + public ConnectorPageSourceProvider getPageSourceProvider() { return pageSourceProvider; @@ -589,16 +641,6 @@ public Optional getPlanOptimizerProvider() return planOptimizerProvider; } - public Optional getMetadataUpdaterProvider() - { - return metadataUpdaterProvider; - } - - public Optional getConnectorTypeSerdeProvider() - { - return connectorTypeSerdeProvider; - } - public Optional getAccessControl() { return accessControl; @@ -614,6 +656,11 @@ public List> getTableProperties() return tableProperties; } + public List> getMaterializedViewProperties() + { + return materializedViewProperties; + } + public List> getColumnProperties() { return columnProperties; @@ -628,5 +675,20 @@ public List> getAnalyzeProperties() { return analyzeProperties; } + + public Optional getConnectorCodecProvider() + { + return connectorCodecProvider; + } + + public Set getTableFunctions() + { + return connectorTableFunctions; + } + + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionProcessorProvider; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorTypeSerdeManager.java b/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorTypeSerdeManager.java deleted file mode 100644 index d581e5b8a09d6..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/ConnectorTypeSerdeManager.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.connector; - -import com.facebook.presto.server.ForJsonMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import com.facebook.presto.spi.connector.ConnectorTypeSerdeProvider; -import com.google.inject.Inject; - -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public class ConnectorTypeSerdeManager -{ - private final Map connectorTypeSerdeProviderMap = new ConcurrentHashMap<>(); - private final ConnectorTypeSerde connectorMetadataUpdateHandleJsonSerde; - - @Inject - public ConnectorTypeSerdeManager(@ForJsonMetadataUpdateHandle ConnectorTypeSerde connectorMetadataUpdateHandleJsonSerde) - { - this.connectorMetadataUpdateHandleJsonSerde = requireNonNull(connectorMetadataUpdateHandleJsonSerde, "connectorMetadataUpdateHandleJsonSerde is null"); - } - - public void addConnectorTypeSerdeProvider(ConnectorId connectorId, ConnectorTypeSerdeProvider connectorTypeSerdeProvider) - { - requireNonNull(connectorId, "connectorId is null"); - requireNonNull(connectorTypeSerdeProvider, "connectorTypeSerdeProvider is null"); - checkArgument( - connectorTypeSerdeProviderMap.putIfAbsent(connectorId, connectorTypeSerdeProvider) == null, - "ConnectorMetadataUpdateHandleSerdeProvider for connector '%s' is already registered", connectorId); - } - - public void removeConnectorTypeSerdeProvider(ConnectorId connectorId) - { - requireNonNull(connectorId, "connectorId is null"); - connectorTypeSerdeProviderMap.remove(connectorId); - } - - public ConnectorTypeSerde getMetadataUpdateHandleSerde(ConnectorId connectorId) - { - requireNonNull(connectorId, "connectorId is null"); - return Optional.ofNullable(connectorTypeSerdeProviderMap.get(connectorId)) - .map(ConnectorTypeSerdeProvider::getConnectorMetadataUpdateHandleSerde) - .orElse(connectorMetadataUpdateHandleJsonSerde); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaMetadata.java b/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaMetadata.java index 9a5deaab70716..5867793cdae9e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaMetadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaMetadata.java @@ -49,11 +49,13 @@ import java.util.function.Predicate; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.getMaxPrefixesCount; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.metadata.MetadataUtil.SchemaMetadataBuilder.schemaMetadataBuilder; import static com.facebook.presto.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder; import static com.facebook.presto.metadata.MetadataUtil.findColumnMetadata; +import static com.facebook.presto.metadata.QualifiedTablePrefix.toQualifiedTablePrefix; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Predicates.compose; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -72,6 +74,7 @@ public class InformationSchemaMetadata public static final SchemaTableName TABLE_COLUMNS = new SchemaTableName(INFORMATION_SCHEMA, "columns"); public static final SchemaTableName TABLE_TABLES = new SchemaTableName(INFORMATION_SCHEMA, "tables"); public static final SchemaTableName TABLE_VIEWS = new SchemaTableName(INFORMATION_SCHEMA, "views"); + public static final SchemaTableName TABLE_MATERIALIZED_VIEWS = new SchemaTableName(INFORMATION_SCHEMA, "materialized_views"); public static final SchemaTableName TABLE_SCHEMATA = new SchemaTableName(INFORMATION_SCHEMA, "schemata"); public static final SchemaTableName TABLE_TABLE_PRIVILEGES = new SchemaTableName(INFORMATION_SCHEMA, "table_privileges"); public static final SchemaTableName TABLE_ROLES = new SchemaTableName(INFORMATION_SCHEMA, "roles"); @@ -90,6 +93,9 @@ public class InformationSchemaMetadata .column("data_type", createUnboundedVarcharType()) .column("comment", createUnboundedVarcharType()) .column("extra_info", createUnboundedVarcharType()) + .column("precision", BIGINT) + .column("scale", BIGINT) + .column("length", BIGINT) .build()) .table(tableMetadataBuilder(TABLE_TABLES) .column("table_catalog", createUnboundedVarcharType()) @@ -104,6 +110,18 @@ public class InformationSchemaMetadata .column("view_owner", createUnboundedVarcharType()) .column("view_definition", createUnboundedVarcharType()) .build()) + .table(tableMetadataBuilder(TABLE_MATERIALIZED_VIEWS) + .column("table_catalog", createUnboundedVarcharType()) + .column("table_schema", createUnboundedVarcharType()) + .column("table_name", createUnboundedVarcharType()) + .column("view_definition", createUnboundedVarcharType()) + .column("view_owner", createUnboundedVarcharType()) + .column("view_security", createUnboundedVarcharType()) + .column("storage_schema", createUnboundedVarcharType()) + .column("storage_table_name", createUnboundedVarcharType()) + .column("base_tables", createUnboundedVarcharType()) + .column("freshness_state", createUnboundedVarcharType()) + .build()) .table(tableMetadataBuilder(TABLE_SCHEMATA) .column("catalog_name", createUnboundedVarcharType()) .column("schema_name", createUnboundedVarcharType()) @@ -137,7 +155,6 @@ public class InformationSchemaMetadata private static final InformationSchemaColumnHandle CATALOG_COLUMN_HANDLE = new InformationSchemaColumnHandle("table_catalog"); private static final InformationSchemaColumnHandle SCHEMA_COLUMN_HANDLE = new InformationSchemaColumnHandle("table_schema"); private static final InformationSchemaColumnHandle TABLE_NAME_COLUMN_HANDLE = new InformationSchemaColumnHandle("table_name"); - private static final int MAX_PREFIXES_COUNT = 100; private final String catalogName; private final Metadata metadata; @@ -233,16 +250,17 @@ public Map> listTableColumns(ConnectorSess public ConnectorTableLayoutResult getTableLayoutForConstraint(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { InformationSchemaTableHandle handle = checkTableHandle(table); + int maxPrefixesCount = getMaxPrefixesCount(((FullConnectorSession) session).getSession()); Set prefixes = calculatePrefixesWithSchemaName(session, constraint.getSummary(), constraint.predicate()); if (isTablesEnumeratingTable(handle.getSchemaTableName())) { Set tablePrefixes = calculatePrefixesWithTableName(session, prefixes, constraint.getSummary(), constraint.predicate()); // in case of high number of prefixes it is better to populate all data and then filter - if (tablePrefixes.size() <= MAX_PREFIXES_COUNT) { + if (tablePrefixes.size() <= maxPrefixesCount) { prefixes = tablePrefixes; } } - if (prefixes.size() > MAX_PREFIXES_COUNT) { + if (prefixes.size() > maxPrefixesCount) { // in case of high number of prefixes it is better to populate all data and then filter prefixes = ImmutableSet.of(new QualifiedTablePrefix(catalogName)); } @@ -253,7 +271,7 @@ public ConnectorTableLayoutResult getTableLayoutForConstraint(ConnectorSession s private boolean isTablesEnumeratingTable(SchemaTableName schemaTableName) { - return ImmutableSet.of(TABLE_COLUMNS, TABLE_VIEWS, TABLE_TABLES, TABLE_TABLE_PRIVILEGES).contains(schemaTableName); + return ImmutableSet.of(TABLE_COLUMNS, TABLE_VIEWS, TABLE_MATERIALIZED_VIEWS, TABLE_TABLES, TABLE_TABLE_PRIVILEGES).contains(schemaTableName); } private Set calculatePrefixesWithSchemaName( @@ -264,7 +282,6 @@ private Set calculatePrefixesWithSchemaName( Optional> schemas = filterString(constraint, SCHEMA_COLUMN_HANDLE); if (schemas.isPresent()) { return schemas.get().stream() - .filter(this::isLowerCase) .map(schema -> new QualifiedTablePrefix(catalogName, schema)) .collect(toImmutableSet()); } @@ -289,20 +306,26 @@ public Set calculatePrefixesWithTableName( if (tables.isPresent()) { return prefixes.stream() .flatMap(prefix -> tables.get().stream() - .filter(this::isLowerCase) - .map(table -> table.toLowerCase(ENGLISH)) .map(table -> new QualifiedObjectName(catalogName, prefix.getSchemaName().get(), table))) .filter(objectName -> metadataResolver.getView(objectName).isPresent() || metadataResolver.getTableHandle(objectName).isPresent()) - .map(QualifiedTablePrefix::toQualifiedTablePrefix) + .map(value -> toQualifiedTablePrefix(new QualifiedObjectName( + value.getCatalogName(), + metadata.normalizeIdentifier(session, value.getCatalogName(), value.getSchemaName()), + metadata.normalizeIdentifier(session, value.getCatalogName(), value.getObjectName())))) .collect(toImmutableSet()); } return prefixes.stream() .flatMap(prefix -> Stream.concat( - metadata.listTables(session, prefix).stream(), - metadata.listViews(session, prefix).stream())) + Stream.concat( + metadata.listTables(session, prefix).stream(), + metadata.listViews(session, prefix).stream()), + metadata.listMaterializedViews(session, prefix).stream())) .filter(objectName -> !predicate.isPresent() || predicate.get().test(asFixedValues(objectName))) - .map(QualifiedTablePrefix::toQualifiedTablePrefix) + .map(value -> toQualifiedTablePrefix(new QualifiedObjectName( + value.getCatalogName(), + metadata.normalizeIdentifier(session, value.getCatalogName(), value.getSchemaName()), + metadata.normalizeIdentifier(session, value.getCatalogName(), value.getObjectName())))) .collect(toImmutableSet()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaPageSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaPageSourceProvider.java index 7de4e65f20f7f..2c3931035291e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaPageSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/informationSchema/InformationSchemaPageSourceProvider.java @@ -18,6 +18,11 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.metadata.InternalTable; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.QualifiedTablePrefix; @@ -27,6 +32,8 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.FixedPageSource; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.analyzer.ViewDefinition; @@ -44,15 +51,24 @@ import java.util.Map.Entry; import java.util.Set; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.common.type.TinyintType.TINYINT; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_APPLICABLE_ROLES; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_COLUMNS; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_ENABLED_ROLES; +import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_MATERIALIZED_VIEWS; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_ROLES; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_SCHEMATA; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_TABLES; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_TABLE_PRIVILEGES; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.TABLE_VIEWS; import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.informationSchemaTableColumns; +import static com.facebook.presto.connector.system.jdbc.ColumnJdbcTable.decimalDigits; +import static com.facebook.presto.metadata.MetadataListing.listMaterializedViews; import static com.facebook.presto.metadata.MetadataListing.listSchemas; import static com.facebook.presto.metadata.MetadataListing.listTableColumns; import static com.facebook.presto.metadata.MetadataListing.listTablePrivileges; @@ -127,6 +143,9 @@ public InternalTable getInformationSchemaTable(Session session, String catalog, if (table.equals(TABLE_VIEWS)) { return buildViews(session, prefixes); } + if (table.equals(TABLE_MATERIALIZED_VIEWS)) { + return buildMaterializedViews(session, prefixes); + } if (table.equals(TABLE_SCHEMATA)) { return buildSchemata(session, catalog); } @@ -167,7 +186,10 @@ private InternalTable buildColumns(Session session, Set pr column.isNullable() ? "YES" : "NO", column.getType().getDisplayName(), column.getComment().orElse(null), - column.getExtraInfo().orElse(null)); + column.getExtraInfo().orElse(null), + getPrecision(column.getType()), + decimalDigits(column.getType()), + getCharVarcharLength(column.getType())); ordinalPosition++; } } @@ -181,10 +203,19 @@ private InternalTable buildTables(Session session, Set pre for (QualifiedTablePrefix prefix : prefixes) { Set tables = listTables(session, metadata, accessControl, prefix); Set views = listViews(session, metadata, accessControl, prefix); + Set materializedViews = listMaterializedViews(session, metadata, accessControl, prefix); - for (SchemaTableName name : union(tables, views)) { - // if table and view names overlap, the view wins - String type = views.contains(name) ? "VIEW" : "BASE TABLE"; + for (SchemaTableName name : union(union(tables, views), materializedViews)) { + String type; + if (materializedViews.contains(name)) { + type = "MATERIALIZED VIEW"; + } + else if (views.contains(name)) { + type = "VIEW"; + } + else { + type = "BASE TABLE"; + } table.add( prefix.getCatalogName(), name.getSchemaName(), @@ -233,6 +264,39 @@ private InternalTable buildViews(Session session, Set pref return table.build(); } + private InternalTable buildMaterializedViews(Session session, Set prefixes) + { + InternalTable.Builder table = InternalTable.builder(informationSchemaTableColumns(TABLE_MATERIALIZED_VIEWS)); + + for (QualifiedTablePrefix prefix : prefixes) { + for (Entry entry : metadata.getMaterializedViews(session, prefix).entrySet()) { + QualifiedObjectName viewName = entry.getKey(); + MaterializedViewDefinition definition = entry.getValue(); + + String baseTablesStr = definition.getBaseTables().stream() + .map(baseTable -> viewName.getCatalogName() + "." + baseTable.getSchemaName() + "." + baseTable.getTableName()) + .collect(java.util.stream.Collectors.joining(", ")); + + MaterializedViewStatus status = metadata.getMaterializedViewStatus(session, viewName, TupleDomain.all()); + String freshnessState = status.getMaterializedViewState().name(); + + table.add( + viewName.getCatalogName(), + viewName.getSchemaName(), + viewName.getObjectName(), + definition.getOriginalSql(), + definition.getOwner().orElse(null), + definition.getSecurityMode().map(Object::toString).orElse(null), + definition.getSchema(), + definition.getTable(), + baseTablesStr, + freshnessState); + } + } + + return table.build(); + } + private InternalTable buildSchemata(Session session, String catalogName) { InternalTable.Builder table = InternalTable.builder(informationSchemaTableColumns(TABLE_SCHEMATA)); @@ -281,4 +345,40 @@ private InternalTable buildEnabledRoles(Session session, String catalog) } return table.build(); } + + public static Integer getCharVarcharLength(Type type) + { + if (type instanceof VarcharType) { + return ((VarcharType) type).getLength(); + } + if (type instanceof CharType) { + return (((CharType) type).getLength()); + } + return null; + } + public static Integer getPrecision(Type type) + { + if (type.equals(BIGINT)) { + return 19; // 2**63-1 + } + if (type.equals(INTEGER)) { + return 10; // 2**31-1 + } + if (type.equals(SMALLINT)) { + return 5; // 2**15-1 + } + if (type.equals(TINYINT)) { + return 3; // 2**7-1 + } + if (type instanceof DecimalType) { + return ((DecimalType) type).getPrecision(); + } + if (type.equals(REAL)) { + return 24; // IEEE 754 + } + if (type.equals(DOUBLE)) { + return 53; // IEEE 754 + } + return null; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/AnalyzePropertiesSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/AnalyzePropertiesSystemTable.java index 6377ada289bc6..cbafed277479d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/AnalyzePropertiesSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/AnalyzePropertiesSystemTable.java @@ -15,8 +15,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.transaction.TransactionManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class AnalyzePropertiesSystemTable extends AbstractPropertiesSystemTable diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/CatalogSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/CatalogSystemTable.java index c604b2cfe9cec..e8fbcdf6a2303 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/CatalogSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/CatalogSystemTable.java @@ -15,8 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.InMemoryRecordSet; @@ -26,14 +26,13 @@ import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.AccessControl; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Map; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.connector.system.SystemConnectorSessionUtil.toSession; -import static com.facebook.presto.metadata.MetadataListing.listCatalogs; +import static com.facebook.presto.metadata.MetadataListing.listCatalogsWithConnectorContext; import static com.facebook.presto.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder; import static com.facebook.presto.spi.SystemTable.Distribution.SINGLE_COORDINATOR; import static java.util.Objects.requireNonNull; @@ -46,6 +45,7 @@ public class CatalogSystemTable public static final ConnectorTableMetadata CATALOG_TABLE = tableMetadataBuilder(CATALOG_TABLE_NAME) .column("catalog_name", createUnboundedVarcharType()) .column("connector_id", createUnboundedVarcharType()) + .column("connector_name", createUnboundedVarcharType()) .build(); private final Metadata metadata; private final AccessControl accessControl; @@ -74,8 +74,8 @@ public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, Connect { Session session = toSession(transactionHandle, connectorSession); Builder table = InMemoryRecordSet.builder(CATALOG_TABLE); - for (Map.Entry entry : listCatalogs(session, metadata, accessControl).entrySet()) { - table.addRow(entry.getKey(), entry.getValue().toString()); + for (Map.Entry entry : listCatalogsWithConnectorContext(session, metadata, accessControl).entrySet()) { + table.addRow(entry.getKey(), entry.getValue().getCatalogName(), entry.getValue().getConnectorName()); } return table.build().cursor(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/ColumnPropertiesSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/ColumnPropertiesSystemTable.java index cbec3f4c7ccea..44a9788a6d5e1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/ColumnPropertiesSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/ColumnPropertiesSystemTable.java @@ -15,8 +15,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.transaction.TransactionManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class ColumnPropertiesSystemTable extends AbstractPropertiesSystemTable diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java index 860f6e47a6285..708c3f314cc48 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnector.java @@ -15,11 +15,14 @@ import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableLayout; import com.facebook.presto.spi.ConnectorTableLayoutHandle; @@ -34,6 +37,9 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.transaction.IsolationLevel; import com.facebook.presto.transaction.InternalConnector; @@ -45,7 +51,9 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; +import static com.facebook.presto.operator.table.Sequence.getSequenceFunctionSplitSource; import static java.util.Objects.requireNonNull; public class GlobalSystemConnector @@ -56,12 +64,14 @@ public class GlobalSystemConnector private final String connectorId; private final Set systemTables; private final Set procedures; + private final Set tableFunctions; - public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures) + public GlobalSystemConnector(String connectorId, Set systemTables, Set procedures, Set tableFunctions) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.systemTables = ImmutableSet.copyOf(requireNonNull(systemTables, "systemTables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); } @Override @@ -138,8 +148,22 @@ public Map> listTableColumns(ConnectorSess @Override public ConnectorSplitManager getSplitManager() { - return (transactionHandle, session, layout, splitSchedulingContext) -> { - throw new UnsupportedOperationException(); + return new ConnectorSplitManager() { + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle function) + { + if (function instanceof Sequence.SequenceFunctionHandle) { + Sequence.SequenceFunctionHandle sequenceFunctionHandle = (Sequence.SequenceFunctionHandle) function; + return getSequenceFunctionSplitSource(sequenceFunctionHandle); + } + throw new UnsupportedOperationException(); + } }; } @@ -166,4 +190,24 @@ public Set getProcedures() { return procedures; } + + @Override + public Set getTableFunctions() + { + return tableFunctions; + } + + @Override + public Function getTableFunctionProcessorProvider() + { + return connectorTableFunctionHandle -> { + if (connectorTableFunctionHandle instanceof ExcludeColumns.ExcludeColumnsFunctionHandle) { + return ExcludeColumns.getExcludeColumnsFunctionProcessorProvider(); + } + else if (connectorTableFunctionHandle instanceof Sequence.SequenceFunctionHandle) { + return Sequence.getSequenceFunctionProcessorProvider(); + } + return null; + }; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java index 07cd6ad48a90c..223684418a819 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/GlobalSystemConnectorFactory.java @@ -18,10 +18,10 @@ import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorContext; import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Map; import java.util.Set; @@ -33,12 +33,14 @@ public class GlobalSystemConnectorFactory { private final Set tables; private final Set procedures; + private final Set tableFunctions; @Inject - public GlobalSystemConnectorFactory(Set tables, Set procedures) + public GlobalSystemConnectorFactory(Set tables, Set procedures, Set tableFunctions) { this.tables = ImmutableSet.copyOf(requireNonNull(tables, "tables is null")); this.procedures = ImmutableSet.copyOf(requireNonNull(procedures, "procedures is null")); + this.tableFunctions = ImmutableSet.copyOf(requireNonNull(tableFunctions, "tableFunctions is null")); } @Override @@ -56,6 +58,6 @@ public ConnectorHandleResolver getHandleResolver() @Override public Connector create(String catalogName, Map config, ConnectorContext context) { - return new GlobalSystemConnector(catalogName, tables, procedures); + return new GlobalSystemConnector(catalogName, tables, procedures, tableFunctions); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java index d3eb3758223d4..e236e51c43b33 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/KillQueryProcedure.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.procedure.Procedure; import com.facebook.presto.spi.procedure.Procedure.Argument; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.lang.invoke.MethodHandle; import java.util.NoSuchElementException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/NodeSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/NodeSystemTable.java index 7495be91e69dc..9e989f6ada3ed 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/NodeSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/NodeSystemTable.java @@ -26,8 +26,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Locale; import java.util.Set; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/QuerySystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/QuerySystemTable.java index dee2de8039720..48b7f53a0db87 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/QuerySystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/QuerySystemTable.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.connector.system; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.predicate.TupleDomain; @@ -30,9 +31,7 @@ import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.NoSuchElementException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SchemaPropertiesSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SchemaPropertiesSystemTable.java index be9791fc718ca..b52c36a3f1cb6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SchemaPropertiesSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SchemaPropertiesSystemTable.java @@ -15,8 +15,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.transaction.TransactionManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class SchemaPropertiesSystemTable extends AbstractPropertiesSystemTable diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java index 8afd2f7dadea8..40c974f6e7b90 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorModule.java @@ -27,7 +27,10 @@ import com.facebook.presto.connector.system.jdbc.TableTypeJdbcTable; import com.facebook.presto.connector.system.jdbc.TypesJdbcTable; import com.facebook.presto.connector.system.jdbc.UdtJdbcTable; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.SystemTable; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; import com.facebook.presto.spi.procedure.Procedure; import com.google.common.collect.ImmutableMap; import com.google.inject.Binder; @@ -36,8 +39,7 @@ import com.google.inject.multibindings.Multibinder; import com.google.inject.multibindings.MultibindingsScanner; import com.google.inject.multibindings.ProvidesIntoSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class SystemConnectorModule implements Module @@ -78,6 +80,10 @@ public void configure(Binder binder) binder.bind(GlobalSystemConnectorFactory.class).in(Scopes.SINGLETON); binder.bind(SystemConnectorRegistrar.class).asEagerSingleton(); + + Multibinder tableFunctions = Multibinder.newSetBinder(binder, ConnectorTableFunction.class); + tableFunctions.addBinding().toProvider(ExcludeColumns.class).in(Scopes.SINGLETON); + tableFunctions.addBinding().toProvider(Sequence.class).in(Scopes.SINGLETON); } @ProvidesIntoSet diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorSessionUtil.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorSessionUtil.java index 19cc00b0a7065..405b3f4cceadf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorSessionUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemConnectorSessionUtil.java @@ -21,6 +21,10 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.ConnectorIdentity; import com.facebook.presto.spi.security.Identity; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; @@ -35,7 +39,7 @@ public static Session toSession(ConnectorTransactionHandle transactionHandle, Co { TransactionId transactionId = ((GlobalSystemTransactionHandle) transactionHandle).getTransactionId(); ConnectorIdentity connectorIdentity = session.getIdentity(); - Identity identity = new Identity(connectorIdentity.getUser(), connectorIdentity.getPrincipal(), connectorIdentity.getExtraCredentials()); + Identity identity = new Identity(connectorIdentity.getUser(), connectorIdentity.getPrincipal(), ImmutableMap.of(), connectorIdentity.getExtraCredentials(), ImmutableMap.of(), Optional.empty(), connectorIdentity.getReasonForSelect(), ImmutableList.of()); return Session.builder(createTestingSessionPropertyManager(SYSTEM_SESSION_PROPERTIES)) .setQueryId(new QueryId(session.getQueryId())) .setTransactionId(transactionId) diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemTableHandle.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemTableHandle.java index 56fbab3b9173b..71132d30d5045 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemTableHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/SystemTableHandle.java @@ -21,8 +21,6 @@ import java.util.Objects; -import static com.facebook.presto.metadata.MetadataUtil.checkSchemaName; -import static com.facebook.presto.metadata.MetadataUtil.checkTableName; import static java.util.Objects.requireNonNull; public class SystemTableHandle @@ -39,8 +37,8 @@ public SystemTableHandle( @JsonProperty("tableName") String tableName) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); - this.schemaName = checkSchemaName(schemaName); - this.tableName = checkTableName(tableName); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); + this.tableName = requireNonNull(tableName, "tableName is null"); } public static SystemTableHandle fromSchemaTableName(ConnectorId connectorId, SchemaTableName tableName) diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TablePropertiesSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TablePropertiesSystemTable.java index 6df37eaf77504..637cfa56b51c2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TablePropertiesSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TablePropertiesSystemTable.java @@ -15,8 +15,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.transaction.TransactionManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class TablePropertiesSystemTable extends AbstractPropertiesSystemTable diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TaskSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TaskSystemTable.java index 6d50a6444b370..f2f168281308c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TaskSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TaskSystemTable.java @@ -14,6 +14,8 @@ package com.facebook.presto.connector.system; import com.facebook.airlift.node.NodeInfo; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskManager; @@ -27,10 +29,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TransactionsSystemTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TransactionsSystemTable.java index 83140c0ec7ff1..b2a7a80244da1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/TransactionsSystemTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/TransactionsSystemTable.java @@ -31,8 +31,7 @@ import com.facebook.presto.transaction.TransactionInfo; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/CatalogJdbcTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/CatalogJdbcTable.java index d61b245628b6c..c5103e4525149 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/CatalogJdbcTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/CatalogJdbcTable.java @@ -24,8 +24,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.AccessControl; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; import static com.facebook.presto.connector.system.SystemConnectorSessionUtil.toSession; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java index 1138d9746beab..3b91bf5ac8257 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/ColumnJdbcTable.java @@ -32,8 +32,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.AccessControl; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.sql.DatabaseMetaData; import java.sql.Types; @@ -273,7 +272,7 @@ static Integer columnSize(Type type) } // DECIMAL_DIGITS is the number of fractional digits - private static Integer decimalDigits(Type type) + public static Integer decimalDigits(Type type) { if (type instanceof DecimalType) { return ((DecimalType) type).getScale(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/SchemaJdbcTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/SchemaJdbcTable.java index bc3e77b066eb3..5d9212c8b863e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/SchemaJdbcTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/SchemaJdbcTable.java @@ -25,8 +25,7 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.AccessControl; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TableJdbcTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TableJdbcTable.java index e058ae6a10348..8bf503bc1f959 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TableJdbcTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TableJdbcTable.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.security.AccessControl; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; import java.util.Set; @@ -38,6 +37,7 @@ import static com.facebook.presto.connector.system.jdbc.FilterUtil.stringFilter; import static com.facebook.presto.connector.system.jdbc.FilterUtil.tablePrefix; import static com.facebook.presto.metadata.MetadataListing.listCatalogs; +import static com.facebook.presto.metadata.MetadataListing.listMaterializedViews; import static com.facebook.presto.metadata.MetadataListing.listTables; import static com.facebook.presto.metadata.MetadataListing.listViews; import static com.facebook.presto.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder; @@ -98,9 +98,17 @@ public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, Connect } } + Set materializedViews = ImmutableSet.of(); + if (FilterUtil.emptyOrEquals(typeFilter, "MATERIALIZED VIEW")) { + materializedViews = ImmutableSet.copyOf(listMaterializedViews(session, metadata, accessControl, prefix)); + for (SchemaTableName name : materializedViews) { + table.addRow(tableRow(catalog, name, "MATERIALIZED VIEW")); + } + } + if (FilterUtil.emptyOrEquals(typeFilter, "TABLE")) { for (SchemaTableName name : listTables(session, metadata, accessControl, prefix)) { - if (!views.contains(name)) { + if (!views.contains(name) && !materializedViews.contains(name)) { table.addRow(tableRow(catalog, name, "TABLE")); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TypesJdbcTable.java b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TypesJdbcTable.java index 5e330237a5faa..bce7fe113c532 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TypesJdbcTable.java +++ b/presto-main-base/src/main/java/com/facebook/presto/connector/system/jdbc/TypesJdbcTable.java @@ -24,8 +24,7 @@ import com.facebook.presto.spi.RecordCursor; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.sql.DatabaseMetaData; import java.sql.Types; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java b/presto-main-base/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java index 85a51c2b6bb51..e400481929d25 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/ConnectorFilterStatsCalculatorService.java @@ -81,6 +81,9 @@ else if (!tableStatistics.getTotalSize().isUnknown() double totalSizeAfterFilter = filteredStatistics.getRowCount().getValue() / tableStatistics.getRowCount().getValue() * tableStatistics.getTotalSize().getValue(); filteredStatsWithSize.setTotalSize(Estimate.of(totalSizeAfterFilter)); } + if (!tableStatistics.getParallelismFactor().isUnknown()) { + filteredStatsWithSize.setParallelismFactor(tableStatistics.getParallelismFactor()); + } return filteredStatsWithSize.setConfidenceLevel(LOW).build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculator.java index e681c26ce7f52..dce4afaf26b4e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -16,10 +16,9 @@ import com.facebook.presto.Session; import com.facebook.presto.spi.plan.PlanNode; +import com.google.errorprone.annotations.ThreadSafe; import com.google.inject.BindingAnnotation; -import javax.annotation.concurrent.ThreadSafe; - import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java index 18c84c8ad9dac..5ec55e6734eca 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorUsingExchanges.java @@ -42,9 +42,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.List; import java.util.Objects; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java index 16856a4755e18..ecf1b0b12aa0c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/CostCalculatorWithEstimatedExchanges.java @@ -28,9 +28,8 @@ import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.SequenceNode; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.Objects; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/CostComparator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/CostComparator.java index a394fdb3131cd..3dad25970abbd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/CostComparator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/CostComparator.java @@ -17,8 +17,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Ordering; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java index 368c12d2df43c..d01c5b6ba68f2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/FilterStatsCalculator.java @@ -57,9 +57,8 @@ import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.annotation.Nullable; -import javax.inject.Inject; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/FragmentStatsProvider.java b/presto-main-base/src/main/java/com/facebook/presto/cost/FragmentStatsProvider.java index a9ec314b0bf03..f8c8b79b7cdc4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/FragmentStatsProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/FragmentStatsProvider.java @@ -14,12 +14,12 @@ package com.facebook.presto.cost; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.plan.PlanFragmentId; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.inject.Inject; -import io.airlift.units.DataSize; import java.util.Objects; import java.util.stream.IntStream; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedOptimizationConfig.java b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedOptimizationConfig.java index 84dd8a0fe8353..f307eda8307d1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedOptimizationConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedOptimizationConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.presto.spi.function.Description; - -import javax.validation.constraints.Min; +import jakarta.validation.constraints.Min; public class HistoryBasedOptimizationConfig { diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java index 65cd001801876..a903dd9dce459 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsCalculator.java @@ -36,6 +36,7 @@ import java.util.function.Supplier; import static com.facebook.presto.SystemSessionProperties.enableVerboseHistoryBasedOptimizerRuntimeStats; +import static com.facebook.presto.SystemSessionProperties.estimateSizeUsingVariablesForHBO; import static com.facebook.presto.SystemSessionProperties.getHistoryBasedOptimizerTimeoutLimit; import static com.facebook.presto.SystemSessionProperties.getHistoryInputTableStatisticsMatchingThreshold; import static com.facebook.presto.SystemSessionProperties.isVerboseRuntimeStatsEnabled; @@ -203,7 +204,7 @@ private PlanNodeStatsEstimate getStatistics(PlanNode planNode, Session session, if ((toConfidenceLevel(predictedPlanStatistics.getConfidence()).getConfidenceOrdinal() >= delegateStats.confidenceLevel().getConfidenceOrdinal())) { return delegateStats.combineStats( predictedPlanStatistics, - new HistoryBasedSourceInfo(entry.getKey().getHash(), inputTableStatistics, Optional.ofNullable(historicalPlanStatisticsEntry.get().getHistoricalPlanStatisticsEntryInfo()))); + new HistoryBasedSourceInfo(entry.getKey().getHash(), inputTableStatistics, Optional.ofNullable(historicalPlanStatisticsEntry.get().getHistoricalPlanStatisticsEntryInfo()), estimateSizeUsingVariablesForHBO(session))); } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java index 2c02e64c85cde..12f7f96871402 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/HistoryBasedPlanStatisticsTracker.java @@ -59,12 +59,12 @@ import java.util.Set; import java.util.function.Supplier; +import static com.facebook.presto.SystemSessionProperties.estimateSizeUsingVariablesForHBO; import static com.facebook.presto.SystemSessionProperties.getHistoryBasedOptimizerTimeoutLimit; +import static com.facebook.presto.SystemSessionProperties.getQueryTypesEnabledForHBO; import static com.facebook.presto.SystemSessionProperties.trackHistoryBasedPlanStatisticsEnabled; import static com.facebook.presto.SystemSessionProperties.trackHistoryStatsFromFailedQuery; import static com.facebook.presto.SystemSessionProperties.trackPartialAggregationHistory; -import static com.facebook.presto.common.resourceGroups.QueryType.INSERT; -import static com.facebook.presto.common.resourceGroups.QueryType.SELECT; import static com.facebook.presto.cost.HistoricalPlanStatisticsUtil.updatePlanStatistics; import static com.facebook.presto.cost.HistoryBasedPlanStatisticsManager.historyBasedPlanCanonicalizationStrategyList; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; @@ -78,7 +78,6 @@ public class HistoryBasedPlanStatisticsTracker { private static final Logger LOG = Logger.get(HistoryBasedPlanStatisticsTracker.class); - private static final Set ALLOWED_QUERY_TYPES = ImmutableSet.of(SELECT, INSERT); private final Supplier historyBasedPlanStatisticsProvider; private final HistoryBasedStatisticsCacheManager historyBasedStatisticsCacheManager; @@ -139,7 +138,8 @@ public Map getQueryStats(QueryIn } // Only update statistics for SELECT/INSERT queries - if (!queryInfo.getQueryType().isPresent() || !ALLOWED_QUERY_TYPES.contains(queryInfo.getQueryType().get())) { + List queryTypesEnabled = getQueryTypesEnabledForHBO(session); + if (!queryInfo.getQueryType().isPresent() || !queryTypesEnabled.contains(queryInfo.getQueryType().get())) { return ImmutableMap.of(); } @@ -242,7 +242,7 @@ else if (trackStatsForFailedQueries) { PlanStatisticsWithSourceInfo planStatsWithSourceInfo = new PlanStatisticsWithSourceInfo( planNode.getId(), newPlanNodeStats, - new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics), Optional.of(historicalPlanStatisticsEntryInfo))); + new HistoryBasedSourceInfo(Optional.of(hash), Optional.of(inputTableStatistics), Optional.of(historicalPlanStatisticsEntryInfo), estimateSizeUsingVariablesForHBO(session))); planStatisticsMap.put(planNodeWithHash, planStatsWithSourceInfo); if (isAggregation(planNode, AggregationNode.Step.FINAL) && ((AggregationNode) planNode).getAggregationId().isPresent() && trackPartialAggregationHistory(session)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java index a25351ca7a88a..14ecc0aa325b0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -51,8 +51,7 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.type.TypeUtils; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Map; import java.util.OptionalDouble; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java b/presto-main-base/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java index 41716b3d1e687..7229d7fae23d9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/StatsCalculatorModule.java @@ -14,13 +14,13 @@ package com.facebook.presto.cost; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.google.common.collect.ImmutableList; import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; @@ -46,9 +46,10 @@ public static StatsCalculator createNewStatsCalculator( StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator, HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, - FragmentStatsProvider fragmentStatsProvider) + FragmentStatsProvider fragmentStatsProvider, + ExpressionOptimizerManager expressionOptimizerManager) { - StatsCalculator delegate = createComposableStatsCalculator(metadata, scalarStatsCalculator, normalizer, filterStatsCalculator, fragmentStatsProvider); + StatsCalculator delegate = createComposableStatsCalculator(metadata, scalarStatsCalculator, normalizer, filterStatsCalculator, fragmentStatsProvider, expressionOptimizerManager); return historyBasedPlanStatisticsManager.getHistoryBasedPlanStatisticsCalculator(delegate); } @@ -57,14 +58,15 @@ public static ComposableStatsCalculator createComposableStatsCalculator( ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer, FilterStatsCalculator filterStatsCalculator, - FragmentStatsProvider fragmentStatsProvider) + FragmentStatsProvider fragmentStatsProvider, + ExpressionOptimizerManager expressionOptimizerManager) { ImmutableList.Builder> rules = ImmutableList.builder(); rules.add(new OutputStatsRule()); rules.add(new TableScanStatsRule(metadata, normalizer)); rules.add(new SimpleFilterProjectSemiJoinStatsRule(normalizer, filterStatsCalculator, metadata.getFunctionAndTypeManager())); // this must be before FilterStatsRule rules.add(new FilterStatsRule(normalizer, filterStatsCalculator)); - rules.add(new ValuesStatsRule(metadata)); + rules.add(new ValuesStatsRule(metadata, expressionOptimizerManager)); rules.add(new LimitStatsRule(normalizer)); rules.add(new EnforceSingleRowStatsRule(normalizer)); rules.add(new ProjectStatsRule(scalarStatsCalculator, normalizer)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java b/presto-main-base/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java index 8e6228078068a..fab3516b7b857 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/TaskCountEstimator.java @@ -17,8 +17,7 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Set; import java.util.function.IntSupplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java b/presto-main-base/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java index 3ee92f0046d6c..bedd01008af6b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/UnnestStatsRule.java @@ -16,10 +16,10 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.ComposableStatsCalculator.Rule; import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.plan.UnnestNode; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java b/presto-main-base/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java index a7c67756c4622..1187d3dd5136b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/cost/ValuesStatsRule.java @@ -19,7 +19,9 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.iterative.Lookup; @@ -32,10 +34,12 @@ import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.cost.StatsUtil.toStatsRepresentation; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.FACT; -import static com.facebook.presto.sql.planner.RowExpressionInterpreter.evaluateConstantRowExpression; import static com.facebook.presto.sql.planner.plan.Patterns.values; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; public class ValuesStatsRule @@ -44,10 +48,12 @@ public class ValuesStatsRule private static final Pattern PATTERN = values(); private final Metadata metadata; + private final ExpressionOptimizerManager expressionOptimizerManager; - public ValuesStatsRule(Metadata metadata) + public ValuesStatsRule(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { - this.metadata = metadata; + this.metadata = requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); } @Override @@ -82,7 +88,10 @@ private List getVariableValues(ValuesNode valuesNode, int symbolId, Sess } return valuesNode.getRows().stream() .map(row -> row.get(symbolId)) - .map(rowExpression -> evaluateConstantRowExpression(rowExpression, metadata.getFunctionAndTypeManager(), session.toConnectorSession())) + .map(rowExpression -> expressionOptimizerManager.getExpressionOptimizer(session.toConnectorSession()) + .optimize(rowExpression, EVALUATED, session.toConnectorSession(), i -> i)) + .peek(rowExpression -> verify(rowExpression instanceof ConstantExpression, "Expected constant expression, but got: %s", rowExpression)) + .map(rowExpression -> ((ConstantExpression) rowExpression).getValue()) .collect(toList()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchExecutor.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchExecutor.java index 5a3f7d1816d15..859994e2d9ffe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchExecutor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchExecutor.java @@ -19,13 +19,12 @@ import com.google.common.io.Closer; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.ListeningScheduledExecutorService; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchInfo.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchInfo.java index 2b08af85aa5d9..6445da1f82351 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchInfo.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.dispatcher; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.ExecutionFailureInfo; -import io.airlift.units.Duration; import java.util.Optional; @@ -22,7 +22,6 @@ public class DispatchInfo { - private final Optional coordinatorLocation; private final Optional failureInfo; private final Duration elapsedTime; private final Duration waitingForPrerequisitesTime; @@ -30,48 +29,40 @@ public class DispatchInfo public static DispatchInfo waitingForPrerequisites(Duration elapsedTime, Duration waitingForPrerequisitesTime) { - return new DispatchInfo(Optional.empty(), Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.empty()); + return new DispatchInfo(Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.empty()); } public static DispatchInfo queued(Duration elapsedTime, Duration waitingForPrerequisitesTime, Duration queuedTime) { requireNonNull(queuedTime, "queuedTime is null"); - return new DispatchInfo(Optional.empty(), Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); + return new DispatchInfo(Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); } - public static DispatchInfo dispatched(CoordinatorLocation coordinatorLocation, Duration elapsedTime, Duration waitingForPrerequisitesTime, Duration queuedTime) + public static DispatchInfo dispatched(Duration elapsedTime, Duration waitingForPrerequisitesTime, Duration queuedTime) { - requireNonNull(coordinatorLocation, "coordinatorLocation is null"); requireNonNull(queuedTime, "queuedTime is null"); - return new DispatchInfo(Optional.of(coordinatorLocation), Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); + return new DispatchInfo(Optional.empty(), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); } public static DispatchInfo failed(ExecutionFailureInfo failureInfo, Duration elapsedTime, Duration waitingForPrerequisitesTime, Duration queuedTime) { requireNonNull(failureInfo, "coordinatorLocation is null"); requireNonNull(queuedTime, "queuedTime is null"); - return new DispatchInfo(Optional.empty(), Optional.of(failureInfo), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); + return new DispatchInfo(Optional.of(failureInfo), elapsedTime, waitingForPrerequisitesTime, Optional.of(queuedTime)); } private DispatchInfo( - Optional coordinatorLocation, Optional failureInfo, Duration elapsedTime, Duration waitingForPrerequisitesTime, Optional queuedTime) { - this.coordinatorLocation = requireNonNull(coordinatorLocation, "coordinatorLocation is null"); this.failureInfo = requireNonNull(failureInfo, "failureInfo is null"); this.elapsedTime = requireNonNull(elapsedTime, "elapsedTime is null"); this.waitingForPrerequisitesTime = requireNonNull(waitingForPrerequisitesTime, "waitingForPrerequisitesTime is null"); this.queuedTime = requireNonNull(queuedTime, "queuedTime is null"); } - public Optional getCoordinatorLocation() - { - return coordinatorLocation; - } - public Optional getFailureInfo() { return failureInfo; diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java index 0049ba2a3d614..25d923e3079a9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/DispatchManager.java @@ -41,13 +41,12 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.security.Principal; import java.util.List; import java.util.Optional; @@ -412,6 +411,11 @@ public Optional getDispatchInfo(QueryId queryId) }); } + public long getDurationUntilExpirationInMillis(QueryId queryId) + { + return queryTracker.getQuery(queryId).getDurationUntilExpirationInMillis(); + } + /** * Check if a given queryId exists in query tracker * diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedDispatchQuery.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedDispatchQuery.java index b3dbc66fd6f49..ed1818f916228 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedDispatchQuery.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedDispatchQuery.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.dispatcher; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.execution.ExecutionFailureInfo; @@ -23,7 +24,6 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.net.URI; import java.util.Optional; @@ -156,6 +156,12 @@ public long getCreateTimeInMillis() return basicQueryInfo.getQueryStats().getCreateTimeInMillis(); } + @Override + public Duration getQueuedTime() + { + return basicQueryInfo.getQueryStats().getQueuedTime(); + } + @Override public long getExecutionStartTimeInMillis() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedLocalDispatchQueryFactory.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedLocalDispatchQueryFactory.java index 0e518950e979d..4bb9ef56e66f6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedLocalDispatchQueryFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/FailedLocalDispatchQueryFactory.java @@ -19,8 +19,7 @@ import com.facebook.presto.execution.LocationFactory; import com.facebook.presto.server.BasicQueryInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; import java.util.concurrent.ExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/NoOpQueryManager.java b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/NoOpQueryManager.java index 2e33580bc483e..14bd6b143709f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/dispatcher/NoOpQueryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/dispatcher/NoOpQueryManager.java @@ -75,6 +75,12 @@ public QueryInfo getFullQueryInfo(QueryId queryId) throw new UnsupportedOperationException(); } + @Override + public long getDurationUntilExpirationInMillis(QueryId queryId) + { + throw new UnsupportedOperationException(); + } + @Override public Session getQuerySession(QueryId queryId) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitor.java index 73f35985a3a97..5174a14bc980f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitor.java @@ -41,15 +41,17 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.operator.OperatorInfo; -import com.facebook.presto.operator.OperatorStats; +import com.facebook.presto.operator.OperatorInfoUnion; import com.facebook.presto.operator.TableFinishInfo; import com.facebook.presto.operator.TaskStats; import com.facebook.presto.server.BasicQueryInfo; import com.facebook.presto.server.BasicQueryStats; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.eventlistener.Column; import com.facebook.presto.spi.eventlistener.OperatorStatistics; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.eventlistener.QueryCompletedEvent; import com.facebook.presto.spi.eventlistener.QueryContext; import com.facebook.presto.spi.eventlistener.QueryCreatedEvent; @@ -74,16 +76,17 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SystemSessionProperties.logQueryPlansUsedInHistoryBasedOptimizer; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.StageInfo.getAllStages; @@ -93,7 +96,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.NaN; import static java.lang.Math.max; import static java.lang.Math.toIntExact; @@ -223,6 +225,7 @@ public void queryImmediateFailureEvent(BasicQueryInfo queryInfo, ExecutionFailur ofMillis(0), ofMillis(0), ofMillis(0), + ofMillis(0), ofMillis(queryInfo.getQueryStats().getWaitingForPrerequisitesTime().toMillis()), ofMillis(queryInfo.getQueryStats().getQueuedTime().toMillis()), ofMillis(0), @@ -277,6 +280,7 @@ public void queryImmediateFailureEvent(BasicQueryInfo queryInfo, ExecutionFailur ImmutableSet.of(), Optional.empty(), ImmutableMap.of(), + Optional.empty(), Optional.empty())); logQueryTimeline(queryInfo); @@ -320,7 +324,8 @@ public void queryCompletedEvent(QueryInfo queryInfo) queryInfo.getWindowFunctions(), queryInfo.getPrestoSparkExecutionContext(), getPlanHash(queryInfo.getPlanCanonicalInfo(), historyBasedPlanStatisticsTracker.getStatsEquivalentPlanRootNode(queryInfo.getQueryId())), - Optional.of(queryInfo.getPlanIdNodeMap()))); + Optional.of(queryInfo.getPlanIdNodeMap()), + Optional.ofNullable(queryInfo.getUpdateInfo()).map(UpdateInfo::getUpdateObject))); logQueryTimeline(queryInfo); } @@ -363,7 +368,7 @@ private QueryMetadata createQueryMetadata(QueryInfo queryInfo) .map(stageId -> String.valueOf(stageId.getId())) .collect(toImmutableList()), queryInfo.getSession().getTraceToken(), - Optional.ofNullable(queryInfo.getUpdateType())); + Optional.ofNullable(queryInfo.getUpdateInfo()).map(UpdateInfo::getUpdateType)); } private List createOperatorStatistics(QueryInfo queryInfo) @@ -431,6 +436,7 @@ private QueryStatistics createQueryStatistics(QueryInfo queryInfo) ofMillis(queryStats.getTotalCpuTime().toMillis()), ofMillis(queryStats.getRetriedCpuTime().toMillis()), ofMillis(queryStats.getElapsedTime().toMillis()), + ofMillis(queryStats.getTotalScheduledTime().toMillis()), ofMillis(queryStats.getWaitingForPrerequisitesTime().toMillis()), ofMillis(queryStats.getQueuedTime().toMillis()), ofMillis(queryStats.getResourceWaitingTime().toMillis()), @@ -471,6 +477,7 @@ private QueryStatistics createQueryStatistics(BasicQueryInfo basicQueryInfo) ofMillis(queryStats.getTotalCpuTime().toMillis()), ofMillis(0), ofMillis(queryStats.getElapsedTime().toMillis()), + ofMillis(queryStats.getTotalScheduledTime().toMillis()), ofMillis(queryStats.getWaitingForPrerequisitesTime().toMillis()), ofMillis(queryStats.getQueuedTime().toMillis()), ofMillis(0), @@ -590,22 +597,35 @@ private static QueryIOMetadata getQueryIOMetadata(QueryInfo queryInfo) .collect(Collectors.toList()), input.getConnectorInfo(), input.getStatistics(), - input.getSerializedCommitOutput())); + input.getCommitOutput())); } Optional output = Optional.empty(); if (queryInfo.getOutput().isPresent()) { + // Check both info (JSON) and infoUnion (Thrift) fields for TableFinishInfo Optional tableFinishInfo = queryInfo.getQueryStats().getOperatorSummaries().stream() - .map(OperatorStats::getInfo) - .filter(TableFinishInfo.class::isInstance) - .map(TableFinishInfo.class::cast) + .map(operatorStats -> { + // First try the info field (JSON serialization) + OperatorInfo info = operatorStats.getInfo(); + if (info instanceof TableFinishInfo) { + return (TableFinishInfo) info; + } + // Fall back to infoUnion field (Thrift serialization) + OperatorInfoUnion infoUnion = operatorStats.getInfoUnion(); + if (infoUnion != null) { + return infoUnion.getTableFinishInfo(); + } + return null; + }) + .filter(Objects::nonNull) .findFirst(); - Optional> outputColumns = queryInfo.getOutput().get().getColumns() + Optional> outputColumnsMetadata = queryInfo.getOutput().get().getColumns() .map(columns -> columns.stream() - .map(column -> new Column( - column.getName(), - column.getType())) + .map(column -> new OutputColumnMetadata( + column.getColumnName(), + column.getColumnType(), + column.getSourceColumns())) .collect(toImmutableList())); output = Optional.of( @@ -615,8 +635,8 @@ private static QueryIOMetadata getQueryIOMetadata(QueryInfo queryInfo) queryInfo.getOutput().get().getTable(), tableFinishInfo.map(TableFinishInfo::getSerializedConnectorOutputMetadata), tableFinishInfo.map(TableFinishInfo::isJsonLengthLimitExceeded), - queryInfo.getOutput().get().getSerializedCommitOutput(), - outputColumns)); + outputColumnsMetadata, + queryInfo.getOutput().get().getCommitOutput())); } return new QueryIOMetadata(inputs.build(), output); diff --git a/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitorConfig.java b/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitorConfig.java index c43f208028f94..b48cdb9cf7b96 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitorConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/event/QueryMonitorConfig.java @@ -15,13 +15,12 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; -import io.airlift.units.Duration; -import io.airlift.units.MaxDataSize; -import io.airlift.units.MinDataSize; - -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MaxDataSize; +import com.facebook.airlift.units.MinDataSize; +import jakarta.validation.constraints.NotNull; import static java.util.concurrent.TimeUnit.MINUTES; diff --git a/presto-main-base/src/main/java/com/facebook/presto/event/QueryProgressMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/event/QueryProgressMonitor.java index 8864a3da8a0fb..cef18a44ddd61 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/event/QueryProgressMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/event/QueryProgressMonitor.java @@ -13,14 +13,13 @@ */ package com.facebook.presto.event; +import com.facebook.airlift.units.Duration; import com.facebook.presto.dispatcher.DispatchManager; import com.facebook.presto.server.BasicQueryInfo; -import io.airlift.units.Duration; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicLong; diff --git a/presto-main-base/src/main/java/com/facebook/presto/event/SplitMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/event/SplitMonitor.java index 10db412e3bd72..4fcb2b9a10324 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/event/SplitMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/event/SplitMonitor.java @@ -22,9 +22,8 @@ import com.facebook.presto.spi.eventlistener.SplitStatistics; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; - -import javax.annotation.Nullable; -import javax.inject.Inject; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import java.time.Duration; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/eventlistener/EventListenerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/eventlistener/EventListenerConfig.java index 1d248f56e59a0..5346f2f7f2334 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/eventlistener/EventListenerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/eventlistener/EventListenerConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.Config; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/AccessControlCheckerExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/AccessControlCheckerExecution.java index cf204e015785d..3ab62ab71b85f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/AccessControlCheckerExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/AccessControlCheckerExecution.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.resourceGroups.QueryType; @@ -36,9 +37,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.inject.Inject; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Optional; import java.util.concurrent.ScheduledExecutorService; @@ -145,6 +144,12 @@ public long getCreateTimeInMillis() return stateMachine.getCreateTimeInMillis(); } + @Override + public Duration getQueuedTime() + { + return stateMachine.getQueuedTime(); + } + @Override public long getExecutionStartTimeInMillis() { @@ -266,7 +271,7 @@ private ListenableFuture executeTask() } stateMachine.beginColumnAccessPermissionChecking(); - checkAccessPermissions(queryAnalysis.getAccessControlReferences(), query); + checkAccessPermissions(queryAnalysis.getAccessControlReferences(), queryAnalysis.getViewDefinitionReferences(), query, getSession().getPreparedStatements(), getSession().getIdentity(), accessControl, getSession().getAccessControlContext()); stateMachine.endColumnAccessPermissionChecking(); return immediateFuture(null); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/AddColumnTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/AddColumnTask.java index 031e8e282f413..db9f05c402538 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/AddColumnTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/AddColumnTask.java @@ -14,7 +14,6 @@ package com.facebook.presto.execution; import com.facebook.presto.Session; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; @@ -25,10 +24,12 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.AddColumn; import com.facebook.presto.sql.tree.ColumnDefinition; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; @@ -61,7 +62,7 @@ public String getName() @Override public ListenableFuture execute(AddColumn statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { if (!statement.isTableExists()) { @@ -114,8 +115,11 @@ public ListenableFuture execute(AddColumn statement, TransactionManager trans metadata, parameterExtractor(statement, parameters)); + Identifier columnIdentifier = element.getName(); + String name = metadata.normalizeIdentifier(session, tableName.getCatalogName(), columnIdentifier.getValue()); + ColumnMetadata column = ColumnMetadata.builder() - .setName(element.getName().getValue()) + .setName(name) .setType(type) .setNullable(element.isNullable()) .setComment(element.getComment().orElse(null)) diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/AddConstraintTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/AddConstraintTask.java index fa9754337e6a9..202e13c9df22f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/AddConstraintTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/AddConstraintTask.java @@ -89,7 +89,7 @@ public String getName() @Override public ListenableFuture execute(AddConstraint statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { if (!statement.isTableExists()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/AlterColumnNotNullTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/AlterColumnNotNullTask.java index f5bedb4880a1b..851546b9fbf59 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/AlterColumnNotNullTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/AlterColumnNotNullTask.java @@ -57,7 +57,7 @@ public String getName() @Override public ListenableFuture execute(AlterColumnNotNull statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable(), metadata); Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandleOptional.isPresent()) { if (!statement.isTableExists()) { @@ -82,8 +82,7 @@ public ListenableFuture execute(AlterColumnNotNull statement, TransactionMana } TableHandle tableHandle = tableHandleOptional.get(); - - String column = statement.getColumn().getValueLowerCase(); + String column = metadata.normalizeIdentifier(session, tableName.getCatalogName(), statement.getColumn().getValue()); accessControl.checkCanAddConstraints(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/AlterFunctionTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/AlterFunctionTask.java index 924ded74bf0da..57442b4fca918 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/AlterFunctionTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/AlterFunctionTask.java @@ -17,10 +17,16 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.AlterRoutineCharacteristics; import com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.AlterFunction; @@ -29,14 +35,14 @@ import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; import java.util.Optional; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Objects.requireNonNull; @@ -68,8 +74,9 @@ public String explain(AlterFunction statement, List parameters) public ListenableFuture execute(AlterFunction statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { Map, Expression> parameterLookup = parameterExtractor(statement, parameters); - Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, warningCollector, query); - analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, warningCollector, query, new ViewDefinitionReferences()); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + checkAccessPermissions(analysis.getAccessControlReferences(), analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); QualifiedObjectName functionName = metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver().qualifyObjectName(statement.getFunctionName()); AlterRoutineCharacteristics alterRoutineCharacteristics = new AlterRoutineCharacteristics( @@ -84,4 +91,7 @@ public ListenableFuture execute(AlterFunction statement, TransactionManager t alterRoutineCharacteristics); return immediateFuture(null); } + + @Override + public void queryPermissionCheck(AccessControl accessControl, Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) {} } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/BasicStageExecutionStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/BasicStageExecutionStats.java index e760019e61422..3e6c7f06e5f1d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/BasicStageExecutionStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/BasicStageExecutionStats.java @@ -13,16 +13,16 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.operator.BlockedReason; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import java.util.HashSet; import java.util.OptionalDouble; import java.util.Set; +import static com.facebook.airlift.units.Duration.succinctDuration; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.Duration.succinctDuration; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -32,6 +32,14 @@ public class BasicStageExecutionStats public static final BasicStageExecutionStats EMPTY_STAGE_STATS = new BasicStageExecutionStats( false, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0, 0, 0, @@ -60,6 +68,14 @@ public class BasicStageExecutionStats private final int queuedDrivers; private final int runningDrivers; private final int completedDrivers; + private final int totalNewDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + private final int totalSplits; + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; private final long rawInputDataSizeInBytes; private final long rawInputPositions; private final double cumulativeUserMemory; @@ -81,6 +97,16 @@ public BasicStageExecutionStats( int runningDrivers, int completedDrivers, + int totalNewDrivers, + int queuedNewDrivers, + int runningNewDrivers, + int completedNewDrivers, + + int totalSplits, + int queuedSplits, + int runningSplits, + int completedSplits, + long rawInputDataSizeInBytes, long rawInputPositions, @@ -104,6 +130,14 @@ public BasicStageExecutionStats( this.queuedDrivers = queuedDrivers; this.runningDrivers = runningDrivers; this.completedDrivers = completedDrivers; + this.totalNewDrivers = totalNewDrivers; + this.queuedNewDrivers = queuedNewDrivers; + this.runningNewDrivers = runningNewDrivers; + this.completedNewDrivers = completedNewDrivers; + this.totalSplits = totalSplits; + this.queuedSplits = queuedSplits; + this.runningSplits = runningSplits; + this.completedSplits = completedSplits; checkArgument(rawInputDataSizeInBytes >= 0, "rawInputDataSizeInBytes is negative"); this.rawInputDataSizeInBytes = rawInputDataSizeInBytes; this.rawInputPositions = rawInputPositions; @@ -147,6 +181,46 @@ public int getCompletedDrivers() return completedDrivers; } + public int getTotalNewDrivers() + { + return totalNewDrivers; + } + + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } + + public int getTotalSplits() + { + return totalSplits; + } + + public int getQueuedSplits() + { + return queuedSplits; + } + + public int getRunningSplits() + { + return runningSplits; + } + + public int getCompletedSplits() + { + return completedSplits; + } + public long getRawInputDataSizeInBytes() { return rawInputDataSizeInBytes; @@ -214,6 +288,16 @@ public static BasicStageExecutionStats aggregateBasicStageStats(Iterable execute(Call call, TransactionManager transactionMana throw new PrestoException(NOT_SUPPORTED, "Procedures cannot be called within a transaction (use autocommit mode)"); } - QualifiedObjectName procedureName = createQualifiedObjectName(session, call, call.getName()); + QualifiedObjectName procedureName = createQualifiedObjectName(session, call, call.getName(), metadata); ConnectorId connectorId = getConnectorIdOrThrow(session, metadata, procedureName.getCatalogName(), call, catalogError); - Procedure procedure = metadata.getProcedureRegistry().resolve(connectorId, toSchemaTableName(procedureName)); + BaseProcedure procedure = metadata.getProcedureRegistry().resolve(connectorId, toSchemaTableName(procedureName)); + accessControl.checkCanCallProcedure(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), procedureName); + Map, Expression> parameterLookup = parameterExtractor(call, parameters); + checkArgument(procedure instanceof Procedure, "Must call an inner procedure in CallTask"); + Procedure innerProcedure = (Procedure) procedure; + Object[] values = extractParameterValuesInOrder(call, innerProcedure, metadata, session, parameterLookup); + + // validate arguments + MethodType methodType = innerProcedure.getMethodHandle().type(); + for (int i = 0; i < innerProcedure.getArguments().size(); i++) { + if ((values[i] == null) && methodType.parameterType(i).isPrimitive()) { + String name = innerProcedure.getArguments().get(i).getName(); + throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "Procedure argument cannot be null: " + name); + } + } + + // insert session argument + List arguments = new ArrayList<>(); + Iterator valuesIterator = asList(values).iterator(); + for (Class type : methodType.parameterList()) { + if (ConnectorSession.class.isAssignableFrom(type)) { + arguments.add(session.toConnectorSession(connectorId)); + } + else { + arguments.add(valuesIterator.next()); + } + } + + try { + innerProcedure.getMethodHandle().invokeWithArguments(arguments); + } + catch (Throwable t) { + if (t instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throwIfInstanceOf(t, PrestoException.class); + throw new PrestoException(PROCEDURE_CALL_FAILED, t); + } + + return immediateFuture(null); + } + + public static Object[] extractParameterValuesInOrder(Call call, BaseProcedure procedure, Metadata metadata, Session session, Map, Expression> parameterLookup) + { // map declared argument names to positions Map positions = new HashMap<>(); for (int i = 0; i < procedure.getArguments().size(); i++) { @@ -121,9 +166,9 @@ else if (i < procedure.getArguments().size()) { } procedure.getArguments().stream() - .filter(Argument::isRequired) + .filter(BaseArgument::isRequired) .filter(argument -> !names.containsKey(argument.getName())) - .map(Argument::getName) + .map(BaseArgument::getName) .findFirst() .ifPresent(argument -> { throw new SemanticException(INVALID_PROCEDURE_ARGUMENTS, call, format("Required procedure argument '%s' is missing", argument)); @@ -131,11 +176,10 @@ else if (i < procedure.getArguments().size()) { // get argument values Object[] values = new Object[procedure.getArguments().size()]; - Map, Expression> parameterLookup = parameterExtractor(call, parameters); for (Entry entry : names.entrySet()) { CallArgument callArgument = entry.getValue(); int index = positions.get(entry.getKey()); - Argument argument = procedure.getArguments().get(index); + BaseArgument argument = procedure.getArguments().get(index); Expression expression = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(parameterLookup), callArgument.getValue()); Type type = metadata.getType(argument.getType()); @@ -148,7 +192,7 @@ else if (i < procedure.getArguments().size()) { // fill values with optional arguments defaults for (int i = 0; i < procedure.getArguments().size(); i++) { - Argument argument = procedure.getArguments().get(i); + BaseArgument argument = procedure.getArguments().get(i); if (!names.containsKey(argument.getName())) { verify(argument.isOptional()); @@ -156,39 +200,7 @@ else if (i < procedure.getArguments().size()) { } } - // validate arguments - MethodType methodType = procedure.getMethodHandle().type(); - for (int i = 0; i < procedure.getArguments().size(); i++) { - if ((values[i] == null) && methodType.parameterType(i).isPrimitive()) { - String name = procedure.getArguments().get(i).getName(); - throw new PrestoException(INVALID_PROCEDURE_ARGUMENT, "Procedure argument cannot be null: " + name); - } - } - - // insert session argument - List arguments = new ArrayList<>(); - Iterator valuesIterator = asList(values).iterator(); - for (Class type : methodType.parameterList()) { - if (ConnectorSession.class.isAssignableFrom(type)) { - arguments.add(session.toConnectorSession(connectorId)); - } - else { - arguments.add(valuesIterator.next()); - } - } - - try { - procedure.getMethodHandle().invokeWithArguments(arguments); - } - catch (Throwable t) { - if (t instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throwIfInstanceOf(t, PrestoException.class); - throw new PrestoException(PROCEDURE_CALL_FAILED, t); - } - - return immediateFuture(null); + return values; } private static Object toTypeObjectValue(Session session, Type type, Object value) diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterOverloadConfig.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterOverloadConfig.java new file mode 100644 index 0000000000000..a8ddf72707447 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterOverloadConfig.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.airlift.configuration.Config; + +public class ClusterOverloadConfig +{ + public static final String OVERLOAD_POLICY_CNT_BASED = "overload_worker_cnt_based_throttling"; + public static final String OVERLOAD_POLICY_PCT_BASED = "overload_worker_pct_based_throttling"; + private boolean clusterOverloadThrottlingEnabled; + private double allowedOverloadWorkersPct = 0.01; + private int allowedOverloadWorkersCnt; + private String overloadPolicyType = OVERLOAD_POLICY_CNT_BASED; + private int overloadCheckCacheTtlInSecs = 5; + + /** + * Gets the time-to-live for the cached cluster overload state. + * This determines how frequently the system will re-evaluate whether the cluster is overloaded. + * + * @return the cache TTL duration + */ + public int getOverloadCheckCacheTtlInSecs() + { + return overloadCheckCacheTtlInSecs; + } + + /** + * Gets the time-to-live for the cached cluster overload state. + * This determines how frequently the system will re-evaluate whether the cluster is overloaded. + * + * @return the cache TTL duration + */ + public int getOverloadCheckCacheTtlMillis() + { + return overloadCheckCacheTtlInSecs * 1000; + } + + /** + * Sets the time-to-live for the cached cluster overload state. + * + * @param overloadCheckCacheTtlInSecs the cache TTL duration + * @return this for chaining + */ + @Config("cluster.overload-check-cache-ttl-secs") + public ClusterOverloadConfig setOverloadCheckCacheTtlInSecs(int overloadCheckCacheTtlInSecs) + { + this.overloadCheckCacheTtlInSecs = overloadCheckCacheTtlInSecs; + return this; + } + + @Config("cluster-overload.enable-throttling") + public ClusterOverloadConfig setClusterOverloadThrottlingEnabled(boolean clusterOverloadThrottlingEnabled) + { + this.clusterOverloadThrottlingEnabled = clusterOverloadThrottlingEnabled; + return this; + } + + public boolean isClusterOverloadThrottlingEnabled() + { + return this.clusterOverloadThrottlingEnabled; + } + + @Config("cluster-overload.allowed-overload-workers-pct") + public ClusterOverloadConfig setAllowedOverloadWorkersPct(Double allowedOverloadWorkersPct) + { + this.allowedOverloadWorkersPct = allowedOverloadWorkersPct; + return this; + } + + public double getAllowedOverloadWorkersPct() + { + return this.allowedOverloadWorkersPct; + } + + @Config("cluster-overload.allowed-overload-workers-cnt") + public ClusterOverloadConfig setAllowedOverloadWorkersCnt(int allowedOverloadWorkersCnt) + { + this.allowedOverloadWorkersCnt = allowedOverloadWorkersCnt; + return this; + } + + public double getAllowedOverloadWorkersCnt() + { + return this.allowedOverloadWorkersCnt; + } + + @Config("cluster-overload.overload-policy-type") + public ClusterOverloadConfig setOverloadPolicyType(String overloadPolicyType) + { + // validate + this.overloadPolicyType = overloadPolicyType; + return this; + } + + public String getOverloadPolicyType() + { + return this.overloadPolicyType; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java index 6833912e31c18..708af458d5a67 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ClusterSizeMonitor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.metadata.InternalNodeManager; @@ -22,12 +23,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.HashSet; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java index 318c8fed1246c..779d74f06ea33 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateFunctionTask.java @@ -18,13 +18,18 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.parser.SqlParser; @@ -38,8 +43,7 @@ import com.facebook.presto.sql.tree.RoutineBody; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -51,6 +55,7 @@ import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; import static com.facebook.presto.sql.SqlFormatter.formatSql; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.lang.String.format; @@ -84,8 +89,10 @@ public ListenableFuture execute(CreateFunction statement, TransactionManager { Map, Expression> parameterLookup = parameterExtractor(statement, parameters); Session session = stateMachine.getSession(); - Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, stateMachine.getWarningCollector(), query); - Analysis analysis = analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, stateMachine.getWarningCollector(), query, new ViewDefinitionReferences()); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + checkAccessPermissions(analysis.getAccessControlReferences(), analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); + if (analysis.getFunctionHandles().values().stream() .anyMatch(SqlFunctionHandle.class::isInstance)) { throw new PrestoException(NOT_SUPPORTED, "Invoking a dynamically registered function in SQL function body is not supported"); @@ -102,6 +109,9 @@ public ListenableFuture execute(CreateFunction statement, TransactionManager return immediateFuture(null); } + @Override + public void queryPermissionCheck(AccessControl accessControl, Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) {} + private SqlInvokedFunction createSqlInvokedFunction(CreateFunction statement, Metadata metadata, Analysis analysis) { QualifiedObjectName functionName = statement.isTemporary() ? diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateMaterializedViewTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateMaterializedViewTask.java index dfaebac50954b..99ca276ae6813 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateMaterializedViewTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateMaterializedViewTask.java @@ -23,7 +23,9 @@ import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.ViewSecurity; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.analyzer.MaterializedViewColumnMappingExtractor; @@ -35,13 +37,14 @@ import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.getDefaultViewSecurityMode; +import static com.facebook.presto.SystemSessionProperties.isLegacyMaterializedViews; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; @@ -51,6 +54,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MATERIALIZED_VIEW_ALREADY_EXISTS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Objects.requireNonNull; @@ -75,7 +79,7 @@ public String getName() @Override public ListenableFuture execute(CreateMaterializedView statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName viewName = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName viewName = createQualifiedObjectName(session, statement, statement.getName(), metadata); Optional viewHandle = metadata.getMetadataResolver(session).getTableHandle(viewName); if (viewHandle.isPresent()) { @@ -89,19 +93,20 @@ public ListenableFuture execute(CreateMaterializedView statement, Transaction accessControl.checkCanCreateView(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), viewName); Map, Expression> parameterLookup = parameterExtractor(statement, parameters); - Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, warningCollector, query); - Analysis analysis = analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, warningCollector, query, new ViewDefinitionReferences()); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + checkAccessPermissions(analysis.getAccessControlReferences(), analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); List columnMetadata = analysis.getOutputDescriptor(statement.getQuery()) .getVisibleFields().stream() .map(field -> ColumnMetadata.builder() - .setName(field.getName().get()) + .setName(metadata.normalizeIdentifier(session, viewName.getCatalogName(), field.getName().get())) .setType(field.getType()) .build()) .collect(toImmutableList()); Map sqlProperties = mapFromProperties(statement.getProperties()); - Map properties = metadata.getTablePropertyManager().getProperties( + Map properties = metadata.getMaterializedViewPropertyManager().getProperties( getConnectorIdOrThrow(session, metadata, viewName.getCatalogName()), viewName.getCatalogName(), sqlProperties, @@ -119,7 +124,7 @@ public ListenableFuture execute(CreateMaterializedView statement, Transaction List baseTables = analysis.getTableNodes().stream() .map(table -> { - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); if (!viewName.getCatalogName().equals(tableName.getCatalogName())) { throw new SemanticException( NOT_SUPPORTED, @@ -132,13 +137,32 @@ public ListenableFuture execute(CreateMaterializedView statement, Transaction .distinct() .collect(toImmutableList()); - MaterializedViewColumnMappingExtractor extractor = new MaterializedViewColumnMappingExtractor(analysis, session); + MaterializedViewColumnMappingExtractor extractor = new MaterializedViewColumnMappingExtractor(analysis, session, metadata); + + if (isLegacyMaterializedViews(session) && statement.getSecurity().isPresent()) { + throw new SemanticException( + NOT_SUPPORTED, + statement, + "SECURITY clause is not supported when legacy_materialized_views is enabled"); + } + + Optional owner = Optional.of(session.getUser()); + Optional securityMode; + if (isLegacyMaterializedViews(session)) { + // Legacy mode: no securityMode field, empty to preserve existing behavior + securityMode = Optional.empty(); + } + else { + securityMode = Optional.of(statement.getSecurity().orElse(getDefaultViewSecurityMode(session))); + } + MaterializedViewDefinition viewDefinition = new MaterializedViewDefinition( sql, viewName.getSchemaName(), viewName.getObjectName(), baseTables, - Optional.of(session.getUser()), + owner, + securityMode, extractor.getMaterializedViewColumnMappings(), extractor.getMaterializedViewDirectColumnMappings(), extractor.getBaseTablesOnOuterJoinSide(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateSchemaTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateSchemaTask.java index 4dc2b8d13386a..b0b7a42d263e5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateSchemaTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateSchemaTask.java @@ -53,7 +53,7 @@ public String explain(CreateSchema statement, List parameters) @Override public ListenableFuture execute(CreateSchema statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - CatalogSchemaName schema = createCatalogSchemaName(session, statement, Optional.of(statement.getSchemaName())); + CatalogSchemaName schema = createCatalogSchemaName(session, statement, Optional.of(statement.getSchemaName()), metadata); // TODO: validate that catalog exists diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTableTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTableTask.java index 60f2c8c083129..306bbde275f6e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTableTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTableTask.java @@ -14,7 +14,6 @@ package com.facebook.presto.execution; import com.facebook.presto.Session; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; @@ -28,6 +27,7 @@ import com.facebook.presto.spi.constraints.PrimaryKeyConstraint; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.ColumnDefinition; import com.facebook.presto.sql.tree.ConstraintSpecification; @@ -48,7 +48,6 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -103,7 +102,7 @@ public ListenableFuture internalExecute(CreateTable statement, Metadata metad checkArgument(!statement.getElements().isEmpty(), "no columns for table"); Map, Expression> parameterLookup = parameterExtractor(statement, parameters); - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (tableHandle.isPresent()) { if (!statement.isNotExists()) { @@ -121,7 +120,8 @@ public ListenableFuture internalExecute(CreateTable statement, Metadata metad for (TableElement element : statement.getElements()) { if (element instanceof ColumnDefinition) { ColumnDefinition column = (ColumnDefinition) element; - String name = column.getName().getValue().toLowerCase(Locale.ENGLISH); + String columnName = column.getName().getValue(); + String name = metadata.normalizeIdentifier(session, tableName.getCatalogName(), columnName); Type type; try { type = metadata.getType(parseTypeSignature(column.getType())); @@ -158,7 +158,7 @@ public ListenableFuture internalExecute(CreateTable statement, Metadata metad } else if (element instanceof LikeClause) { LikeClause likeClause = (LikeClause) element; - QualifiedObjectName likeTableName = createQualifiedObjectName(session, statement, likeClause.getTableName()); + QualifiedObjectName likeTableName = createQualifiedObjectName(session, statement, likeClause.getTableName(), metadata); getConnectorIdOrThrow(session, metadata, likeTableName.getCatalogName(), statement, likeTableCatalogError); if (!tableName.getCatalogName().equals(likeTableName.getCatalogName())) { throw new SemanticException(NOT_SUPPORTED, statement, "LIKE table across catalogs is not supported"); @@ -180,10 +180,10 @@ else if (element instanceof LikeClause) { likeTableMetadata.getColumns().stream() .filter(column -> !column.isHidden()) .forEach(column -> { - if (columns.containsKey(column.getName().toLowerCase(Locale.ENGLISH))) { + if (columns.containsKey(column.getName())) { throw new SemanticException(DUPLICATE_COLUMN_NAME, element, "Column name '%s' specified more than once", column.getName()); } - columns.put(column.getName().toLowerCase(Locale.ENGLISH), column); + columns.put(column.getName(), column); }); } else if (element instanceof ConstraintSpecification) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTypeTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTypeTask.java index 99dcf3a8682d3..ffd66102574be 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTypeTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateTypeTask.java @@ -29,8 +29,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.Streams; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateViewTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateViewTask.java index 865010f4cc184..594854dba08e1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/CreateViewTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/CreateViewTask.java @@ -19,9 +19,14 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewSecurity; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -31,19 +36,20 @@ import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; +import java.util.Map; import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.getDefaultViewSecurityMode; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.analyzer.ViewDefinition.ViewColumn; +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; import static com.facebook.presto.sql.SqlFormatterUtil.getFormattedSql; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; -import static com.facebook.presto.sql.tree.CreateView.Security.INVOKER; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.util.Objects.requireNonNull; @@ -80,7 +86,7 @@ public String explain(CreateView statement, List parameters) @Override public ListenableFuture execute(CreateView statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName(), metadata); accessControl.checkCanCreateView(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), name); @@ -95,14 +101,14 @@ public ListenableFuture execute(CreateView statement, TransactionManager tran List columnMetadata = columns.stream() .map(column -> ColumnMetadata.builder() - .setName(column.getName()) + .setName(metadata.normalizeIdentifier(session, name.getCatalogName(), column.getName())) .setType(column.getType()) .build()) .collect(toImmutableList()); ConnectorTableMetadata viewMetadata = new ConnectorTableMetadata(toSchemaTableName(name), columnMetadata); - CreateView.Security defaultViewSecurityMode = getDefaultViewSecurityMode(session); + ViewSecurity defaultViewSecurityMode = getDefaultViewSecurityMode(session); Optional owner = Optional.of(session.getUser()); if (statement.getSecurity().orElse(defaultViewSecurityMode) == INVOKER) { owner = Optional.empty(); @@ -115,9 +121,15 @@ public ListenableFuture execute(CreateView statement, TransactionManager tran return immediateFuture(null); } + @Override + public void queryPermissionCheck(AccessControl accessControl, Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) {} + private Analysis analyzeStatement(Statement statement, Session session, Metadata metadata, AccessControl accessControl, List parameters, WarningCollector warningCollector, String query) { - Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterExtractor(statement, parameters), warningCollector, query); - return analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterExtractor(statement, parameters), warningCollector, query, new ViewDefinitionReferences()); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + checkAccessPermissions(analysis.getAccessControlReferences(), analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); + + return analysis; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DDLDefinitionExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DDLDefinitionExecution.java index aeefdd2d62693..8fb65531f65fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DDLDefinitionExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DDLDefinitionExecution.java @@ -24,9 +24,9 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -62,6 +62,7 @@ private DDLDefinitionExecution( @Override protected ListenableFuture executeTask() { + task.queryPermissionCheck(accessControl, stateMachine.getSession().getIdentity(), stateMachine.getSession().getAccessControlContext(), query, stateMachine.getSession().getPreparedStatements(), ImmutableMap.of(), ImmutableMap.of()); return task.execute(statement, transactionManager, metadata, accessControl, stateMachine.getSession(), parameters, stateMachine.getWarningCollector(), query); } @@ -101,6 +102,8 @@ public DDLDefinitionExecution createQueryExecution( //TODO: PreparedQuery should be passed all the way to analyzer checkState(preparedQuery instanceof BuiltInQueryPreparer.BuiltInPreparedQuery, "Unsupported prepared query type: %s", preparedQuery.getClass().getSimpleName()); BuiltInQueryPreparer.BuiltInPreparedQuery builtInQueryPreparer = (BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery; + Statement statement = builtInQueryPreparer.getStatement(); + stateMachine.setUpdateInfo(statement.getUpdateInfo()); return createDDLDefinitionExecution(builtInQueryPreparer.getStatement(), builtInQueryPreparer.getParameters(), stateMachine, slug, retryCount, query); } @@ -117,7 +120,6 @@ private DDLDefinitionExecution createDDLDefinitionExecu DDLDefinitionTask task = (DDLDefinitionTask) tasks.get(statement.getClass()); checkArgument(task != null, "no task for statement: %s", statement.getClass().getSimpleName()); - stateMachine.setUpdateType(task.getName()); return new DDLDefinitionExecution<>(task, statement, slug, retryCount, transactionManager, metadata, accessControl, stateMachine, parameters, query); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java index 49c22ac63e363..abffda10d555c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionExecution.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.memory.VersionedMemoryPoolId; @@ -28,9 +29,7 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; @@ -124,6 +123,12 @@ public long getCreateTimeInMillis() return stateMachine.getCreateTimeInMillis(); } + @Override + public Duration getQueuedTime() + { + return stateMachine.getQueuedTime(); + } + @Override public long getExecutionStartTimeInMillis() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionTask.java index a8b6aea7b7f08..ed806bfcf2959 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DataDefinitionTask.java @@ -13,12 +13,19 @@ */ package com.facebook.presto.execution; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.SqlFormatter; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.Statement; import java.util.List; +import java.util.Map; import java.util.Optional; public interface DataDefinitionTask @@ -33,4 +40,9 @@ default String explain(T statement, List parameters) return SqlFormatter.formatSql(statement, Optional.of(parameters)); } + + default void queryPermissionCheck(AccessControl accessControl, Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) + { + accessControl.checkQueryIntegrity(identity, context, query, preparedStatements, viewDefinitions, materializedViewDefinitions); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropBranchTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropBranchTask.java new file mode 100644 index 0000000000000..17d9cf2e4e3aa --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropBranchTask.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.DropBranch; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; +import static com.google.common.util.concurrent.Futures.immediateFuture; + +public class DropBranchTask + implements DDLDefinitionTask +{ + @Override + public String getName() + { + return "DROP BRANCH"; + } + + @Override + public ListenableFuture execute(DropBranch statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) + { + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); + Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); + + if (!tableHandleOptional.isPresent()) { + if (!statement.isTableExists()) { + throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName); + } + return immediateFuture(null); + } + + Optional optionalMaterializedView = metadata.getMetadataResolver(session).getMaterializedView(tableName); + if (optionalMaterializedView.isPresent()) { + throw new SemanticException(NOT_SUPPORTED, statement, "'%s' is a materialized view, and drop branch is not supported", tableName); + } + + getConnectorIdOrThrow(session, metadata, tableName.getCatalogName()); + accessControl.checkCanDropBranch(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); + + metadata.dropBranch(session, tableHandleOptional.get(), statement.getBranchName(), statement.isBranchExists()); + return immediateFuture(null); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropColumnTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropColumnTask.java index 1ced892f6e01f..1414786384b8e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropColumnTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropColumnTask.java @@ -24,6 +24,7 @@ import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; @@ -48,7 +49,7 @@ public String getName() @Override public ListenableFuture execute(DropColumn statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable(), metadata); Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandleOptional.isPresent()) { @@ -67,7 +68,8 @@ public ListenableFuture execute(DropColumn statement, TransactionManager tran } TableHandle tableHandle = tableHandleOptional.get(); - String column = statement.getColumn().getValueLowerCase(); + Identifier columnIdentifier = statement.getColumn(); + String column = metadata.normalizeIdentifier(session, tableName.getCatalogName(), columnIdentifier.getValue()); accessControl.checkCanDropColumn(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropConstraintTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropConstraintTask.java index 4e0986d93a9df..562e4ceac1c74 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropConstraintTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropConstraintTask.java @@ -49,7 +49,7 @@ public String getName() @Override public ListenableFuture execute(DropConstraint statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandleOptional.isPresent()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropFunctionTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropFunctionTask.java index a4af6fe23ce28..ed5d4224e9286 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropFunctionTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropFunctionTask.java @@ -16,8 +16,14 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.DropFunction; @@ -26,8 +32,7 @@ import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -35,6 +40,7 @@ import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; import static com.facebook.presto.sql.analyzer.utils.ParameterUtils.parameterExtractor; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFuture; import static java.lang.String.format; @@ -68,8 +74,10 @@ public String explain(DropFunction statement, List parameters) public ListenableFuture execute(DropFunction statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, QueryStateMachine stateMachine, List parameters, String query) { Map, Expression> parameterLookup = parameterExtractor(statement, parameters); - Analyzer analyzer = new Analyzer(stateMachine.getSession(), metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, stateMachine.getWarningCollector(), query); - analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(stateMachine.getSession(), metadata, sqlParser, accessControl, Optional.empty(), parameters, parameterLookup, stateMachine.getWarningCollector(), query, new ViewDefinitionReferences()); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + checkAccessPermissions(analysis.getAccessControlReferences(), analysis.getViewDefinitionReferences(), query, stateMachine.getSession().getPreparedStatements(), stateMachine.getSession().getIdentity(), accessControl, stateMachine.getSession().getAccessControlContext()); + Optional> parameterTypes = statement.getParameterTypes().map(types -> types.stream().map(TypeSignature::parseTypeSignature).collect(toImmutableList())); if (statement.isTemporary()) { @@ -88,4 +96,7 @@ public ListenableFuture execute(DropFunction statement, TransactionManager tr return immediateFuture(null); } + + @Override + public void queryPermissionCheck(AccessControl accessControl, Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) {} } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropMaterializedViewTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropMaterializedViewTask.java index 2fa0ec25e5107..283a70d86bb99 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropMaterializedViewTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropMaterializedViewTask.java @@ -44,7 +44,7 @@ public String getName() @Override public ListenableFuture execute(DropMaterializedView statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName(), metadata); Optional view = metadata.getMetadataResolver(session).getMaterializedView(name); if (!view.isPresent()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropSchemaTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropSchemaTask.java index ff5084e7933d6..7df311897e67d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropSchemaTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropSchemaTask.java @@ -55,7 +55,7 @@ public ListenableFuture execute(DropSchema statement, TransactionManager tran throw new PrestoException(NOT_SUPPORTED, "CASCADE is not yet supported for DROP SCHEMA"); } - CatalogSchemaName schema = createCatalogSchemaName(session, statement, Optional.of(statement.getSchemaName())); + CatalogSchemaName schema = createCatalogSchemaName(session, statement, Optional.of(statement.getSchemaName()), metadata); if (!metadata.getMetadataResolver(session).schemaExists(schema)) { if (!statement.isExists()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropTableTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropTableTask.java index e2931ab899559..463cd226e854a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropTableTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropTableTask.java @@ -46,7 +46,7 @@ public String getName() @Override public ListenableFuture execute(DropTable statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropTagTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropTagTask.java new file mode 100644 index 0000000000000..c49411d395bdd --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropTagTask.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.DropTag; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_TABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; +import static com.google.common.util.concurrent.Futures.immediateFuture; + +public class DropTagTask + implements DDLDefinitionTask +{ + @Override + public String getName() + { + return "DROP TAG"; + } + + @Override + public ListenableFuture execute(DropTag statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) + { + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); + Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); + + if (!tableHandleOptional.isPresent()) { + if (!statement.isTableExists()) { + throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName); + } + return immediateFuture(null); + } + + Optional optionalMaterializedView = metadata.getMetadataResolver(session).getMaterializedView(tableName); + if (optionalMaterializedView.isPresent()) { + throw new SemanticException(NOT_SUPPORTED, statement, "'%s' is a materialized view, and drop tag is not supported", tableName); + } + + getConnectorIdOrThrow(session, metadata, tableName.getCatalogName()); + accessControl.checkCanDropTag(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); + + metadata.dropTag(session, tableHandleOptional.get(), statement.getTagName(), statement.isTagExists()); + return immediateFuture(null); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/DropViewTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/DropViewTask.java index 2e3d1099caf5b..65c64aa88349e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/DropViewTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/DropViewTask.java @@ -44,7 +44,7 @@ public String getName() @Override public ListenableFuture execute(DropView statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName()); + QualifiedObjectName name = createQualifiedObjectName(session, statement, statement.getName(), metadata); Optional view = metadata.getMetadataResolver(session).getView(name); if (!view.isPresent()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/EagerPlanValidationExecutionMBean.java b/presto-main-base/src/main/java/com/facebook/presto/execution/EagerPlanValidationExecutionMBean.java index ad77325e42635..bccf77145a24f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/EagerPlanValidationExecutionMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/EagerPlanValidationExecutionMBean.java @@ -14,11 +14,10 @@ package com.facebook.presto.execution; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ExecutionFailureInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ExecutionFailureInfo.java index c9da4a4ba2196..760d8f14761f2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ExecutionFailureInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ExecutionFailureInfo.java @@ -25,9 +25,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.regex.Matcher; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ExplainAnalyzeContext.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ExplainAnalyzeContext.java index bdc1540f8c9f8..92502c7c733e9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ExplainAnalyzeContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ExplainAnalyzeContext.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/Failure.java b/presto-main-base/src/main/java/com/facebook/presto/execution/Failure.java index 3ef50d6dbde9b..a2ed242f52557 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/Failure.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/Failure.java @@ -14,8 +14,7 @@ package com.facebook.presto.execution; import com.facebook.presto.common.ErrorCode; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ForEagerPlanValidation.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ForEagerPlanValidation.java index 5193341a73cc1..782ce58f6094e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ForEagerPlanValidation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ForEagerPlanValidation.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ForQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ForQueryExecution.java index 3dc08c52fa808..0a8c738cb6aae 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ForQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ForQueryExecution.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ForTimeoutThread.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ForTimeoutThread.java index ae0a584e14e49..faea2557b0294 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ForTimeoutThread.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ForTimeoutThread.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/FragmentResultCacheContext.java b/presto-main-base/src/main/java/com/facebook/presto/execution/FragmentResultCacheContext.java index 819ad6f391e84..984eecb4ab8fc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/FragmentResultCacheContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/FragmentResultCacheContext.java @@ -24,9 +24,9 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.sql.planner.CanonicalPlanFragment; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableSet; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/FutureStateChange.java b/presto-main-base/src/main/java/com/facebook/presto/execution/FutureStateChange.java index 1a3a3d75e69b4..28765f7d4fe23 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/FutureStateChange.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/FutureStateChange.java @@ -16,9 +16,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.HashSet; import java.util.Set; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/GrantTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/GrantTask.java index 084161aa75c23..952f875bb6075 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/GrantTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/GrantTask.java @@ -50,7 +50,7 @@ public String getName() @Override public ListenableFuture execute(Grant statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/Input.java b/presto-main-base/src/main/java/com/facebook/presto/execution/Input.java index 667fda883f366..4ec3ebfdce47e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/Input.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/Input.java @@ -18,8 +18,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Objects; @@ -38,9 +37,9 @@ public final class Input private final Optional connectorInfo; private final Optional statistics; - // This field records the metastore commit info about this input. + // This field stores any connector specific metadata about the commit // E.g., the last data commit time for the input partitions. - private final String serializedCommitOutput; + private final Optional commitOutput; @JsonCreator public Input( @@ -50,7 +49,7 @@ public Input( @JsonProperty("connectorInfo") Optional connectorInfo, @JsonProperty("columns") List columns, @JsonProperty("statistics") Optional statistics, - @JsonProperty("serializedCommitOutput") String serializedCommitOutput) + @JsonProperty("commitOutput") Optional commitOutput) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.schema = requireNonNull(schema, "schema is null"); @@ -58,7 +57,7 @@ public Input( this.connectorInfo = requireNonNull(connectorInfo, "connectorInfo is null"); this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); this.statistics = requireNonNull(statistics, "table statistics is null"); - this.serializedCommitOutput = requireNonNull(serializedCommitOutput, "serializedCommitOutput is null"); + this.commitOutput = requireNonNull(commitOutput, "commitOutput is null"); } @JsonProperty @@ -98,9 +97,9 @@ public Optional getStatistics() } @JsonProperty - public String getSerializedCommitOutput() + public Optional getCommitOutput() { - return serializedCommitOutput; + return commitOutput; } @Override @@ -119,13 +118,13 @@ public boolean equals(Object o) Objects.equals(columns, input.columns) && Objects.equals(connectorInfo, input.connectorInfo) && Objects.equals(statistics, input.statistics) && - Objects.equals(serializedCommitOutput, input.serializedCommitOutput); + Objects.equals(commitOutput, input.commitOutput); } @Override public int hashCode() { - return Objects.hash(connectorId, schema, table, columns, connectorInfo, statistics, serializedCommitOutput); + return Objects.hash(connectorId, schema, table, columns, connectorInfo, statistics, commitOutput); } @Override @@ -137,7 +136,7 @@ public String toString() .addValue(table) .addValue(columns) .addValue(statistics) - .addValue(serializedCommitOutput) + .addValue(commitOutput) .toString(); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/Location.java b/presto-main-base/src/main/java/com/facebook/presto/execution/Location.java index ed5e59718992c..fc93ee63c9458 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/Location.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/Location.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.execution; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -20,17 +23,20 @@ import static java.util.Objects.requireNonNull; +@ThriftStruct public class Location { private final String location; @JsonCreator + @ThriftConstructor public Location(@JsonProperty("location") String location) { this.location = requireNonNull(location, "location is null"); } @JsonProperty + @ThriftField(1) public String getLocation() { return location; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/ManagedQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/ManagedQueryExecution.java index 31caf9ae91280..217e12c197601 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/ManagedQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/ManagedQueryExecution.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.server.BasicQueryInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; -import io.airlift.units.Duration; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/NodeResourceStatusConfig.java b/presto-main-base/src/main/java/com/facebook/presto/execution/NodeResourceStatusConfig.java index d58f0a0caa114..e7be9804fff31 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/NodeResourceStatusConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/NodeResourceStatusConfig.java @@ -15,8 +15,7 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; - -import javax.validation.constraints.Min; +import jakarta.validation.constraints.Min; public class NodeResourceStatusConfig { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/NodeTaskMap.java b/presto-main-base/src/main/java/com/facebook/presto/execution/NodeTaskMap.java index ba015428bfe52..5213f35210c40 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/NodeTaskMap.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/NodeTaskMap.java @@ -18,9 +18,8 @@ import com.facebook.presto.util.FinalizerService; import com.google.common.collect.Sets; import com.google.common.util.concurrent.AtomicDouble; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/Output.java b/presto-main-base/src/main/java/com/facebook/presto/execution/Output.java index 74f633b871407..3d0d1621d0546 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/Output.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/Output.java @@ -14,11 +14,11 @@ package com.facebook.presto.execution; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Objects; @@ -32,22 +32,22 @@ public final class Output private final ConnectorId connectorId; private final String schema; private final String table; - private final String serializedCommitOutput; - private final Optional> columns; + private final Optional> columns; + private final Optional commitOutput; @JsonCreator public Output( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("schema") String schema, @JsonProperty("table") String table, - @JsonProperty("serializedCommitOutput") String serializedCommitOutput, - @JsonProperty("columns") Optional> columns) + @JsonProperty("columns") Optional> columns, + @JsonProperty("commitOutput") Optional commitOutput) { this.connectorId = requireNonNull(connectorId, "connectorId is null"); this.schema = requireNonNull(schema, "schema is null"); this.table = requireNonNull(table, "table is null"); - this.serializedCommitOutput = requireNonNull(serializedCommitOutput, "connectorCommitOutput is null"); this.columns = columns.map(ImmutableList::copyOf); + this.commitOutput = requireNonNull(commitOutput, "commitOutput is null"); } @JsonProperty @@ -69,15 +69,15 @@ public String getTable() } @JsonProperty - public String getSerializedCommitOutput() + public Optional> getColumns() { - return serializedCommitOutput; + return columns; } @JsonProperty - public Optional> getColumns() + public Optional getCommitOutput() { - return columns; + return commitOutput; } @Override @@ -93,13 +93,13 @@ public boolean equals(Object o) return Objects.equals(connectorId, output.connectorId) && Objects.equals(schema, output.schema) && Objects.equals(table, output.table) && - Objects.equals(serializedCommitOutput, output.serializedCommitOutput) && - Objects.equals(columns, output.columns); + Objects.equals(columns, output.columns) && + Objects.equals(commitOutput, output.commitOutput); } @Override public int hashCode() { - return Objects.hash(connectorId, schema, table, serializedCommitOutput, columns); + return Objects.hash(connectorId, schema, table, columns, commitOutput); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/PartialResultQueryManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/PartialResultQueryManager.java index 4409340deabfc..f125a9be24807 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/PartialResultQueryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/PartialResultQueryManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.execution.scheduler.PartialResultQueryTaskTracker; import com.google.inject.Inject; - -import javax.annotation.PreDestroy; +import jakarta.annotation.PreDestroy; import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/PrepareTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/PrepareTask.java index c9efb3e866147..c434c553cc7c1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/PrepareTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/PrepareTask.java @@ -24,8 +24,7 @@ import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecution.java index af23ce6bc9791..35c9db00bcfba 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecution.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.type.Type; @@ -28,7 +29,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.net.URI; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecutionMBean.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecutionMBean.java index 3faa63f2c906d..b8b765a6bff57 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecutionMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryExecutionMBean.java @@ -14,11 +14,10 @@ package com.facebook.presto.execution; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryIdGenerator.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryIdGenerator.java index 425e1af8fd18a..ecd8f20d6c45d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryIdGenerator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryIdGenerator.java @@ -18,11 +18,10 @@ import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Chars; import com.google.common.util.concurrent.Uninterruptibles; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; -import javax.annotation.concurrent.GuardedBy; - import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryInfo.java index 0a55cb79ea4a0..7ad6d4159ae6b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryInfo.java @@ -21,6 +21,7 @@ import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.eventlistener.CTEInformation; import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation; import com.facebook.presto.spi.function.SqlFunctionId; @@ -37,9 +38,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.List; @@ -77,7 +77,7 @@ public class QueryInfo private final Set deallocatedPreparedStatements; private final Optional startedTransactionId; private final boolean clearTransactionId; - private final String updateType; + private final UpdateInfo updateInfo; private final Optional outputStage; private final ExecutionFailureInfo failureInfo; private final ErrorType errorType; @@ -127,7 +127,7 @@ public QueryInfo( @JsonProperty("deallocatedPreparedStatements") Set deallocatedPreparedStatements, @JsonProperty("startedTransactionId") Optional startedTransactionId, @JsonProperty("clearTransactionId") boolean clearTransactionId, - @JsonProperty("updateType") String updateType, + @JsonProperty("updateInfo") UpdateInfo updateInfo, @JsonProperty("outputStage") Optional outputStage, @JsonProperty("failureInfo") ExecutionFailureInfo failureInfo, @JsonProperty("errorCode") ErrorCode errorCode, @@ -206,7 +206,7 @@ public QueryInfo( this.deallocatedPreparedStatements = ImmutableSet.copyOf(deallocatedPreparedStatements); this.startedTransactionId = startedTransactionId; this.clearTransactionId = clearTransactionId; - this.updateType = updateType; + this.updateInfo = updateInfo; this.outputStage = outputStage; this.failureInfo = failureInfo; this.errorType = errorCode == null ? null : errorCode.getType(); @@ -363,9 +363,9 @@ public boolean isClearTransactionId() @Nullable @JsonProperty - public String getUpdateType() + public UpdateInfo getUpdateInfo() { - return updateType; + return updateInfo; } @JsonProperty diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryLimit.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryLimit.java index f638dbdf9251b..690c205631d60 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryLimit.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryLimit.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.Arrays; import java.util.Comparator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManager.java index f5bbec8b358ae..b467275fd1d4a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManager.java @@ -65,6 +65,12 @@ BasicQueryInfo getQueryInfo(QueryId queryId) QueryInfo getFullQueryInfo(QueryId queryId) throws NoSuchElementException; + /** + * @throws NoSuchElementException if query does not exist + */ + long getDurationUntilExpirationInMillis(QueryId queryId) + throws NoSuchElementException; + /** * @throws NoSuchElementException if query does not exist */ diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerConfig.java index 5d73177e47a3f..838f69f9fad87 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerConfig.java @@ -17,21 +17,20 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; import com.facebook.airlift.configuration.LegacyConfig; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDataSize; +import com.facebook.airlift.units.MinDuration; import com.facebook.presto.connector.system.GlobalSystemConnector; import com.facebook.presto.spi.api.Experimental; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDataSize; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Max; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; -import static io.airlift.units.DataSize.Unit.PETABYTE; -import static io.airlift.units.DataSize.Unit.TERABYTE; +import static com.facebook.airlift.units.DataSize.Unit.PETABYTE; +import static com.facebook.airlift.units.DataSize.Unit.TERABYTE; import static java.util.concurrent.TimeUnit.MINUTES; @DefunctConfig({ @@ -74,6 +73,7 @@ public class QueryManagerConfig private String queryExecutionPolicy = "all-at-once"; private Duration queryMaxRunTime = new Duration(100, TimeUnit.DAYS); + private Duration queryMaxQueuedTime = new Duration(100, TimeUnit.DAYS); private Duration queryMaxExecutionTime = new Duration(100, TimeUnit.DAYS); private Duration queryMaxCpuTime = new Duration(1_000_000_000, TimeUnit.DAYS); @@ -102,6 +102,10 @@ public class QueryManagerConfig private int minColumnarEncodingChannelsToPreferRowWiseEncoding = 1000; + private int maxQueryAdmissionsPerSecond = Integer.MAX_VALUE; + + private int minRunningQueriesForPacing = 30; + @Min(1) public int getScheduleSplitBatchSize() { @@ -431,6 +435,19 @@ public QueryManagerConfig setQueryMaxRunTime(Duration queryMaxRunTime) return this; } + @NotNull + public Duration getQueryMaxQueuedTime() + { + return queryMaxQueuedTime; + } + + @Config("query.max-queued-time") + public QueryManagerConfig setQueryMaxQueuedTime(Duration queryMaxQueuedTime) + { + this.queryMaxQueuedTime = queryMaxQueuedTime; + return this; + } + @NotNull public Duration getQueryMaxExecutionTime() { @@ -753,6 +770,34 @@ public QueryManagerConfig setMinColumnarEncodingChannelsToPreferRowWiseEncoding( return this; } + @Min(1) + public int getMaxQueryAdmissionsPerSecond() + { + return maxQueryAdmissionsPerSecond; + } + + @Config("query-manager.query-pacing.max-queries-per-second") + @ConfigDescription("Maximum number of queries that can be admitted per second globally for admission pacing. Default is unlimited (Integer.MAX_VALUE). Set to a lower value (e.g., 1) to pace query admissions to one per second.") + public QueryManagerConfig setMaxQueryAdmissionsPerSecond(int maxQueryAdmissionsPerSecond) + { + this.maxQueryAdmissionsPerSecond = maxQueryAdmissionsPerSecond; + return this; + } + + @Min(0) + public int getMinRunningQueriesForPacing() + { + return minRunningQueriesForPacing; + } + + @Config("query-manager.query-pacing.min-running-queries") + @ConfigDescription("Minimum number of running queries before admission pacing is applied. Default is 30. Set to a higher value to only pace when cluster is busy.") + public QueryManagerConfig setMinRunningQueriesForPacing(int minRunningQueriesForPacing) + { + this.minRunningQueriesForPacing = minRunningQueriesForPacing; + return this; + } + public enum ExchangeMaterializationStrategy { NONE, diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerStats.java index d7cbcfdcb7d2f..79a7385a4e916 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryManagerStats.java @@ -19,11 +19,10 @@ import com.facebook.presto.dispatcher.DispatchQuery; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.server.BasicQueryInfo; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.GuardedBy; - import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java index 01830e0f733e6..8ee8e42ce447a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateMachine.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.common.resourceGroups.QueryType; @@ -32,6 +33,7 @@ import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.connector.ConnectorCommitHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; @@ -56,10 +58,9 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.ArrayList; @@ -77,6 +78,7 @@ import java.util.function.Consumer; import java.util.function.Predicate; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.execution.BasicStageExecutionStats.EMPTY_STAGE_STATS; import static com.facebook.presto.execution.QueryState.DISPATCHING; import static com.facebook.presto.execution.QueryState.FINISHED; @@ -101,7 +103,6 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -155,7 +156,7 @@ public class QueryStateMachine private final AtomicReference startedTransactionId = new AtomicReference<>(); private final AtomicBoolean clearTransactionId = new AtomicBoolean(); - private final AtomicReference updateType = new AtomicReference<>(); + private final AtomicReference updateInfo = new AtomicReference<>(); private final AtomicReference failureCause = new AtomicReference<>(); @@ -388,6 +389,16 @@ public BasicQueryInfo getBasicQueryInfo(Optional rootS stageStats.getRunningDrivers(), stageStats.getCompletedDrivers(), + stageStats.getTotalNewDrivers(), + stageStats.getQueuedNewDrivers(), + stageStats.getRunningNewDrivers(), + stageStats.getCompletedNewDrivers(), + + stageStats.getTotalSplits(), + stageStats.getQueuedSplits(), + stageStats.getRunningSplits(), + stageStats.getCompletedSplits(), + succinctBytes(stageStats.getRawInputDataSizeInBytes()), stageStats.getRawInputPositions(), @@ -482,7 +493,7 @@ public QueryInfo getQueryInfo(Optional rootStage) deallocatedPreparedStatements, Optional.ofNullable(startedTransactionId.get()), clearTransactionId.get(), - updateType.get(), + updateInfo.get(), rootStage, failureCause, errorCode, @@ -606,8 +617,8 @@ private void addSerializedCommitOutputToOutput(ConnectorCommitHandle commitHandl outputInfo.getConnectorId(), outputInfo.getSchema(), outputInfo.getTable(), - commitHandle.getSerializedCommitOutputForWrite(table), - outputInfo.getColumns()))); + outputInfo.getColumns(), + Optional.of(commitHandle.getCommitOutputForWrite(table))))); } private void addSerializedCommitOutputToInputs(List commitHandles) @@ -638,7 +649,7 @@ private Input attachSerializedCommitOutput(Input input, List commitHandles) input.getConnectorInfo(), input.getColumns(), input.getStatistics(), - commitHandle.getSerializedCommitOutputForRead(table)); + Optional.of(commitHandle.getCommitOutputForRead(table))); } } return input; @@ -753,9 +764,9 @@ public void clearTransactionId() clearTransactionId.set(true); } - public void setUpdateType(String updateType) + public void setUpdateInfo(UpdateInfo updateInfo) { - this.updateType.set(updateType); + this.updateInfo.set(updateInfo); } public void setExpandedQuery(Optional expandedQuery) @@ -1014,6 +1025,11 @@ public long getCreateTimeInMillis() return queryStateTimer.getCreateTimeInMillis(); } + public Duration getQueuedTime() + { + return queryStateTimer.getQueuedTime(); + } + public long getExecutionStartTimeInMillis() { return queryStateTimer.getExecutionStartTimeInMillis(); @@ -1122,7 +1138,7 @@ private static QueryInfo pruneFinishedQueryInfo(QueryInfo queryInfo, Set queryInfo.getDeallocatedPreparedStatements(), queryInfo.getStartedTransactionId(), queryInfo.isClearTransactionId(), - queryInfo.getUpdateType(), + queryInfo.getUpdateInfo(), queryInfo.getOutputStage().map(QueryStateMachine::pruneStatsFromStageInfo), queryInfo.getFailureInfo(), queryInfo.getErrorCode(), @@ -1161,7 +1177,7 @@ private static Set pruneInputHistograms(Set inputs) .setHistogram(Optional.empty()) .build()))) .build()), - input.getSerializedCommitOutput())) + input.getCommitOutput())) .collect(toImmutableSet()); } @@ -1196,6 +1212,7 @@ private static StageInfo pruneStatsFromStageInfo(StageInfo stage) plan.getPartitioning(), plan.getTableScanSchedulingOrder(), plan.getPartitioningScheme(), + plan.getOutputOrderingScheme(), plan.getStageExecutionDescriptor(), plan.isOutputTableWriterFragment(), plan.getStatsAndCosts().map(QueryStateMachine::pruneHistogramsFromStatsAndCosts), @@ -1240,7 +1257,7 @@ private static QueryInfo pruneExpiredQueryInfo(QueryInfo queryInfo, VersionedMem queryInfo.getDeallocatedPreparedStatements(), queryInfo.getStartedTransactionId(), queryInfo.isClearTransactionId(), - queryInfo.getUpdateType(), + queryInfo.getUpdateInfo(), prunedOutputStage, queryInfo.getFailureInfo(), queryInfo.getErrorCode(), @@ -1302,6 +1319,14 @@ private static QueryStats pruneQueryStats(QueryStats queryStats) queryStats.getRunningDrivers(), queryStats.getBlockedDrivers(), queryStats.getCompletedDrivers(), + queryStats.getTotalNewDrivers(), + queryStats.getQueuedNewDrivers(), + queryStats.getRunningNewDrivers(), + queryStats.getCompletedNewDrivers(), + queryStats.getTotalSplits(), + queryStats.getQueuedSplits(), + queryStats.getRunningSplits(), + queryStats.getCompletedSplits(), queryStats.getCumulativeUserMemory(), queryStats.getCumulativeTotalMemory(), queryStats.getUserMemoryReservation(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateTimer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateTimer.java index 36da1ab1d57ad..7629b7f4d52c3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateTimer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStateTimer.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.google.common.base.Ticker; -import io.airlift.units.Duration; import java.util.concurrent.atomic.AtomicReference; -import static io.airlift.units.Duration.succinctNanos; +import static com.facebook.airlift.units.Duration.succinctNanos; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStats.java index 59eb42db30f35..1af44c2a98eb6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryStats.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.operator.BlockedReason; import com.facebook.presto.operator.ExchangeOperator; @@ -27,22 +29,19 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.annotation.Nullable; import org.joda.time.DateTime; -import javax.annotation.Nullable; - import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.OptionalDouble; import java.util.Set; +import static com.facebook.airlift.units.DataSize.succinctBytes; +import static com.facebook.airlift.units.Duration.succinctDuration; import static com.facebook.presto.util.DateTimeUtils.toTimeStampInMillis; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.succinctBytes; -import static io.airlift.units.Duration.succinctDuration; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -78,6 +77,16 @@ public class QueryStats private final int blockedDrivers; private final int completedDrivers; + private final int totalNewDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + + private final int totalSplits; + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; + private final double cumulativeUserMemory; private final double cumulativeTotalMemory; private final DataSize userMemoryReservation; @@ -152,6 +161,16 @@ public QueryStats( int blockedDrivers, int completedDrivers, + int totalNewDrivers, + int queuedNewDrivers, + int runningNewDrivers, + int completedNewDrivers, + + int totalSplits, + int queuedSplits, + int runningSplits, + int completedSplits, + double cumulativeUserMemory, double cumulativeTotalMemory, DataSize userMemoryReservation, @@ -234,6 +253,22 @@ public QueryStats( this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; + checkArgument(totalNewDrivers >= 0, "totalNewDrivers is negative"); + this.totalNewDrivers = totalNewDrivers; + checkArgument(queuedNewDrivers >= 0, "queuedNewDrivers is negative"); + this.queuedNewDrivers = queuedNewDrivers; + checkArgument(runningNewDrivers >= 0, "runningNewDrivers is negative"); + this.runningNewDrivers = runningNewDrivers; + checkArgument(completedNewDrivers >= 0, "completedNewDrivers is negative"); + this.completedNewDrivers = completedNewDrivers; + checkArgument(totalSplits >= 0, "totalSplits is negative"); + this.totalSplits = totalSplits; + checkArgument(queuedSplits >= 0, "queuedSplits is negative"); + this.queuedSplits = queuedSplits; + checkArgument(runningSplits >= 0, "runningSplits is negative"); + this.runningSplits = runningSplits; + checkArgument(completedSplits >= 0, "completedSplits is negative"); + this.completedSplits = completedSplits; checkArgument(cumulativeUserMemory >= 0, "cumulativeUserMemory is negative"); this.cumulativeUserMemory = cumulativeUserMemory; checkArgument(cumulativeTotalMemory >= 0, "cumulativeTotalMemory is negative"); @@ -314,6 +349,16 @@ public QueryStats( @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, + @JsonProperty("totalNewDrivers") int totalNewDrivers, + @JsonProperty("queuedNewDrivers") int queuedNewDrivers, + @JsonProperty("runningNewDrivers") int runningNewDrivers, + @JsonProperty("completedNewDrivers") int completedNewDrivers, + + @JsonProperty("totalSplits") int totalSplits, + @JsonProperty("queuedSplits") int queuedSplits, + @JsonProperty("runningSplits") int runningSplits, + @JsonProperty("completedSplits") int completedSplits, + @JsonProperty("cumulativeUserMemory") double cumulativeUserMemory, @JsonProperty("cumulativeTotalMemory") double cumulativeTotalMemory, @JsonProperty("userMemoryReservation") DataSize userMemoryReservation, @@ -386,6 +431,16 @@ public QueryStats( blockedDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + + totalSplits, + queuedSplits, + runningSplits, + completedSplits, + cumulativeUserMemory, cumulativeTotalMemory, userMemoryReservation, @@ -452,6 +507,16 @@ public static QueryStats create( int blockedDrivers = 0; int completedDrivers = 0; + int totalNewDrivers = 0; + int queuedNewDrivers = 0; + int runningNewDrivers = 0; + int completedNewDrivers = 0; + + int totalSplits = 0; + int queuedSplits = 0; + int runningSplits = 0; + int completedSplits = 0; + double cumulativeUserMemory = 0; double cumulativeTotalMemory = 0; long userMemoryReservation = 0; @@ -501,6 +566,16 @@ public static QueryStats create( blockedDrivers += stageExecutionStats.getBlockedDrivers(); completedDrivers += stageExecutionStats.getCompletedDrivers(); + totalNewDrivers += stageExecutionStats.getTotalNewDrivers(); + queuedNewDrivers += stageExecutionStats.getQueuedNewDrivers(); + runningNewDrivers += stageExecutionStats.getRunningNewDrivers(); + completedNewDrivers += stageExecutionStats.getCompletedNewDrivers(); + + totalSplits += stageExecutionStats.getTotalSplits(); + queuedSplits += stageExecutionStats.getQueuedSplits(); + runningSplits += stageExecutionStats.getRunningSplits(); + completedSplits += stageExecutionStats.getCompletedSplits(); + cumulativeUserMemory += stageExecutionStats.getCumulativeUserMemory(); cumulativeTotalMemory += stageExecutionStats.getCumulativeTotalMemory(); userMemoryReservation += stageExecutionStats.getUserMemoryReservationInBytes(); @@ -521,7 +596,7 @@ public static QueryStats create( for (OperatorStats operatorStats : stageExecutionStats.getOperatorSummaries()) { // NOTE: we need to literally check each operator type to tell if the source is from table input or shuffled input. A stage can have input from both types of source. String operatorType = operatorStats.getOperatorType(); - if (operatorType.equals(ExchangeOperator.class.getSimpleName()) || operatorType.equals(MergeOperator.class.getSimpleName())) { + if (operatorType.equals(ExchangeOperator.class.getSimpleName()) || operatorType.equals(MergeOperator.class.getSimpleName()) || operatorType.equals("PrestoSparkRemoteSourceOperator") || operatorType.equals("ShuffleRead")) { shuffledPositions += operatorStats.getRawInputPositions(); shuffledDataSize += operatorStats.getRawInputDataSizeInBytes(); } @@ -600,6 +675,16 @@ else if (operatorType.equals(TableScanOperator.class.getSimpleName()) || operato blockedDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + + totalSplits, + queuedSplits, + runningSplits, + completedSplits, + cumulativeUserMemory, cumulativeTotalMemory, succinctBytes(userMemoryReservation), @@ -681,6 +766,14 @@ public static QueryStats immediateFailureQueryStats() 0, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, succinctBytes(0), succinctBytes(0), succinctBytes(0), @@ -878,6 +971,54 @@ public int getCompletedDrivers() return completedDrivers; } + @JsonProperty + public int getTotalNewDrivers() + { + return totalNewDrivers; + } + + @JsonProperty + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + @JsonProperty + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + @JsonProperty + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } + + @JsonProperty + public int getTotalSplits() + { + return totalSplits; + } + + @JsonProperty + public int getQueuedSplits() + { + return queuedSplits; + } + + @JsonProperty + public int getRunningSplits() + { + return runningSplits; + } + + @JsonProperty + public int getCompletedSplits() + { + return completedSplits; + } + @JsonProperty public double getCumulativeUserMemory() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryTracker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryTracker.java index d6290b664cf2a..142c71485c78b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/QueryTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/QueryTracker.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.execution.QueryTracker.TrackedQuery; import com.facebook.presto.resourcemanager.ClusterQueryTrackerService; @@ -22,10 +23,8 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.Collection; import java.util.NoSuchElementException; @@ -43,6 +42,7 @@ import static com.facebook.presto.SystemSessionProperties.getQueryClientTimeout; import static com.facebook.presto.SystemSessionProperties.getQueryMaxExecutionTime; +import static com.facebook.presto.SystemSessionProperties.getQueryMaxQueuedTime; import static com.facebook.presto.SystemSessionProperties.getQueryMaxRunTime; import static com.facebook.presto.execution.QueryLimit.Source.QUERY; import static com.facebook.presto.execution.QueryLimit.Source.RESOURCE_GROUP; @@ -211,15 +211,17 @@ public long getQueriesKilledDueToTooManyTask() } /** - * Enforce query max runtime/execution time limits + * Enforce query max runtime/queued/execution time limits */ - private void enforceTimeLimits() + @VisibleForTesting + void enforceTimeLimits() { for (T query : queries.values()) { if (query.isDone()) { continue; } Duration queryMaxRunTime = getQueryMaxRunTime(query.getSession()); + Duration queryMaxQueuedTime = getQueryMaxQueuedTime(query.getSession()); QueryLimit queryMaxExecutionTime = getMinimum( createDurationLimit(getQueryMaxExecutionTime(query.getSession()), QUERY), query.getResourceGroupQueryLimits() @@ -227,6 +229,10 @@ private void enforceTimeLimits() .map(rgLimit -> createDurationLimit(rgLimit, RESOURCE_GROUP)).orElse(null)); long executionStartTime = query.getExecutionStartTimeInMillis(); long createTimeInMillis = query.getCreateTimeInMillis(); + long queuedTimeInMillis = query.getQueuedTime().toMillis(); + if (queuedTimeInMillis > queryMaxQueuedTime.toMillis()) { + query.fail(new PrestoException(EXCEEDED_TIME_LIMIT, "Query exceeded maximum queued time limit of " + queryMaxQueuedTime)); + } if (executionStartTime > 0 && (executionStartTime + queryMaxExecutionTime.getLimit().toMillis()) < currentTimeMillis()) { query.fail( new PrestoException(EXCEEDED_TIME_LIMIT, @@ -399,12 +405,21 @@ public interface TrackedQuery long getCreateTimeInMillis(); + Duration getQueuedTime(); + long getExecutionStartTimeInMillis(); long getLastHeartbeatInMillis(); long getEndTimeInMillis(); + default long getDurationUntilExpirationInMillis() + { + Duration queryClientTimeout = getQueryClientTimeout(getSession()); + long expireTime = getLastHeartbeatInMillis() + queryClientTimeout.toMillis(); + return Math.max(0, expireTime - currentTimeMillis()); + } + Optional getResourceGroupQueryLimits(); void fail(Throwable cause); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameColumnTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameColumnTask.java index 8c63ab92f96cf..6488deb47cf89 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameColumnTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameColumnTask.java @@ -23,6 +23,7 @@ import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.RenameColumn; import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; @@ -50,7 +51,7 @@ public String getName() @Override public ListenableFuture execute(RenameColumn statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTable(), metadata); Optional tableHandleOptional = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandleOptional.isPresent()) { if (!statement.isTableExists()) { @@ -69,8 +70,10 @@ public ListenableFuture execute(RenameColumn statement, TransactionManager tr TableHandle tableHandle = tableHandleOptional.get(); - String source = statement.getSource().getValueLowerCase(); - String target = statement.getTarget().getValueLowerCase(); + Identifier sourceName = statement.getSource(); + String source = metadata.normalizeIdentifier(session, tableName.getCatalogName(), sourceName.getValue()); + Identifier targetName = statement.getTarget(); + String target = metadata.normalizeIdentifier(session, tableName.getCatalogName(), targetName.getValue()); accessControl.checkCanRenameColumn(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameSchemaTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameSchemaTask.java index 3ec6ee641e694..c105a452bf1ea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameSchemaTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameSchemaTask.java @@ -45,7 +45,7 @@ public String getName() @Override public ListenableFuture execute(RenameSchema statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - CatalogSchemaName source = createCatalogSchemaName(session, statement, Optional.of(statement.getSource())); + CatalogSchemaName source = createCatalogSchemaName(session, statement, Optional.of(statement.getSource()), metadata); CatalogSchemaName target = new CatalogSchemaName(source.getCatalogName(), statement.getTarget().getValue()); MetadataResolver metadataResolver = metadata.getMetadataResolver(session); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameTableTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameTableTask.java index 1fa57877f0cc8..b05f7950063b2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameTableTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameTableTask.java @@ -49,7 +49,7 @@ public String getName() @Override public ListenableFuture execute(RenameTable statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getSource()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getSource(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { if (!statement.isExists()) { @@ -66,7 +66,7 @@ public ListenableFuture execute(RenameTable statement, TransactionManager tra return immediateFuture(null); } - QualifiedObjectName target = createQualifiedObjectName(session, statement, statement.getTarget()); + QualifiedObjectName target = createQualifiedObjectName(session, statement, statement.getTarget(), metadata); getConnectorIdOrThrow(session, metadata, target.getCatalogName(), statement, targetTableCatalogError); if (metadata.getMetadataResolver(session).getTableHandle(target).isPresent()) { throw new SemanticException(TABLE_ALREADY_EXISTS, statement, "Target table '%s' already exists", target); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameViewTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameViewTask.java index b105f9a936ced..3408adc7b8ca4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/RenameViewTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/RenameViewTask.java @@ -46,7 +46,7 @@ public String getName() public ListenableFuture execute(RenameView statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName viewName = createQualifiedObjectName(session, statement, statement.getSource()); + QualifiedObjectName viewName = createQualifiedObjectName(session, statement, statement.getSource(), metadata); Optional view = metadata.getMetadataResolver(session).getView(viewName); if (!view.isPresent()) { @@ -56,7 +56,7 @@ public ListenableFuture execute(RenameView statement, TransactionManager tran return immediateFuture(null); } - QualifiedObjectName target = createQualifiedObjectName(session, statement, statement.getTarget()); + QualifiedObjectName target = createQualifiedObjectName(session, statement, statement.getTarget(), metadata); if (!metadata.getCatalogHandle(session, target.getCatalogName()).isPresent()) { throw new SemanticException(MISSING_CATALOG, statement, "Target catalog '%s' does not exist", target.getCatalogName()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/RevokeTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/RevokeTask.java index 601f7cd2a0afe..651b781235708 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/RevokeTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/RevokeTask.java @@ -50,7 +50,7 @@ public String getName() @Override public ListenableFuture execute(Revoke statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(tableName); if (!tableHandle.isPresent()) { throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SafeEventLoopGroup.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SafeEventLoopGroup.java index 8032af249ef4a..bca2888c6b395 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SafeEventLoopGroup.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SafeEventLoopGroup.java @@ -26,7 +26,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.ThreadFactory; import java.util.function.Consumer; -import java.util.function.Supplier; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -75,7 +74,7 @@ protected void run() runTask(task); } catch (Throwable t) { - log.error("Error running task on event loop", t); + log.error(t, "Error executing task on event loop"); } updateLastExecutionTime(); } @@ -86,24 +85,19 @@ protected void run() public void execute(Runnable task, Consumer failureHandler, SchedulerStatsTracker statsTracker, String methodSignature) { requireNonNull(task, "task is null"); - - long initialGCTime = getTotalGCTime(); - long start = THREAD_MX_BEAN.getCurrentThreadCpuTime(); - this.execute(() -> { + long start = THREAD_MX_BEAN.getCurrentThreadCpuTime(); try { task.run(); } catch (Throwable t) { - log.error("Error executing task on event loop", t); + log.error(t, "Error executing method %s on event loop.", methodSignature); if (failureHandler != null) { failureHandler.accept(t); } } finally { - long currentGCTime = getTotalGCTime(); - long cpuTimeInNanos = THREAD_MX_BEAN.getCurrentThreadCpuTime() - start - (currentGCTime - initialGCTime); - + long cpuTimeInNanos = THREAD_MX_BEAN.getCurrentThreadCpuTime() - start; statsTracker.recordEventLoopMethodExecutionCpuTime(cpuTimeInNanos); if (slowMethodThresholdOnEventLoopInNanos > 0 && cpuTimeInNanos > slowMethodThresholdOnEventLoopInNanos) { log.warn("Slow method execution on event loop: %s took %s milliseconds", methodSignature, NANOSECONDS.toMillis(cpuTimeInNanos)); @@ -111,29 +105,5 @@ public void execute(Runnable task, Consumer failureHandler, Scheduler } }); } - - public void execute(Supplier task, Consumer successHandler, Consumer failureHandler) - { - requireNonNull(task, "task is null"); - this.execute(() -> { - try { - T result = task.get(); - if (successHandler != null) { - successHandler.accept(result); - } - } - catch (Throwable t) { - log.error("Error executing task on event loop", t); - if (failureHandler != null) { - failureHandler.accept(t); - } - } - }); - } - } - - private long getTotalGCTime() - { - return gcBeans.stream().mapToLong(GarbageCollectorMXBean::getCollectionTime).sum(); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SchedulerStatsTracker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SchedulerStatsTracker.java index 6a90b597fe44e..5abb0ef8aa36f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SchedulerStatsTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SchedulerStatsTracker.java @@ -28,10 +28,25 @@ public void recordTaskPlanSerializedCpuTime(long nanos) {} @Override public void recordEventLoopMethodExecutionCpuTime(long nanos) {} + + @Override + public void recordDeliveredUpdates(int updates) {} + + @Override + public void recordRoundTripTime(long nanos) {} + + @Override + public void recordStartWaitForEventLoop(long nanos) {} }; void recordTaskUpdateDeliveredTime(long nanos); + void recordDeliveredUpdates(int updates); + + void recordRoundTripTime(long nanos); + + void recordStartWaitForEventLoop(long nanos); + void recordTaskUpdateSerializedCpuTime(long nanos); void recordTaskPlanSerializedCpuTime(long nanos); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SessionDefinitionExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SessionDefinitionExecution.java index 09a3bc4369fc6..f0c88ab90a6eb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SessionDefinitionExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SessionDefinitionExecution.java @@ -19,14 +19,15 @@ import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerProvider; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -62,6 +63,7 @@ private SessionDefinitionExecution( @Override protected ListenableFuture executeTask() { + task.queryPermissionCheck(accessControl, stateMachine.getSession().getIdentity(), stateMachine.getSession().getAccessControlContext(), query, stateMachine.getSession().getPreparedStatements(), ImmutableMap.of(), ImmutableMap.of()); return task.execute(statement, transactionManager, metadata, accessControl, stateMachine, parameters, query); } @@ -116,7 +118,7 @@ private SessionDefinitionExecution createSessionDefinit SessionTransactionControlTask task = (SessionTransactionControlTask) tasks.get(statement.getClass()); checkArgument(task != null, "no task for statement: %s", statement.getClass().getSimpleName()); - stateMachine.setUpdateType(task.getName()); + stateMachine.setUpdateInfo(new UpdateInfo(task.getName(), "")); return new SessionDefinitionExecution<>(task, statement, slug, retryCount, transactionManager, metadata, accessControl, stateMachine, parameters, query); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SetPropertiesTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SetPropertiesTask.java index 4793f957f964d..3c5e47b887b04 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SetPropertiesTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SetPropertiesTask.java @@ -51,7 +51,7 @@ public String getName() @Override public ListenableFuture execute(SetProperties statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { - QualifiedObjectName tableName = MetadataUtil.createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = MetadataUtil.createQualifiedObjectName(session, statement, statement.getTableName(), metadata); Map sqlProperties = mapFromProperties(statement.getProperties()); if (statement.getType() == TABLE) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SplitConcurrencyController.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SplitConcurrencyController.java index 3c069390649e4..d7b0a56a4f822 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SplitConcurrencyController.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SplitConcurrencyController.java @@ -13,9 +13,8 @@ */ package com.facebook.presto.execution; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.NotThreadSafe; +import com.facebook.airlift.concurrent.NotThreadSafe; +import com.facebook.airlift.units.Duration; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SplitRunner.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SplitRunner.java index ded9bd28c1d49..590b2a662c187 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SplitRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SplitRunner.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.io.Closeable; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java index 7d03909dd2ca1..dfbfc86ae0f54 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlQueryExecution.java @@ -14,7 +14,9 @@ package com.facebook.presto.execution; import com.facebook.airlift.concurrent.SetThreadName; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; +import com.facebook.presto.common.InvalidFunctionArgumentException; import com.facebook.presto.common.analyzer.PreparedQuery; import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.cost.CostCalculator; @@ -65,10 +67,8 @@ import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -97,6 +97,7 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.execution.buffer.OutputBuffers.createSpoolingOutputBuffers; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; import static com.facebook.presto.sql.planner.PlanNodeCanonicalInfo.getCanonicalInfo; @@ -218,11 +219,11 @@ private SqlQueryExecution( .recordWallAndCpuTime(ANALYZE_TIME_NANOS, () -> queryAnalyzer.analyze(analyzerContext, preparedQuery)); } - stateMachine.setUpdateType(queryAnalysis.getUpdateType()); + stateMachine.setUpdateInfo(queryAnalysis.getUpdateInfo()); stateMachine.setExpandedQuery(queryAnalysis.getExpandedQuery()); stateMachine.beginColumnAccessPermissionChecking(); - checkAccessPermissions(queryAnalysis.getAccessControlReferences(), query); + checkAccessPermissions(queryAnalysis.getAccessControlReferences(), queryAnalysis.getViewDefinitionReferences(), query, getSession().getPreparedStatements(), getSession().getIdentity(), accessControl, getSession().getAccessControlContext()); stateMachine.endColumnAccessPermissionChecking(); // when the query finishes cache the final query info, and clear the reference to the output stage @@ -339,6 +340,12 @@ public long getCreateTimeInMillis() return stateMachine.getCreateTimeInMillis(); } + @Override + public Duration getQueuedTime() + { + return stateMachine.getQueuedTime(); + } + /** * For a query that has started executing, returns the timestamp when this query started executing * Otherwise returns a {@link Optional#empty()} @@ -620,6 +627,9 @@ private PlanRoot doCreateLogicalPlanAndOptimize() catch (StackOverflowError e) { throw new PrestoException(NOT_SUPPORTED, "statement is too large (stack overflow during analysis)", e); } + catch (InvalidFunctionArgumentException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e.getMessage(), e); + } } private PlanRoot runCreateLogicalPlanAsync() @@ -639,7 +649,7 @@ private PlanRoot runCreateLogicalPlanAsync() private void createQueryScheduler(PlanRoot plan) { - CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager::getSplits); + CloseableSplitSourceProvider splitSourceProvider = new CloseableSplitSourceProvider(splitManager); // ensure split sources are closed stateMachine.addStateChangeListener(state -> { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlStageExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlStageExecution.java index 173dbbb32e919..038dc54fcc59d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlStageExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlStageExecution.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.execution.StateMachine.StateChangeListener; @@ -41,10 +42,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.URI; import java.util.ArrayList; @@ -638,7 +637,13 @@ public List getRequiredCTEList() { // Collect all CTE IDs referenced by TableScanNodes with TemporaryTableInfo return PlanNodeSearcher.searchFrom(planFragment.getRoot()) - .where(planNode -> planNode instanceof TableScanNode) + .where(planNode -> { + if (planNode instanceof TableScanNode) { + TableScanNode tableScanNode = (TableScanNode) planNode; + return tableScanNode.getCteMaterializationInfo().isPresent(); + } + return false; + }) .findAll().stream() .map(planNode -> ((TableScanNode) planNode).getCteMaterializationInfo() .orElseThrow(() -> new IllegalStateException("TableScanNode has no TemporaryTableInfo"))) diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTask.java index 08036cb7664d9..b03c9ef04e3ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTask.java @@ -27,16 +27,12 @@ import com.facebook.presto.execution.buffer.SpoolingOutputBufferFactory; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.memory.QueryContext; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.operator.ExchangeClientSupplier; import com.facebook.presto.operator.PipelineContext; import com.facebook.presto.operator.PipelineStatus; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.TaskExchangeClientManager; import com.facebook.presto.operator.TaskStats; -import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.sql.planner.PlanFragment; import com.google.common.base.Function; @@ -44,10 +40,9 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import jakarta.annotation.Nullable; import org.joda.time.DateTime; -import javax.annotation.Nullable; - import java.net.URI; import java.util.List; import java.util.Optional; @@ -349,24 +344,6 @@ private TaskStats getTaskStats(TaskHolder taskHolder) return new TaskStats(new DateTime(taskStateMachine.getCreatedTimeInMillis()), new DateTime(endTimeInMillis)); } - private MetadataUpdates getMetadataUpdateRequests(TaskHolder taskHolder) - { - ConnectorId connectorId = null; - ImmutableList.Builder connectorMetadataUpdatesBuilder = ImmutableList.builder(); - - if (taskHolder.getTaskExecution() != null) { - TaskMetadataContext taskMetadataContext = taskHolder.getTaskExecution().getTaskContext().getTaskMetadataContext(); - if (!taskMetadataContext.getMetadataUpdaters().isEmpty()) { - connectorId = taskMetadataContext.getConnectorId(); - for (ConnectorMetadataUpdater metadataUpdater : taskMetadataContext.getMetadataUpdaters()) { - connectorMetadataUpdatesBuilder.addAll(metadataUpdater.getPendingMetadataUpdateRequests()); - } - } - } - - return new MetadataUpdates(connectorId, connectorMetadataUpdatesBuilder.build()); - } - private static Set getNoMoreSplits(TaskHolder taskHolder) { TaskInfo finalTaskInfo = taskHolder.getFinalTaskInfo(); @@ -384,7 +361,6 @@ private TaskInfo createTaskInfo(TaskHolder taskHolder) { TaskStats taskStats = getTaskStats(taskHolder); Set noMoreSplits = getNoMoreSplits(taskHolder); - MetadataUpdates metadataRequests = getMetadataUpdateRequests(taskHolder); TaskStatus taskStatus = createTaskStatus(taskHolder); return new TaskInfo( @@ -395,7 +371,6 @@ private TaskInfo createTaskInfo(TaskHolder taskHolder) noMoreSplits, taskStats, needsPlan.get(), - metadataRequests, nodeId); } @@ -481,11 +456,6 @@ public TaskInfo updateTask( return getTaskInfo(); } - public TaskMetadataContext getTaskMetadataContext() - { - return taskHolderReference.get().taskExecution.getTaskContext().getTaskMetadataContext(); - } - public ListenableFuture getTaskResults(OutputBufferId bufferId, long startingSequenceId, long maxSizeInBytes) { requireNonNull(bufferId, "bufferId is null"); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java index 43885270d21a6..0268c52825ce1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskExecution.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.airlift.concurrent.SetThreadName; +import com.facebook.airlift.units.Duration; import com.facebook.presto.event.SplitMonitor; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.buffer.BufferState; @@ -39,12 +41,9 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.lang.ref.WeakReference; import java.util.ArrayList; @@ -1082,7 +1081,10 @@ public ListenableFuture processFor(Duration duration) @Override public String getInfo() { - return (partitionedSplit == null) ? "" : partitionedSplit.getSplit().getInfo().toString(); + if (partitionedSplit != null && partitionedSplit.getSplit() != null && partitionedSplit.getSplit().getInfo() != null) { + return partitionedSplit.getSplit().getInfo().toString(); + } + return ""; } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskManager.java index 443e96adc789e..6f5d5fd2c94a7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/SqlTaskManager.java @@ -19,6 +19,8 @@ import com.facebook.airlift.node.NodeInfo; import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.GcMonitor; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.block.BlockEncodingSerde; import com.facebook.presto.event.SplitMonitor; @@ -36,13 +38,11 @@ import com.facebook.presto.memory.MemoryPoolAssignmentsRequest; import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.memory.QueryContext; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.operator.ExchangeClientSupplier; import com.facebook.presto.operator.FragmentResultCacheManager; import com.facebook.presto.operator.TaskMemoryReservationSummary; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; import com.facebook.presto.spiller.LocalSpillManager; import com.facebook.presto.spiller.NodeSpillConfig; import com.facebook.presto.sql.gen.OrderingCompiler; @@ -54,18 +54,15 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.concurrent.GuardedBy; import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; - import java.io.Closeable; import java.nio.file.Paths; import java.util.List; @@ -434,15 +431,6 @@ public TaskInfo updateTask( return sqlTask.updateTask(session, fragment, sources, outputBuffers, tableWriteInfo); } - @Override - public void updateMetadataResults(TaskId taskId, MetadataUpdates metadataUpdates) - { - TaskMetadataContext metadataContext = tasks.getUnchecked(taskId).getTaskMetadataContext(); - for (ConnectorMetadataUpdater metadataUpdater : metadataContext.getMetadataUpdaters()) { - metadataUpdater.setMetadataUpdateResults(metadataUpdates.getMetadataUpdates()); - } - } - @Override public ListenableFuture getTaskResults(TaskId taskId, OutputBufferId bufferId, long startingSequenceId, long maxSizeInBytes) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionInfo.java index f03c5e364bc86..cea22884f15d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionInfo.java @@ -33,6 +33,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.airlift.units.Duration.succinctDuration; import static com.facebook.presto.common.RuntimeMetricName.DRIVER_COUNT_PER_TASK; import static com.facebook.presto.common.RuntimeMetricName.TASK_BLOCKED_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.TASK_ELAPSED_TIME_NANOS; @@ -42,7 +43,6 @@ import static com.facebook.presto.common.RuntimeUnit.NONE; import static com.facebook.presto.execution.StageExecutionState.FINISHED; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.Duration.succinctDuration; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.Math.toIntExact; @@ -113,6 +113,16 @@ public static StageExecutionInfo create( taskStatsAggregator.blockedDrivers, taskStatsAggregator.completedDrivers, + taskStatsAggregator.totalNewDrivers, + taskStatsAggregator.queuedNewDrivers, + taskStatsAggregator.runningNewDrivers, + taskStatsAggregator.completedNewDrivers, + + taskStatsAggregator.totalSplits, + taskStatsAggregator.queuedSplits, + taskStatsAggregator.runningSplits, + taskStatsAggregator.completedSplits, + taskStatsAggregator.cumulativeUserMemory, taskStatsAggregator.cumulativeTotalMemory, taskStatsAggregator.userMemoryReservation, @@ -246,6 +256,16 @@ private static class TaskStatsAggregator private int blockedDrivers; private int completedDrivers; + private int totalNewDrivers; + private int queuedNewDrivers; + private int runningNewDrivers; + private int completedNewDrivers; + + private int totalSplits; + private int queuedSplits; + private int runningSplits; + private int completedSplits; + private double cumulativeUserMemory; private double cumulativeTotalMemory; private long userMemoryReservation; @@ -291,6 +311,16 @@ public void processTaskStats(TaskStats taskStats) blockedDrivers += taskStats.getBlockedDrivers(); completedDrivers += taskStats.getCompletedDrivers(); + totalNewDrivers += taskStats.getTotalNewDrivers(); + queuedNewDrivers += taskStats.getQueuedNewDrivers(); + runningNewDrivers += taskStats.getRunningNewDrivers(); + completedNewDrivers += taskStats.getCompletedNewDrivers(); + + totalSplits += taskStats.getTotalSplits(); + queuedSplits += taskStats.getQueuedSplits(); + runningSplits += taskStats.getRunningSplits(); + completedSplits += taskStats.getCompletedSplits(); + cumulativeUserMemory += taskStats.getCumulativeUserMemory(); cumulativeTotalMemory += taskStats.getCumulativeTotalMemory(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStateMachine.java b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStateMachine.java index 0973a369c77ec..ff6497f1eac57 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStateMachine.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStateMachine.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.Distribution; +import com.facebook.presto.common.RuntimeMetricName; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.facebook.presto.execution.scheduler.ScheduleResult; @@ -23,8 +24,7 @@ import com.facebook.presto.operator.TaskStats; import com.facebook.presto.util.Failures; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.HashSet; import java.util.List; @@ -37,6 +37,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.common.RuntimeMetricName.EVENT_LOOP_METHOD_EXECUTION_CPU_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.GET_SPLITS_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.SCAN_STAGE_SCHEDULER_BLOCKED_TIME_NANOS; @@ -46,9 +47,11 @@ import static com.facebook.presto.common.RuntimeMetricName.SCHEDULER_CPU_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.SCHEDULER_WALL_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.TASK_PLAN_SERIALIZED_CPU_TIME_NANOS; +import static com.facebook.presto.common.RuntimeMetricName.TASK_START_WAIT_FOR_EVENT_LOOP; import static com.facebook.presto.common.RuntimeMetricName.TASK_UPDATE_DELIVERED_WALL_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.TASK_UPDATE_SERIALIZED_CPU_TIME_NANOS; import static com.facebook.presto.common.RuntimeUnit.NANO; +import static com.facebook.presto.common.RuntimeUnit.NONE; import static com.facebook.presto.execution.StageExecutionState.ABORTED; import static com.facebook.presto.execution.StageExecutionState.CANCELED; import static com.facebook.presto.execution.StageExecutionState.FAILED; @@ -63,7 +66,6 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Math.max; import static java.lang.Math.min; import static java.lang.System.currentTimeMillis; @@ -253,6 +255,16 @@ public BasicStageExecutionStats getBasicStageStats(Supplier> int runningDrivers = 0; int completedDrivers = 0; + int totalNewDrivers = 0; + int queuedNewDrivers = 0; + int runningNewDrivers = 0; + int completedNewDrivers = 0; + + int totalSplits = 0; + int queuedSplits = 0; + int runningSplits = 0; + int completedSplits = 0; + double cumulativeUserMemory = 0; double cumulativeTotalMemory = 0; long userMemoryReservationInBytes = 0; @@ -278,6 +290,16 @@ public BasicStageExecutionStats getBasicStageStats(Supplier> runningDrivers += taskStats.getRunningDrivers(); completedDrivers += taskStats.getCompletedDrivers(); + totalNewDrivers += taskStats.getTotalNewDrivers(); + queuedNewDrivers += taskStats.getQueuedNewDrivers(); + runningNewDrivers += taskStats.getRunningNewDrivers(); + completedNewDrivers += taskStats.getCompletedNewDrivers(); + + totalSplits += taskStats.getTotalSplits(); + queuedSplits += taskStats.getQueuedSplits(); + runningSplits += taskStats.getRunningSplits(); + completedSplits += taskStats.getCompletedSplits(); + cumulativeUserMemory += taskStats.getCumulativeUserMemory(); long taskUserMemory = taskStats.getUserMemoryReservationInBytes(); @@ -313,6 +335,16 @@ public BasicStageExecutionStats getBasicStageStats(Supplier> runningDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + + totalSplits, + queuedSplits, + runningSplits, + completedSplits, + rawInputDataSizeInBytes, rawInputPositions, @@ -402,6 +434,22 @@ public void recordTaskUpdateDeliveredTime(long nanos) runtimeStats.addMetricValue(TASK_UPDATE_DELIVERED_WALL_TIME_NANOS, NANO, max(nanos, 0)); } + @Override + public void recordStartWaitForEventLoop(long nanos) + { + runtimeStats.addMetricValue(TASK_START_WAIT_FOR_EVENT_LOOP, NANO, max(nanos, 0)); + } + + public void recordDeliveredUpdates(int updates) + { + runtimeStats.addMetricValue(RuntimeMetricName.TASK_UPDATE_DELIVERED_UPDATES, NONE, max(updates, 0)); + } + + public void recordRoundTripTime(long nanos) + { + runtimeStats.addMetricValue(RuntimeMetricName.TASK_UPDATE_ROUND_TRIP_TIME, NANO, max(nanos, 0)); + } + @Override public void recordTaskUpdateSerializedCpuTime(long nanos) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStats.java index 236bcc1095997..b684fa12cb152 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/StageExecutionStats.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.Distribution; import com.facebook.airlift.stats.Distribution.DistributionSnapshot; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.operator.BlockedReason; import com.facebook.presto.operator.OperatorStats; @@ -23,9 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.OptionalDouble; @@ -57,6 +56,16 @@ public class StageExecutionStats private final int blockedDrivers; private final int completedDrivers; + private final int totalNewDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + + private final int totalSplits; + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; + private final double cumulativeUserMemory; private final double cumulativeTotalMemory; private final long userMemoryReservationInBytes; @@ -111,6 +120,16 @@ public StageExecutionStats( @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, + @JsonProperty("totalNewDrivers") int totalNewDrivers, + @JsonProperty("queuedNewDrivers") int queuedNewDrivers, + @JsonProperty("runningNewDrivers") int runningNewDrivers, + @JsonProperty("completedNewDrivers") int completedNewDrivers, + + @JsonProperty("totalSplits") int totalSplits, + @JsonProperty("queuedSplits") int queuedSplits, + @JsonProperty("runningSplits") int runningSplits, + @JsonProperty("completedSplits") int completedSplits, + @JsonProperty("cumulativeUserMemory") double cumulativeUserMemory, @JsonProperty("cumulativeTotalMemory") double cumulativeTotalMemory, @JsonProperty("userMemoryReservationInBytes") long userMemoryReservationInBytes, @@ -169,18 +188,31 @@ public StageExecutionStats( this.blockedDrivers = blockedDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; - checkArgument(cumulativeUserMemory >= 0, "cumulativeUserMemory is negative"); - this.cumulativeUserMemory = cumulativeUserMemory; - checkArgument(cumulativeTotalMemory >= 0, "cumulativeTotalMemory is negative"); - this.cumulativeTotalMemory = cumulativeTotalMemory; - checkArgument(userMemoryReservationInBytes >= 0, "userMemoryReservationInBytes is negative"); - this.userMemoryReservationInBytes = userMemoryReservationInBytes; - checkArgument(totalMemoryReservationInBytes >= 0, "totalMemoryReservationInBytes is negative"); - this.totalMemoryReservationInBytes = totalMemoryReservationInBytes; - checkArgument(peakUserMemoryReservationInBytes >= 0, "peakUserMemoryReservationInBytes is negative"); - this.peakUserMemoryReservationInBytes = peakUserMemoryReservationInBytes; - checkArgument(peakNodeTotalMemoryReservationInBytes >= 0, "peakNodeTotalMemoryReservationInBytes is negative"); - this.peakNodeTotalMemoryReservationInBytes = peakNodeTotalMemoryReservationInBytes; + + checkArgument(totalNewDrivers >= 0, "totalNewDrivers is negative"); + this.totalNewDrivers = totalNewDrivers; + checkArgument(queuedNewDrivers >= 0, "queuedNewDrivers is negative"); + this.queuedNewDrivers = queuedNewDrivers; + checkArgument(runningNewDrivers >= 0, "runningNewDrivers is negative"); + this.runningNewDrivers = runningNewDrivers; + checkArgument(completedNewDrivers >= 0, "completedNewDrivers is negative"); + this.completedNewDrivers = completedNewDrivers; + + checkArgument(totalSplits >= 0, "totalSplits is negative"); + this.totalSplits = totalSplits; + checkArgument(queuedSplits >= 0, "queuedSplits is negative"); + this.queuedSplits = queuedSplits; + checkArgument(runningSplits >= 0, "runningSplits is negative"); + this.runningSplits = runningSplits; + checkArgument(completedSplits >= 0, "completedSplits is negative"); + this.completedSplits = completedSplits; + + this.cumulativeUserMemory = (cumulativeUserMemory >= 0) ? cumulativeUserMemory : Long.MAX_VALUE; + this.cumulativeTotalMemory = (cumulativeTotalMemory >= 0) ? cumulativeTotalMemory : Long.MAX_VALUE; + this.userMemoryReservationInBytes = (userMemoryReservationInBytes >= 0) ? userMemoryReservationInBytes : Long.MAX_VALUE; + this.totalMemoryReservationInBytes = (totalMemoryReservationInBytes >= 0) ? totalMemoryReservationInBytes : Long.MAX_VALUE; + this.peakUserMemoryReservationInBytes = (peakUserMemoryReservationInBytes >= 0) ? peakUserMemoryReservationInBytes : Long.MAX_VALUE; + this.peakNodeTotalMemoryReservationInBytes = (peakNodeTotalMemoryReservationInBytes >= 0) ? peakNodeTotalMemoryReservationInBytes : Long.MAX_VALUE; this.totalScheduledTime = requireNonNull(totalScheduledTime, "totalScheduledTime is null"); this.totalCpuTime = requireNonNull(totalCpuTime, "totalCpuTime is null"); @@ -189,29 +221,20 @@ public StageExecutionStats( this.fullyBlocked = fullyBlocked; this.blockedReasons = ImmutableSet.copyOf(requireNonNull(blockedReasons, "blockedReasons is null")); - checkArgument(totalAllocationInBytes >= 0, "totalAllocationInBytes is negative"); - this.totalAllocationInBytes = totalAllocationInBytes; - checkArgument(rawInputDataSizeInBytes >= 0, "rawInputDataSizeInBytes is negative"); - this.rawInputDataSizeInBytes = rawInputDataSizeInBytes; - checkArgument(rawInputPositions >= 0, "rawInputPositions is negative"); - this.rawInputPositions = rawInputPositions; + this.totalAllocationInBytes = (totalAllocationInBytes >= 0) ? totalAllocationInBytes : Long.MAX_VALUE; + this.rawInputDataSizeInBytes = (rawInputDataSizeInBytes >= 0) ? rawInputDataSizeInBytes : Long.MAX_VALUE; + this.rawInputPositions = (rawInputPositions >= 0) ? rawInputPositions : Long.MAX_VALUE; - checkArgument(processedInputDataSizeInBytes >= 0, "processedInputDataSizeInBytes is negative"); - this.processedInputDataSizeInBytes = processedInputDataSizeInBytes; - checkArgument(processedInputPositions >= 0, "processedInputPositions is negative"); - this.processedInputPositions = processedInputPositions; + this.processedInputDataSizeInBytes = (processedInputDataSizeInBytes >= 0) ? processedInputDataSizeInBytes : Long.MAX_VALUE; + this.processedInputPositions = (processedInputPositions >= 0) ? processedInputPositions : Long.MAX_VALUE; - checkArgument(bufferedDataSizeInBytes >= 0, "bufferedDataSizeInBytes is negative"); - this.bufferedDataSizeInBytes = bufferedDataSizeInBytes; + this.bufferedDataSizeInBytes = (bufferedDataSizeInBytes >= 0) ? bufferedDataSizeInBytes : Long.MAX_VALUE; - // An overflow could have occurred on this stat - handle this gracefully. this.outputDataSizeInBytes = (outputDataSizeInBytes >= 0) ? outputDataSizeInBytes : Long.MAX_VALUE; - checkArgument(outputPositions >= 0, "outputPositions is negative"); - this.outputPositions = outputPositions; + this.outputPositions = (outputPositions >= 0) ? outputPositions : Long.MAX_VALUE; - checkArgument(physicalWrittenDataSizeInBytes >= 0, "writtenDataSizeInBytes is negative"); - this.physicalWrittenDataSizeInBytes = physicalWrittenDataSizeInBytes; + this.physicalWrittenDataSizeInBytes = (physicalWrittenDataSizeInBytes >= 0) ? physicalWrittenDataSizeInBytes : Long.MAX_VALUE; this.gcInfo = requireNonNull(gcInfo, "gcInfo is null"); @@ -291,6 +314,54 @@ public int getCompletedDrivers() return completedDrivers; } + @JsonProperty + public int getTotalNewDrivers() + { + return totalNewDrivers; + } + + @JsonProperty + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + @JsonProperty + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + @JsonProperty + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } + + @JsonProperty + public int getTotalSplits() + { + return totalSplits; + } + + @JsonProperty + public int getQueuedSplits() + { + return queuedSplits; + } + + @JsonProperty + public int getRunningSplits() + { + return runningSplits; + } + + @JsonProperty + public int getCompletedSplits() + { + return completedSplits; + } + @JsonProperty public double getCumulativeUserMemory() { @@ -450,6 +521,14 @@ public BasicStageExecutionStats toBasicStageStats(StageExecutionState stageExecu queuedDrivers, runningDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + totalSplits, + queuedSplits, + runningSplits, + completedSplits, rawInputDataSizeInBytes, rawInputPositions, cumulativeUserMemory, @@ -470,6 +549,15 @@ public static StageExecutionStats zero(int stageId) 0L, new Distribution().snapshot(), 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0, 0, diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/StageInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/StageInfo.java index 2ac6e04f0de10..028ba2141aefb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/StageInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/StageInfo.java @@ -17,8 +17,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.net.URI; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/StateMachine.java b/presto-main-base/src/main/java/com/facebook/presto/execution/StateMachine.java index 559e245720291..49af8d9757002 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/StateMachine.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/StateMachine.java @@ -19,9 +19,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskInfo.java index 57ed6931318d8..f0e9c48121915 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskInfo.java @@ -18,23 +18,20 @@ import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.execution.buffer.BufferInfo; import com.facebook.presto.execution.buffer.OutputBufferInfo; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.operator.TaskStats; import com.facebook.presto.spi.plan.PlanNodeId; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.Immutable; import org.joda.time.DateTime; -import javax.annotation.concurrent.Immutable; - import java.net.URI; import java.util.List; import java.util.Set; import static com.facebook.presto.execution.TaskStatus.initialTaskStatus; import static com.facebook.presto.execution.buffer.BufferState.OPEN; -import static com.facebook.presto.metadata.MetadataUpdates.DEFAULT_METADATA_UPDATES; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static java.lang.System.currentTimeMillis; @@ -50,9 +47,7 @@ public class TaskInfo private final OutputBufferInfo outputBuffers; private final Set noMoreSplits; private final TaskStats stats; - private final boolean needsPlan; - private final MetadataUpdates metadataUpdates; private final String nodeId; @JsonCreator @@ -65,7 +60,6 @@ public TaskInfo( @JsonProperty("noMoreSplits") Set noMoreSplits, @JsonProperty("stats") TaskStats stats, @JsonProperty("needsPlan") boolean needsPlan, - @JsonProperty("metadataUpdates") MetadataUpdates metadataUpdates, @JsonProperty("nodeId") String nodeId) { this.taskId = requireNonNull(taskId, "taskId is null"); @@ -77,7 +71,6 @@ public TaskInfo( this.stats = requireNonNull(stats, "stats is null"); this.needsPlan = needsPlan; - this.metadataUpdates = metadataUpdates; this.nodeId = requireNonNull(nodeId, "nodeId is null"); } @@ -137,13 +130,6 @@ public boolean isNeedsPlan() @JsonProperty @ThriftField(8) - public MetadataUpdates getMetadataUpdates() - { - return metadataUpdates; - } - - @JsonProperty - @ThriftField(9) public String getNodeId() { return nodeId; @@ -160,7 +146,6 @@ public TaskInfo summarize() noMoreSplits, stats.summarizeFinal(), needsPlan, - metadataUpdates, nodeId); } return new TaskInfo( @@ -171,7 +156,6 @@ public TaskInfo summarize() noMoreSplits, stats.summarize(), needsPlan, - metadataUpdates, nodeId); } @@ -194,12 +178,11 @@ public static TaskInfo createInitialTask(TaskId taskId, URI location, List metadataUpdaters; - private ConnectorId connectorId; - - public TaskMetadataContext() - { - this.metadataUpdaters = new CopyOnWriteArrayList<>(); - } - - public void setConnectorId(ConnectorId connectorId) - { - this.connectorId = connectorId; - } - - public ConnectorId getConnectorId() - { - return connectorId; - } - - public void addMetadataUpdater(ConnectorMetadataUpdater metadataUpdater) - { - metadataUpdaters.add(metadataUpdater); - } - - public List getMetadataUpdaters() - { - return metadataUpdaters; - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskStateMachine.java b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskStateMachine.java index 65d11fcc6d22f..3b74d8c06baa6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskStateMachine.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskStateMachine.java @@ -16,8 +16,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.execution.StateMachine.StateChangeListener; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.concurrent.Executor; import java.util.concurrent.LinkedBlockingQueue; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java index 2dd21f43c936a..9a549af356a4b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/TaskThresholdMemoryRevokingScheduler.java @@ -23,11 +23,10 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/TimeoutThread.java b/presto-main-base/src/main/java/com/facebook/presto/execution/TimeoutThread.java index 473aa1c82f6a2..e6b077999f1d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/TimeoutThread.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/TimeoutThread.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/TruncateTableTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/TruncateTableTask.java index 34df5768bebf1..9136084915565 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/TruncateTableTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/TruncateTableTask.java @@ -47,7 +47,7 @@ public String getName() public ListenableFuture execute(TruncateTable statement, TransactionManager transactionManager, Metadata metadata, AccessControl accessControl, Session session, List parameters, WarningCollector warningCollector, String query) { MetadataResolver metadataResolver = metadata.getMetadataResolver(session); - QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getTableName(), metadata); if (metadataResolver.isMaterializedView(tableName)) { throw new SemanticException(NOT_SUPPORTED, statement, "Cannot truncate a materialized view"); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/UseTask.java b/presto-main-base/src/main/java/com/facebook/presto/execution/UseTask.java index c1ba8bc5991e1..86fe1eebbfb3b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/UseTask.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/UseTask.java @@ -92,7 +92,9 @@ private void checkAndSetSchema(Use statement, Metadata metadata, QueryStateMachi String catalog = statement.getCatalog() .map(Identifier::getValueLowerCase) .orElseGet(() -> session.getCatalog().map(String::toLowerCase).get()); - String schema = statement.getSchema().getValueLowerCase(); + + Identifier schemaIdentifier = statement.getSchema(); + String schema = metadata.normalizeIdentifier(session, catalog, schemaIdentifier.getValue()); if (!metadata.getMetadataResolver(session).schemaExists(new CatalogSchemaName(catalog, schema))) { throw new SemanticException(MISSING_SCHEMA, format("Schema does not exist: %s.%s", catalog, schema)); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ArbitraryOutputBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ArbitraryOutputBuffer.java index 7a4ec56b5626c..05c4ea7651f9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ArbitraryOutputBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ArbitraryOutputBuffer.java @@ -26,9 +26,8 @@ import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/BroadcastOutputBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/BroadcastOutputBuffer.java index 13b34ab20d553..467d61493bf98 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/BroadcastOutputBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/BroadcastOutputBuffer.java @@ -24,8 +24,7 @@ import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ClientBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ClientBuffer.java index 05804e4ff47c9..839caf2055c26 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ClientBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/ClientBuffer.java @@ -19,10 +19,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.Immutable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.Immutable; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LazyOutputBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LazyOutputBuffer.java index 8ca9c2021b24b..9292064a85d07 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LazyOutputBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LazyOutputBuffer.java @@ -24,9 +24,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.HashSet; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LifespanSerializedPageTracker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LifespanSerializedPageTracker.java index 99fbc224b9acc..7940b84ecf442 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LifespanSerializedPageTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/LifespanSerializedPageTracker.java @@ -15,8 +15,7 @@ import com.facebook.presto.execution.Lifespan; import com.facebook.presto.execution.buffer.SerializedPageReference.PagesReleasedListener; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Optional; import java.util.Set; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/OutputBufferMemoryManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/OutputBufferMemoryManager.java index 6028adbc55e80..4c5dab92763b7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/OutputBufferMemoryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/OutputBufferMemoryManager.java @@ -18,10 +18,9 @@ import com.google.common.base.Suppliers; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SerializedPageReference.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SerializedPageReference.java index 2bbe3427ecef7..1c8137815e6ac 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SerializedPageReference.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SerializedPageReference.java @@ -15,8 +15,7 @@ import com.facebook.presto.execution.Lifespan; import com.facebook.presto.spi.page.SerializedPage; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBuffer.java index c8a2b7637c25d..5b7de4725e2e9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBuffer.java @@ -33,12 +33,11 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.Immutable; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.airlift.slice.InputStreamSliceInput; import io.airlift.slice.SliceInput; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.Immutable; - import java.io.IOException; import java.util.ArrayDeque; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBufferFactory.java b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBufferFactory.java index 86c10ef0cf940..7ef22a83bb186 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBufferFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/buffer/SpoolingOutputBufferFactory.java @@ -24,9 +24,8 @@ import com.google.common.io.Closer; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.Inject; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; import java.io.IOException; import java.util.concurrent.ExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java index c6e550a91fec8..028e10f07a6d9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/MultilevelSplitQueue.java @@ -17,13 +17,12 @@ import com.facebook.presto.execution.TaskManagerConfig; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.Collection; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java index d4017069c1fb8..682ea8abef9f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/PrioritizedSplitRunner.java @@ -16,11 +16,11 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.SplitRunner; import com.google.common.base.Ticker; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/Priority.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/Priority.java index 6bb3609a492fd..969bbfb92578f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/Priority.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/Priority.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution.executor; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static com.google.common.base.MoreObjects.toStringHelper; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java index a2de9b34d1b58..5e906b430cb3c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskExecutor.java @@ -19,6 +19,7 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeDistribution; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.SplitRunner; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskManagerConfig; @@ -35,16 +36,14 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ComparisonChain; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; @@ -68,7 +67,6 @@ import java.util.function.Predicate; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; -import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.presto.execution.executor.MultilevelSplitQueue.computeLevel; import static com.facebook.presto.util.MoreMath.min; import static com.google.common.base.MoreObjects.toStringHelper; @@ -267,7 +265,7 @@ public TaskExecutor( checkArgument(interruptSplitInterval.getValue(SECONDS) >= 1.0, "interruptSplitInterval must be at least 1 second"); // we manage thread pool size directly, so create an unlimited pool - this.executor = newCachedThreadPool(threadsNamed("task-processor-%s")); + this.executor = newCachedThreadPool(daemonThreadsNamed("task-processor-%s")); this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor); this.runnerThreads = runnerThreads; this.embedVersion = requireNonNull(embedVersion, "embedVersion is null"); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java index 38f0847da414e..4616aced43015 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskHandle.java @@ -13,13 +13,12 @@ */ package com.facebook.presto.execution.executor; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.SplitConcurrencyController; import com.facebook.presto.execution.TaskId; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayDeque; import java.util.ArrayList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskPriorityTracker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskPriorityTracker.java index d7ac789a179a4..c6a7c55f6bd9a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskPriorityTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/executor/TaskPriorityTracker.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution.executor; -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java index 571798230fbf6..9bd91021d34cc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroup.java @@ -14,8 +14,11 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.ManagedQueryExecution; import com.facebook.presto.execution.resourceGroups.WeightedFairQueue.Usage; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.server.QueryStateInfo; import com.facebook.presto.server.ResourceGroupInfo; @@ -26,14 +29,11 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupState; import com.facebook.presto.spi.resourceGroups.SchedulingPolicy; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - import java.util.Collection; import java.util.HashMap; import java.util.HashSet; @@ -50,6 +50,7 @@ import java.util.function.Function; import java.util.function.Predicate; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.SystemSessionProperties.getQueryPriority; import static com.facebook.presto.common.ErrorType.USER_ERROR; import static com.facebook.presto.server.QueryStateInfo.createQueryStateInfo; @@ -69,7 +70,6 @@ import static com.google.common.math.LongMath.saturatedAdd; import static com.google.common.math.LongMath.saturatedMultiply; import static com.google.common.math.LongMath.saturatedSubtract; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.min; import static java.lang.String.format; import static java.lang.System.currentTimeMillis; @@ -97,6 +97,8 @@ public class InternalResourceGroup private final Function> additionalRuntimeInfo; private final Predicate shouldWaitForResourceManagerUpdate; private final InternalNodeManager nodeManager; + private final ClusterResourceChecker clusterResourceChecker; + private final QueryPacingContext queryPacingContext; // Configuration // ============= @@ -167,12 +169,16 @@ protected InternalResourceGroup( boolean staticResourceGroup, Function> additionalRuntimeInfo, Predicate shouldWaitForResourceManagerUpdate, - InternalNodeManager nodeManager) + InternalNodeManager nodeManager, + ClusterResourceChecker clusterResourceChecker, + QueryPacingContext queryPacingContext) { this.parent = requireNonNull(parent, "parent is null"); this.jmxExportListener = requireNonNull(jmxExportListener, "jmxExportListener is null"); this.executor = requireNonNull(executor, "executor is null"); this.nodeManager = requireNonNull(nodeManager, "node manager is null"); + this.clusterResourceChecker = requireNonNull(clusterResourceChecker, "clusterResourceChecker is null"); + this.queryPacingContext = requireNonNull(queryPacingContext, "queryPacingContext is null"); requireNonNull(name, "name is null"); if (parent.isPresent()) { id = new ResourceGroupId(parent.get().id, name); @@ -672,7 +678,9 @@ public InternalResourceGroup getOrCreateSubGroup(String name, boolean staticSegm staticResourceGroup && staticSegment, additionalRuntimeInfo, shouldWaitForResourceManagerUpdate, - nodeManager); + nodeManager, + clusterResourceChecker, + queryPacingContext); // Sub group must use query priority to ensure ordering if (schedulingPolicy == QUERY_PRIORITY) { subGroup.setSchedulingPolicy(QUERY_PRIORITY); @@ -731,12 +739,14 @@ public void run(ManagedQueryExecution query) } else { query.setResourceGroupQueryLimits(perQueryLimits); - if (canRun && queuedQueries.isEmpty()) { + boolean immediateStartCandidate = canRun && queuedQueries.isEmpty(); + if (immediateStartCandidate && queryPacingContext.tryAcquireAdmissionSlot()) { startInBackground(query); } else { enqueueQuery(query); } + query.addStateChangeListener(state -> { if (state.isDone()) { queryFinished(query); @@ -771,7 +781,7 @@ private void enqueueQuery(ManagedQueryExecution query) } // This method must be called whenever the group's eligibility to run more queries may have changed. - private void updateEligibility() + protected void updateEligibility() { checkState(Thread.holdsLock(root), "Must hold lock to update eligibility"); synchronized (root) { @@ -803,6 +813,8 @@ private void startInBackground(ManagedQueryExecution query) group = group.parent.get(); } updateEligibility(); + // Increment global running query counter for pacing + queryPacingContext.onQueryStarted(); executor.execute(query::startWaitingForResources); group = this; long lastRunningQueryStartTimeMillis = currentTimeMillis(); @@ -836,6 +848,8 @@ private void queryFinished(ManagedQueryExecution query) group.parent.get().descendantRunningQueries--; group = group.parent.get(); } + // Decrement global running query counter for pacing + queryPacingContext.onQueryFinished(); } else { queuedQueries.remove(query); @@ -904,8 +918,13 @@ protected boolean internalStartNext() return false; } - ManagedQueryExecution query = queuedQueries.poll(); + ManagedQueryExecution query = queuedQueries.peek(); if (query != null) { + if (!queryPacingContext.tryAcquireAdmissionSlot()) { + return false; + } + + queuedQueries.poll(); // Remove from queue; use query from peek() above startInBackground(query); return true; } @@ -1020,6 +1039,11 @@ private boolean canRunMore() { checkState(Thread.holdsLock(root), "Must hold lock"); synchronized (root) { + // Check if more queries can be run on the cluster based on cluster overload + if (clusterResourceChecker.isClusterCurrentlyOverloaded()) { + return false; + } + if (cpuUsageMillis >= hardCpuLimitMillis) { return false; } @@ -1136,7 +1160,9 @@ public RootInternalResourceGroup( Executor executor, Function> additionalRuntimeInfo, Predicate shouldWaitForResourceManagerUpdate, - InternalNodeManager nodeManager) + InternalNodeManager nodeManager, + ClusterResourceChecker clusterResourceChecker, + QueryPacingContext queryPacingContext) { super(Optional.empty(), name, @@ -1145,7 +1171,17 @@ public RootInternalResourceGroup( true, additionalRuntimeInfo, shouldWaitForResourceManagerUpdate, - nodeManager); + nodeManager, + clusterResourceChecker, + queryPacingContext); + } + + public synchronized void updateEligibilityRecursively(InternalResourceGroup group) + { + group.updateEligibility(); + for (InternalResourceGroup subGroup : group.subGroups()) { + updateEligibilityRecursively(subGroup); + } } public synchronized void processQueuedQueries() @@ -1153,7 +1189,7 @@ public synchronized void processQueuedQueries() internalRefreshStats(); while (internalStartNext()) { - // start all the queries we can + // start all the queries we can (subject to limits and pacing) } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java index 96b7951e365ce..806e82e1e64d9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/InternalResourceGroupManager.java @@ -15,9 +15,12 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.ManagedQueryExecution; import com.facebook.presto.execution.QueryManagerConfig; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterOverloadStateListener; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.resourcemanager.ResourceGroupService; import com.facebook.presto.server.ResourceGroupInfo; @@ -35,17 +38,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.JmxException; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.Managed; import org.weakref.jmx.ObjectNames; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.io.File; import java.util.HashMap; import java.util.List; @@ -58,6 +59,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.LongSupplier; @@ -82,13 +84,25 @@ @ThreadSafe public final class InternalResourceGroupManager - implements ResourceGroupManager + implements ResourceGroupManager, ClusterOverloadStateListener { private static final Logger log = Logger.get(InternalResourceGroupManager.class); private static final File RESOURCE_GROUPS_CONFIGURATION = new File("etc/resource-groups.properties"); private static final String CONFIGURATION_MANAGER_PROPERTY_NAME = "resource-groups.configuration-manager"; private static final int REFRESH_EXECUTOR_POOL_SIZE = 2; + private final int maxQueryAdmissionsPerSecond; + private final int minRunningQueriesForPacing; + private final long queryAdmissionIntervalNanos; + private final AtomicLong lastAdmittedQueryNanos = new AtomicLong(0L); + + // Pacing metrics - use AtomicLong/AtomicInteger for lock-free updates to avoid deadlock + // with resource group locks (see tryAcquireAdmissionSlot for details) + private final AtomicLong totalAdmissionAttempts = new AtomicLong(0L); + private final AtomicLong totalAdmissionsGranted = new AtomicLong(0L); + private final AtomicLong totalAdmissionsDenied = new AtomicLong(0L); + private final AtomicInteger totalRunningQueriesCounter = new AtomicInteger(0); + private final ScheduledExecutorService refreshExecutor = newScheduledThreadPool(REFRESH_EXECUTOR_POOL_SIZE, daemonThreadsNamed("resource-group-manager-refresher-%d-" + REFRESH_EXECUTOR_POOL_SIZE)); private final PeriodicTaskExecutor resourceGroupRuntimeExecutor; private final List rootGroups = new CopyOnWriteArrayList<>(); @@ -113,6 +127,8 @@ public final class InternalResourceGroupManager private final QueryManagerConfig queryManagerConfig; private final InternalNodeManager nodeManager; private AtomicBoolean isConfigurationManagerLoaded; + private final ClusterResourceChecker clusterResourceChecker; + private final QueryPacingContext queryPacingContext; @Inject public InternalResourceGroupManager( @@ -122,7 +138,8 @@ public InternalResourceGroupManager( MBeanExporter exporter, ResourceGroupService resourceGroupService, ServerConfig serverConfig, - InternalNodeManager nodeManager) + InternalNodeManager nodeManager, + ClusterResourceChecker clusterResourceChecker) { this.queryManagerConfig = requireNonNull(queryManagerConfig, "queryManagerConfig is null"); this.exporter = requireNonNull(exporter, "exporter is null"); @@ -138,6 +155,101 @@ public InternalResourceGroupManager( this.resourceGroupRuntimeExecutor = new PeriodicTaskExecutor(resourceGroupRuntimeInfoRefreshInterval.toMillis(), refreshExecutor, this::refreshResourceGroupRuntimeInfo); configurationManagerFactories.putIfAbsent(LegacyResourceGroupConfigurationManager.NAME, new LegacyResourceGroupConfigurationManager.Factory()); this.isConfigurationManagerLoaded = new AtomicBoolean(false); + this.clusterResourceChecker = requireNonNull(clusterResourceChecker, "clusterResourceChecker is null"); + this.maxQueryAdmissionsPerSecond = queryManagerConfig.getMaxQueryAdmissionsPerSecond(); + this.minRunningQueriesForPacing = queryManagerConfig.getMinRunningQueriesForPacing(); + this.queryAdmissionIntervalNanos = (maxQueryAdmissionsPerSecond == Integer.MAX_VALUE) + ? 0L + : 1_000_000_000L / maxQueryAdmissionsPerSecond; + this.queryPacingContext = new QueryPacingContext() + { + @Override + public boolean tryAcquireAdmissionSlot() + { + return InternalResourceGroupManager.this.tryAcquireAdmissionSlot(); + } + + @Override + public void onQueryStarted() + { + incrementRunningQueries(); + } + + @Override + public void onQueryFinished() + { + decrementRunningQueries(); + } + }; + } + + /** + * Global rate limiter for query admissions. Enforces maxQueryAdmissionsPerSecond + * when running queries exceed minRunningQueriesForPacing threshold. + * + * @return true if query can be admitted, false if rate limit exceeded + */ + boolean tryAcquireAdmissionSlot() + { + // Pacing disabled - return early without tracking metrics + if (queryAdmissionIntervalNanos == 0L) { + return true; + } + + // Running queries below threshold - bypass pacing + int currentRunningQueries = getTotalRunningQueries(); + if (currentRunningQueries < minRunningQueriesForPacing) { + return true; + } + + totalAdmissionAttempts.incrementAndGet(); + + // Atomic update for global rate limiting. With multiple root resource groups, + // concurrent threads may call this method simultaneously (each holding their + // own root group's lock). Compare-and-swap ensures correctness in that scenario. + // With a single root group, the root lock serializes access, making the atomic + // update redundant but harmless. + for (int attempt = 0; attempt < 10; attempt++) { + long now = System.nanoTime(); + long last = lastAdmittedQueryNanos.get(); + + // Check if enough time has elapsed since last admission + if (last != 0L && (now - last) < queryAdmissionIntervalNanos) { + totalAdmissionsDenied.incrementAndGet(); + return false; + } + + // Atomically update timestamp if unchanged; retry if another thread won + if (lastAdmittedQueryNanos.compareAndSet(last, now)) { + totalAdmissionsGranted.incrementAndGet(); + return true; + } + } + + // Exhausted retries - deny to prevent starvation under extreme contention + totalAdmissionsDenied.incrementAndGet(); + return false; + } + + /** + * Returns total running queries across all resource groups. + * Uses atomic counter updated via callbacks to avoid locking resource groups. + */ + private int getTotalRunningQueries() + { + return totalRunningQueriesCounter.get(); + } + + /** Called by InternalResourceGroup when a query starts execution. */ + public void incrementRunningQueries() + { + totalRunningQueriesCounter.incrementAndGet(); + } + + /** Called by InternalResourceGroup when a query finishes execution. */ + public void decrementRunningQueries() + { + totalRunningQueriesCounter.decrementAndGet(); } @Override @@ -255,6 +367,8 @@ public ResourceGroupConfigurationManager getConfigurationManager() @PreDestroy public void destroy() { + // Unregister from cluster overload state changes + clusterResourceChecker.removeListener(this); refreshExecutor.shutdownNow(); resourceGroupRuntimeExecutor.stop(); } @@ -276,6 +390,9 @@ public void start() if (isResourceManagerEnabled) { resourceGroupRuntimeExecutor.start(); } + + // Register as listener for cluster overload state changes + clusterResourceChecker.addListener(this); } } @@ -397,7 +514,15 @@ private synchronized void createGroupIfNecessary(SelectionContext context, Ex else { RootInternalResourceGroup root; if (!isResourceManagerEnabled) { - root = new RootInternalResourceGroup(id.getSegments().get(0), this::exportGroup, executor, ignored -> Optional.empty(), rg -> false, nodeManager); + root = new RootInternalResourceGroup( + id.getSegments().get(0), + this::exportGroup, + executor, + ignored -> Optional.empty(), + rg -> false, + nodeManager, + clusterResourceChecker, + queryPacingContext); } else { root = new RootInternalResourceGroup( @@ -410,7 +535,9 @@ private synchronized void createGroupIfNecessary(SelectionContext context, Ex resourceGroupRuntimeInfosSnapshot::get, lastUpdatedResourceGroupRuntimeInfo::get, concurrencyThreshold), - nodeManager); + nodeManager, + clusterResourceChecker, + queryPacingContext); } group = root; rootGroups.add(root); @@ -464,6 +591,24 @@ public int getQueriesQueuedOnInternal() return queriesQueuedInternal; } + @Override + public void onClusterEnteredOverloadedState() + { + // Resource groups will handle overload state through their existing admission control logic + // No additional action needed here as queries will be queued automatically + } + + @Override + public void onClusterExitedOverloadedState() + { + log.info("Cluster exited overloaded state, updating eligibility for all resource groups"); + for (RootInternalResourceGroup rootGroup : rootGroups) { + synchronized (rootGroup) { + rootGroup.updateEligibilityRecursively(rootGroup); + } + } + } + @Managed public long getLastSchedulingCycleRuntimeDelayMs() { @@ -472,6 +617,57 @@ public long getLastSchedulingCycleRuntimeDelayMs() return lastSchedulingCycleRunTimeMs.get() == 0L ? lastSchedulingCycleRunTimeMs.get() : currentTimeMillis() - lastSchedulingCycleRunTimeMs.get(); } + @Managed + public int getMaxQueryAdmissionsPerSecond() + { + return maxQueryAdmissionsPerSecond; + } + + @Managed + public long getTotalAdmissionAttempts() + { + return totalAdmissionAttempts.get(); + } + + @Managed + public long getTotalAdmissionsGranted() + { + return totalAdmissionsGranted.get(); + } + + @Managed + public long getTotalAdmissionsDenied() + { + return totalAdmissionsDenied.get(); + } + + @Managed + public int getMinRunningQueriesForPacing() + { + return minRunningQueriesForPacing; + } + + @Managed + public double getAdmissionGrantRate() + { + long attempts = totalAdmissionAttempts.get(); + return attempts > 0 ? (double) totalAdmissionsGranted.get() / attempts : 0.0; + } + + @Managed + public double getAdmissionDenyRate() + { + long attempts = totalAdmissionAttempts.get(); + return attempts > 0 ? (double) totalAdmissionsDenied.get() / attempts : 0.0; + } + + @Managed + public long getMillisSinceLastAdmission() + { + long last = lastAdmittedQueryNanos.get(); + return last == 0L ? -1L : (System.nanoTime() - last) / 1_000_000; + } + private int getQueriesQueuedOnInternal(InternalResourceGroup resourceGroup) { if (resourceGroup.subGroups().isEmpty()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/QueryPacingContext.java b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/QueryPacingContext.java new file mode 100644 index 0000000000000..02eca0b4ac607 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/QueryPacingContext.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.resourceGroups; + +/** + * Context for query admission pacing. Provides a single interface for + * global rate limiting and running query tracking to prevent worker overload. + *

+ * This interface consolidates the pacing-related callbacks that are shared + * across all resource groups, keeping resource group objects smaller. + */ +public interface QueryPacingContext +{ + /** + * A no-op implementation that allows all queries and tracks nothing. + */ + QueryPacingContext NOOP = new QueryPacingContext() + { + @Override + public boolean tryAcquireAdmissionSlot() + { + return true; + } + + @Override + public void onQueryStarted() + { + } + + @Override + public void onQueryFinished() + { + } + }; + + /** + * Attempts to acquire an admission slot for starting a new query. + * Enforces global rate limiting when running queries exceed threshold. + * + * @return true if query can be admitted, false if rate limit exceeded + */ + boolean tryAcquireAdmissionSlot(); + + /** + * Called when a query starts running. Used to track global running query count. + */ + void onQueryStarted(); + + /** + * Called when a query finishes (success or failure). Used to track global running query count. + */ + void onQueryFinished(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java index d3d97d118756b..8afd4cf95c569 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/resourceGroups/ResourceGroupManager.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.resourceGroups.SelectionContext; import com.facebook.presto.spi.resourceGroups.SelectionCriteria; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; import java.util.concurrent.Executor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/AllAtOnceExecutionSchedule.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/AllAtOnceExecutionSchedule.java index 98ece04e5ad83..9cd917a8ff0cf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/AllAtOnceExecutionSchedule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/AllAtOnceExecutionSchedule.java @@ -15,6 +15,7 @@ import com.facebook.presto.execution.SqlStageExecution; import com.facebook.presto.execution.StageExecutionState; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MergeJoinNode; import com.facebook.presto.spi.plan.PlanFragmentId; @@ -24,7 +25,6 @@ import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.annotations.VisibleForTesting; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/BroadcastOutputBufferManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/BroadcastOutputBufferManager.java index 92d390e09426d..4ea566158358c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/BroadcastOutputBufferManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/BroadcastOutputBufferManager.java @@ -15,9 +15,8 @@ import com.facebook.presto.execution.buffer.OutputBuffers; import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; import java.util.function.Consumer; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java index a82ef6508d954..f24cc813491f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/CTEMaterializationTracker.java @@ -13,11 +13,14 @@ */ package com.facebook.presto.execution.scheduler; +import com.facebook.presto.execution.SqlStageExecution; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -60,4 +63,25 @@ public void markCTEAsMaterialized(String cteName) return existingFuture; }); } + + public List> waitForCteMaterialization(SqlStageExecution stage) + { + if (stage.requiresMaterializedCTE()) { + ImmutableList.Builder> blockedFutures = new ImmutableList.Builder<>(); + boolean blocked = false; + List requiredCTEIds = stage.getRequiredCTEList(); + for (String cteId : requiredCTEIds) { + ListenableFuture cteFuture = getFutureForCTE(cteId); + if (!cteFuture.isDone()) { + // Add CTE materialization future to the blocked list + blockedFutures.add(cteFuture); + blocked = true; + } + } + if (blocked) { + return blockedFutures.build(); + } + } + return ImmutableList.of(); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java index 2b0c9e5167dac..c890b215ceff4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTarget.java @@ -14,7 +14,12 @@ package com.facebook.presto.execution.scheduler; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.metadata.DeleteTableHandle; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.spi.SchemaTableName; @@ -32,10 +37,14 @@ @JsonSubTypes.Type(value = ExecutionWriterTarget.InsertHandle.class, name = "InsertHandle"), @JsonSubTypes.Type(value = ExecutionWriterTarget.DeleteHandle.class, name = "DeleteHandle"), @JsonSubTypes.Type(value = ExecutionWriterTarget.RefreshMaterializedViewHandle.class, name = "RefreshMaterializedViewHandle"), - @JsonSubTypes.Type(value = ExecutionWriterTarget.UpdateHandle.class, name = "UpdateHandle")}) + @JsonSubTypes.Type(value = ExecutionWriterTarget.UpdateHandle.class, name = "UpdateHandle"), + @JsonSubTypes.Type(value = ExecutionWriterTarget.ExecuteProcedureHandle.class, name = "ExecuteProcedureHandle"), + @JsonSubTypes.Type(value = ExecutionWriterTarget.MergeHandle.class, name = "MergeHandle") +}) @SuppressWarnings({"EmptyClass", "ClassMayBeInterface"}) public abstract class ExecutionWriterTarget { + @ThriftStruct public static class CreateHandle extends ExecutionWriterTarget { @@ -43,6 +52,7 @@ public static class CreateHandle private final SchemaTableName schemaTableName; @JsonCreator + @ThriftConstructor public CreateHandle( @JsonProperty("handle") OutputTableHandle handle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName) @@ -52,12 +62,14 @@ public CreateHandle( } @JsonProperty + @ThriftField(1) public OutputTableHandle getHandle() { return handle; } @JsonProperty + @ThriftField(2) public SchemaTableName getSchemaTableName() { return schemaTableName; @@ -70,6 +82,7 @@ public String toString() } } + @ThriftStruct public static class InsertHandle extends ExecutionWriterTarget { @@ -77,6 +90,7 @@ public static class InsertHandle private final SchemaTableName schemaTableName; @JsonCreator + @ThriftConstructor public InsertHandle( @JsonProperty("handle") InsertTableHandle handle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName) @@ -86,12 +100,14 @@ public InsertHandle( } @JsonProperty + @ThriftField(1) public InsertTableHandle getHandle() { return handle; } @JsonProperty + @ThriftField(2) public SchemaTableName getSchemaTableName() { return schemaTableName; @@ -104,6 +120,7 @@ public String toString() } } + @ThriftStruct public static class DeleteHandle extends ExecutionWriterTarget { @@ -111,6 +128,7 @@ public static class DeleteHandle private final SchemaTableName schemaTableName; @JsonCreator + @ThriftConstructor public DeleteHandle( @JsonProperty("handle") DeleteTableHandle handle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName) @@ -120,12 +138,14 @@ public DeleteHandle( } @JsonProperty + @ThriftField(1) public DeleteTableHandle getHandle() { return handle; } @JsonProperty + @ThriftField(2) public SchemaTableName getSchemaTableName() { return schemaTableName; @@ -138,6 +158,7 @@ public String toString() } } + @ThriftStruct public static class RefreshMaterializedViewHandle extends ExecutionWriterTarget { @@ -145,6 +166,7 @@ public static class RefreshMaterializedViewHandle private final SchemaTableName schemaTableName; @JsonCreator + @ThriftConstructor public RefreshMaterializedViewHandle( @JsonProperty("handle") InsertTableHandle handle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName) @@ -154,12 +176,14 @@ public RefreshMaterializedViewHandle( } @JsonProperty + @ThriftField(1) public InsertTableHandle getHandle() { return handle; } @JsonProperty + @ThriftField(2) public SchemaTableName getSchemaTableName() { return schemaTableName; @@ -172,6 +196,7 @@ public String toString() } } + @ThriftStruct public static class UpdateHandle extends ExecutionWriterTarget { @@ -179,6 +204,7 @@ public static class UpdateHandle private final SchemaTableName schemaTableName; @JsonCreator + @ThriftConstructor public UpdateHandle( @JsonProperty("handle") TableHandle handle, @JsonProperty("schemaTableName") SchemaTableName schemaTableName) @@ -188,12 +214,14 @@ public UpdateHandle( } @JsonProperty + @ThriftField(1) public TableHandle getHandle() { return handle; } @JsonProperty + @ThriftField(2) public SchemaTableName getSchemaTableName() { return schemaTableName; @@ -205,4 +233,74 @@ public String toString() return handle.toString(); } } + + public static class ExecuteProcedureHandle + extends ExecutionWriterTarget + { + private final DistributedProcedureHandle handle; + private final SchemaTableName schemaTableName; + private final QualifiedObjectName procedureName; + + @JsonCreator + public ExecuteProcedureHandle( + @JsonProperty("handle") DistributedProcedureHandle handle, + @JsonProperty("schemaTableName") SchemaTableName schemaTableName, + @JsonProperty("procedureName") QualifiedObjectName procedureName) + { + this.handle = requireNonNull(handle, "handle is null"); + this.schemaTableName = requireNonNull(schemaTableName, "schemaTableName is null"); + this.procedureName = requireNonNull(procedureName, "procedureName is null"); + } + + @JsonProperty + public DistributedProcedureHandle getHandle() + { + return handle; + } + + @JsonProperty + public SchemaTableName getSchemaTableName() + { + return schemaTableName; + } + + @JsonProperty + public QualifiedObjectName getProcedureName() + { + return procedureName; + } + + @Override + public String toString() + { + return handle.toString(); + } + } + + @ThriftStruct + public static class MergeHandle + extends ExecutionWriterTarget + { + private final com.facebook.presto.spi.MergeHandle handle; + + @JsonCreator + @ThriftConstructor + public MergeHandle(@JsonProperty("handle") com.facebook.presto.spi.MergeHandle handle) + { + this.handle = requireNonNull(handle, "tableHandle is null"); + } + + @JsonProperty + @ThriftField(1) + public com.facebook.presto.spi.MergeHandle getHandle() + { + return handle; + } + + @Override + public String toString() + { + return handle.toString(); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTargetUnion.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTargetUnion.java new file mode 100644 index 0000000000000..26ea661d30505 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ExecutionWriterTargetUnion.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler; + +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftUnion; +import com.facebook.drift.annotations.ThriftUnionId; + +import static java.util.Objects.requireNonNull; + +@ThriftUnion +public class ExecutionWriterTargetUnion +{ + private short id; + private ExecutionWriterTarget.CreateHandle createHandle; + private ExecutionWriterTarget.InsertHandle insertHandle; + private ExecutionWriterTarget.DeleteHandle deleteHandle; + private ExecutionWriterTarget.RefreshMaterializedViewHandle refreshMaterializedViewHandle; + private ExecutionWriterTarget.UpdateHandle updateHandle; + private ExecutionWriterTarget.MergeHandle mergeHandle; + + @ThriftConstructor + public ExecutionWriterTargetUnion() + { + this.id = 0; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.CreateHandle createHandle) + { + this.id = 1; + this.createHandle = createHandle; + } + + @ThriftField(1) + public ExecutionWriterTarget.CreateHandle getCreateHandle() + { + return createHandle; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.InsertHandle insertHandle) + { + this.id = 2; + this.insertHandle = insertHandle; + } + + @ThriftField(2) + public ExecutionWriterTarget.InsertHandle getInsertHandle() + { + return insertHandle; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.DeleteHandle deleteHandle) + { + this.id = 3; + this.deleteHandle = deleteHandle; + } + + @ThriftField(3) + public ExecutionWriterTarget.DeleteHandle getDeleteHandle() + { + return deleteHandle; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.RefreshMaterializedViewHandle refreshMaterializedViewHandle) + { + this.id = 4; + this.refreshMaterializedViewHandle = refreshMaterializedViewHandle; + } + + @ThriftField(4) + public ExecutionWriterTarget.RefreshMaterializedViewHandle getRefreshMaterializedViewHandle() + { + return refreshMaterializedViewHandle; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.UpdateHandle updateHandle) + { + this.id = 5; + this.updateHandle = updateHandle; + } + + @ThriftField(5) + public ExecutionWriterTarget.UpdateHandle getUpdateHandle() + { + return updateHandle; + } + + @ThriftConstructor + public ExecutionWriterTargetUnion(ExecutionWriterTarget.MergeHandle mergeHandle) + { + this.id = 6; + this.mergeHandle = mergeHandle; + } + + @ThriftField(6) + public ExecutionWriterTarget.MergeHandle getMergeHandle() + { + return mergeHandle; + } + + @ThriftUnionId + public short getId() + { + return id; + } + + public static ExecutionWriterTarget toExecutionWriterTarget(ExecutionWriterTargetUnion executionWriterTargetUnion) + { + requireNonNull(executionWriterTargetUnion, "executionWriterTargetUnion is null"); + if (executionWriterTargetUnion.getCreateHandle() != null) { + return executionWriterTargetUnion.getCreateHandle(); + } + else if (executionWriterTargetUnion.getInsertHandle() != null) { + return executionWriterTargetUnion.getInsertHandle(); + } + else if (executionWriterTargetUnion.getDeleteHandle() != null) { + return executionWriterTargetUnion.getDeleteHandle(); + } + else if (executionWriterTargetUnion.getRefreshMaterializedViewHandle() != null) { + return executionWriterTargetUnion.getRefreshMaterializedViewHandle(); + } + else if (executionWriterTargetUnion.getUpdateHandle() != null) { + return executionWriterTargetUnion.getUpdateHandle(); + } + else if (executionWriterTargetUnion.getMergeHandle() != null) { + return executionWriterTargetUnion.getMergeHandle(); + } + else { + throw new IllegalArgumentException("Unrecognized execution writer target: " + executionWriterTargetUnion); + } + } + + public static ExecutionWriterTargetUnion fromExecutionWriterTarget(ExecutionWriterTarget executionWriterTarget) + { + requireNonNull(executionWriterTarget, "executionWriterTarget is null"); + + if (executionWriterTarget instanceof ExecutionWriterTarget.CreateHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.CreateHandle) executionWriterTarget); + } + else if (executionWriterTarget instanceof ExecutionWriterTarget.InsertHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.InsertHandle) executionWriterTarget); + } + else if (executionWriterTarget instanceof ExecutionWriterTarget.DeleteHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.DeleteHandle) executionWriterTarget); + } + else if (executionWriterTarget instanceof ExecutionWriterTarget.RefreshMaterializedViewHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.RefreshMaterializedViewHandle) executionWriterTarget); + } + else if (executionWriterTarget instanceof ExecutionWriterTarget.UpdateHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.UpdateHandle) executionWriterTarget); + } + else if (executionWriterTarget instanceof ExecutionWriterTarget.MergeHandle) { + return new ExecutionWriterTargetUnion((ExecutionWriterTarget.MergeHandle) executionWriterTarget); + } + else { + throw new IllegalArgumentException("Unsupported execution writer target: " + executionWriterTarget); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java index 9febc345114f2..b87629a8a0326 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -33,8 +33,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Iterator; @@ -181,30 +180,18 @@ private ConnectorPartitionHandle partitionHandleFor(Lifespan lifespan) @Override public ScheduleResult schedule() { + List> cteMaterializationFutures = cteMaterializationTracker.waitForCteMaterialization(stage); + if (!cteMaterializationFutures.isEmpty()) { + return ScheduleResult.blocked( + false, + ImmutableList.of(), + whenAnyComplete(cteMaterializationFutures), + ScheduleResult.BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, + 0); + } + // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); - - // CTE Materialization Check - if (stage.requiresMaterializedCTE()) { - List> blocked = new ArrayList<>(); - List requiredCTEIds = stage.getRequiredCTEList(); - for (String cteId : requiredCTEIds) { - ListenableFuture cteFuture = cteMaterializationTracker.getFutureForCTE(cteId); - if (!cteFuture.isDone()) { - // Add CTE materialization future to the blocked list - blocked.add(cteFuture); - } - } - // If any CTE is not materialized, return a blocked ScheduleResult - if (!blocked.isEmpty()) { - return ScheduleResult.blocked( - false, - newTasks, - whenAnyComplete(blocked), - BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, - 0); - } - } // schedule a task on every node in the distribution if (!scheduledTasks) { newTasks = Streams.mapWithIndex( diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkLocationCache.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkLocationCache.java index d4c75c4c72364..84c982f229be7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkLocationCache.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkLocationCache.java @@ -14,12 +14,12 @@ package com.facebook.presto.execution.scheduler; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.HostAddress; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import io.airlift.units.Duration; import java.util.concurrent.ExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkTopology.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkTopology.java index 583d5d3207998..f0472f63686cc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkTopology.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NetworkTopology.java @@ -14,8 +14,7 @@ package com.facebook.presto.execution.scheduler; import com.facebook.presto.spi.HostAddress; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeAssignmentStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeAssignmentStats.java index 7b4c7707e1292..feab5116aa8fa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeAssignmentStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeAssignmentStats.java @@ -70,6 +70,12 @@ public long getQueuedSplitsWeightForStage(InternalNode node) return stageInfo == null ? 0 : stageInfo.getQueuedSplitsWeight(); } + public long getAssignedSplitsWeightForStage(InternalNode node) + { + PendingSplitInfo stageInfo = stageQueuedSplitInfo.get(node.getNodeIdentifier()); + return stageInfo == null ? 0 : stageInfo.getAssignedSplitsWeight(); + } + public int getUnacknowledgedSplitCountForStage(InternalNode node) { PendingSplitInfo stageInfo = stageQueuedSplitInfo.get(node.getNodeIdentifier()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeScheduler.java index 8c2ee85d75f9d..1f3f45f269b47 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeScheduler.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.scheduler; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.execution.NodeTaskMap; import com.facebook.presto.execution.QueryManager; @@ -42,10 +43,8 @@ import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.net.InetAddress; import java.net.UnknownHostException; @@ -62,6 +61,7 @@ import static com.facebook.airlift.concurrent.MoreFutures.whenAnyCompleteCancelOthers; import static com.facebook.presto.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; import static com.facebook.presto.SystemSessionProperties.getResourceAwareSchedulingStrategy; +import static com.facebook.presto.SystemSessionProperties.isScheduleSplitsBasedOnTaskLoad; import static com.facebook.presto.execution.scheduler.NodeSchedulerConfig.NetworkTopologyType; import static com.facebook.presto.execution.scheduler.NodeSchedulerConfig.ResourceAwareSchedulingStrategy; import static com.facebook.presto.execution.scheduler.NodeSchedulerConfig.ResourceAwareSchedulingStrategy.TTL; @@ -92,6 +92,7 @@ public class NodeScheduler private final int minCandidates; private final boolean includeCoordinator; private final long maxSplitsWeightPerNode; + private final long maxSplitsWeightPerTask; private final long maxPendingSplitsWeightPerTask; private final NodeTaskMap nodeTaskMap; private final boolean useNetworkTopology; @@ -147,6 +148,7 @@ public NodeScheduler( int maxPendingSplitsPerTask = config.getMaxPendingSplitsPerTask(); checkArgument(maxSplitsPerNode >= maxPendingSplitsPerTask, "maxSplitsPerNode must be > maxPendingSplitsPerTask"); this.maxSplitsWeightPerNode = SplitWeight.rawValueForStandardSplitCount(maxSplitsPerNode); + this.maxSplitsWeightPerTask = SplitWeight.rawValueForStandardSplitCount(config.getMaxSplitsPerTask()); this.maxPendingSplitsWeightPerTask = SplitWeight.rawValueForStandardSplitCount(maxPendingSplitsPerTask); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.useNetworkTopology = !config.getNetworkTopology().equals(NetworkTopologyType.LEGACY); @@ -232,9 +234,11 @@ public NodeSelector createNodeSelector(Session session, ConnectorId connectorId, nodeSelectionStats, nodeTaskMap, includeCoordinator, + isScheduleSplitsBasedOnTaskLoad(session), nodeMap, minCandidates, maxSplitsWeightPerNode, + maxSplitsWeightPerTask, maxPendingSplitsWeightPerTask, maxUnacknowledgedSplitsPerTask, maxTasksPerStage, diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerConfig.java index 1bc0af2e4cccb..34aa69d9be896 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerConfig.java @@ -17,9 +17,8 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; import com.facebook.airlift.configuration.LegacyConfig; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; @DefunctConfig({"node-scheduler.location-aware-scheduling-enabled", "node-scheduler.multiple-tasks-per-node-enabled"}) public class NodeSchedulerConfig @@ -34,6 +33,8 @@ public static class NetworkTopologyType private int minCandidates = 10; private boolean includeCoordinator = true; private int maxSplitsPerNode = 100; + private int maxSplitsPerTask = 10; + private boolean scheduleSplitsBasedOnTaskLoad; private int maxPendingSplitsPerTask = 10; private int maxUnacknowledgedSplitsPerTask = 500; private String networkTopology = NetworkTopologyType.LEGACY; @@ -107,6 +108,33 @@ public NodeSchedulerConfig setMaxSplitsPerNode(int maxSplitsPerNode) return this; } + public int getMaxSplitsPerTask() + { + return maxSplitsPerTask; + } + + @Config("node-scheduler.max-splits-per-task") + @ConfigDescription("The number of splits weighted at the standard split weight that are allowed to be scheduled for each task " + + "when scheduling splits based on the task load.") + public NodeSchedulerConfig setMaxSplitsPerTask(int maxSplitsPerTask) + { + this.maxSplitsPerTask = maxSplitsPerTask; + return this; + } + + public boolean isScheduleSplitsBasedOnTaskLoad() + { + return scheduleSplitsBasedOnTaskLoad; + } + + @Config("node-scheduler.schedule-splits-based-on-task-load") + @ConfigDescription("Schedule splits based on task load, rather than on the node load") + public NodeSchedulerConfig setScheduleSplitsBasedOnTaskLoad(boolean scheduleSplitsBasedOnTaskLoad) + { + this.scheduleSplitsBasedOnTaskLoad = scheduleSplitsBasedOnTaskLoad; + return this; + } + @Min(1) public int getMaxUnacknowledgedSplitsPerTask() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerExporter.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerExporter.java index d76c6bd4971e6..c9fdacd1709de 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerExporter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/NodeSchedulerExporter.java @@ -14,14 +14,13 @@ package com.facebook.presto.execution.scheduler; import com.facebook.airlift.stats.CounterStat; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.JmxException; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.ObjectNames; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PartitionedOutputBufferManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PartitionedOutputBufferManager.java index fc793e37837a1..fb9727eb6c3fc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PartitionedOutputBufferManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PartitionedOutputBufferManager.java @@ -17,8 +17,7 @@ import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; import com.facebook.presto.spi.plan.PartitioningHandle; import com.google.common.collect.ImmutableMap; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java index 05c8d2789a6d9..162a8fe89ddaf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/PhasedExecutionSchedule.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.execution.scheduler; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.execution.SqlStageExecution; import com.facebook.presto.execution.StageExecutionState; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MergeJoinNode; import com.facebook.presto.spi.plan.PlanFragmentId; @@ -24,7 +26,6 @@ import com.facebook.presto.spi.plan.UnionNode; import com.facebook.presto.sql.planner.PlanFragment; import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.annotations.VisibleForTesting; @@ -36,8 +37,6 @@ import org.jgrapht.graph.DefaultEdge; import org.jgrapht.traverse.TopologicalOrderIterator; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledOutputBufferManager.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledOutputBufferManager.java index c158ae526b481..ac938f9cb2314 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledOutputBufferManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledOutputBufferManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.execution.buffer.OutputBuffers; import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; import java.util.function.Consumer; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java index 4ffc5da5a25a4..3d63bda706327 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/ScaledWriterScheduler.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution.scheduler; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.execution.RemoteTask; import com.facebook.presto.execution.SqlStageExecution; import com.facebook.presto.execution.TaskStatus; @@ -20,7 +21,6 @@ import com.facebook.presto.metadata.InternalNode; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; import java.util.Collection; import java.util.HashSet; @@ -34,6 +34,7 @@ import static com.facebook.presto.execution.scheduler.ScheduleResult.BlockedReason.WRITER_SCALING; import static com.facebook.presto.spi.StandardErrorCode.NO_NODES_AVAILABLE; import static com.facebook.presto.util.Failures.checkCondition; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -73,6 +74,8 @@ public ScaledWriterScheduler( this.writerMinSizeBytes = requireNonNull(writerMinSize, "minWriterSize is null").toBytes(); this.optimizedScaleWriterProducerBuffer = optimizedScaleWriterProducerBuffer; this.initialTaskCount = requireNonNull(initialTaskCount, "initialTaskCount is null"); + + future.set(null); } public void finish() @@ -84,13 +87,20 @@ public void finish() @Override public ScheduleResult schedule() { - List writers = scheduleTasks(getNewTaskCount()); + List writers = ImmutableList.of(); - future.set(null); - future = SettableFuture.create(); - executor.schedule(() -> future.set(null), 200, MILLISECONDS); + if (future.isDone()) { + writers = scheduleTasks(getNewTaskCount()); + future = SettableFuture.create(); + executor.schedule(() -> future.set(null), 200, MILLISECONDS); + } - return ScheduleResult.blocked(done.get(), writers, future, WRITER_SCALING, 0); + return ScheduleResult.blocked( + done.get(), + writers, + nonCancellationPropagating(future), + WRITER_SCALING, + 0); } private int getNewTaskCount() diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecution.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecution.java index a829aa8542070..3241cb617b62f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecution.java @@ -15,8 +15,7 @@ import com.facebook.presto.execution.StageExecutionState; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java index 26b77ef163ca5..79204a8601f6c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SectionExecutionFactory.java @@ -51,8 +51,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; @@ -305,7 +304,7 @@ private StageScheduler createStageScheduler( SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stageExecution::getAllTasks); checkArgument(!plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution()); - return newSourcePartitionedSchedulerAsStageScheduler(stageExecution, planNodeId, splitSource, placementPolicy, splitBatchSize); + return newSourcePartitionedSchedulerAsStageScheduler(stageExecution, planNodeId, splitSource, placementPolicy, splitBatchSize, cteMaterializationTracker); } else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { Supplier> sourceTasksProvider = () -> childStageExecutions.stream() diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java index e742881690cda..55e8b38d496b3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SourcePartitionedScheduler.java @@ -135,15 +135,26 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, - int splitBatchSize) + int splitBatchSize, + CTEMaterializationTracker cteMaterializationTracker) { SourcePartitionedScheduler sourcePartitionedScheduler = new SourcePartitionedScheduler(stage, partitionedNode, splitSource, splitPlacementPolicy, splitBatchSize, false); sourcePartitionedScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED); - return new StageScheduler() { + return new StageScheduler() + { @Override public ScheduleResult schedule() { + List> cteMaterializationFutures = cteMaterializationTracker.waitForCteMaterialization(stage); + if (!cteMaterializationFutures.isEmpty()) { + return ScheduleResult.blocked( + false, + ImmutableList.of(), + whenAnyComplete(cteMaterializationFutures), + ScheduleResult.BlockedReason.WAITING_FOR_CTE_MATERIALIZATION, + 0); + } ScheduleResult scheduleResult = sourcePartitionedScheduler.schedule(); sourcePartitionedScheduler.drainCompletelyScheduledLifespans(); return scheduleResult; @@ -258,6 +269,7 @@ else if (scheduleGroup.pendingSplits.isEmpty()) { Multimap splitAssignment = ImmutableMultimap.of(); if (!scheduleGroup.pendingSplits.isEmpty()) { if (!scheduleGroup.placementFuture.isDone()) { + overallBlockedFutures.add(scheduleGroup.placementFuture); anyBlockedOnPlacements = true; continue; } @@ -375,6 +387,8 @@ else if (scheduleGroup.pendingSplits.isEmpty()) { blockedReason = anyBlockedOnPlacements ? SPLIT_QUEUES_FULL : NO_ACTIVE_DRIVER_GROUP; } + verify(!overallBlockedFutures.isEmpty() || blockedReason == NO_ACTIVE_DRIVER_GROUP, "overallBlockedFutures is expected to be not empty when blocked on placement or splits"); + overallBlockedFutures.add(whenFinishedOrNewLifespanAdded); return ScheduleResult.blocked( false, diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java index 83b4bbaa8c69c..af137ef9c514d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SplitSchedulerStats.java @@ -16,11 +16,10 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.DistributionStat; import com.facebook.airlift.stats.TimeStat; +import com.google.errorprone.annotations.ThreadSafe; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; - import static java.util.concurrent.TimeUnit.MILLISECONDS; @ThreadSafe diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java index 93a707e4e2da1..a3b0ebf06b787 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQueryScheduler.java @@ -16,6 +16,7 @@ import com.facebook.airlift.concurrent.SetThreadName; import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.execution.BasicStageExecutionStats; @@ -53,7 +54,6 @@ import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.sun.management.ThreadMXBean; -import io.airlift.units.Duration; import org.apache.http.client.utils.URIBuilder; import java.lang.management.ManagementFactory; @@ -451,6 +451,7 @@ private void schedule() .flatMap(schedule -> schedule.getStagesToSchedule().stream()) .collect(toImmutableList()); + boolean allBlocked = true; for (StageExecutionAndScheduler stageExecutionAndScheduler : executionsToSchedule) { long startCpuNanos = THREAD_MX_BEAN.getCurrentThreadCpuTime(); long startWallNanos = System.nanoTime(); @@ -477,6 +478,9 @@ private void schedule() else if (!result.getBlocked().isDone()) { blockedStages.add(result.getBlocked()); } + else { + allBlocked = false; + } stageExecutionAndScheduler.getStageLinkage() .processScheduleResults(stageExecution.getState(), result.getNewTasks()); schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled()); @@ -535,13 +539,13 @@ else if (!result.getBlocked().isDone()) { } // wait for a state change and then schedule again - if (!blockedStages.isEmpty()) { + if (allBlocked && !blockedStages.isEmpty()) { try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); } - for (ListenableFuture blockedStage : blockedStages) { - blockedStage.cancel(true); - } + } + for (ListenableFuture blockedStage : blockedStages) { + blockedStage.cancel(true); } } } @@ -665,6 +669,7 @@ private Optional performRuntimeOptimizations(StreamingSubPlan subP fragment.getPartitioning(), scheduleOrder(newRoot), fragment.getPartitioningScheme(), + fragment.getOutputOrderingScheme(), fragment.getStageExecutionDescriptor(), fragment.isOutputTableWriterFragment(), estimatedStatsAndCosts, diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQuerySchedulerInterface.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQuerySchedulerInterface.java index 03a2d0e7e62d7..94aa0f7444e0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQuerySchedulerInterface.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/SqlQuerySchedulerInterface.java @@ -14,10 +14,10 @@ package com.facebook.presto.execution.scheduler; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.BasicStageExecutionStats; import com.facebook.presto.execution.StageId; import com.facebook.presto.execution.StageInfo; -import io.airlift.units.Duration; public interface SqlQuerySchedulerInterface { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java index 63d952c1c0ce1..cff7bec518dde 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/TableWriteInfo.java @@ -14,12 +14,18 @@ package com.facebook.presto.execution.scheduler; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.Session; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.metadata.AnalyzeTableHandle; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.fasterxml.jackson.annotation.JsonCreator; @@ -38,6 +44,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; +@ThriftStruct public class TableWriteInfo { private final Optional writerTarget; @@ -53,6 +60,12 @@ public TableWriteInfo( checkArgument(!analyzeTableHandle.isPresent() || !writerTarget.isPresent(), "analyzeTableHandle is present, so no other fields should be present"); } + @ThriftConstructor + public TableWriteInfo(ExecutionWriterTargetUnion writerTargetUnion, Optional analyzeTableHandle) + { + this(Optional.ofNullable(writerTargetUnion).map(ExecutionWriterTargetUnion::toExecutionWriterTarget), analyzeTableHandle == null ? Optional.empty() : analyzeTableHandle); + } + public static TableWriteInfo createTableWriteInfo(StreamingSubPlan plan, Metadata metadata, Session session) { Optional writerTarget = createWriterTarget(plan, metadata, session); @@ -91,6 +104,24 @@ private static Optional createWriterTarget(Optional

mergeHandle = mergeTarget.getMergeHandle(); + return Optional.of(new ExecutionWriterTarget.MergeHandle(mergeHandle.orElseThrow( + () -> new VerifyException("mergeHandle is absent: " + target.getClass().getSimpleName())))); + } throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName()); } @@ -159,7 +190,14 @@ public Optional getWriterTarget() return writerTarget; } + @ThriftField(value = 1, name = "writerTargetUnion") + public ExecutionWriterTargetUnion getWriterTargetUnion() + { + return writerTarget.map(ExecutionWriterTargetUnion::fromExecutionWriterTarget).orElse(null); + } + @JsonProperty + @ThriftField(2) public Optional getAnalyzeTableHandle() { return analyzeTableHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicy.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicy.java new file mode 100644 index 0000000000000..175c18afe5405 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicy.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.presto.metadata.InternalNodeManager; + +/** + * Interface for policies that determine if cluster is overloaded. + * Implementations can check various metrics from NodeStats to determine + * if a worker is overloaded and queries should be throttled. + */ +public interface ClusterOverloadPolicy +{ + /** + * Checks if cluster is overloaded. + * + * @param nodeManager The node manager to get node information + * @return true if cluster is overloaded, false otherwise + */ + boolean isClusterOverloaded(InternalNodeManager nodeManager); + + /** + * Gets the name of the policy. + * + * @return The name of the policy + */ + String getName(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyFactory.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyFactory.java new file mode 100644 index 0000000000000..2d6b9b82a8962 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyFactory.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.google.common.collect.ImmutableMap; +import jakarta.inject.Inject; + +import java.util.Map; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * Factory for creating ClusterOverloadPolicy instances. + * This allows for extensible policy creation based on configuration. + */ +public class ClusterOverloadPolicyFactory +{ + private final Map policies; + + @Inject + public ClusterOverloadPolicyFactory(ClusterOverloadPolicy clusterOverloadPolicy) + { + requireNonNull(clusterOverloadPolicy, "clusterOverloadPolicy is null"); + + // Register available policies + ImmutableMap.Builder policiesBuilder = ImmutableMap.builder(); + + // Add the default overload policy - use the injected instance + policiesBuilder.put(clusterOverloadPolicy.getName(), clusterOverloadPolicy); + + // Add more policies here as they are implemented + this.policies = policiesBuilder.build(); + } + + /** + * Get a policy by name. + * + * @param name The name of the policy to get + * @return The policy, or empty if no policy with that name exists + */ + public Optional getPolicy(String name) + { + return Optional.ofNullable(policies.get(name)); + } + + /** + * Get the default policy. + * + * @return The default policy + */ + public ClusterOverloadPolicy getDefaultPolicy() + { + // Default to CPU/Memory policy + return policies.get("cpu-memory-overload"); + } + + /** + * Get all available policies. + * + * @return Map of policy name to policy + */ + public Map getPolicies() + { + return policies; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyModule.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyModule.java new file mode 100644 index 0000000000000..207a0fd6dd198 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadPolicyModule.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.presto.execution.ClusterOverloadConfig; +import com.google.inject.Binder; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.google.inject.Scopes.SINGLETON; + +/** + * Provides bindings for the node overload policy and cluster resource checker. + */ +public class ClusterOverloadPolicyModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + // Bind the default node overload policy + binder.bind(ClusterOverloadPolicy.class).to(CpuMemoryOverloadPolicy.class).in(SINGLETON); + + // Bind the node overload policy factory + binder.bind(ClusterOverloadPolicyFactory.class).in(SINGLETON); + + // Bind the cluster resource checker + binder.bind(ClusterResourceChecker.class).in(SINGLETON); + + // Bind the cluster overload config + configBinder(binder).bindConfig(ClusterOverloadConfig.class); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadStateListener.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadStateListener.java new file mode 100644 index 0000000000000..86d0845545e01 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterOverloadStateListener.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +/** + * Listener interface for receiving notifications about cluster overload state changes. + * This interface allows components to react to changes in cluster capacity and + * resource availability, particularly when the cluster transitions from an overloaded + * state back to a normal state. + */ +public interface ClusterOverloadStateListener +{ + /** + * Called when the cluster enters an overloaded state. + * This indicates that the cluster resources are under stress and + * new query admissions should be restricted. + */ + void onClusterEnteredOverloadedState(); + + /** + * Called when the cluster exits the overloaded state. + * This indicates that cluster resources have recovered and + * normal query processing can resume. + */ + void onClusterExitedOverloadedState(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterResourceChecker.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterResourceChecker.java new file mode 100644 index 0000000000000..683fd703e970d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/ClusterResourceChecker.java @@ -0,0 +1,248 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.stats.TimeStat; +import com.facebook.presto.execution.ClusterOverloadConfig; +import com.facebook.presto.metadata.InternalNodeManager; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.inject.Singleton; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import static com.facebook.airlift.concurrent.Threads.threadsNamed; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; + +/** + * Provides methods to check if more queries can be run on the cluster + * based on various resource constraints. + */ +@Singleton +@ThreadSafe +public class ClusterResourceChecker +{ + private static final Logger log = Logger.get(ClusterResourceChecker.class); + + private final ClusterOverloadPolicy clusterOverloadPolicy; + private final ClusterOverloadConfig config; + private final AtomicBoolean cachedOverloadState = new AtomicBoolean(false); + private final AtomicLong lastCheckTimeMillis = new AtomicLong(0); + private final CounterStat overloadDetectionCount = new CounterStat(); + private final TimeStat timeSinceLastCheck = new TimeStat(); + private final AtomicLong overloadStartTimeMillis = new AtomicLong(0); + private final CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>(); + + private final ScheduledExecutorService overloadCheckerExecutor; + private final InternalNodeManager nodeManager; + + @Inject + public ClusterResourceChecker(ClusterOverloadPolicy clusterOverloadPolicy, ClusterOverloadConfig config, InternalNodeManager nodeManager) + { + this.clusterOverloadPolicy = requireNonNull(clusterOverloadPolicy, "clusterOverloadPolicy is null"); + this.config = requireNonNull(config, "config is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.overloadCheckerExecutor = newSingleThreadScheduledExecutor(threadsNamed("cluster-overload-checker-%s")); + } + + @PostConstruct + public void start() + { + if (config.isClusterOverloadThrottlingEnabled()) { + long checkIntervalMillis = Math.max(1000, config.getOverloadCheckCacheTtlInSecs() * 1000L); + overloadCheckerExecutor.scheduleWithFixedDelay(() -> { + try { + performPeriodicOverloadCheck(); + } + catch (Exception e) { + log.error(e, "Error polling cluster overload state"); + } + }, checkIntervalMillis, checkIntervalMillis, TimeUnit.MILLISECONDS); + log.info("Started periodic cluster overload checker with interval: %d milliseconds", checkIntervalMillis); + // Perform initial check + performPeriodicOverloadCheck(); + } + } + + @PreDestroy + public void stop() + { + overloadCheckerExecutor.shutdownNow(); + } + + /** + * Registers a listener to be notified when the cluster exits the overloaded state. + * + * @param listener the listener to register + */ + public void addListener(ClusterOverloadStateListener listener) + { + requireNonNull(listener, "listener is null"); + listeners.add(listener); + } + + /** + * Removes a previously registered listener. + * + * @param listener the listener to remove + */ + public void removeListener(ClusterOverloadStateListener listener) + { + listeners.remove(listener); + } + + /** + * Performs a periodic check of cluster overload state. + * This method is called by the periodic task when throttling is enabled. + * Updates JMX metrics and notifies listeners when cluster exits overloaded state. + */ + private void performPeriodicOverloadCheck() + { + try { + long currentTimeMillis = System.currentTimeMillis(); + long lastCheckTime = lastCheckTimeMillis.get(); + + if (lastCheckTime > 0) { + timeSinceLastCheck.add(currentTimeMillis - lastCheckTime, TimeUnit.MILLISECONDS); + } + + boolean isOverloaded = clusterOverloadPolicy.isClusterOverloaded(nodeManager); + synchronized (this) { + boolean wasOverloaded = cachedOverloadState.getAndSet(isOverloaded); + lastCheckTimeMillis.set(currentTimeMillis); + + if (isOverloaded && !wasOverloaded) { + overloadDetectionCount.update(1); + overloadStartTimeMillis.set(currentTimeMillis); + log.info("Cluster entered overloaded state via periodic check"); + } + else if (!isOverloaded && wasOverloaded) { + long overloadDuration = currentTimeMillis - overloadStartTimeMillis.get(); + log.info("Cluster exited overloaded state after %d ms via periodic check", overloadDuration); + overloadStartTimeMillis.set(0); + // Notify listeners that cluster exited overload state + notifyClusterExitedOverloadedState(); + } + } + + log.debug("Periodic overload check completed: %s", isOverloaded ? "OVERLOADED" : "NOT OVERLOADED"); + } + catch (Exception e) { + log.error(e, "Error during periodic cluster overload check"); + } + } + + /** + * Returns the current overload state of the cluster. + * @return true if cluster is overloaded, false otherwise + */ + public boolean isClusterCurrentlyOverloaded() + { + if (!config.isClusterOverloadThrottlingEnabled()) { + return false; + } + + return cachedOverloadState.get(); + } + + /** + * Notifies all registered listeners that the cluster has exited the overloaded state. + */ + private void notifyClusterExitedOverloadedState() + { + for (ClusterOverloadStateListener listener : listeners) { + listener.onClusterExitedOverloadedState(); + } + } + + /** + * Returns whether cluster overload throttling is enabled. + * When disabled, the cluster overload check will be bypassed. + * + * @return true if throttling is enabled, false otherwise + */ + @Managed + public boolean isClusterOverloadThrottlingEnabled() + { + return config.isClusterOverloadThrottlingEnabled(); + } + + /** + * Returns whether the cluster is currently in an overloaded state. + * This is exposed as a JMX metric for monitoring. + * + * @return true if the cluster is overloaded, false otherwise + */ + @Managed + public boolean isClusterOverloaded() + { + return cachedOverloadState.get(); + } + + /** + * Returns the number of times the cluster has entered an overloaded state. + * + * @return counter of overload detections + */ + @Managed + @Nested + public CounterStat getOverloadDetectionCount() + { + return overloadDetectionCount; + } + + /** + * Returns statistics about the time between overload checks. + * + * @return time statistics for overload checks + */ + @Managed + @Nested + public TimeStat getTimeSinceLastCheck() + { + return timeSinceLastCheck; + } + + /** + * Returns the duration in milliseconds that the cluster has been in an overloaded state. + * Returns 0 if the cluster is not currently overloaded. + * + * Note: This method reads two atomic fields but doesn't need synchronization because: + * 1. Single writer (periodic task) ensures consistent updates + * 2. Atomic fields provide memory visibility guarantees + * 3. Slight inconsistency in edge cases is acceptable for monitoring metrics + * + * @return duration in milliseconds of current overload state + */ + @Managed + public long getOverloadDurationMillis() + { + long startTime = overloadStartTimeMillis.get(); + if (startTime == 0 || !cachedOverloadState.get()) { + return 0; + } + return System.currentTimeMillis() - startTime; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/CpuMemoryOverloadPolicy.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/CpuMemoryOverloadPolicy.java new file mode 100644 index 0000000000000..14d5387454178 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/clusterOverload/CpuMemoryOverloadPolicy.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.execution.ClusterOverloadConfig; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.NodeLoadMetrics; +import jakarta.inject.Inject; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.NodeState.ACTIVE; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * A policy that checks if cluster is overloaded based on CPU or memory metrics. + * Supports two modes of operation: + * - Percentage-based: Checks if the percentage of overloaded workers exceeds a threshold + * - Count-based: Checks if the absolute count of overloaded workers exceeds a threshold + */ +public class CpuMemoryOverloadPolicy + implements ClusterOverloadPolicy +{ + private static final Logger log = Logger.get(CpuMemoryOverloadPolicy.class); + + private final double allowedOverloadWorkersPct; + private final double allowedOverloadWorkersCnt; + private final String policyType; + + @Inject + public CpuMemoryOverloadPolicy(ClusterOverloadConfig config) + { + this.allowedOverloadWorkersPct = config.getAllowedOverloadWorkersPct(); + this.allowedOverloadWorkersCnt = config.getAllowedOverloadWorkersCnt(); + this.policyType = requireNonNull(config.getOverloadPolicyType(), "policyType is null"); + } + + @Override + public boolean isClusterOverloaded(InternalNodeManager nodeManager) + { + Set activeNodes = nodeManager.getNodes(ACTIVE); + if (activeNodes.isEmpty()) { + return false; + } + + OverloadStats stats = collectOverloadStats(activeNodes, nodeManager); + return evaluateOverload(stats, activeNodes.size()); + } + + private OverloadStats collectOverloadStats(Set activeNodes, InternalNodeManager nodeManager) + { + Set overloadedNodeIds = new HashSet<>(); + int overloadedWorkersCnt = 0; + + for (InternalNode node : activeNodes) { + Optional metricsOptional = nodeManager.getNodeLoadMetrics(node.getNodeIdentifier()); + String nodeId = node.getNodeIdentifier(); + + if (!metricsOptional.isPresent()) { + continue; + } + + NodeLoadMetrics metrics = metricsOptional.get(); + + // Check for CPU overload + if (metrics.getCpuOverload()) { + overloadedNodeIds.add(String.format("%s (CPU overloaded)", nodeId)); + overloadedWorkersCnt++; + continue; + } + + // Check for memory overload + if (metrics.getMemoryOverload()) { + overloadedNodeIds.add(String.format("%s (Memory overloaded)", nodeId)); + overloadedWorkersCnt++; + } + } + + return new OverloadStats(overloadedNodeIds, overloadedWorkersCnt); + } + + private boolean evaluateOverload(OverloadStats stats, int totalNodes) + { + boolean isOverloaded; + + if (ClusterOverloadConfig.OVERLOAD_POLICY_PCT_BASED.equals(policyType)) { + double overloadedWorkersPct = (double) stats.getOverloadedWorkersCnt() / totalNodes; + isOverloaded = overloadedWorkersPct > allowedOverloadWorkersPct; + + if (isOverloaded) { + logOverload( + String.format("%s%% of workers are overloaded (threshold: %s%%)", + format("%.2f", overloadedWorkersPct * 100), + format("%.2f", allowedOverloadWorkersPct * 100)), + stats.getOverloadedNodeIds()); + } + } + else if (ClusterOverloadConfig.OVERLOAD_POLICY_CNT_BASED.equals(policyType)) { + isOverloaded = stats.getOverloadedWorkersCnt() > allowedOverloadWorkersCnt; + + if (isOverloaded) { + logOverload( + String.format("%s workers are overloaded (threshold: %s workers)", + stats.getOverloadedWorkersCnt(), allowedOverloadWorkersCnt), + stats.getOverloadedNodeIds()); + } + } + else { + throw new IllegalStateException("Unknown cluster overload policy type: " + policyType); + } + + return isOverloaded; + } + + private void logOverload(String message, Set overloadedNodeIds) + { + log.warn("Cluster is overloaded: " + message); + if (!overloadedNodeIds.isEmpty()) { + log.warn("Overloaded nodes: %s", String.join(", ", overloadedNodeIds)); + } + } + + @Override + public String getName() + { + return "cpu-memory-overload-" + + (ClusterOverloadConfig.OVERLOAD_POLICY_PCT_BASED.equals(policyType) ? "pct" : "cnt"); + } + + // Helper class to encapsulate overload statistics + private static class OverloadStats + { + private final Set overloadedNodeIds; + private final int overloadedWorkersCnt; + + public OverloadStats(Set overloadedNodeIds, int overloadedWorkersCnt) + { + this.overloadedNodeIds = overloadedNodeIds; + this.overloadedWorkersCnt = overloadedWorkersCnt; + } + + public Set getOverloadedNodeIds() + { + return overloadedNodeIds; + } + + public int getOverloadedWorkersCnt() + { + return overloadedWorkersCnt; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/DynamicLifespanScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/DynamicLifespanScheduler.java index e07610a509a25..41f1be5c02c5e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/DynamicLifespanScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/DynamicLifespanScheduler.java @@ -20,12 +20,11 @@ import com.facebook.presto.spi.connector.ConnectorPartitionHandle; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import it.unimi.dsi.fastutil.ints.IntArrayFIFOQueue; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; -import javax.annotation.concurrent.GuardedBy; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/FixedLifespanScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/FixedLifespanScheduler.java index ba2f542e18903..ce0c6c6bd4a5b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/FixedLifespanScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/group/FixedLifespanScheduler.java @@ -20,14 +20,13 @@ import com.facebook.presto.spi.connector.ConnectorPartitionHandle; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntListIterator; -import javax.annotation.concurrent.GuardedBy; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/NodeSelectionStats.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/NodeSelectionStats.java index a5fcc78a01ec5..4de5b91586be8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/NodeSelectionStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/NodeSelectionStats.java @@ -14,11 +14,10 @@ package com.facebook.presto.execution.scheduler.nodeSelection; import com.facebook.airlift.stats.CounterStat; +import com.google.errorprone.annotations.ThreadSafe; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; - @ThreadSafe public class NodeSelectionStats { diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleNodeSelector.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleNodeSelector.java index a7205852ea4f5..3345ce52c5be0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleNodeSelector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleNodeSelector.java @@ -16,6 +16,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.execution.NodeTaskMap; import com.facebook.presto.execution.RemoteTask; +import com.facebook.presto.execution.TaskStatus; import com.facebook.presto.execution.scheduler.BucketNodeMap; import com.facebook.presto.execution.scheduler.InternalNodeInfo; import com.facebook.presto.execution.scheduler.NodeAssignmentStats; @@ -35,8 +36,10 @@ import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; @@ -71,9 +74,11 @@ public class SimpleNodeSelector private final NodeSelectionStats nodeSelectionStats; private final NodeTaskMap nodeTaskMap; private final boolean includeCoordinator; + private final boolean scheduleSplitsBasedOnTaskLoad; private final AtomicReference> nodeMap; private final int minCandidates; private final long maxSplitsWeightPerNode; + private final long maxSplitsWeightPerTask; private final long maxPendingSplitsWeightPerTask; private final int maxUnacknowledgedSplitsPerTask; private final int maxTasksPerStage; @@ -84,9 +89,11 @@ public SimpleNodeSelector( NodeSelectionStats nodeSelectionStats, NodeTaskMap nodeTaskMap, boolean includeCoordinator, + boolean scheduleSplitsBasedOnTaskLoad, Supplier nodeMap, int minCandidates, long maxSplitsWeightPerNode, + long maxSplitsWeightPerTask, long maxPendingSplitsWeightPerTask, int maxUnacknowledgedSplitsPerTask, int maxTasksPerStage, @@ -96,9 +103,11 @@ public SimpleNodeSelector( this.nodeSelectionStats = requireNonNull(nodeSelectionStats, "nodeSelectionStats is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.includeCoordinator = includeCoordinator; + this.scheduleSplitsBasedOnTaskLoad = scheduleSplitsBasedOnTaskLoad; this.nodeMap = new AtomicReference<>(nodeMap); this.minCandidates = minCandidates; this.maxSplitsWeightPerNode = maxSplitsWeightPerNode; + this.maxSplitsWeightPerTask = maxSplitsWeightPerTask; this.maxPendingSplitsWeightPerTask = maxPendingSplitsWeightPerTask; this.maxUnacknowledgedSplitsPerTask = maxUnacknowledgedSplitsPerTask; checkArgument(maxUnacknowledgedSplitsPerTask > 0, "maxUnacknowledgedSplitsPerTask must be > 0, found: %s", maxUnacknowledgedSplitsPerTask); @@ -149,6 +158,11 @@ public SplitPlacementResult computeAssignments(Set splits, List blockedExactNodes = new HashSet<>(); boolean splitWaitingForAnyNode = false; + Optional> taskLoadSplitWeightProvider = Optional.empty(); + if (this.scheduleSplitsBasedOnTaskLoad) { + taskLoadSplitWeightProvider = Optional.of(createTaskLoadSplitWeightProvider(existingTasks, assignmentStats)); + } + NodeProvider nodeProvider = nodeMap.getNodeProvider(maxPreferredNodes); OptionalInt preferredNodeCount = OptionalInt.empty(); for (Split split : splits) { @@ -179,9 +193,16 @@ public SplitPlacementResult computeAssignments(Set splits, List chosenNodeInfo = chooseLeastBusyNode(splitWeight, candidateNodes, assignmentStats::getTotalSplitsWeight, preferredNodeCount, maxSplitsWeightPerNode, assignmentStats); - if (!chosenNodeInfo.isPresent()) { - chosenNodeInfo = chooseLeastBusyNode(splitWeight, candidateNodes, assignmentStats::getQueuedSplitsWeightForStage, preferredNodeCount, maxPendingSplitsWeightPerTask, assignmentStats); + Optional chosenNodeInfo = Optional.empty(); + + if (taskLoadSplitWeightProvider.isPresent()) { + chosenNodeInfo = chooseLeastBusyNode(splitWeight, candidateNodes, taskLoadSplitWeightProvider.get(), preferredNodeCount, maxSplitsWeightPerTask, assignmentStats); + } + else { + chosenNodeInfo = chooseLeastBusyNode(splitWeight, candidateNodes, assignmentStats::getTotalSplitsWeight, preferredNodeCount, maxSplitsWeightPerNode, assignmentStats); + if (!chosenNodeInfo.isPresent()) { + chosenNodeInfo = chooseLeastBusyNode(splitWeight, candidateNodes, assignmentStats::getQueuedSplitsWeightForStage, preferredNodeCount, maxPendingSplitsWeightPerTask, assignmentStats); + } } if (chosenNodeInfo.isPresent()) { @@ -223,6 +244,26 @@ public SplitPlacementResult computeAssignments(Set splits, List createTaskLoadSplitWeightProvider(List existingTasks, NodeAssignmentStats assignmentStats) + { + // Create a map from nodeId to RemoteTask for efficient lookup + Map tasksByNodeId = new HashMap<>(); + for (RemoteTask task : existingTasks) { + tasksByNodeId.put(task.getNodeId(), task); + } + + return node -> { + RemoteTask remoteTask = tasksByNodeId.get(node.getNodeIdentifier()); + if (remoteTask == null) { + // No task for this node, return only the queued splits weight for the stage + return assignmentStats.getQueuedSplitsWeightForStage(node); + } + + TaskStatus taskStatus = remoteTask.getTaskStatus(); + return taskStatus.getRunningPartitionedSplitsWeight() + assignmentStats.getQueuedSplitsWeightForStage(node); + }; + } + protected Optional chooseLeastBusyNode(SplitWeight splitWeight, List candidateNodes, ToLongFunction splitWeightProvider, OptionalInt preferredNodeCount, long maxSplitsWeight, NodeAssignmentStats assignmentStats) { long minWeight = Long.MAX_VALUE; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelector.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelector.java index cb3c7291909d3..1b2592d59a009 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelector.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.scheduler.nodeSelection; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.execution.NodeTaskMap; import com.facebook.presto.execution.QueryManager; @@ -39,7 +40,6 @@ import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.time.Instant; import java.util.Comparator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelectorConfig.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelectorConfig.java index 71462289c5460..eea9b72195e53 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelectorConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/SimpleTtlNodeSelectorConfig.java @@ -14,7 +14,7 @@ package com.facebook.presto.execution.scheduler.nodeSelection; import com.facebook.airlift.configuration.Config; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/TopologyAwareNodeSelector.java b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/TopologyAwareNodeSelector.java index a4a351cd3f777..cd7e69c63f8f7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/TopologyAwareNodeSelector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/scheduler/nodeSelection/TopologyAwareNodeSelector.java @@ -37,8 +37,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.HashSet; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/execution/warnings/DefaultWarningCollector.java b/presto-main-base/src/main/java/com/facebook/presto/execution/warnings/DefaultWarningCollector.java index f0ab689802a33..71f863dc69b04 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/execution/warnings/DefaultWarningCollector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/execution/warnings/DefaultWarningCollector.java @@ -21,9 +21,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Multimap; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java index 730b1869a471a..499f3f58088af 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/index/IndexHandleJacksonModule.java @@ -13,20 +13,47 @@ */ package com.facebook.presto.index; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.AbstractTypedJacksonModule; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorIndexHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class IndexHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public IndexHandleJacksonModule(HandleResolver handleResolver) + public IndexHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorIndexHandle.class, handleResolver::getId, - handleResolver::getIndexHandleClass); + handleResolver::getIndexHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorIndexHandleCodec)); + } + + public IndexHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorIndexHandle.class, + handleResolver::getId, + handleResolver::getIndexHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryLeakDetector.java b/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryLeakDetector.java index 74b0b2e3c71aa..d1364f2bac4b6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryLeakDetector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryLeakDetector.java @@ -14,13 +14,12 @@ package com.facebook.presto.memory; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.server.BasicQueryInfo; import com.facebook.presto.spi.QueryId; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.Map; import java.util.Map.Entry; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryPool.java b/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryPool.java index b192e320d88be..c7cb9ae504e6d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryPool.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/ClusterMemoryPool.java @@ -19,11 +19,10 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.memory.MemoryPoolInfo; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/ForMemoryManager.java b/presto-main-base/src/main/java/com/facebook/presto/memory/ForMemoryManager.java index 469a13312882a..84dbe6115c5a7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/ForMemoryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/ForMemoryManager.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.memory; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManager.java b/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManager.java index db836ab7fa3cc..844ab5088001e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManager.java @@ -13,26 +13,25 @@ */ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.memory.MemoryPoolInfo; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.memory.NodeMemoryConfig.QUERY_MAX_MEMORY_PER_NODE_CONFIG; import static com.facebook.presto.memory.NodeMemoryConfig.QUERY_MAX_TOTAL_MEMORY_PER_NODE_CONFIG; import static com.facebook.presto.memory.NodeMemoryConfig.QUERY_SOFT_MAX_MEMORY_PER_NODE_CONFIG; import static com.facebook.presto.memory.NodeMemoryConfig.QUERY_SOFT_MAX_TOTAL_MEMORY_PER_NODE_CONFIG; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManagerExporter.java b/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManagerExporter.java index d4c8c17ac9ddc..19ac4af1f36b6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManagerExporter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/LocalMemoryManagerExporter.java @@ -13,14 +13,13 @@ */ package com.facebook.presto.memory; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.JmxException; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.ObjectNames; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; - import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/LowMemoryMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/memory/LowMemoryMonitor.java index 634e26648fb9f..67c942728e0f7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/LowMemoryMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/LowMemoryMonitor.java @@ -16,10 +16,9 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.executor.TaskExecutor; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryInfo.java b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryInfo.java index 86f7dd737927b..73cbe09c9a9c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryInfo.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -21,7 +22,6 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryManagerConfig.java index 68a14787a8601..dd0ac9ebc47fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryManagerConfig.java @@ -16,14 +16,13 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.NotNull; -import javax.validation.constraints.NotNull; - -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.succinctBytes; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static java.util.concurrent.TimeUnit.MINUTES; @DefunctConfig({ diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryPool.java b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryPool.java index a3f18d8d03ffb..fa693a6ec643b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryPool.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryPool.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.execution.TaskId; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.memory.MemoryAllocation; @@ -21,12 +22,10 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.AbstractFuture; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import org.weakref.jmx.Managed; -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/NodeMemoryConfig.java b/presto-main-base/src/main/java/com/facebook/presto/memory/NodeMemoryConfig.java index af4c8f187494c..9b7fd353885a5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/NodeMemoryConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/NodeMemoryConfig.java @@ -15,11 +15,10 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; +import jakarta.validation.constraints.NotNull; -import javax.validation.constraints.NotNull; - -import static io.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; // This is separate from MemoryManagerConfig because it's difficult to test the default value of maxQueryMemoryPerNode public class NodeMemoryConfig diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/QueryContext.java b/presto-main-base/src/main/java/com/facebook/presto/memory/QueryContext.java index 358fb55ccda4c..74f9b7b501b9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/QueryContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/QueryContext.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.stats.GcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskStateMachine; @@ -30,10 +31,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.Collection; import java.util.Comparator; @@ -49,6 +48,8 @@ import java.util.function.BiPredicate; import java.util.function.Predicate; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalBroadcastMemoryLimit; import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalRevocableMemoryLimit; import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalTotalMemoryLimit; @@ -63,8 +64,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.Comparator.comparing; import static java.util.Map.Entry.comparingByValue; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/ReservedSystemMemoryConfig.java b/presto-main-base/src/main/java/com/facebook/presto/memory/ReservedSystemMemoryConfig.java index c868373c04506..8116eb4a84bbd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/ReservedSystemMemoryConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/memory/ReservedSystemMemoryConfig.java @@ -14,11 +14,10 @@ package com.facebook.presto.memory; import com.facebook.airlift.configuration.Config; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; +import jakarta.validation.constraints.NotNull; -import javax.validation.constraints.NotNull; - -import static io.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; // This is separate from MemoryManagerConfig because it's difficult to test the default value of reservedSystemMemory public class ReservedSystemMemoryConfig diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractPropertyManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractPropertyManager.java index 11b4b805d6a59..b432a9557b9c6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractPropertyManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractPropertyManager.java @@ -93,22 +93,46 @@ public final ImmutableMap.Builder getUserSpecifiedProperties( propertyName)); } - Object sqlObjectValue; + PropertyMetadata.AdditionalSqlTypeHandler usedSqlTypeHandler = null; + Object sqlObjectValue = null; try { sqlObjectValue = evaluatePropertyValue(sqlProperty.getValue(), property.getSqlType(), session, metadata, parameters); } catch (SemanticException e) { - throw new PrestoException(propertyError, - format("Invalid value for %s property '%s': Cannot convert '%s' to %s", - propertyType, - property.getName(), - sqlProperty.getValue(), - property.getSqlType()), e); + for (PropertyMetadata.AdditionalSqlTypeHandler additionalSqlTypeHandler : property.getAdditionalSqlTypeHandlers()) { + try { + sqlObjectValue = evaluatePropertyValue(sqlProperty.getValue(), additionalSqlTypeHandler.getSqlType(), session, metadata, parameters); + usedSqlTypeHandler = additionalSqlTypeHandler; + break; + } + catch (Exception ex) { + // ignored + } + } + + if (usedSqlTypeHandler == null) { + String additionalTypesInfo = property.getAdditionalSqlTypeHandlers().stream() + .map(handler -> handler.getSqlType().getDisplayName()) + .reduce((type, type2) -> type + ", " + type2) + .map(message -> " or any of [" + message + "]") + .orElse(""); + throw new PrestoException(propertyError, + format("Invalid value for %s property '%s': Cannot convert '%s' to %s", + propertyType, + property.getName(), + sqlProperty.getValue(), + property.getSqlType()) + additionalTypesInfo, e); + } } Object value; try { - value = property.decode(sqlObjectValue); + if (usedSqlTypeHandler != null) { + value = property.decode(sqlObjectValue, usedSqlTypeHandler.getDecoder()); + } + else { + value = property.decode(sqlObjectValue); + } } catch (Exception e) { throw new PrestoException(propertyError, diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java index 489bb076d764c..3176c95664419 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/AbstractTypedJacksonModule.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.fasterxml.jackson.annotation.JsonTypeInfo.Id; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; @@ -38,6 +40,7 @@ import com.google.common.cache.CacheBuilder; import java.io.IOException; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.function.Function; @@ -49,18 +52,38 @@ public abstract class AbstractTypedJacksonModule extends SimpleModule { private static final String TYPE_PROPERTY = "@type"; + private static final String DATA_PROPERTY = "customSerializedValue"; protected AbstractTypedJacksonModule( Class baseClass, Function nameResolver, - Function> classResolver) + Function> classResolver, + boolean binarySerializationEnabled, + Function>> codecExtractor) { super(baseClass.getSimpleName() + "Module", Version.unknownVersion()); - TypeIdResolver typeResolver = new InternalTypeResolver<>(nameResolver, classResolver); + requireNonNull(baseClass, "baseClass is null"); + requireNonNull(nameResolver, "nameResolver is null"); + requireNonNull(classResolver, "classResolver is null"); + requireNonNull(codecExtractor, "codecExtractor is null"); - addSerializer(baseClass, new InternalTypeSerializer<>(baseClass, typeResolver)); - addDeserializer(baseClass, new InternalTypeDeserializer<>(baseClass, typeResolver)); + if (binarySerializationEnabled) { + // Use codec serialization + addSerializer(baseClass, new CodecSerializer<>( + TYPE_PROPERTY, + DATA_PROPERTY, + codecExtractor, + nameResolver, + new InternalTypeResolver<>(nameResolver, classResolver))); + addDeserializer(baseClass, new CodecDeserializer<>(TYPE_PROPERTY, DATA_PROPERTY, codecExtractor, classResolver)); + } + else { + // Use legacy typed serialization + TypeIdResolver typeResolver = new InternalTypeResolver<>(nameResolver, classResolver); + addSerializer(baseClass, new InternalTypeSerializer<>(baseClass, typeResolver)); + addDeserializer(baseClass, new InternalTypeDeserializer<>(baseClass, typeResolver)); + } } private static class InternalTypeDeserializer diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/AnalyzeTableHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/AnalyzeTableHandle.java index 5e77adeaa1818..df97d5f1eb735 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/AnalyzeTableHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/AnalyzeTableHandle.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -23,6 +26,7 @@ import static java.util.Objects.requireNonNull; +@ThriftStruct public class AnalyzeTableHandle { private final ConnectorId connectorId; @@ -30,6 +34,7 @@ public class AnalyzeTableHandle private final ConnectorTableHandle connectorHandle; @JsonCreator + @ThriftConstructor public AnalyzeTableHandle( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, @@ -41,18 +46,21 @@ public AnalyzeTableHandle( } @JsonProperty + @ThriftField(1) public ConnectorId getConnectorId() { return connectorId; } @JsonProperty + @ThriftField(3) public ConnectorTableHandle getConnectorHandle() { return connectorHandle; } @JsonProperty + @ThriftField(2) public ConnectorTransactionHandle getTransactionHandle() { return transactionHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java index 40de75c838169..9c5ed82af2d1a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java @@ -24,18 +24,26 @@ import java.util.List; import java.util.Objects; +import static com.facebook.presto.metadata.BuiltInFunctionKind.ENGINE; import static java.util.Objects.requireNonNull; public class BuiltInFunctionHandle implements FunctionHandle { private final Signature signature; + private final BuiltInFunctionKind builtInFunctionKind; + + public BuiltInFunctionHandle(Signature signature) + { + this(signature, ENGINE); + } @JsonCreator - public BuiltInFunctionHandle(@JsonProperty("signature") Signature signature) + public BuiltInFunctionHandle(@JsonProperty("signature") Signature signature, @JsonProperty("builtInFunctionKind") BuiltInFunctionKind builtInFunctionKind) { this.signature = requireNonNull(signature, "signature is null"); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + this.builtInFunctionKind = requireNonNull(builtInFunctionKind, "builtInFunctionKind is null"); } @JsonProperty @@ -68,6 +76,12 @@ public CatalogSchemaName getCatalogSchemaName() return signature.getName().getCatalogSchemaName(); } + @JsonProperty + public BuiltInFunctionKind getBuiltInFunctionKind() + { + return builtInFunctionKind; + } + @Override public boolean equals(Object o) { @@ -78,13 +92,14 @@ public boolean equals(Object o) return false; } BuiltInFunctionHandle that = (BuiltInFunctionHandle) o; - return Objects.equals(signature, that.signature); + return Objects.equals(signature, that.signature) + && Objects.equals(builtInFunctionKind, that.builtInFunctionKind); } @Override public int hashCode() { - return Objects.hash(signature); + return Objects.hash(signature, builtInFunctionKind); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdateJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java similarity index 60% rename from presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdateJacksonModule.java rename to presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java index a0d5bb3d0c1ab..71ca25dd468cc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdateJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java @@ -13,16 +13,26 @@ */ package com.facebook.presto.metadata; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; +import com.facebook.drift.annotations.ThriftEnum; +import com.facebook.drift.annotations.ThriftEnumValue; -import javax.inject.Inject; - -public class MetadataUpdateJacksonModule - extends AbstractTypedJacksonModule +@ThriftEnum +public enum BuiltInFunctionKind { - @Inject - public MetadataUpdateJacksonModule(HandleResolver handleResolver) + ENGINE(0), + PLUGIN(1), + WORKER(2); + + private final int value; + + BuiltInFunctionKind(int value) + { + this.value = value; + } + + @ThriftEnumValue + public int getValue() { - super(ConnectorMetadataUpdateHandle.class, handleResolver::getId, handleResolver::getMetadataUpdateHandleClass); + return value; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java new file mode 100644 index 0000000000000..e0cf7dc25472f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.SqlFunction; + +import java.util.Collection; +import java.util.List; + +import static com.facebook.presto.metadata.BuiltInFunctionKind.PLUGIN; +import static com.facebook.presto.spi.function.FunctionImplementationType.SQL; +import static com.google.common.base.Preconditions.checkArgument; + +public class BuiltInPluginFunctionNamespaceManager + extends BuiltInSpecialFunctionNamespaceManager +{ + public BuiltInPluginFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + { + super(functionAndTypeManager); + } + + public void triggerConflictCheckWithBuiltInFunctions() + { + checkForNamingConflicts(this.getFunctionsFromDefaultNamespace()); + } + + @Override + public synchronized void registerBuiltInSpecialFunctions(List functions) + { + checkForNamingConflicts(functions); + this.functions = new FunctionMap(this.functions, functions); + } + + @Override + protected synchronized void checkForNamingConflicts(Collection functions) + { + for (SqlFunction function : functions) { + for (SqlFunction existingFunction : this.functions.list()) { + checkArgument(!function.getSignature().equals(existingFunction.getSignature()), "Function already registered: %s", function.getSignature()); + } + } + } + + @Override + protected BuiltInFunctionKind getBuiltInFunctionKind() + { + return PLUGIN; + } + + @Override + protected FunctionImplementationType getDefaultFunctionMetadataImplementationType() + { + return SQL; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ProcedureRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInProcedureRegistry.java similarity index 67% rename from presto-main-base/src/main/java/com/facebook/presto/metadata/ProcedureRegistry.java rename to presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInProcedureRegistry.java index 52f63cf7ff6da..978cbcfb09efd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ProcedureRegistry.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInProcedureRegistry.java @@ -21,11 +21,14 @@ import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.google.common.collect.Maps; import com.google.common.primitives.Primitives; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -40,48 +43,53 @@ import static com.facebook.presto.common.type.StandardTypes.MAP; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; -import static com.facebook.presto.spi.procedure.Procedure.Argument; +import static com.facebook.presto.spi.procedure.BaseProcedure.BaseArgument; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @ThreadSafe -public class ProcedureRegistry +public class BuiltInProcedureRegistry + implements ProcedureRegistry { - private final Map> connectorProcedures = new ConcurrentHashMap<>(); + private final Map>> connectorProcedures = new ConcurrentHashMap<>(); private final TypeManager typeManager; - public ProcedureRegistry(TypeManager typeManager) + @Inject + public BuiltInProcedureRegistry(TypeManager typeManager) { this.typeManager = requireNonNull(typeManager, "typeManager is null"); } - public void addProcedures(ConnectorId connectorId, Collection procedures) + @Override + public void addProcedures(ConnectorId connectorId, Collection> procedures) { requireNonNull(connectorId, "connectorId is null"); requireNonNull(procedures, "procedures is null"); procedures.forEach(this::validateProcedure); - Map proceduresByName = Maps.uniqueIndex( + Map> proceduresByName = Maps.uniqueIndex( procedures, procedure -> new SchemaTableName(procedure.getSchema(), procedure.getName())); checkState(connectorProcedures.putIfAbsent(connectorId, proceduresByName) == null, "Procedures already registered for connector: %s", connectorId); } + @Override public void removeProcedures(ConnectorId connectorId) { connectorProcedures.remove(connectorId); } - public Procedure resolve(ConnectorId connectorId, SchemaTableName name) + @Override + public BaseProcedure resolve(ConnectorId connectorId, SchemaTableName name) { - Map procedures = connectorProcedures.get(connectorId); + Map> procedures = connectorProcedures.get(connectorId); if (procedures != null) { - Procedure procedure = procedures.get(name); + BaseProcedure procedure = procedures.get(name); if (procedure != null) { return procedure; } @@ -89,14 +97,41 @@ public Procedure resolve(ConnectorId connectorId, SchemaTableName name) throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); } - private void validateProcedure(Procedure procedure) + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + BaseProcedure procedure = procedures.get(name); + if (procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + + private void validateProcedure(BaseProcedure procedure) { - List> parameters = procedure.getMethodHandle().type().parameterList().stream() + if (procedure instanceof DistributedProcedure) { + return; + } + + Procedure innerProcedure = (Procedure) procedure; + List> parameters = innerProcedure.getMethodHandle().type().parameterList().stream() .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) .collect(toList()); for (int i = 0; i < procedure.getArguments().size(); i++) { - Argument argument = procedure.getArguments().get(i); + BaseArgument argument = innerProcedure.getArguments().get(i); Type type = typeManager.getType(argument.getType()); Class argumentType = Primitives.unwrap(parameters.get(i)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInSpecialFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInSpecialFunctionNamespaceManager.java new file mode 100644 index 0000000000000..6a4e716076cd9 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInSpecialFunctionNamespaceManager.java @@ -0,0 +1,301 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.metadata; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.function.SqlFunctionResult; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.UserDefinedType; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunctionImplementation; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.AlterRoutineCharacteristics; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionNamespaceManager; +import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.ScalarFunctionImplementation; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.util.concurrent.UncheckedExecutionException; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; +import static com.facebook.presto.spi.function.FunctionKind.SCALAR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.HOURS; + +public abstract class BuiltInSpecialFunctionNamespaceManager + implements FunctionNamespaceManager +{ + protected volatile FunctionMap functions = new FunctionMap(); + private final FunctionAndTypeManager functionAndTypeManager; + private final LoadingCache specializedFunctionKeyCache; + private final LoadingCache specializedScalarCache; + private final LoadingCache specializedAggregationCache; + + public BuiltInSpecialFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + specializedFunctionKeyCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(this::doGetSpecializedFunctionKey)); + specializedScalarCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(key -> { + checkArgument( + key.getFunction() instanceof SqlInvokedFunction, + "Unsupported scalar function class: %s", + key.getFunction().getClass()); + return new SqlInvokedScalarFunctionImplementation(((SqlInvokedFunction) key.getFunction()).getBody()); + })); + + specializedAggregationCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(key -> (sqlInvokedFunctionToAggregationImplementation((SqlInvokedFunction) key.getFunction(), functionAndTypeManager)))); + } + + @Override + public Collection getFunctions(Optional transactionHandle, QualifiedObjectName functionName) + { + return functions.get(functionName); + } + + /** + * likePattern / escape is not used for optimization, returning all functions. + */ + @Override + public Collection listFunctions(Optional likePattern, Optional escape) + { + return functions.list(); + } + + @Override + public FunctionHandle getFunctionHandle(Optional transactionHandle, Signature signature) + { + return new BuiltInFunctionHandle(signature, getBuiltInFunctionKind()); + } + + public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) + { + checkArgument(functionHandle instanceof BuiltInFunctionHandle, "Expect BuiltInFunctionHandle"); + Signature signature = ((BuiltInFunctionHandle) functionHandle).getSignature(); + SpecializedFunctionKey functionKey; + try { + functionKey = specializedFunctionKeyCache.getUnchecked(signature); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + SqlFunction function = functionKey.getFunction(); + checkArgument(function instanceof SqlInvokedFunction, "BuiltInPluginFunctionNamespaceManager only support SqlInvokedFunctions"); + SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; + List argumentNames = sqlFunction.getParameters().stream().map(Parameter::getName).collect(toImmutableList()); + return new FunctionMetadata( + signature.getName(), + signature.getArgumentTypes(), + argumentNames, + signature.getReturnType(), + signature.getKind(), + sqlFunction.getRoutineCharacteristics().getLanguage(), + getDefaultFunctionMetadataImplementationType(), + function.isDeterministic(), + function.isCalledOnNullInput(), + sqlFunction.getVersion(), + sqlFunction.getComplexTypeFunctionDescriptor()); + } + + @Override + public void setBlockEncodingSerde(BlockEncodingSerde blockEncodingSerde) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public FunctionNamespaceTransactionHandle beginTransaction() + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void commit(FunctionNamespaceTransactionHandle transactionHandle) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void abort(FunctionNamespaceTransactionHandle transactionHandle) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void createFunction(SqlInvokedFunction function, boolean replace) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void dropFunction(QualifiedObjectName functionName, Optional parameterTypes, boolean exists) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support drop function"); + } + + @Override + public void alterFunction(QualifiedObjectName functionName, Optional parameterTypes, AlterRoutineCharacteristics alterRoutineCharacteristics) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not alter function"); + } + + @Override + public void addUserDefinedType(UserDefinedType userDefinedType) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support adding user defined types"); + } + + @Override + public Optional getUserDefinedType(QualifiedObjectName typeName) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support getting user defined types"); + } + + @Override + public CompletableFuture executeFunction(String source, FunctionHandle functionHandle, Page input, List channels, TypeManager typeManager) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not execute function"); + } + + protected abstract void checkForNamingConflicts(Collection functions); + + protected abstract BuiltInFunctionKind getBuiltInFunctionKind(); + + protected abstract FunctionImplementationType getDefaultFunctionMetadataImplementationType(); + + public abstract void registerBuiltInSpecialFunctions(List functions); + + @Override + public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHandle functionHandle) + { + checkArgument(functionHandle instanceof BuiltInFunctionHandle, "Expect BuiltInFunctionHandle"); + return getScalarFunctionImplementation(((BuiltInFunctionHandle) functionHandle).getSignature()); + } + + @Override + public AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle, TypeManager typeManager) + { + checkArgument(functionHandle instanceof BuiltInFunctionHandle, "Expect BuiltInFunctionHandle"); + Signature signature = ((BuiltInFunctionHandle) functionHandle).getSignature(); + checkArgument(signature.getKind() == AGGREGATE, "%s is not an aggregate function", signature); + checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + + try { + return specializedAggregationCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + } + + @VisibleForTesting + public synchronized void registerAggregateFunctions(List aggregateFunctions) + { + this.functions = new FunctionMap(this.functions, aggregateFunctions); + } + + private AggregationFunctionImplementation sqlInvokedFunctionToAggregationImplementation( + SqlInvokedFunction function, + TypeManager typeManager) + { + checkArgument( + function.getAggregationMetadata().isPresent(), + "Need aggregationMetadata to get aggregation function implementation"); + + AggregationFunctionMetadata aggregationMetadata = function.getAggregationMetadata().get(); + List parameters = function.getSignature().getArgumentTypes().stream().map( + (typeManager::getType)).collect(toImmutableList()); + return new SqlInvokedAggregationFunctionImplementation( + typeManager.getType(aggregationMetadata.getIntermediateType()), + typeManager.getType(function.getSignature().getReturnType()), + aggregationMetadata.isOrderSensitive(), + parameters); + } + + protected Collection getFunctionsFromDefaultNamespace() + { + Optional> functionNamespaceManager = + functionAndTypeManager.getServingFunctionNamespaceManager(functionAndTypeManager.getDefaultNamespace()); + checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for catalog '%s'", functionAndTypeManager.getDefaultNamespace().getCatalogName()); + return functionNamespaceManager.get().listFunctions(Optional.empty(), Optional.empty()); + } + + private synchronized FunctionMap createFunctionMap() + { + return functions; + } + + private ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) + { + checkArgument(signature.getKind() == SCALAR, "%s is not a scalar function", signature); + checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + + try { + return specializedScalarCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + } + + private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature) + { + return functionAndTypeManager.getSpecializedFunctionKey(signature, getFunctions(Optional.empty(), signature.getName())); + } + + private SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) + { + try { + return specializedFunctionKeyCache.getUnchecked(signature); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 897f6657f4f3a..4984cd8fe065d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.metadata; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.Page; import com.facebook.presto.common.QualifiedObjectName; @@ -68,6 +67,7 @@ import com.facebook.presto.operator.aggregation.DoubleCovarianceAggregation; import com.facebook.presto.operator.aggregation.DoubleHistogramAggregation; import com.facebook.presto.operator.aggregation.DoubleRegressionAggregation; +import com.facebook.presto.operator.aggregation.DoubleRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.DoubleSumAggregation; import com.facebook.presto.operator.aggregation.EntropyAggregation; import com.facebook.presto.operator.aggregation.GeometricMeanAggregations; @@ -85,6 +85,7 @@ import com.facebook.presto.operator.aggregation.RealGeometricMeanAggregations; import com.facebook.presto.operator.aggregation.RealHistogramAggregation; import com.facebook.presto.operator.aggregation.RealRegressionAggregation; +import com.facebook.presto.operator.aggregation.RealRegressionExtendedAggregation; import com.facebook.presto.operator.aggregation.RealSumAggregation; import com.facebook.presto.operator.aggregation.ReduceAggregationFunction; import com.facebook.presto.operator.aggregation.SumDataSizeForStats; @@ -109,6 +110,7 @@ import com.facebook.presto.operator.aggregation.sketch.kll.KllSketchAggregationFunction; import com.facebook.presto.operator.aggregation.sketch.kll.KllSketchWithKAggregationFunction; import com.facebook.presto.operator.aggregation.sketch.theta.ThetaSketchAggregationFunction; +import com.facebook.presto.operator.scalar.AbstractArraySortByKeyFunction; import com.facebook.presto.operator.scalar.ArrayAllMatchFunction; import com.facebook.presto.operator.scalar.ArrayAnyMatchFunction; import com.facebook.presto.operator.scalar.ArrayCardinalityFunction; @@ -204,11 +206,6 @@ import com.facebook.presto.operator.scalar.WilsonInterval; import com.facebook.presto.operator.scalar.WordStemFunction; import com.facebook.presto.operator.scalar.queryplan.JsonPrestoQueryPlanFunctions; -import com.facebook.presto.operator.scalar.sql.ArraySqlFunctions; -import com.facebook.presto.operator.scalar.sql.MapNormalizeFunction; -import com.facebook.presto.operator.scalar.sql.MapSqlFunctions; -import com.facebook.presto.operator.scalar.sql.SimpleSamplingPercent; -import com.facebook.presto.operator.scalar.sql.StringSqlFunctions; import com.facebook.presto.operator.window.CumulativeDistributionFunction; import com.facebook.presto.operator.window.DenseRankFunction; import com.facebook.presto.operator.window.FirstValueFunction; @@ -236,6 +233,7 @@ import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.type.BigintOperators; import com.facebook.presto.type.BooleanOperators; @@ -292,14 +290,10 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; import com.google.common.util.concurrent.UncheckedExecutionException; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; -import javax.annotation.concurrent.ThreadSafe; - import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.util.ArrayList; @@ -556,6 +550,15 @@ public BuiltInTypeAndFunctionNamespaceManager( FunctionsConfig functionsConfig, Set types, FunctionAndTypeManager functionAndTypeManager) + { + this(blockEncodingSerde, functionsConfig, types, functionAndTypeManager, true); + } + public BuiltInTypeAndFunctionNamespaceManager( + BlockEncodingSerde blockEncodingSerde, + FunctionsConfig functionsConfig, + Set types, + FunctionAndTypeManager functionAndTypeManager, + boolean registerFunctions) { this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); this.magicLiteralFunction = new MagicLiteralFunction(blockEncodingSerde); @@ -609,7 +612,9 @@ public BuiltInTypeAndFunctionNamespaceManager( .expireAfterWrite(1, HOURS) .build(CacheLoader.from(this::instantiateParametricType)); - registerBuiltInFunctions(getBuiltInFunctions(functionsConfig)); + if (registerFunctions) { + registerBuiltInFunctions(getBuiltInFunctions(functionsConfig)); + } registerBuiltInTypes(functionsConfig); for (Type type : requireNonNull(types, "types is null")) { @@ -742,7 +747,9 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregates(DoubleCovarianceAggregation.class) .aggregates(RealCovarianceAggregation.class) .aggregates(DoubleRegressionAggregation.class) + .aggregates(DoubleRegressionExtendedAggregation.class) .aggregates(RealRegressionAggregation.class) + .aggregates(RealRegressionExtendedAggregation.class) .aggregates(DoubleCorrelationAggregation.class) .aggregates(RealCorrelationAggregation.class) .aggregates(BitwiseOrAggregation.class) @@ -876,6 +883,8 @@ private List getBuiltInFunctions(FunctionsConfig function .scalar(ArrayGreaterThanOrEqualOperator.class) .scalar(ArrayElementAtFunction.class) .scalar(ArraySortFunction.class) + .function(AbstractArraySortByKeyFunction.ArraySortByKeyFunction.ARRAY_SORT_BY_KEY_FUNCTION) + .function(AbstractArraySortByKeyFunction.ArraySortDescByKeyFunction.ARRAY_SORT_DESC_BY_KEY_FUNCTION) .scalar(MapSubsetFunction.class) .scalar(ArraySortComparatorFunction.class) .scalar(ArrayShuffleFunction.class) @@ -987,12 +996,6 @@ private List getBuiltInFunctions(FunctionsConfig function .aggregate(ThetaSketchAggregationFunction.class) .scalars(ThetaSketchFunctions.class) .function(MergeTDigestFunction.MERGE) - .sqlInvokedScalar(MapNormalizeFunction.class) - .sqlInvokedScalars(ArraySqlFunctions.class) - .sqlInvokedScalars(ArrayIntersectFunction.class) - .sqlInvokedScalars(MapSqlFunctions.class) - .sqlInvokedScalars(SimpleSamplingPercent.class) - .sqlInvokedScalars(StringSqlFunctions.class) .scalar(DynamicFilterPlaceholderFunction.class) .scalars(EnumCasts.class) .scalars(LongEnumOperators.class) @@ -1278,6 +1281,18 @@ public Type getType(TypeSignature typeSignature) } } + @Override + public boolean hasType(TypeSignature typeSignature) + { + try { + getType(typeSignature); + return true; + } + catch (UnknownTypeException e) { + return false; + } + } + @Override public Type getParameterizedType(String baseTypeName, List typeParameters) { @@ -1396,44 +1411,6 @@ private static class EmptyTransactionHandle { } - private static class FunctionMap - { - private final Multimap functions; - - public FunctionMap() - { - functions = ImmutableListMultimap.of(); - } - - public FunctionMap(FunctionMap map, Iterable functions) - { - this.functions = ImmutableListMultimap.builder() - .putAll(map.functions) - .putAll(Multimaps.index(functions, function -> function.getSignature().getName())) - .build(); - - // Make sure all functions with the same name are aggregations or none of them are - for (Map.Entry> entry : this.functions.asMap().entrySet()) { - Collection values = entry.getValue(); - long aggregations = values.stream() - .map(function -> function.getSignature().getKind()) - .filter(kind -> kind == AGGREGATE) - .count(); - checkState(aggregations == 0 || aggregations == values.size(), "'%s' is both an aggregation and a scalar function", entry.getKey()); - } - } - - public List list() - { - return ImmutableList.copyOf(functions.values()); - } - - public Collection get(QualifiedObjectName name) - { - return functions.get(name); - } - } - /** * TypeSignature but has overridden equals(). Here, we compare exact signature of any underlying distinct * types. Some distinct types may have extra information on their lazily loaded parents, and same parent diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInWorkerFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInWorkerFunctionNamespaceManager.java new file mode 100644 index 0000000000000..41ef943001cb3 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInWorkerFunctionNamespaceManager.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.SqlFunction; + +import java.util.Collection; +import java.util.List; + +import static com.facebook.presto.metadata.BuiltInFunctionKind.WORKER; +import static com.facebook.presto.spi.function.FunctionImplementationType.CPP; + +public class BuiltInWorkerFunctionNamespaceManager + extends BuiltInSpecialFunctionNamespaceManager +{ + public BuiltInWorkerFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + { + super(functionAndTypeManager); + } + + @Override + public synchronized void registerBuiltInSpecialFunctions(List functions) + { + // only register functions once + if (!this.functions.list().isEmpty()) { + return; + } + this.functions = new FunctionMap(this.functions, functions); + } + + @Override + protected synchronized void checkForNamingConflicts(Collection functions) + { + } + + @Override + protected BuiltInFunctionKind getBuiltInFunctionKind() + { + return WORKER; + } + + @Override + protected FunctionImplementationType getDefaultFunctionMetadataImplementationType() + { + return CPP; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Catalog.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Catalog.java index 4ee7ed05f963b..5cbdb9bfb1308 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Catalog.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Catalog.java @@ -32,6 +32,8 @@ public class Catalog private final ConnectorId systemTablesId; private final Connector systemTables; + private final CatalogContext catalogContext; + public Catalog( String catalogName, ConnectorId connectorId, @@ -40,6 +42,26 @@ public Catalog( Connector informationSchema, ConnectorId systemTablesId, Connector systemTables) + { + this(catalogName, + connectorId, + connector, + informationSchemaId, + informationSchema, + systemTablesId, + systemTables, + catalogName); + } + + public Catalog( + String catalogName, + ConnectorId connectorId, + Connector connector, + ConnectorId informationSchemaId, + Connector informationSchema, + ConnectorId systemTablesId, + Connector systemTables, + String connectorName) { this.catalogName = checkCatalogName(catalogName); this.connectorId = requireNonNull(connectorId, "connectorId is null"); @@ -48,8 +70,9 @@ public Catalog( this.informationSchema = requireNonNull(informationSchema, "informationSchema is null"); this.systemTablesId = requireNonNull(systemTablesId, "systemTablesId is null"); this.systemTables = requireNonNull(systemTables, "systemTables is null"); + requireNonNull(connectorName, "connectorName is null"); + this.catalogContext = new CatalogContext(catalogName, connectorName); } - public String getCatalogName() { return catalogName; @@ -60,6 +83,11 @@ public ConnectorId getConnectorId() return connectorId; } + public CatalogContext getCatalogContext() + { + return catalogContext; + } + public ConnectorId getInformationSchemaId() { return informationSchemaId; @@ -92,4 +120,26 @@ public String toString() .add("connectorId", connectorId) .toString(); } + + public class CatalogContext + { + private final String catalogName; + private final String connectorName; + + public CatalogContext(String catalogName, String connectorName) + { + this.catalogName = catalogName; + this.connectorName = connectorName; + } + + public String getCatalogName() + { + return catalogName; + } + + public String getConnectorName() + { + return connectorName; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/CatalogManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/CatalogManager.java index 4a4c34f3d5ffa..ade9364a7d42f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/CatalogManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/CatalogManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.ConnectorId; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecDeserializer.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecDeserializer.java new file mode 100644 index 0000000000000..1e9d5a8785316 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecDeserializer.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.TreeNode; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; +import com.fasterxml.jackson.databind.node.ObjectNode; + +import java.io.IOException; +import java.util.Base64; +import java.util.Optional; +import java.util.function.Function; + +import static java.util.Objects.requireNonNull; + +class CodecDeserializer + extends JsonDeserializer +{ + private final Function> classResolver; + private final Function>> codecExtractor; + private final String typePropertyName; + private final String dataPropertyName; + + public CodecDeserializer( + String typePropertyName, + String dataPropertyName, + Function>> codecExtractor, + Function> classResolver) + { + this.classResolver = requireNonNull(classResolver, "classResolver is null"); + this.codecExtractor = requireNonNull(codecExtractor, "codecExtractor is null"); + this.typePropertyName = requireNonNull(typePropertyName, "typePropertyName is null"); + this.dataPropertyName = requireNonNull(dataPropertyName, "dataPropertyName is null"); + } + + @Override + public T deserialize(JsonParser parser, DeserializationContext context) + throws IOException + { + if (parser.getCurrentToken() == JsonToken.VALUE_NULL) { + return null; + } + + if (parser.getCurrentToken() != JsonToken.START_OBJECT) { + throw new IOException("Expected START_OBJECT, got " + parser.getCurrentToken()); + } + + // Parse the JSON tree + TreeNode tree = parser.readValueAsTree(); + + if (tree instanceof ObjectNode) { + ObjectNode node = (ObjectNode) tree; + + // Get the @type field + if (!node.has(typePropertyName)) { + throw new IOException("Missing " + typePropertyName + " field"); + } + String connectorIdString = node.get(typePropertyName).asText(); + // Check if @data field is present (binary serialization) + if (node.has(dataPropertyName)) { + // Binary data is present, we need a codec to deserialize it + // Special handling for internal handles like "$remote" + if (!connectorIdString.startsWith("$")) { + ConnectorId connectorId = new ConnectorId(connectorIdString); + Optional> codec = codecExtractor.apply(connectorId); + if (codec.isPresent()) { + String base64Data = node.get(dataPropertyName).asText(); + byte[] data = Base64.getDecoder().decode(base64Data); + return codec.get().deserialize(data); + } + } + // @data field present but no codec available or internal handle + throw new IOException("Type " + connectorIdString + " has binary data (" + dataPropertyName + " field) but no codec available to deserialize it"); + } + + // No @data field - use standard JSON deserialization + Class handleClass = classResolver.apply(connectorIdString); + + // Remove the @type field and deserialize the remaining content + node.remove(typePropertyName); + return context.readTreeAsValue(node, handleClass); + } + + throw new IOException("Unable to deserialize"); + } + + @Override + public T deserializeWithType(JsonParser p, DeserializationContext ctxt, + TypeDeserializer typeDeserializer) + throws IOException + { + // We handle the type ourselves + return deserialize(p, ctxt); + } + + @Override + public T deserializeWithType(JsonParser p, DeserializationContext ctxt, + TypeDeserializer typeDeserializer, T intoValue) + throws IOException + { + // We handle the type ourselves + return deserialize(p, ctxt); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecSerializer.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecSerializer.java new file mode 100644 index 0000000000000..9948e95a9c323 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/CodecSerializer.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.fasterxml.jackson.databind.jsontype.impl.AsPropertyTypeSerializer; +import com.fasterxml.jackson.databind.ser.BeanSerializerFactory; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + +import java.io.IOException; +import java.util.Base64; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; + +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static java.util.Objects.requireNonNull; + +class CodecSerializer + extends JsonSerializer +{ + private final Function nameResolver; + private final Function>> codecExtractor; + private final TypeIdResolver typeResolver; + private final TypeSerializer typeSerializer; + private final Cache, JsonSerializer> serializerCache = CacheBuilder.newBuilder().build(); + private final String typePropertyName; + private final String dataPropertyName; + + public CodecSerializer( + String typePropertyName, + String dataPropertyName, + Function>> codecExtractor, + Function nameResolver, + TypeIdResolver typeIdResolver) + { + this.typePropertyName = requireNonNull(typePropertyName, "typePropertyName is null"); + this.dataPropertyName = requireNonNull(dataPropertyName, "dataPropertyName is null"); + this.nameResolver = requireNonNull(nameResolver, "nameResolver is null"); + this.codecExtractor = requireNonNull(codecExtractor, "codecExtractor is null"); + this.typeResolver = requireNonNull(typeIdResolver, "typeIdResolver is null"); + this.typeSerializer = new AsPropertyTypeSerializer(typeResolver, null, typePropertyName); + } + + @Override + public void serialize(T value, JsonGenerator jsonGenerator, SerializerProvider provider) + throws IOException + { + if (value == null) { + jsonGenerator.writeNull(); + return; + } + + String connectorIdString = nameResolver.apply(value); + + // Only try binary serialization for actual connectors (not internal handles like "$remote") + if (!connectorIdString.startsWith("$")) { + ConnectorId connectorId = new ConnectorId(connectorIdString); + + // Check if connector has a binary codec + Optional> codec = codecExtractor.apply(connectorId); + if (codec.isPresent()) { + // Use binary serialization with flat structure + jsonGenerator.writeStartObject(); + jsonGenerator.writeStringField(typePropertyName, connectorIdString); + byte[] data = codec.get().serialize(value); + jsonGenerator.writeStringField(dataPropertyName, Base64.getEncoder().encodeToString(data)); + jsonGenerator.writeEndObject(); + return; + } + } + + // Fall back to legacy typed JSON serialization + try { + Class type = value.getClass(); + JsonSerializer serializer = serializerCache.get(type, () -> createSerializer(provider, type)); + serializer.serializeWithType(value, jsonGenerator, provider, typeSerializer); + } + catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause != null) { + throwIfInstanceOf(cause, IOException.class); + } + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + private static JsonSerializer createSerializer(SerializerProvider provider, Class type) + throws JsonMappingException + { + JavaType javaType = provider.constructType(type); + return BeanSerializerFactory.instance.createSerializer(provider, javaType); + } + + @Override + public void serializeWithType(T value, JsonGenerator gen, + SerializerProvider serializers, TypeSerializer typeSer) + throws IOException + { + serialize(value, gen, serializers); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java index a8375a2e90666..26de084ff7bac 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/ColumnHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class ColumnHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public ColumnHandleJacksonModule(HandleResolver handleResolver) + public ColumnHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ColumnHandle.class, handleResolver::getId, - handleResolver::getColumnHandleClass); + handleResolver::getColumnHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getColumnHandleCodec)); + } + + public ColumnHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ColumnHandle.class, + handleResolver::getId, + handleResolver::getColumnHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ConnectorMetadataUpdaterManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ConnectorMetadataUpdaterManager.java deleted file mode 100644 index a93bbb67ded48..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ConnectorMetadataUpdaterManager.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.metadata; - -import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdaterProvider; -import com.google.inject.Inject; - -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.util.Objects.requireNonNull; - -public class ConnectorMetadataUpdaterManager -{ - private final Map metadataUpdaterProviderMap = new ConcurrentHashMap<>(); - - @Inject - public ConnectorMetadataUpdaterManager() {} - - public void addMetadataUpdaterProvider(ConnectorId connectorId, ConnectorMetadataUpdaterProvider metadataUpdaterProvider) - { - requireNonNull(connectorId, "connectorId is null"); - requireNonNull(metadataUpdaterProvider, "metadataUpdaterProvider is null"); - checkArgument(metadataUpdaterProviderMap.putIfAbsent(connectorId, metadataUpdaterProvider) == null, - "ConnectorMetadataUpdaterProvider for connector '%s' is already registered", connectorId); - } - - public void removeMetadataUpdaterProvider(ConnectorId connectorId) - { - requireNonNull(connectorId, "connectorId is null"); - metadataUpdaterProviderMap.remove(connectorId); - } - - public Optional getMetadataUpdater(ConnectorId connectorId) - { - requireNonNull(connectorId, "connectorId is null"); - return Optional.ofNullable(metadataUpdaterProviderMap.get(connectorId)).map(ConnectorMetadataUpdaterProvider::getMetadataUpdater); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java index 63b8782214e1a..efab521446259 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DelegatingMetadataManager.java @@ -20,15 +20,15 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.execution.QueryManager; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewStatus; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; -import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; @@ -36,9 +36,12 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.RowChangeParadigm; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -48,8 +51,7 @@ import com.facebook.presto.spi.statistics.TableStatisticsMetadata; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.List; @@ -89,6 +91,12 @@ public void registerBuiltInFunctions(List functionInfos) delegate.registerBuiltInFunctions(functionInfos); } + @Override + public void registerConnectorFunctions(String catalogName, List functionInfos) + { + delegate.registerConnectorFunctions(catalogName, functionInfos); + } + @Override public List listSchemaNames(Session session, String catalogName) { @@ -361,15 +369,15 @@ public Optional finishInsert( } @Override - public ColumnHandle getDeleteRowIdColumnHandle(Session session, TableHandle tableHandle) + public Optional getDeleteRowIdColumn(Session session, TableHandle tableHandle) { - return delegate.getDeleteRowIdColumnHandle(session, tableHandle); + return delegate.getDeleteRowIdColumn(session, tableHandle); } @Override - public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle, List updatedColumns) + public Optional getUpdateRowIdColumn(Session session, TableHandle tableHandle, List updatedColumns) { - return delegate.getUpdateRowIdColumnHandle(session, tableHandle, updatedColumns); + return delegate.getUpdateRowIdColumn(session, tableHandle, updatedColumns); } @Override @@ -391,9 +399,22 @@ public DeleteTableHandle beginDelete(Session session, TableHandle tableHandle) } @Override - public void finishDelete(Session session, DeleteTableHandle tableHandle, Collection fragments) + public Optional finishDeleteWithOutput(Session session, DeleteTableHandle tableHandle, Collection fragments) + { + return delegate.finishDeleteWithOutput(session, tableHandle, fragments); + } + + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, + TableHandle tableHandle, Object[] arguments, boolean sourceTableEliminated) { - delegate.finishDelete(session, tableHandle, fragments); + return delegate.beginCallDistributedProcedure(session, procedureName, tableHandle, arguments, sourceTableEliminated); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + delegate.finishCallDistributedProcedure(session, procedureHandle, procedureName, fragments); } @Override @@ -408,6 +429,30 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection fragments, Collection computedStatistics) + { + delegate.finishMerge(session, tableHandle, fragments, computedStatistics); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { @@ -432,6 +477,26 @@ public Map getViews(Session session, Qualif return delegate.getViews(session, prefix); } + @Override + public List listMaterializedViews(Session session, QualifiedTablePrefix prefix) + { + return delegate.listMaterializedViews(session, prefix); + } + + @Override + public Map getMaterializedViews( + Session session, + QualifiedTablePrefix prefix) + { + return delegate.getMaterializedViews(session, prefix); + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus(Session session, QualifiedObjectName viewName, TupleDomain baseQueryDomain) + { + return delegate.getMaterializedViewStatus(session, viewName, baseQueryDomain); + } + @Override public void createView(Session session, String catalogName, ConnectorTableMetadata viewMetadata, String viewData, boolean replace) { @@ -583,12 +648,6 @@ public ListenableFuture commitPageSinkAsync(Session session, DeleteTableHa return delegate.commitPageSinkAsync(session, tableHandle, fragments); } - @Override - public MetadataUpdates getMetadataUpdateResults(Session session, QueryManager queryManager, MetadataUpdates metadataUpdates, QueryId queryId) - { - return delegate.getMetadataUpdateResults(session, queryManager, metadataUpdates, queryId); - } - @Override public FunctionAndTypeManager getFunctionAndTypeManager() { @@ -625,6 +684,12 @@ public TablePropertyManager getTablePropertyManager() return delegate.getTablePropertyManager(); } + @Override + public MaterializedViewPropertyManager getMaterializedViewPropertyManager() + { + return delegate.getMaterializedViewPropertyManager(); + } + @Override public ColumnPropertyManager getColumnPropertyManager() { @@ -643,6 +708,18 @@ public Set getConnectorCapabilities(Session session, Conn return delegate.getConnectorCapabilities(session, catalogName); } + @Override + public void dropBranch(Session session, TableHandle tableHandle, String branchName, boolean branchExists) + { + delegate.dropBranch(session, tableHandle, branchName, branchExists); + } + + @Override + public void dropTag(Session session, TableHandle tableHandle, String tagName, boolean tagExists) + { + delegate.dropTag(session, tableHandle, tagName, tagExists); + } + @Override public void dropConstraint(Session session, TableHandle tableHandle, Optional constraintName, Optional columnName) { @@ -654,4 +731,16 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra { delegate.addConstraint(session, tableHandle, tableConstraint); } + + @Override + public String normalizeIdentifier(Session session, String catalogName, String identifier) + { + return delegate.normalizeIdentifier(session, catalogName, identifier); + } + + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return delegate.applyTableFunction(session, handle); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandle.java index ea05b223843d5..de2868576bb3f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandle.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.ConnectorDeleteTableHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -23,6 +26,7 @@ import static java.util.Objects.requireNonNull; +@ThriftStruct public final class DeleteTableHandle { private final ConnectorId connectorId; @@ -30,6 +34,7 @@ public final class DeleteTableHandle private final ConnectorDeleteTableHandle connectorHandle; @JsonCreator + @ThriftConstructor public DeleteTableHandle( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, @@ -41,18 +46,21 @@ public DeleteTableHandle( } @JsonProperty + @ThriftField(1) public ConnectorId getConnectorId() { return connectorId; } @JsonProperty + @ThriftField(2) public ConnectorTransactionHandle getTransactionHandle() { return transactionHandle; } @JsonProperty + @ThriftField(3) public ConnectorDeleteTableHandle getConnectorHandle() { return connectorHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java index 513348cf0fff4..5545f91d40037 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DeleteTableHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class DeleteTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public DeleteTableHandleJacksonModule(HandleResolver handleResolver) + public DeleteTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorDeleteTableHandle.class, handleResolver::getId, - handleResolver::getDeleteTableHandleClass); + handleResolver::getDeleteTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorDeleteTableHandleCodec)); + } + + public DeleteTableHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorDeleteTableHandle.class, + handleResolver::getId, + handleResolver::getDeleteTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java new file mode 100644 index 0000000000000..1d3776b3ecca8 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandle.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public final class DistributedProcedureHandle +{ + private final ConnectorId connectorId; + private final ConnectorTransactionHandle transactionHandle; + private final ConnectorDistributedProcedureHandle connectorHandle; + + @JsonCreator + public DistributedProcedureHandle( + @JsonProperty("connectorId") ConnectorId connectorId, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, + @JsonProperty("connectorHandle") ConnectorDistributedProcedureHandle connectorHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + this.connectorHandle = requireNonNull(connectorHandle, "connectorHandle is null"); + } + + @JsonProperty + public ConnectorId getConnectorId() + { + return connectorId; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } + + @JsonProperty + public ConnectorDistributedProcedureHandle getConnectorHandle() + { + return connectorHandle; + } + + @Override + public int hashCode() + { + return Objects.hash(connectorId, transactionHandle, connectorHandle); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DistributedProcedureHandle o = (DistributedProcedureHandle) obj; + return Objects.equals(this.connectorId, o.connectorId) && + Objects.equals(this.transactionHandle, o.transactionHandle) && + Objects.equals(this.connectorHandle, o.connectorHandle); + } + + @Override + public String toString() + { + return connectorId + ":" + connectorHandle; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java new file mode 100644 index 0000000000000..e60a686404036 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/DistributedProcedureHandleJacksonModule.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Provider; + +import javax.inject.Inject; + +import java.util.Optional; +import java.util.function.Function; + +public class DistributedProcedureHandleJacksonModule + extends AbstractTypedJacksonModule +{ + @Inject + public DistributedProcedureHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) + { + super(ConnectorDistributedProcedureHandle.class, + handleResolver::getId, + handleResolver::getDistributedProcedureHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorDistributedProcedureHandleCodec)); + } + + public DistributedProcedureHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorDistributedProcedureHandle.class, + handleResolver::getId, + handleResolver::getDistributedProcedureHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ForNodeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ForNodeManager.java index 345b9e4744066..13f3e92f77af3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ForNodeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/ForNodeManager.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.metadata; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 4ea19e25698cd..23d724dbcc54d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.metadata; +import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.Page; @@ -33,11 +34,14 @@ import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.common.type.UserDefinedType; import com.facebook.presto.operator.window.WindowFunctionSupplier; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.function.AggregationFunctionImplementation; import com.facebook.presto.spi.function.AlterRoutineCharacteristics; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionImplementationType; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.FunctionMetadataManager; import com.facebook.presto.spi.function.FunctionNamespaceManager; @@ -49,14 +53,18 @@ import com.facebook.presto.spi.function.ScalarFunctionImplementation; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlFunctionSupplier; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.type.TypeManagerContext; import com.facebook.presto.spi.type.TypeManagerFactory; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver; import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.facebook.presto.sql.gen.CacheStatsMBean; import com.facebook.presto.sql.tree.QualifiedName; @@ -68,14 +76,15 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.UncheckedExecutionException; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -86,16 +95,23 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.regex.Pattern; +import static com.facebook.presto.SystemSessionProperties.getNonBuiltInFunctionNamespacesToListFunctions; import static com.facebook.presto.SystemSessionProperties.isExperimentalFunctionsEnabled; import static com.facebook.presto.SystemSessionProperties.isListBuiltInFunctionsOnly; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.metadata.BuiltInFunctionKind.PLUGIN; +import static com.facebook.presto.metadata.BuiltInFunctionKind.WORKER; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.toOperatorType; import static com.facebook.presto.metadata.FunctionSignatureMatcher.constructFunctionNotFoundErrorMessage; +import static com.facebook.presto.metadata.FunctionSignatureMatcher.decideAndThrow; import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; +import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; @@ -128,7 +144,9 @@ public class FunctionAndTypeManager implements FunctionMetadataManager, TypeManager { private static final Pattern DEFAULT_NAMESPACE_PREFIX_PATTERN = Pattern.compile("[a-z]+\\.[a-z]+"); + private static final Logger log = Logger.get(FunctionAndTypeManager.class); private final TransactionManager transactionManager; + private final TableFunctionRegistry tableFunctionRegistry; private final BlockEncodingSerde blockEncodingSerde; private final BuiltInTypeAndFunctionNamespaceManager builtInTypeAndFunctionNamespaceManager; private final FunctionInvokerProvider functionInvokerProvider; @@ -142,13 +160,20 @@ public class FunctionAndTypeManager private final LoadingCache functionCache; private final CacheStatsMBean cacheStatsMBean; private final boolean nativeExecution; + private final boolean isBuiltInSidecarFunctionsEnabled; private final CatalogSchemaName defaultNamespace; private final AtomicReference servingTypeManager; private final AtomicReference>> servingTypeManagerParametricTypesSupplier; + private final BuiltInWorkerFunctionNamespaceManager builtInWorkerFunctionNamespaceManager; + private final BuiltInPluginFunctionNamespaceManager builtInPluginFunctionNamespaceManager; + private final ConcurrentHashMap> tableFunctionProcessorProviderMap = new ConcurrentHashMap<>(); + private final FunctionsConfig functionsConfig; + private final Set types; @Inject public FunctionAndTypeManager( TransactionManager transactionManager, + TableFunctionRegistry tableFunctionRegistry, BlockEncodingSerde blockEncodingSerde, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, @@ -156,7 +181,10 @@ public FunctionAndTypeManager( Set types) { this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); + this.tableFunctionRegistry = requireNonNull(tableFunctionRegistry, "tableFunctionRegistry is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); + this.functionsConfig = requireNonNull(functionsConfig, "functionsConfig is null"); + this.types = requireNonNull(types, "types is null"); this.builtInTypeAndFunctionNamespaceManager = new BuiltInTypeAndFunctionNamespaceManager(blockEncodingSerde, functionsConfig, types, this); this.functionNamespaceManagers.put(JAVA_BUILTIN_NAMESPACE.getCatalogName(), builtInTypeAndFunctionNamespaceManager); this.functionInvokerProvider = new FunctionInvokerProvider(this); @@ -173,15 +201,19 @@ public FunctionAndTypeManager( this.functionSignatureMatcher = new FunctionSignatureMatcher(this); this.typeCoercer = new TypeCoercer(functionsConfig, this); this.nativeExecution = featuresConfig.isNativeExecutionEnabled(); + this.isBuiltInSidecarFunctionsEnabled = featuresConfig.isBuiltInSidecarFunctionsEnabled(); this.defaultNamespace = configureDefaultNamespace(functionsConfig.getDefaultNamespacePrefix()); this.servingTypeManager = new AtomicReference<>(builtInTypeAndFunctionNamespaceManager); this.servingTypeManagerParametricTypesSupplier = new AtomicReference<>(this::getServingTypeManagerParametricTypes); + this.builtInWorkerFunctionNamespaceManager = new BuiltInWorkerFunctionNamespaceManager(this); + this.builtInPluginFunctionNamespaceManager = new BuiltInPluginFunctionNamespaceManager(this); } public static FunctionAndTypeManager createTestFunctionAndTypeManager() { return new FunctionAndTypeManager( createTestTransactionManager(), + new TableFunctionRegistry(), new BlockEncodingManager(), new FeaturesConfig(), new FunctionsConfig(), @@ -264,6 +296,12 @@ public Collection getParametricTypes() return FunctionAndTypeManager.this.getParametricTypes(); } + @Override + public boolean hasType(TypeSignature signature) + { + return FunctionAndTypeManager.this.hasType(signature); + } + @Override public Collection listBuiltInFunctions() { @@ -288,6 +326,12 @@ public FunctionHandle lookupCast(String castType, Type fromType, Type toType) return FunctionAndTypeManager.this.lookupCast(CastType.valueOf(castType), fromType, toType); } + @Override + public void validateFunctionCall(FunctionHandle functionHandle, List arguments) + { + FunctionAndTypeManager.this.validateFunctionCall(functionHandle, arguments); + } + public QualifiedObjectName qualifyObjectName(QualifiedName name) { if (name.getSuffix().startsWith("$internal")) { @@ -344,8 +388,15 @@ public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) if (functionHandle.getCatalogSchemaName().equals(SESSION_NAMESPACE)) { return ((SessionFunctionHandle) functionHandle).getFunctionMetadata(); } + if (isBuiltInPluginFunctionHandle(functionHandle)) { + return builtInPluginFunctionNamespaceManager.getFunctionMetadata(functionHandle); + } + if (isBuiltInWorkerFunctionHandle(functionHandle)) { + return builtInWorkerFunctionNamespaceManager.getFunctionMetadata(functionHandle); + } Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for '%s'", functionHandle.getCatalogSchemaName()); + return functionNamespaceManager.get().getFunctionMetadata(functionHandle); } @@ -378,6 +429,12 @@ public Type getType(TypeSignature signature) return getUserDefinedType(signature); } + @Override + public boolean hasType(TypeSignature signature) + { + return servingTypeManager.get().hasType(signature); + } + @Override public Type getParameterizedType(String baseTypeName, List typeParameters) { @@ -400,7 +457,17 @@ public void addFunctionNamespaceFactory(FunctionNamespaceManagerFactory factory) if (functionNamespaceManagerFactories.putIfAbsent(factory.getName(), factory) != null) { throw new IllegalArgumentException(format("Resource group configuration manager '%s' is already registered", factory.getName())); } - handleResolver.addFunctionNamespace(factory.getName(), factory.getHandleResolver()); + String name = factory.getName(); + // SqlFunctionHandle is in SPI and used by multiple function namespace managers, use the same name for it. + if (factory.getHandleResolver().getFunctionHandleClass().equals(SqlFunctionHandle.class)) { + name = "sql_function_handle"; + } + handleResolver.addFunctionNamespace(name, factory.getHandleResolver()); + } + + public TableFunctionRegistry getTableFunctionRegistry() + { + return tableFunctionRegistry; } public void loadTypeManager(String typeManagerName) @@ -418,13 +485,6 @@ public void loadTypeManager(String typeManagerName) servingTypeManagerParametricTypesSupplier.set(this::getServingTypeManagerParametricTypes); } - public void loadTypeManagers() - { - for (String typeManagerName : typeManagerFactories.keySet()) { - loadTypeManager(typeManagerName); - } - } - public void addTypeManagerFactory(TypeManagerFactory factory) { if (typeManagerFactories.putIfAbsent(factory.getName(), factory) != null) { @@ -432,11 +492,44 @@ public void addTypeManagerFactory(TypeManagerFactory factory) } } + public TransactionManager getTransactionManager() + { + return transactionManager; + } + public void registerBuiltInFunctions(List functions) { builtInTypeAndFunctionNamespaceManager.registerBuiltInFunctions(functions); } + public void registerWorkerFunctions(List functions) + { + if (isBuiltInSidecarFunctionsEnabled) { + builtInWorkerFunctionNamespaceManager.registerBuiltInSpecialFunctions(functions); + } + } + + @VisibleForTesting + public void registerWorkerAggregateFunctions(List aggregateFunctions) + { + builtInWorkerFunctionNamespaceManager.registerAggregateFunctions(aggregateFunctions); + } + + public void registerPluginFunctions(List functions) + { + builtInPluginFunctionNamespaceManager.registerBuiltInSpecialFunctions(functions); + } + + public void registerConnectorFunctions(String catalogName, List functions) + { + FunctionNamespaceManager builtInPluginFunctionNamespaceManager = functionNamespaceManagers.get(catalogName); + if (builtInPluginFunctionNamespaceManager == null) { + builtInPluginFunctionNamespaceManager = new BuiltInTypeAndFunctionNamespaceManager(blockEncodingSerde, functionsConfig, types, this, false); + addFunctionNamespace(catalogName, builtInPluginFunctionNamespaceManager); + } + ((BuiltInTypeAndFunctionNamespaceManager) builtInPluginFunctionNamespaceManager).registerBuiltInFunctions(functions); + } + /** * likePattern / escape is an opportunistic optimization push down to function namespace managers. * Not all function namespace managers can handle it, thus the returned function list could @@ -454,12 +547,18 @@ public List listFunctions(Session session, Optional likePat functions.addAll(functionNamespaceManagers.get( defaultNamespace.getCatalogName()).listFunctions(likePattern, escape).stream() .collect(toImmutableList())); + functions.addAll(builtInPluginFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); + functions.addAll(builtInWorkerFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); } else { + Set catalogsToListFunction = getNonBuiltInFunctionNamespacesToListFunctions(session); functions.addAll(SessionFunctionUtils.listFunctions(session.getSessionFunctions())); functions.addAll(functionNamespaceManagers.values().stream() + .filter(x -> x instanceof BuiltInTypeAndFunctionNamespaceManager || catalogsToListFunction.isEmpty() || catalogsToListFunction.contains(x.getCatalogName())) .flatMap(manager -> manager.listFunctions(likePattern, escape).stream()) .collect(toImmutableList())); + functions.addAll(builtInPluginFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); + functions.addAll(builtInWorkerFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); } return functions.build().stream() @@ -487,7 +586,7 @@ public Collection getFunctions(Session session, Qualified Optional transactionHandle = session.getTransactionId().map( id -> transactionManager.getFunctionNamespaceTransaction(id, functionName.getCatalogName())); - return functionNamespaceManager.get().getFunctions(transactionHandle, functionName); + return getFunctions(functionName, transactionHandle, functionNamespaceManager.get()); } public void createFunction(SqlInvokedFunction function, boolean replace) @@ -602,13 +701,45 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHand if (functionHandle.getCatalogSchemaName().equals(SESSION_NAMESPACE)) { return ((SessionFunctionHandle) functionHandle).getScalarFunctionImplementation(); } + if (isBuiltInPluginFunctionHandle(functionHandle)) { + return builtInPluginFunctionNamespaceManager.getScalarFunctionImplementation(functionHandle); + } + if (isBuiltInWorkerFunctionHandle(functionHandle)) { + return builtInWorkerFunctionNamespaceManager.getScalarFunctionImplementation(functionHandle); + } + Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for '%s'", functionHandle.getCatalogSchemaName()); return functionNamespaceManager.get().getScalarFunctionImplementation(functionHandle); } + public TableFunctionProcessorProvider getTableFunctionProcessorProvider(TableFunctionHandle tableFunctionHandle) + { + return tableFunctionProcessorProviderMap.get(tableFunctionHandle.getConnectorId()).apply(tableFunctionHandle.getFunctionHandle()); + } + + public void addTableFunctionProcessorProvider(ConnectorId connectorId, Function tableFunctionProcessorProvider) + { + if (tableFunctionProcessorProviderMap.putIfAbsent(connectorId, tableFunctionProcessorProvider) != null) { + throw new PrestoException(ALREADY_EXISTS, + format("TableFuncitonProcessorProvider already exists for connectorId %s. Overwriting is not supported.", connectorId.getCatalogName())); + } + } + + public void removeTableFunctionProcessorProvider(ConnectorId connectorId) + { + tableFunctionProcessorProviderMap.remove(connectorId); + } + public AggregationFunctionImplementation getAggregateFunctionImplementation(FunctionHandle functionHandle) { + if (isBuiltInPluginFunctionHandle(functionHandle)) { + return builtInPluginFunctionNamespaceManager.getAggregateFunctionImplementation(functionHandle, this); + } + if (isBuiltInWorkerFunctionHandle(functionHandle)) { + return builtInWorkerFunctionNamespaceManager.getAggregateFunctionImplementation(functionHandle, this); + } + Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for '%s'", functionHandle.getCatalogSchemaName()); return functionNamespaceManager.get().getAggregateFunctionImplementation(functionHandle, this); @@ -621,6 +752,19 @@ public CompletableFuture executeFunction(String source, Funct return functionNamespaceManager.get().executeFunction(source, functionHandle, inputPage, channels, this); } + public void validateFunctionCall(FunctionHandle functionHandle, List arguments) + { + // Built-in functions don't need validation + if (functionHandle instanceof BuiltInFunctionHandle) { + return; + } + + Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); + if (functionNamespaceManager.isPresent()) { + functionNamespaceManager.get().validateFunctionCall(functionHandle, arguments); + } + } + public WindowFunctionSupplier getWindowFunctionImplementation(FunctionHandle functionHandle) { return builtInTypeAndFunctionNamespaceManager.getWindowFunctionImplementation(functionHandle); @@ -656,6 +800,12 @@ public List listOperators() .collect(toImmutableList()); } + @VisibleForTesting + public Map> getFunctionNamespaceManagers() + { + return ImmutableMap.copyOf(functionNamespaceManagers); + } + public FunctionHandle resolveOperator(OperatorType operatorType, List argumentTypes) { try { @@ -704,13 +854,7 @@ public FunctionHandle lookupFunction(QualifiedObjectName functionName, List candidates = functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); - Optional match = functionSignatureMatcher.match(candidates, parameterTypes, false); - if (!match.isPresent()) { - throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, candidates)); - } - - return functionNamespaceManager.get().getFunctionHandle(Optional.empty(), match.get()); + return getMatchingFunctionHandle(functionName, Optional.empty(), functionNamespaceManager.get(), parameterTypes, false); } public FunctionHandle lookupCast(CastType castType, Type fromType, Type toType) @@ -738,13 +882,18 @@ public CatalogSchemaName getDefaultNamespace() return defaultNamespace; } + public HandleResolver getHandleResolver() + { + return handleResolver; + } + protected Type getType(UserDefinedType userDefinedType) { // Distinct type if (userDefinedType.isDistinctType()) { return getDistinctType(userDefinedType.getPhysicalTypeSignature().getParameters().get(0).getDistinctTypeInfo()); } - // Enum type + // Enum type or primitive type with name return getType(new TypeSignature(userDefinedType)); } @@ -780,11 +929,14 @@ private FunctionHandle resolveFunctionInternal(Optional transacti return functionNamespaceManager.resolveFunction(transactionHandle, functionName, parameterTypes.stream().map(TypeSignatureProvider::getTypeSignature).collect(toImmutableList())); } - Collection candidates = functionNamespaceManager.getFunctions(transactionHandle, functionName); - - Optional match = functionSignatureMatcher.match(candidates, parameterTypes, true); - if (match.isPresent()) { - return functionNamespaceManager.getFunctionHandle(transactionHandle, match.get()); + try { + return getMatchingFunctionHandle(functionName, transactionHandle, functionNamespaceManager, parameterTypes, true); + } + catch (PrestoException e) { + // Could still match to a magic literal function + if (e.getErrorCode().getCode() != StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) { + throw e; + } } if (functionName.getObjectName().startsWith(MAGIC_LITERAL_FUNCTION_PREFIX)) { @@ -800,7 +952,8 @@ private FunctionHandle resolveFunctionInternal(Optional transacti return new BuiltInFunctionHandle(getMagicLiteralFunctionSignature(type)); } - throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, candidates)); + throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage( + functionName, parameterTypes, getFunctions(functionName, transactionHandle, functionNamespaceManager))); } private FunctionHandle resolveBuiltInFunction(QualifiedObjectName functionName, List parameterTypes) @@ -824,7 +977,7 @@ private FunctionHandle lookupCachedFunction(QualifiedObjectName functionName, Li } } - private Optional> getServingFunctionNamespaceManager(CatalogSchemaName functionNamespace) + public Optional> getServingFunctionNamespaceManager(CatalogSchemaName functionNamespace) { return Optional.ofNullable(functionNamespaceManagers.get(functionNamespace.getCatalogName())); } @@ -835,7 +988,6 @@ private Optional> getServingFunc } @Override - @SuppressWarnings("unchecked") public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) { QualifiedObjectName functionName = signature.getName(); @@ -844,8 +996,13 @@ public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) throw new PrestoException(FUNCTION_NOT_FOUND, format("Cannot find function namespace for signature '%s'", functionName)); } - Collection candidates = (Collection) functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); + Collection candidates = functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); + return getSpecializedFunctionKey(signature, candidates); + } + + public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature, Collection candidates) + { // search for exact match Type returnType = getType(signature.getReturnType()); List argumentTypeSignatureProviders = fromTypeSignatures(signature.getArgumentTypes()); @@ -894,7 +1051,12 @@ public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) return builtInTypeAndFunctionNamespaceManager.doGetSpecializedFunctionKeyForMagicLiteralFunctions(signature, this); } - public CatalogSchemaName configureDefaultNamespace(String defaultNamespacePrefixString) + public BuiltInPluginFunctionNamespaceManager getBuiltInPluginFunctionNamespaceManager() + { + return builtInPluginFunctionNamespaceManager; + } + + private CatalogSchemaName configureDefaultNamespace(String defaultNamespacePrefixString) { if (!defaultNamespacePrefixString.matches(DEFAULT_NAMESPACE_PREFIX_PATTERN.pattern())) { throw new PrestoException(GENERIC_USER_ERROR, format("Default namespace prefix string should be in the form of 'catalog.schema', found: %s", defaultNamespacePrefixString)); @@ -909,6 +1071,142 @@ private Map getServingTypeManagerParametricTypes() .collect(toImmutableMap(ParametricType::getName, parametricType -> parametricType)); } + private Collection getFunctions( + QualifiedObjectName functionName, + Optional transactionHandle, + FunctionNamespaceManager functionNamespaceManager) + { + return ImmutableList.builder() + .addAll(functionNamespaceManager.getFunctions(transactionHandle, functionName)) + .addAll(builtInPluginFunctionNamespaceManager.getFunctions(transactionHandle, functionName)) + .addAll(builtInWorkerFunctionNamespaceManager.getFunctions(transactionHandle, functionName)) + .build(); + } + + /** + * Gets the function handle of the function if there is a match. We enforce explicit naming for dynamic function namespaces. + * All unqualified function names will only be resolved against the built-in default function namespace. We get all the candidates + * from the current default namespace and additionally all the candidates from builtInPluginFunctionNamespaceManager and + * builtInWorkerFunctionNamespaceManager. + * + * @throws PrestoException if there are no matches or multiple matches + */ + private FunctionHandle getMatchingFunctionHandle( + QualifiedObjectName functionName, + Optional transactionHandle, + FunctionNamespaceManager functionNamespaceManager, + List parameterTypes, + boolean coercionAllowed) + { + boolean foundMatch = false; + List exceptions = new ArrayList<>(); + List allCandidates = new ArrayList<>(); + Optional matchingDefaultFunctionSignature = Optional.empty(); + Optional matchingPluginFunctionSignature = Optional.empty(); + Optional matchingWorkerFunctionSignature = Optional.empty(); + + try { + Collection defaultCandidates = functionNamespaceManager.getFunctions(transactionHandle, functionName); + allCandidates.addAll(defaultCandidates); + matchingDefaultFunctionSignature = + getMatchingFunction(defaultCandidates, parameterTypes, coercionAllowed); + if (matchingDefaultFunctionSignature.isPresent()) { + foundMatch = true; + } + } + catch (SemanticException e) { + exceptions.add(e); + } + + try { + Collection pluginCandidates = builtInPluginFunctionNamespaceManager.getFunctions(transactionHandle, functionName); + allCandidates.addAll(pluginCandidates); + matchingPluginFunctionSignature = + getMatchingFunction(pluginCandidates, parameterTypes, coercionAllowed); + if (matchingPluginFunctionSignature.isPresent()) { + foundMatch = true; + } + } + catch (SemanticException e) { + exceptions.add(e); + } + + try { + Collection workerCandidates = builtInWorkerFunctionNamespaceManager.getFunctions(transactionHandle, functionName); + allCandidates.addAll(workerCandidates); + matchingWorkerFunctionSignature = + getMatchingFunction(workerCandidates, parameterTypes, coercionAllowed); + if (matchingWorkerFunctionSignature.isPresent()) { + foundMatch = true; + } + } + catch (SemanticException e) { + exceptions.add(e); + } + + if (!foundMatch && !exceptions.isEmpty()) { + decideAndThrow(exceptions, + allCandidates.stream().findFirst() + .map(function -> function.getSignature().getName().getObjectName()) + .orElse("")); + } + + if (matchingDefaultFunctionSignature.isPresent() && matchingPluginFunctionSignature.isPresent()) { + throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Function '%s' has two matching signatures. Please specify parameter types. \n" + + "First match : '%s', Second match: '%s'", functionName, matchingDefaultFunctionSignature.get(), matchingPluginFunctionSignature.get())); + } + + if (matchingDefaultFunctionSignature.isPresent() && matchingWorkerFunctionSignature.isPresent()) { + FunctionHandle defaultFunctionHandle = functionNamespaceManager.getFunctionHandle(transactionHandle, matchingDefaultFunctionSignature.get()); + FunctionHandle workerFunctionHandle = builtInWorkerFunctionNamespaceManager.getFunctionHandle(transactionHandle, matchingWorkerFunctionSignature.get()); + + if (functionNamespaceManager.getFunctionMetadata(defaultFunctionHandle).getImplementationType().equals(FunctionImplementationType.JAVA)) { + return defaultFunctionHandle; + } + if (functionNamespaceManager.getFunctionMetadata(defaultFunctionHandle).getImplementationType().equals(FunctionImplementationType.SQL)) { + return workerFunctionHandle; + } + } + + if (matchingPluginFunctionSignature.isPresent() && matchingWorkerFunctionSignature.isPresent()) { + // built in plugin function namespace manager always has SQL as implementation type + return builtInWorkerFunctionNamespaceManager.getFunctionHandle(transactionHandle, matchingWorkerFunctionSignature.get()); + } + + if (matchingWorkerFunctionSignature.isPresent()) { + return builtInWorkerFunctionNamespaceManager.getFunctionHandle(transactionHandle, matchingWorkerFunctionSignature.get()); + } + + if (matchingPluginFunctionSignature.isPresent()) { + return builtInPluginFunctionNamespaceManager.getFunctionHandle(transactionHandle, matchingPluginFunctionSignature.get()); + } + + if (matchingDefaultFunctionSignature.isPresent()) { + return functionNamespaceManager.getFunctionHandle(transactionHandle, matchingDefaultFunctionSignature.get()); + } + + throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, + getFunctions(functionName, transactionHandle, functionNamespaceManager))); + } + + private Optional getMatchingFunction( + Collection candidates, + List parameterTypes, + boolean coercionAllowed) + { + return functionSignatureMatcher.match(candidates, parameterTypes, coercionAllowed); + } + + private boolean isBuiltInPluginFunctionHandle(FunctionHandle functionHandle) + { + return (functionHandle instanceof BuiltInFunctionHandle) && ((BuiltInFunctionHandle) functionHandle).getBuiltInFunctionKind().equals(PLUGIN); + } + + private boolean isBuiltInWorkerFunctionHandle(FunctionHandle functionHandle) + { + return (functionHandle instanceof BuiltInFunctionHandle) && ((BuiltInFunctionHandle) functionHandle).getBuiltInFunctionKind().equals(WORKER); + } + private static class FunctionResolutionCacheKey { private final QualifiedObjectName functionName; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java index 9521027f30812..7cdbd512a7ce3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.operator.scalar.annotations.CodegenScalarFromAnnotationsParser; import com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser; import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser; @@ -28,46 +29,62 @@ import java.util.Collection; import java.util.List; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; public final class FunctionExtractor { private FunctionExtractor() {} public static List extractFunctions(Collection> classes) + { + return extractFunctions(classes, JAVA_BUILTIN_NAMESPACE); + } + + public static List extractFunctions(Collection> classes, CatalogSchemaName functionNamespace) { return classes.stream() - .map(FunctionExtractor::extractFunctions) + .map(c -> extractFunctions(c, functionNamespace)) .flatMap(Collection::stream) .collect(toImmutableList()); } public static List extractFunctions(Class clazz) + { + return extractFunctions(clazz, JAVA_BUILTIN_NAMESPACE); + } + + public static List extractFunctions(Class clazz, CatalogSchemaName defaultNamespace) { if (WindowFunction.class.isAssignableFrom(clazz)) { + checkArgument(defaultNamespace.equals(JAVA_BUILTIN_NAMESPACE), format("Connector specific Window functions are not supported: Class [%s], Namespace [%s]", clazz.getName(), defaultNamespace)); @SuppressWarnings("unchecked") Class windowClazz = (Class) clazz; return WindowAnnotationsParser.parseFunctionDefinition(windowClazz); } if (clazz.isAnnotationPresent(AggregationFunction.class)) { - return SqlAggregationFunction.createFunctionsByAnnotations(clazz); + return SqlAggregationFunction.createFunctionsByAnnotations(clazz, defaultNamespace); } - if (clazz.isAnnotationPresent(ScalarFunction.class) || - clazz.isAnnotationPresent(ScalarOperator.class)) { + if (clazz.isAnnotationPresent(ScalarFunction.class)) { + return ScalarFromAnnotationsParser.parseFunctionDefinition(clazz, defaultNamespace); + } + if (clazz.isAnnotationPresent(ScalarOperator.class)) { + checkArgument(defaultNamespace.equals(JAVA_BUILTIN_NAMESPACE), format("Connector specific Scalar Operator functions are not supported: Class [%s], Namespace [%s]", clazz.getName(), defaultNamespace)); return ScalarFromAnnotationsParser.parseFunctionDefinition(clazz); } if (clazz.isAnnotationPresent(SqlInvokedScalarFunction.class)) { - return SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz); + return SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz, defaultNamespace); } List scalarFunctions = ImmutableList.builder() - .addAll(ScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) - .addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) - .addAll(CodegenScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) + .addAll(ScalarFromAnnotationsParser.parseFunctionDefinitions(clazz, defaultNamespace)) + .addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz, defaultNamespace)) + .addAll(CodegenScalarFromAnnotationsParser.parseFunctionDefinitions(clazz, defaultNamespace)) .build(); checkArgument(!scalarFunctions.isEmpty(), "Class [%s] does not define any scalar functions", clazz.getName()); return scalarFunctions; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java index f8a8a90053508..ef7b7529c76d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionHandleJacksonModule.java @@ -14,8 +14,7 @@ package com.facebook.presto.metadata; import com.facebook.presto.spi.function.FunctionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class FunctionHandleJacksonModule extends AbstractTypedJacksonModule @@ -23,6 +22,13 @@ public class FunctionHandleJacksonModule @Inject public FunctionHandleJacksonModule(HandleResolver handleResolver) { - super(FunctionHandle.class, handleResolver::getId, handleResolver::getFunctionHandleClass); + // Functions are internal to Presto and don't need binary serialization + super(FunctionHandle.class, + handleResolver::getId, + handleResolver::getFunctionHandleClass, + false, // Always disabled for functions + connectorId -> { + throw new UnsupportedOperationException("Function handles do not support binary serialization"); + }); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java index 13a90da951278..799dd5f8074fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java @@ -15,7 +15,6 @@ import com.facebook.presto.operator.scalar.annotations.CodegenScalarFromAnnotationsParser; import com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser; -import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser; import com.facebook.presto.operator.window.WindowAnnotationsParser; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.function.WindowFunction; @@ -60,18 +59,6 @@ public FunctionListBuilder scalars(Class clazz) return this; } - public FunctionListBuilder sqlInvokedScalar(Class clazz) - { - functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz)); - return this; - } - - public FunctionListBuilder sqlInvokedScalars(Class clazz) - { - functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)); - return this; - } - public FunctionListBuilder codegenScalars(Class clazz) { functions.addAll(CodegenScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java new file mode 100644 index 0000000000000..d9f2f22f43900 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.function.SqlFunction; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class FunctionMap +{ + private final Multimap functions; + + public FunctionMap() + { + functions = ImmutableListMultimap.of(); + } + + public FunctionMap(FunctionMap map, Iterable functions) + { + requireNonNull(map, "map is null"); + requireNonNull(functions, "functions is null"); + this.functions = ImmutableListMultimap.builder() + .putAll(map.functions) + .putAll(Multimaps.index(functions, function -> function.getSignature().getName())) + .build(); + + // Make sure all functions with the same name are aggregations or none of them are + for (Map.Entry> entry : this.functions.asMap().entrySet()) { + Collection values = entry.getValue(); + long aggregations = values.stream() + .map(function -> function.getSignature().getKind()) + .filter(kind -> kind == AGGREGATE) + .count(); + checkState(aggregations == 0 || aggregations == values.size(), "'%s' is both an aggregation and a scalar function", entry.getKey()); + } + } + + public List list() + { + return ImmutableList.copyOf(functions.values()); + } + + public Collection get(QualifiedObjectName name) + { + return functions.get(name); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionSignatureMatcher.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionSignatureMatcher.java index 2ac1ee920a13e..bba1d4fb1e1b4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionSignatureMatcher.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionSignatureMatcher.java @@ -71,7 +71,7 @@ public Optional match(Collection candidates, L .filter(function -> !function.getSignature().getTypeVariableConstraints().isEmpty()) .collect(Collectors.toList()); - match = matchFunctionExact(genericCandidates, parameterTypes); + match = matchFunctionGeneric(genericCandidates, parameterTypes); if (match.isPresent()) { return match; } @@ -91,6 +91,28 @@ private Optional matchFunctionExact(List candidates, Lis return matchFunction(candidates, actualParameters, false); } + private Optional matchFunctionGeneric(List candidates, List actualParameters) + { + List applicableFunctions = identifyApplicableFunctions(candidates, actualParameters, false); + if (applicableFunctions.isEmpty()) { + return Optional.empty(); + } + + if (applicableFunctions.size() == 1) { + return Optional.of(getOnlyElement(applicableFunctions).getBoundSignature()); + } + + List deduplicatedSignatures = applicableFunctions.stream() + .map(applicableFunction -> applicableFunction.boundSignature) + .distinct() + .collect(toImmutableList()); + if (deduplicatedSignatures.size() == 1) { + return Optional.of(getOnlyElement(deduplicatedSignatures)); + } + + throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, getErrorMessage(applicableFunctions)); + } + private Optional matchFunctionWithCoercion(Collection candidates, List actualParameters) { return matchFunction(candidates, actualParameters, true); @@ -112,6 +134,11 @@ private Optional matchFunction(Collection cand return Optional.of(getOnlyElement(applicableFunctions).getBoundSignature()); } + throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, getErrorMessage(applicableFunctions)); + } + + private String getErrorMessage(List applicableFunctions) + { StringBuilder errorMessageBuilder = new StringBuilder(); errorMessageBuilder.append("Could not choose a best candidate operator. Explicit type casts must be added.\n"); errorMessageBuilder.append("Candidates are:\n"); @@ -120,7 +147,8 @@ private Optional matchFunction(Collection cand errorMessageBuilder.append(function.getBoundSignature().toString()); errorMessageBuilder.append("\n"); } - throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, errorMessageBuilder.toString()); + + return errorMessageBuilder.toString(); } private List identifyApplicableFunctions(Collection candidates, List actualParameters, boolean allowCoercion) @@ -203,9 +231,10 @@ private List selectMostSpecificFunctions(List selectMostSpecificFunctions(List getUnknownOnlyCastFunctions(List applicableFunction, List actualParameters) @@ -307,7 +338,7 @@ private static boolean returnsNullOnGivenInputTypes(ApplicableFunction applicabl * If there's only one SemanticException, it throws that SemanticException directly. * If there are multiple SemanticExceptions, it throws the SignatureMatchingException. */ - private static void decideAndThrow(List failedExceptions, String functionName) + public static void decideAndThrow(List failedExceptions, String functionName) throws SemanticException { if (failedExceptions.size() == 1) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java index 61eae56f41895..43c0ad100f528 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleJsonModule.java @@ -23,6 +23,18 @@ public class HandleJsonModule implements Module { + private final HandleResolver handleResolver; + + public HandleJsonModule() + { + this(null); + } + + public HandleJsonModule(HandleResolver handleResolver) + { + this.handleResolver = handleResolver; + } + @Override public void configure(Binder binder) { @@ -33,12 +45,19 @@ public void configure(Binder binder) jsonBinder(binder).addModuleBinding().to(OutputTableHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(InsertTableHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(DeleteTableHandleJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(MergeTableHandleJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(DistributedProcedureHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(IndexHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(TransactionHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(PartitioningHandleJacksonModule.class); jsonBinder(binder).addModuleBinding().to(FunctionHandleJacksonModule.class); - jsonBinder(binder).addModuleBinding().to(MetadataUpdateJacksonModule.class); + jsonBinder(binder).addModuleBinding().to(TableFunctionJacksonHandleModule.class); - binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + if (handleResolver == null) { + binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + } + else { + binder.bind(HandleResolver.class).toInstance(handleResolver); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java index 09992ef314575..7d039db3d8d48 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/HandleResolver.java @@ -15,12 +15,15 @@ import com.facebook.presto.connector.informationSchema.InformationSchemaHandleResolver; import com.facebook.presto.connector.system.SystemHandleResolver; +import com.facebook.presto.operator.table.ExcludeColumns; +import com.facebook.presto.operator.table.Sequence; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorIndexHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -29,13 +32,17 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.split.EmptySplitHandleResolver; - -import javax.inject.Inject; +import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -50,6 +57,8 @@ public class HandleResolver { private final ConcurrentMap handleResolvers = new ConcurrentHashMap<>(); private final ConcurrentMap functionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionHandleResolvers = new ConcurrentHashMap<>(); + private final ConcurrentMap> tableFunctionSplitResolvers = new ConcurrentHashMap<>(); @Inject public HandleResolver() @@ -61,6 +70,17 @@ public HandleResolver() functionHandleResolvers.put("$static", new MaterializedFunctionHandleResolver(new BuiltInFunctionNamespaceHandleResolver())); functionHandleResolvers.put("$session", new MaterializedFunctionHandleResolver(new SessionFunctionHandleResolver())); + + tableFunctionHandleResolvers.put( + "$system", + new MaterializedResolver<>(() -> ImmutableSet.of( + ExcludeColumns.ExcludeColumnsFunctionHandle.class, + Sequence.SequenceFunctionHandle.class))); + + tableFunctionSplitResolvers.put( + "$system", + new MaterializedResolver<>(() -> + ImmutableSet.of(Sequence.SequenceFunctionSplit.class))); } public void addConnectorName(String name, ConnectorHandleResolver resolver) @@ -72,12 +92,39 @@ public void addConnectorName(String name, ConnectorHandleResolver resolver) "Connector '%s' is already assigned to resolver: %s", name, existingResolver); } + public void addTableFunctionNamespace(String name, TableFunctionHandleResolver resolver) + { + addNamespace(name, resolver::getTableFunctionHandleClasses, tableFunctionHandleResolvers); + } + + public void addTableFunctionSplitNamespace(String name, TableFunctionSplitResolver resolver) + { + addNamespace(name, resolver::getTableFunctionSplitClasses, tableFunctionSplitResolvers); + } + + private void addNamespace( + String name, + Supplier>> classSupplier, + ConcurrentMap> resolverMap) + { + requireNonNull(name, "name is null"); + requireNonNull(classSupplier, "classSupplier is null"); + + MaterializedResolver newResolver = new MaterializedResolver<>(classSupplier); + MaterializedResolver existingResolver = resolverMap.putIfAbsent(name, newResolver); + + checkState( + existingResolver == null || existingResolver.equals(newResolver), + "Name %s is already assigned to table function resolver: %s", name, existingResolver); + } + public void addFunctionNamespace(String name, FunctionHandleResolver resolver) { requireNonNull(name, "name is null"); requireNonNull(resolver, "resolver is null"); - MaterializedFunctionHandleResolver existingResolver = functionHandleResolvers.putIfAbsent(name, new MaterializedFunctionHandleResolver(resolver)); - checkState(existingResolver == null || existingResolver.equals(resolver), "Name %s is already assigned to function resolver: %s", name, existingResolver); + MaterializedFunctionHandleResolver materializedFunctionHandleResolver = new MaterializedFunctionHandleResolver(resolver); + MaterializedFunctionHandleResolver existingResolver = functionHandleResolvers.putIfAbsent(name, materializedFunctionHandleResolver); + checkState(existingResolver == null || existingResolver.equals(materializedFunctionHandleResolver), "Name %s is already assigned to function resolver: %s", name, existingResolver); } public String getId(ConnectorTableHandle tableHandle) @@ -97,6 +144,18 @@ public String getId(ColumnHandle columnHandle) public String getId(ConnectorSplit split) { + // First check if this is a table function split + for (Entry> entry : tableFunctionSplitResolvers.entrySet()) { + Optional id = entry.getValue().getClasses().stream() + .filter(clazz -> clazz.isInstance(split)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + + // Fall back to regular connector splits return getId(split, MaterializedHandleResolver::getSplitClass); } @@ -120,6 +179,11 @@ public String getId(ConnectorDeleteTableHandle deleteHandle) return getId(deleteHandle, MaterializedHandleResolver::getDeleteTableHandleClass); } + public String getId(ConnectorDistributedProcedureHandle distributedProcedureHandle) + { + return getId(distributedProcedureHandle, MaterializedHandleResolver::getDistributedProcedureHandleClass); + } + public String getId(ConnectorPartitioningHandle partitioningHandle) { return getId(partitioningHandle, MaterializedHandleResolver::getPartitioningHandleClass); @@ -135,9 +199,23 @@ public String getId(FunctionHandle functionHandle) return getFunctionNamespaceId(functionHandle, MaterializedFunctionHandleResolver::getFunctionHandleClass); } - public String getId(ConnectorMetadataUpdateHandle metadataUpdateHandle) + public String getId(ConnectorMergeTableHandle mergeHandle) { - return getId(metadataUpdateHandle, MaterializedHandleResolver::getMetadataUpdateHandleClass); + return getId(mergeHandle, MaterializedHandleResolver::getMergeTableHandleClass); + } + + public String getId(ConnectorTableFunctionHandle tableFunctionHandle) + { + for (Entry> entry : tableFunctionHandleResolvers.entrySet()) { + Optional id = entry.getValue().getClasses().stream() + .filter(clazz -> clazz.isInstance(tableFunctionHandle)) + .map(Class::getName) + .findFirst(); + if (id.isPresent()) { + return entry.getKey() + ":" + id.get(); + } + } + throw new IllegalArgumentException("No function namespace for table function handle: " + tableFunctionHandle); } public Class getTableHandleClass(String id) @@ -157,7 +235,17 @@ public Class getColumnHandleClass(String id) public Class getSplitClass(String id) { - return resolverFor(id).getSplitClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + for (Entry> entry : tableFunctionSplitResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionSplit = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionSplit.isPresent()) { + return tableFunctionSplit.get(); + } + } + return resolverFor(id).getSplitClass() + .orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } public Class getIndexHandleClass(String id) @@ -180,6 +268,16 @@ public Class getDeleteTableHandleClass(Str return resolverFor(id).getDeleteTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } + public Class getMergeTableHandleClass(String id) + { + return resolverFor(id).getMergeTableHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + } + + public Class getDistributedProcedureHandleClass(String id) + { + return resolverFor(id).getDistributedProcedureHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + } + public Class getPartitioningHandleClass(String id) { return resolverFor(id).getPartitioningHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); @@ -195,9 +293,18 @@ public Class getFunctionHandleClass(String id) return resolverForFunctionNamespace(id).getFunctionHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); } - public Class getMetadataUpdateHandleClass(String id) + public Class getTableFunctionHandleClass(String id) { - return resolverFor(id).getMetadataUpdateHandleClass().orElseThrow(() -> new IllegalArgumentException("No resolver for " + id)); + for (Entry> entry : tableFunctionHandleResolvers.entrySet()) { + MaterializedResolver resolver = entry.getValue(); + Optional> tableFunctionHandle = resolver.getClasses().stream() + .filter(handle -> (entry.getKey() + ":" + handle.getName()).equals(id)) + .findFirst(); + if (tableFunctionHandle.isPresent()) { + return tableFunctionHandle.get(); + } + } + throw new IllegalArgumentException("No handle resolver for table function namespace: " + id); } private MaterializedHandleResolver resolverFor(String id) @@ -252,9 +359,11 @@ private static class MaterializedHandleResolver private final Optional> outputTableHandle; private final Optional> insertTableHandle; private final Optional> deleteTableHandle; + private final Optional> mergeTableHandle; + private final Optional> distributedProcedureHandle; private final Optional> partitioningHandle; private final Optional> transactionHandle; - private final Optional> metadataUpdateHandle; + private final Optional> tableFunctionHandle; public MaterializedHandleResolver(ConnectorHandleResolver resolver) { @@ -266,9 +375,11 @@ public MaterializedHandleResolver(ConnectorHandleResolver resolver) outputTableHandle = getHandleClass(resolver::getOutputTableHandleClass); insertTableHandle = getHandleClass(resolver::getInsertTableHandleClass); deleteTableHandle = getHandleClass(resolver::getDeleteTableHandleClass); + mergeTableHandle = getHandleClass(resolver::getMergeTableHandleClass); partitioningHandle = getHandleClass(resolver::getPartitioningHandleClass); transactionHandle = getHandleClass(resolver::getTransactionHandleClass); - metadataUpdateHandle = getHandleClass(resolver::getMetadataUpdateHandleClass); + distributedProcedureHandle = getHandleClass(resolver::getDistributedProcedureHandleClass); + tableFunctionHandle = getHandleClass(resolver::getTableFunctionHandleClass); } private static Optional> getHandleClass(Supplier> callable) @@ -321,6 +432,16 @@ public Optional> getDeleteTableHandl return deleteTableHandle; } + public Optional> getMergeTableHandleClass() + { + return mergeTableHandle; + } + + public Optional> getDistributedProcedureHandleClass() + { + return distributedProcedureHandle; + } + public Optional> getPartitioningHandleClass() { return partitioningHandle; @@ -331,9 +452,9 @@ public Optional> getTransactionHandl return transactionHandle; } - public Optional> getMetadataUpdateHandleClass() + public Optional> getTableFunctionHandleClass() { - return metadataUpdateHandle; + return tableFunctionHandle; } @Override @@ -354,15 +475,16 @@ public boolean equals(Object o) Objects.equals(outputTableHandle, that.outputTableHandle) && Objects.equals(insertTableHandle, that.insertTableHandle) && Objects.equals(deleteTableHandle, that.deleteTableHandle) && + Objects.equals(mergeTableHandle, that.mergeTableHandle) && Objects.equals(partitioningHandle, that.partitioningHandle) && Objects.equals(transactionHandle, that.transactionHandle) && - Objects.equals(metadataUpdateHandle, that.metadataUpdateHandle); + Objects.equals(tableFunctionHandle, that.tableFunctionHandle); } @Override public int hashCode() { - return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, partitioningHandle, transactionHandle, metadataUpdateHandle); + return Objects.hash(tableHandle, layoutHandle, columnHandle, split, indexHandle, outputTableHandle, insertTableHandle, deleteTableHandle, mergeTableHandle, partitioningHandle, transactionHandle, tableFunctionHandle); } } @@ -409,4 +531,48 @@ public int hashCode() return Objects.hash(functionHandle); } } + + private static class MaterializedResolver + { + private final Set> classes; + + public MaterializedResolver(Supplier>> classSupplier) + { + this.classes = getSafe(classSupplier); + } + + private static Set> getSafe(Supplier>> classSupplier) + { + try { + return classSupplier.get(); + } + catch (UnsupportedOperationException e) { + return ImmutableSet.of(); + } + } + + public Set> getClasses() + { + return classes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MaterializedResolver that = (MaterializedResolver) o; + return Objects.equals(classes, that.classes); + } + + @Override + public int hashCode() + { + return Objects.hash(classes); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/InMemoryNodeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/InMemoryNodeManager.java index d99d860f59aa8..f0a3ca28578fa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/InMemoryNodeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/InMemoryNodeManager.java @@ -15,19 +15,20 @@ import com.facebook.presto.client.NodeVersion; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeLoadMetrics; import com.facebook.presto.spi.NodeState; import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimaps; import com.google.common.collect.SetMultimap; - -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.inject.Inject; import java.net.URI; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.stream.Stream; @@ -180,4 +181,10 @@ public synchronized void removeNodeChangeListener(Consumer listener) { listeners.remove(requireNonNull(listener, "listener is null")); } + + @Override + public Optional getNodeLoadMetrics(String nodeIdentifier) + { + return Optional.empty(); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandle.java index 6cf3e0d2525d9..bb286e048185f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandle.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -23,6 +26,7 @@ import static java.util.Objects.requireNonNull; +@ThriftStruct public final class InsertTableHandle { private final ConnectorId connectorId; @@ -30,6 +34,7 @@ public final class InsertTableHandle private final ConnectorInsertTableHandle connectorHandle; @JsonCreator + @ThriftConstructor public InsertTableHandle( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, @@ -41,18 +46,21 @@ public InsertTableHandle( } @JsonProperty + @ThriftField(1) public ConnectorId getConnectorId() { return connectorId; } @JsonProperty + @ThriftField(2) public ConnectorTransactionHandle getTransactionHandle() { return transactionHandle; } @JsonProperty + @ThriftField(3) public ConnectorInsertTableHandle getConnectorHandle() { return connectorHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java index 6e6795ede2f56..5eebd9311d4bf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/InsertTableHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class InsertTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public InsertTableHandleJacksonModule(HandleResolver handleResolver) + public InsertTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorInsertTableHandle.class, handleResolver::getId, - handleResolver::getInsertTableHandleClass); + handleResolver::getInsertTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorInsertTableHandleCodec)); + } + + public InsertTableHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorInsertTableHandle.class, + handleResolver::getId, + handleResolver::getInsertTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/InternalNodeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/InternalNodeManager.java index 81fa60c8165d7..c080dc40fd42d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/InternalNodeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/InternalNodeManager.java @@ -14,8 +14,10 @@ package com.facebook.presto.metadata; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeLoadMetrics; import com.facebook.presto.spi.NodeState; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -46,4 +48,6 @@ public interface InternalNodeManager void addNodeChangeListener(Consumer listener); void removeNodeChangeListener(Consumer listener); + + Optional getNodeLoadMetrics(String nodeIdentifier); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MaterializedViewPropertyManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MaterializedViewPropertyManager.java new file mode 100644 index 0000000000000..58202510e723e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MaterializedViewPropertyManager.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_MATERIALIZED_VIEW_PROPERTY; + +public class MaterializedViewPropertyManager + extends AbstractPropertyManager +{ + public MaterializedViewPropertyManager() + { + super("materialized view", INVALID_MATERIALIZED_VIEW_PROPERTY); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MergeTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MergeTableHandleJacksonModule.java new file mode 100644 index 0000000000000..1669a6248ad04 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MergeTableHandleJacksonModule.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Provider; + +import javax.inject.Inject; + +import java.util.Optional; +import java.util.function.Function; + +public class MergeTableHandleJacksonModule + extends AbstractTypedJacksonModule +{ + @Inject + public MergeTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) + { + super(ConnectorMergeTableHandle.class, + handleResolver::getId, + handleResolver::getMergeTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorMergeTableHandleCodec)); + } + + public MergeTableHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorMergeTableHandle.class, + handleResolver::getId, + handleResolver::getMergeTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java index a316277f8a96d..41c53f31d775f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Metadata.java @@ -20,7 +20,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.execution.QueryManager; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; @@ -28,9 +28,10 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewStatus; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableLayoutFilterCoverage; @@ -42,9 +43,12 @@ import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.RowChangeParadigm; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.GrantInfo; @@ -54,6 +58,7 @@ import com.facebook.presto.spi.statistics.ComputedStatistics; import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.spi.statistics.TableStatisticsMetadata; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; @@ -74,6 +79,8 @@ public interface Metadata void registerBuiltInFunctions(List functions); + void registerConnectorFunctions(String catalogName, List functionInfos); + List listSchemaNames(Session session, String catalogName); /** @@ -309,12 +316,12 @@ public interface Metadata /** * Get the row ID column handle used with UpdatablePageSource#deleteRows. */ - ColumnHandle getDeleteRowIdColumnHandle(Session session, TableHandle tableHandle); + Optional getDeleteRowIdColumn(Session session, TableHandle tableHandle); /** * Get the row ID column handle used with UpdatablePageSource. */ - ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle, List updatedColumns); + Optional getUpdateRowIdColumn(Session session, TableHandle tableHandle, List updatedColumns); /** * @return whether delete without table scan is supported @@ -336,7 +343,18 @@ public interface Metadata /** * Finish delete query */ - void finishDelete(Session session, DeleteTableHandle tableHandle, Collection fragments); + Optional finishDeleteWithOutput(Session session, DeleteTableHandle tableHandle, Collection fragments); + + /** + * Begin call distributed procedure + */ + DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, + TableHandle tableHandle, Object[] arguments, boolean sourceTableEliminated); + + /** + * Finish call distributed procedure + */ + void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments); /** * Begin update query @@ -348,6 +366,29 @@ public interface Metadata */ void finishUpdate(Session session, TableHandle tableHandle, Collection fragments); + /** + * Return the row update paradigm supported by the connector on the table or throw + * an exception if row change is not supported. + */ + RowChangeParadigm getRowChangeParadigm(Session session, TableHandle tableHandle); + + /** + * Get the column handle that will generate row IDs for the merge operation. + * These IDs will be passed to the {@code storeMergedRows()} method of the + * {@link com.facebook.presto.spi.ConnectorMergeSink} that created them. + */ + ColumnHandle getMergeTargetTableRowIdColumnHandle(Session session, TableHandle tableHandle); + + /** + * Begin merge query + */ + MergeHandle beginMerge(Session session, TableHandle tableHandle); + + /** + * Finish merge query + */ + void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics); + /** * Returns a connector id for the specified catalog name. */ @@ -360,6 +401,11 @@ public interface Metadata */ Map getCatalogNames(Session session); + default Map getCatalogNamesWithConnectorContext(Session session) + { + return ImmutableMap.of(); + } + /** * Get the names that match the specified table prefix (never null). */ @@ -395,6 +441,19 @@ public interface Metadata */ void dropMaterializedView(Session session, QualifiedObjectName viewName); + /** + * List materialized views in the specified schema prefix. + */ + List listMaterializedViews(Session session, QualifiedTablePrefix prefix); + + /** + * Get materialized view definitions for all materialized views matching the prefix. + * This is used by information_schema to efficiently retrieve view definitions. + */ + Map getMaterializedViews( + Session session, + QualifiedTablePrefix prefix); + /** * Begin refresh materialized view */ @@ -410,6 +469,11 @@ public interface Metadata */ List getReferencedMaterializedViews(Session session, QualifiedObjectName tableName); + /** + * Gets the status of a materialized view (freshness state) + */ + MaterializedViewStatus getMaterializedViewStatus(Session session, QualifiedObjectName viewName, TupleDomain baseQueryDomain); + /** * Try to locate a table index that can lookup results by indexableColumns and provide the requested outputColumns. */ @@ -494,8 +558,6 @@ public interface Metadata @Experimental ListenableFuture commitPageSinkAsync(Session session, DeleteTableHandle tableHandle, Collection fragments); - MetadataUpdates getMetadataUpdateResults(Session session, QueryManager queryManager, MetadataUpdates metadataUpdates, QueryId queryId); - // TODO: metadata should not provide FunctionAndTypeManager FunctionAndTypeManager getFunctionAndTypeManager(); @@ -509,6 +571,8 @@ public interface Metadata TablePropertyManager getTablePropertyManager(); + MaterializedViewPropertyManager getMaterializedViewPropertyManager(); + ColumnPropertyManager getColumnPropertyManager(); AnalyzePropertyManager getAnalyzePropertyManager(); @@ -527,6 +591,10 @@ default TableLayoutFilterCoverage getTableLayoutFilterCoverage(Session session, return NOT_APPLICABLE; } + void dropBranch(Session session, TableHandle tableHandle, String branchName, boolean branchExists); + + void dropTag(Session session, TableHandle tableHandle, String tagName, boolean tagExists); + void dropConstraint(Session session, TableHandle tableHandle, Optional constraintName, Optional columnName); void addConstraint(Session session, TableHandle tableHandle, TableConstraint tableConstraint); @@ -540,4 +608,13 @@ default boolean isPushdownSupportedForFilter(Session session, TableHandle tableH { return false; } + + String normalizeIdentifier(Session session, String catalogName, String identifier); + + /** + * Attempt to push down the table function invocation into the connector. + * @return {@link Optional#empty()} if the connector doesn't support table function invocation pushdown, + * or an {@code Optional>} containing the table handle that will be used in place of the table function invocation. + */ + Optional> applyTableFunction(Session session, TableFunctionHandle handle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataListing.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataListing.java index 15aff30e715c8..d71bafb9adf05 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataListing.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataListing.java @@ -14,6 +14,8 @@ package com.facebook.presto.metadata; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.SchemaTableName; @@ -53,6 +55,20 @@ public static SortedMap listCatalogs(Session session, Metad return result.build(); } + public static SortedMap listCatalogsWithConnectorContext(Session session, Metadata metadata, AccessControl accessControl) + { + Map catalogNamesWithConnectorContext = metadata.getCatalogNamesWithConnectorContext(session); + Set allowedCatalogs = accessControl.filterCatalogs(session.getIdentity(), session.getAccessControlContext(), catalogNamesWithConnectorContext.keySet()); + + ImmutableSortedMap.Builder result = ImmutableSortedMap.naturalOrder(); + for (Map.Entry entry : catalogNamesWithConnectorContext.entrySet()) { + if (allowedCatalogs.contains(entry.getKey())) { + result.put(entry); + } + } + return result.build(); + } + public static SortedSet listSchemas(Session session, Metadata metadata, AccessControl accessControl, String catalogName) { Set schemaNames = ImmutableSet.copyOf(metadata.listSchemaNames(session, catalogName)); @@ -75,6 +91,14 @@ public static Set listViews(Session session, Metadata metadata, return accessControl.filterTables(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), prefix.getCatalogName(), tableNames); } + public static Set listMaterializedViews(Session session, Metadata metadata, AccessControl accessControl, QualifiedTablePrefix prefix) + { + Set tableNames = metadata.listMaterializedViews(session, prefix).stream() + .map(MetadataUtil::toSchemaTableName) + .collect(toImmutableSet()); + return accessControl.filterTables(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), prefix.getCatalogName(), tableNames); + } + public static Set listTablePrivileges(Session session, Metadata metadata, AccessControl accessControl, QualifiedTablePrefix prefix) { List grants = metadata.listTablePrivileges(session, prefix); @@ -104,7 +128,15 @@ public static Map> listTableColumns(Sessio ImmutableMap.Builder> result = ImmutableMap.builder(); for (Entry> entry : tableColumns.entrySet()) { if (allowedTables.contains(entry.getKey())) { - result.put(entry); + result.put(entry.getKey(), accessControl.filterColumns( + session.getRequiredTransactionId(), + session.getIdentity(), + session.getAccessControlContext(), + new QualifiedObjectName( + prefix.getCatalogName(), + entry.getKey().getSchemaName(), + entry.getKey().getTableName()), + entry.getValue())); } } return result.build(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java index 9e1f0bcb41651..6c8ef22ebb878 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataManager.java @@ -26,13 +26,14 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.execution.QueryManager; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorResolvedIndex; import com.facebook.presto.spi.ConnectorSession; @@ -44,6 +45,7 @@ import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.MaterializedViewStatus; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; @@ -62,9 +64,12 @@ import com.facebook.presto.spi.connector.ConnectorPartitioningMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.connector.RowChangeParadigm; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.GrantInfo; @@ -78,6 +83,7 @@ import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; +import com.facebook.presto.testing.TestProcedureRegistry; import com.facebook.presto.transaction.TransactionManager; import com.facebook.presto.type.TypeDeserializer; import com.google.common.annotations.VisibleForTesting; @@ -90,8 +96,7 @@ import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; @@ -99,7 +104,6 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -107,9 +111,11 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; import static com.facebook.presto.SystemSessionProperties.isIgnoreStatsCalculatorFailures; +import static com.facebook.presto.common.RuntimeMetricName.GET_IDENTIFIER_NORMALIZATION_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.GET_LAYOUT_TIME_NANOS; import static com.facebook.presto.common.RuntimeMetricName.GET_MATERIALIZED_VIEW_STATUS_TIME_NANOS; import static com.facebook.presto.common.RuntimeUnit.NANO; @@ -155,6 +161,7 @@ public class MetadataManager private final SessionPropertyManager sessionPropertyManager; private final SchemaPropertyManager schemaPropertyManager; private final TablePropertyManager tablePropertyManager; + private final MaterializedViewPropertyManager materializedViewPropertyManager; private final ColumnPropertyManager columnPropertyManager; private final AnalyzePropertyManager analyzePropertyManager; private final TransactionManager transactionManager; @@ -169,9 +176,36 @@ public MetadataManager( SessionPropertyManager sessionPropertyManager, SchemaPropertyManager schemaPropertyManager, TablePropertyManager tablePropertyManager, + MaterializedViewPropertyManager materializedViewPropertyManager, ColumnPropertyManager columnPropertyManager, AnalyzePropertyManager analyzePropertyManager, TransactionManager transactionManager) + { + this( + functionAndTypeManager, + blockEncodingSerde, + sessionPropertyManager, + schemaPropertyManager, + tablePropertyManager, + materializedViewPropertyManager, + columnPropertyManager, + analyzePropertyManager, + transactionManager, + new BuiltInProcedureRegistry(functionAndTypeManager)); + } + + @VisibleForTesting + public MetadataManager( + FunctionAndTypeManager functionAndTypeManager, + BlockEncodingSerde blockEncodingSerde, + SessionPropertyManager sessionPropertyManager, + SchemaPropertyManager schemaPropertyManager, + TablePropertyManager tablePropertyManager, + MaterializedViewPropertyManager materializedViewPropertyManager, + ColumnPropertyManager columnPropertyManager, + AnalyzePropertyManager analyzePropertyManager, + TransactionManager transactionManager, + ProcedureRegistry procedureRegistry) { this( createTestingViewCodec(functionAndTypeManager), @@ -179,10 +213,12 @@ public MetadataManager( sessionPropertyManager, schemaPropertyManager, tablePropertyManager, + materializedViewPropertyManager, columnPropertyManager, analyzePropertyManager, transactionManager, - functionAndTypeManager); + functionAndTypeManager, + procedureRegistry); } @Inject @@ -192,21 +228,24 @@ public MetadataManager( SessionPropertyManager sessionPropertyManager, SchemaPropertyManager schemaPropertyManager, TablePropertyManager tablePropertyManager, + MaterializedViewPropertyManager materializedViewPropertyManager, ColumnPropertyManager columnPropertyManager, AnalyzePropertyManager analyzePropertyManager, TransactionManager transactionManager, - FunctionAndTypeManager functionAndTypeManager) + FunctionAndTypeManager functionAndTypeManager, + ProcedureRegistry procedureRegistry) { this.viewCodec = requireNonNull(viewCodec, "viewCodec is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.schemaPropertyManager = requireNonNull(schemaPropertyManager, "schemaPropertyManager is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); + this.materializedViewPropertyManager = requireNonNull(materializedViewPropertyManager, "materializedViewPropertyManager is null"); this.columnPropertyManager = requireNonNull(columnPropertyManager, "columnPropertyManager is null"); this.analyzePropertyManager = requireNonNull(analyzePropertyManager, "analyzePropertyManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null"); - this.procedures = new ProcedureRegistry(functionAndTypeManager); + this.procedures = requireNonNull(procedureRegistry, "procedureRegistry is null"); verifyComparableOrderableContract(); } @@ -245,16 +284,49 @@ public static MetadataManager createTestMetadataManager(TransactionManager trans { BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); return new MetadataManager( - new FunctionAndTypeManager(transactionManager, blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), blockEncodingManager, createTestingSessionPropertyManager(), new SchemaPropertyManager(), new TablePropertyManager(), + new MaterializedViewPropertyManager(), new ColumnPropertyManager(), new AnalyzePropertyManager(), transactionManager); } + public static MetadataManager createTestMetadataManager(TransactionManager transactionManager, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, ProcedureRegistry procedureRegistry) + { + BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); + return new MetadataManager( + new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + blockEncodingManager, + createTestingSessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new MaterializedViewPropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + transactionManager, + procedureRegistry); + } + + public static MetadataManager createTestMetadataManager(FunctionAndTypeManager functionAndTypeManager) + { + BlockEncodingManager blockEncodingManager = new BlockEncodingManager(); + return new MetadataManager( + functionAndTypeManager, + blockEncodingManager, + createTestingSessionPropertyManager(), + new SchemaPropertyManager(), + new TablePropertyManager(), + new MaterializedViewPropertyManager(), + new ColumnPropertyManager(), + new AnalyzePropertyManager(), + functionAndTypeManager.getTransactionManager(), + new TestProcedureRegistry()); + } + @Override public final void verifyComparableOrderableContract() { @@ -304,6 +376,12 @@ public void registerBuiltInFunctions(List functionInfos) functionAndTypeManager.registerBuiltInFunctions(functionInfos); } + @Override + public void registerConnectorFunctions(String catalogName, List functionInfos) + { + functionAndTypeManager.registerConnectorFunctions(catalogName, functionInfos); + } + @Override public List listSchemaNames(Session session, String catalogName) { @@ -316,7 +394,7 @@ public List listSchemaNames(Session session, String catalogName) for (ConnectorId connectorId : catalogMetadata.listConnectorIds()) { ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); metadata.listSchemaNames(connectorSession).stream() - .map(schema -> schema.toLowerCase(Locale.ENGLISH)) + .map(schema -> normalizeIdentifier(session, connectorId.getCatalogName(), schema)) .forEach(schemaNames::add); } } @@ -349,7 +427,7 @@ public Optional getTableHandleForStatisticsCollection(Session sessi ConnectorId connectorId = catalogMetadata.getConnectorId(session, table); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - ConnectorTableHandle tableHandle = metadata.getTableHandleForStatisticsCollection(session.toConnectorSession(connectorId), toSchemaTableName(table), analyzeProperties); + ConnectorTableHandle tableHandle = metadata.getTableHandleForStatisticsCollection(session.toConnectorSession(connectorId), toSchemaTableName(table.getSchemaName(), table.getObjectName()), analyzeProperties); if (tableHandle != null) { return Optional.of(new TableHandle( connectorId, @@ -381,7 +459,7 @@ public Optional getSystemTable(Session session, QualifiedObjectName ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - return metadata.getSystemTable(session.toConnectorSession(connectorId), toSchemaTableName(tableName)); + return metadata.getSystemTable(session.toConnectorSession(connectorId), toSchemaTableName(tableName.getSchemaName(), tableName.getObjectName())); } return Optional.empty(); } @@ -390,7 +468,6 @@ public Optional getSystemTable(Session session, QualifiedObjectName public TableLayoutResult getLayout(Session session, TableHandle table, Constraint constraint, Optional> desiredColumns) { long startTime = System.nanoTime(); - checkArgument(!constraint.getSummary().isNone(), "Cannot get Layout if constraint is none"); ConnectorId connectorId = table.getConnectorId(); ConnectorTableHandle connectorTable = table.getConnectorHandle(); @@ -530,7 +607,7 @@ public Map getColumnHandles(Session session, TableHandle t ImmutableMap.Builder map = ImmutableMap.builder(); for (Entry mapEntry : handles.entrySet()) { - map.put(mapEntry.getKey().toLowerCase(ENGLISH), mapEntry.getValue()); + map.put(normalizeIdentifier(session, connectorId.getCatalogName(), mapEntry.getKey()), mapEntry.getValue()); } return map.build(); } @@ -543,7 +620,10 @@ public ColumnMetadata getColumnMetadata(Session session, TableHandle tableHandle ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - return metadata.getColumnMetadata(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), columnHandle); + ColumnMetadata columnMetadata = metadata.getColumnMetadata(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), columnHandle); + ColumnMetadata normalizedColumnMetadata = normalizedColumnMetadata(session, connectorId.getCatalogName(), columnMetadata); + + return normalizedColumnMetadata; } @Override @@ -564,12 +644,15 @@ public List listTables(Session session, QualifiedTablePrefi Set tables = new LinkedHashSet<>(); if (catalog.isPresent()) { CatalogMetadata catalogMetadata = catalog.get(); - for (ConnectorId connectorId : catalogMetadata.listConnectorIds()) { ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); ConnectorSession connectorSession = session.toConnectorSession(connectorId); metadata.listTables(connectorSession, prefix.getSchemaName()).stream() .map(convertFromSchemaTableName(prefix.getCatalogName())) + .map(name -> new QualifiedObjectName( + name.getCatalogName(), + normalizeIdentifier(session, connectorId.getCatalogName(), name.getSchemaName()), + normalizeIdentifier(session, connectorId.getCatalogName(), name.getObjectName()))) .filter(prefix::matches) .forEach(tables::add); } @@ -597,7 +680,12 @@ public Map> listTableColumns(Session s prefix.getCatalogName(), entry.getKey().getSchemaName(), entry.getKey().getTableName()); - tableColumns.put(tableName, entry.getValue()); + + ImmutableList.Builder normalizedColumns = ImmutableList.builder(); + for (ColumnMetadata column : entry.getValue()) { + normalizedColumns.add(normalizedColumnMetadata(session, connectorId.getCatalogName(), column)); + } + tableColumns.put(tableName, normalizedColumns.build()); } // if table and view names overlap, the view wins @@ -610,7 +698,7 @@ public Map> listTableColumns(Session s ImmutableList.Builder columns = ImmutableList.builder(); for (ViewColumn column : deserializeView(entry.getValue().getViewData()).getColumns()) { columns.add(ColumnMetadata.builder() - .setName(column.getName()) + .setName(normalizeIdentifier(session, connectorId.getCatalogName(), column.getName())) .setType(column.getType()) .build()); } @@ -664,9 +752,13 @@ public TableHandle createTemporaryTable(Session session, String catalogName, Lis CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, catalogName); ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); + List normalizedColumns = columns.stream() + .map(column -> normalizedColumnMetadata(session, connectorId.getCatalogName(), column)) + .collect(Collectors.toList()); + ConnectorTableHandle connectorTableHandle = metadata.createTemporaryTable( session.toConnectorSession(connectorId), - columns, + normalizedColumns, partitioningMetadata.map(partitioning -> createConnectorPartitioningMetadata(connectorId, partitioning))); return new TableHandle(connectorId, connectorTableHandle, catalogMetadata.getTransactionHandleFor(connectorId), Optional.empty()); } @@ -694,7 +786,9 @@ public void renameTable(Session session, TableHandle tableHandle, QualifiedObjec } ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.renameTable(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), toSchemaTableName(newTableName)); + + metadata.renameTable(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), + toSchemaTableName(newTableName.getSchemaName(), newTableName.getObjectName())); } @Override @@ -710,7 +804,7 @@ public void renameColumn(Session session, TableHandle tableHandle, ColumnHandle { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); - metadata.renameColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), source, target.toLowerCase(ENGLISH)); + metadata.renameColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), source, normalizeIdentifier(session, connectorId.getCatalogName(), target)); } @Override @@ -883,19 +977,27 @@ public Optional finishInsert(Session session, InsertTab } @Override - public ColumnHandle getDeleteRowIdColumnHandle(Session session, TableHandle tableHandle) + public Optional getDeleteRowIdColumn(Session session, TableHandle tableHandle) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + return metadata.getDeleteRowIdColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle()); + } + + @Override + public Optional getUpdateRowIdColumn(Session session, TableHandle tableHandle, List updatedColumns) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - return metadata.getDeleteRowIdColumnHandle(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle()); + return metadata.getUpdateRowIdColumn(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), updatedColumns); } @Override - public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle, List updatedColumns) + public ColumnHandle getMergeTargetTableRowIdColumnHandle(Session session, TableHandle tableHandle) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - return metadata.getUpdateRowIdColumnHandle(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), updatedColumns); + return metadata.getMergeTargetTableRowIdColumnHandle(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle()); } @Override @@ -930,11 +1032,47 @@ public DeleteTableHandle beginDelete(Session session, TableHandle tableHandle) } @Override - public void finishDelete(Session session, DeleteTableHandle tableHandle, Collection fragments) + public Optional finishDeleteWithOutput(Session session, DeleteTableHandle tableHandle, Collection fragments) { ConnectorId connectorId = tableHandle.getConnectorId(); ConnectorMetadata metadata = getMetadata(session, connectorId); - metadata.finishDelete(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), fragments); + return metadata.finishDeleteWithOutput(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), fragments); + } + + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, + TableHandle tableHandle, Object[] arguments, + boolean sourceTableEliminated) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, connectorId); + + ConnectorTableLayoutHandle layout; + if (!tableHandle.getLayout().isPresent()) { + TableLayoutResult result = getLayout(session, tableHandle, sourceTableEliminated ? Constraint.alwaysFalse() : Constraint.alwaysTrue(), Optional.empty()); + layout = result.getLayout().getLayoutHandle(); + } + else { + layout = tableHandle.getLayout().get(); + } + + ConnectorDistributedProcedureHandle procedureHandle = catalogMetadata.getMetadata().beginCallDistributedProcedure( + session.toConnectorSession(connectorId), + procedureName, + layout, + arguments); + return new DistributedProcedureHandle( + tableHandle.getConnectorId(), + tableHandle.getTransaction(), + procedureHandle); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) + { + ConnectorId connectorId = procedureHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + metadata.finishCallDistributedProcedure(session.toConnectorSession(connectorId), procedureHandle.getConnectorHandle(), procedureName, fragments); } @Override @@ -954,6 +1092,35 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection fragments, + Collection computedStatistics) + { + ConnectorId connectorId = mergeHandle.getTableHandle().getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + metadata.finishMerge(session.toConnectorSession(connectorId), mergeHandle.getConnectorMergeTableHandle(), fragments, computedStatistics); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { @@ -966,6 +1133,12 @@ public Map getCatalogNames(Session session) return transactionManager.getCatalogNames(session.getRequiredTransactionId()); } + @Override + public Map getCatalogNamesWithConnectorContext(Session session) + { + return transactionManager.getCatalogNamesWithConnectorContext(session.getRequiredTransactionId()); + } + @Override public List listViews(Session session, QualifiedTablePrefix prefix) { @@ -982,6 +1155,10 @@ public List listViews(Session session, QualifiedTablePrefix ConnectorSession connectorSession = session.toConnectorSession(connectorId); metadata.listViews(connectorSession, prefix.getSchemaName()).stream() .map(convertFromSchemaTableName(prefix.getCatalogName())) + .map(name -> new QualifiedObjectName( + name.getCatalogName(), + normalizeIdentifier(session, connectorId.getCatalogName(), name.getSchemaName()), + normalizeIdentifier(session, connectorId.getCatalogName(), name.getObjectName()))) .filter(prefix::matches) .forEach(views::add); } @@ -1033,7 +1210,9 @@ public void renameView(Session session, QualifiedObjectName source, QualifiedObj ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.renameView(session.toConnectorSession(connectorId), toSchemaTableName(source), toSchemaTableName(target)); + metadata.renameView(session.toConnectorSession(connectorId), + toSchemaTableName(source.getSchemaName(), source.getObjectName()), + toSchemaTableName(target.getSchemaName(), target.getObjectName())); } @Override @@ -1043,7 +1222,7 @@ public void dropView(Session session, QualifiedObjectName viewName) ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.dropView(session.toConnectorSession(connectorId), toSchemaTableName(viewName)); + metadata.dropView(session.toConnectorSession(connectorId), toSchemaTableName(viewName.getSchemaName(), viewName.getObjectName())); } @Override @@ -1063,10 +1242,102 @@ public void dropMaterializedView(Session session, QualifiedObjectName viewName) ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.dropMaterializedView(session.toConnectorSession(connectorId), toSchemaTableName(viewName)); + metadata.dropMaterializedView(session.toConnectorSession(connectorId), toSchemaTableName(viewName.getSchemaName(), viewName.getObjectName())); } - private MaterializedViewStatus getMaterializedViewStatus(Session session, QualifiedObjectName materializedViewName, TupleDomain baseQueryDomain) + @Override + public List listMaterializedViews(Session session, QualifiedTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + + Optional catalog = getOptionalCatalogMetadata(session, transactionManager, prefix.getCatalogName()); + Set materializedViews = new LinkedHashSet<>(); + if (catalog.isPresent()) { + CatalogMetadata catalogMetadata = catalog.get(); + ConnectorId connectorId = catalogMetadata.getConnectorId(); + ConnectorMetadata metadata = catalogMetadata.getMetadata(); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + + List viewNames; + if (prefix.getSchemaName().isPresent()) { + viewNames = metadata.listMaterializedViews(connectorSession, prefix.getSchemaName().get()); + } + else { + viewNames = new ArrayList<>(); + for (String schemaName : metadata.listSchemaNames(connectorSession)) { + viewNames.addAll(metadata.listMaterializedViews(connectorSession, schemaName)); + } + } + + // Convert to QualifiedObjectName and filter by prefix + for (SchemaTableName viewName : viewNames) { + QualifiedObjectName qualifiedName = new QualifiedObjectName( + prefix.getCatalogName(), + viewName.getSchemaName(), + viewName.getTableName()); + if (prefix.matches(qualifiedName)) { + materializedViews.add(qualifiedName); + } + } + } + + return ImmutableList.copyOf(materializedViews); + } + + @Override + public Map getMaterializedViews( + Session session, + QualifiedTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + + Optional catalog = getOptionalCatalogMetadata(session, transactionManager, prefix.getCatalogName()); + Map views = new LinkedHashMap<>(); + + if (catalog.isPresent()) { + CatalogMetadata catalogMetadata = catalog.get(); + ConnectorId connectorId = catalogMetadata.getConnectorId(); + ConnectorMetadata metadata = catalogMetadata.getMetadata(); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + + List viewNames; + if (prefix.getSchemaName().isPresent()) { + viewNames = metadata.listMaterializedViews(connectorSession, prefix.getSchemaName().get()); + + if (prefix.getTableName().isPresent()) { + String tableName = prefix.getTableName().get(); + viewNames = viewNames.stream() + .filter(name -> name.getTableName().equals(tableName)) + .collect(toImmutableList()); + } + } + else { + viewNames = new ArrayList<>(); + for (String schemaName : metadata.listSchemaNames(connectorSession)) { + viewNames.addAll(metadata.listMaterializedViews(connectorSession, schemaName)); + } + } + + // Bulk retrieve definitions + if (!viewNames.isEmpty()) { + Map definitions = metadata.getMaterializedViews(connectorSession, viewNames); + + definitions.forEach((viewName, definition) -> { + views.put( + new QualifiedObjectName( + prefix.getCatalogName(), + viewName.getSchemaName(), + viewName.getTableName()), + definition); + }); + } + } + + return ImmutableMap.copyOf(views); + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus(Session session, QualifiedObjectName materializedViewName, TupleDomain baseQueryDomain) { Optional materializedViewHandle = getOptionalTableHandle(session, transactionManager, materializedViewName, Optional.empty()); @@ -1075,7 +1346,9 @@ private MaterializedViewStatus getMaterializedViewStatus(Session session, Qualif return session.getRuntimeStats().recordWallTime( GET_MATERIALIZED_VIEW_STATUS_TIME_NANOS, - () -> metadata.getMaterializedViewStatus(session.toConnectorSession(connectorId), toSchemaTableName(materializedViewName), baseQueryDomain)); + () -> metadata.getMaterializedViewStatus(session.toConnectorSession(connectorId), + toSchemaTableName(materializedViewName.getSchemaName(), materializedViewName.getObjectName()), + baseQueryDomain)); } @Override @@ -1106,7 +1379,8 @@ public List getReferencedMaterializedViews(Session session, if (catalog.isPresent()) { ConnectorMetadata metadata = catalog.get().getMetadata(); ConnectorSession connectorSession = session.toConnectorSession(catalog.get().getConnectorId()); - Optional> materializedViews = metadata.getReferencedMaterializedViews(connectorSession, toSchemaTableName(tableName)); + + Optional> materializedViews = metadata.getReferencedMaterializedViews(connectorSession, toSchemaTableName(tableName.getSchemaName(), tableName.getObjectName())); if (materializedViews.isPresent()) { return materializedViews.get().stream().map(convertFromSchemaTableName(tableName.getCatalogName())).collect(toImmutableList()); } @@ -1227,7 +1501,7 @@ public void grantTablePrivileges(Session session, QualifiedObjectName tableName, ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.grantTablePrivileges(session.toConnectorSession(connectorId), toSchemaTableName(tableName), privileges, grantee, grantOption); + metadata.grantTablePrivileges(session.toConnectorSession(connectorId), toSchemaTableName(tableName.getSchemaName(), tableName.getObjectName()), privileges, grantee, grantOption); } @Override @@ -1237,7 +1511,7 @@ public void revokeTablePrivileges(Session session, QualifiedObjectName tableName ConnectorId connectorId = catalogMetadata.getConnectorId(); ConnectorMetadata metadata = catalogMetadata.getMetadata(); - metadata.revokeTablePrivileges(session.toConnectorSession(connectorId), toSchemaTableName(tableName), privileges, grantee, grantOption); + metadata.revokeTablePrivileges(session.toConnectorSession(connectorId), toSchemaTableName(tableName.getSchemaName(), tableName.getObjectName()), privileges, grantee, grantOption); } @Override @@ -1293,28 +1567,6 @@ public ListenableFuture commitPageSinkAsync(Session session, DeleteTableHa return toListenableFuture(metadata.commitPageSinkAsync(connectorSession, tableHandle.getConnectorHandle(), fragments)); } - @Override - public MetadataUpdates getMetadataUpdateResults(Session session, QueryManager queryManager, MetadataUpdates metadataUpdateRequests, QueryId queryId) - { - ConnectorId connectorId = metadataUpdateRequests.getConnectorId(); - ConnectorMetadata metadata = getCatalogMetadata(session, connectorId).getMetadata(); - - if (queryManager != null && !queriesWithRegisteredCallbacks.contains(queryId)) { - // This is the first time we are getting requests for queryId. - // Register a callback, so the we do the cleanup when query fails/finishes. - queryManager.addStateChangeListener(queryId, state -> { - if (state.isDone()) { - metadata.doMetadataUpdateCleanup(queryId); - queriesWithRegisteredCallbacks.remove(queryId); - } - }); - queriesWithRegisteredCallbacks.add(queryId); - } - - List metadataResults = metadata.getMetadataUpdateResults(metadataUpdateRequests.getMetadataUpdates(), queryId); - return new MetadataUpdates(connectorId, metadataResults); - } - @Override public FunctionAndTypeManager getFunctionAndTypeManager() { @@ -1352,6 +1604,12 @@ public TablePropertyManager getTablePropertyManager() return tablePropertyManager; } + @Override + public MaterializedViewPropertyManager getMaterializedViewPropertyManager() + { + return materializedViewPropertyManager; + } + @Override public ColumnPropertyManager getColumnPropertyManager() { @@ -1423,8 +1681,8 @@ public Optional getView(QualifiedObjectName viewName) Map views = metadata.getViews( session.toConnectorSession(connectorId), - toSchemaTableName(viewName).toSchemaTablePrefix()); - ConnectorViewDefinition view = views.get(toSchemaTableName(viewName)); + toSchemaTableName(viewName.getSchemaName(), viewName.getObjectName()).toSchemaTablePrefix()); + ConnectorViewDefinition view = views.get(toSchemaTableName(viewName.getSchemaName(), viewName.getObjectName())); if (view != null) { ViewDefinition definition = deserializeView(view.getViewData()); if (view.getOwner().isPresent() && !definition.isRunAsInvoker()) { @@ -1445,7 +1703,7 @@ public Optional getMaterializedView(QualifiedObjectN ConnectorId connectorId = catalogMetadata.getConnectorId(session, viewName); ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); - return metadata.getMaterializedView(session.toConnectorSession(connectorId), toSchemaTableName(viewName)); + return metadata.getMaterializedView(session.toConnectorSession(connectorId), toSchemaTableName(viewName.getSchemaName(), viewName.getObjectName())); } return Optional.empty(); } @@ -1480,6 +1738,22 @@ public TableLayoutFilterCoverage getTableLayoutFilterCoverage(Session session, T return metadata.getTableLayoutFilterCoverage(tableHandle.getLayout().get(), relevantPartitionColumns); } + @Override + public void dropBranch(Session session, TableHandle tableHandle, String branchName, boolean branchExists) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); + metadata.dropBranch(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), branchName, branchExists); + } + + @Override + public void dropTag(Session session, TableHandle tableHandle, String tagName, boolean tagExists) + { + ConnectorId connectorId = tableHandle.getConnectorId(); + ConnectorMetadata metadata = getMetadataForWrite(session, connectorId); + metadata.dropTag(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tagName, tagExists); + } + @Override public void dropConstraint(Session session, TableHandle tableHandle, Optional constraintName, Optional columnName) { @@ -1496,6 +1770,46 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra metadata.addConstraint(session.toConnectorSession(connectorId), tableHandle.getConnectorHandle(), tableConstraint); } + @Override + public String normalizeIdentifier(Session session, String catalogName, String identifier) + { + long startTime = System.nanoTime(); + String normalizedString = identifier.toLowerCase(ENGLISH); + Optional catalogMetadata = getOptionalCatalogMetadata(session, transactionManager, catalogName); + if (catalogMetadata.isPresent()) { + ConnectorId connectorId = catalogMetadata.get().getConnectorId(); + ConnectorMetadata metadata = catalogMetadata.get().getMetadataFor(connectorId); + normalizedString = metadata.normalizeIdentifier(session.toConnectorSession(connectorId), identifier); + } + session.getRuntimeStats().addMetricValue(GET_IDENTIFIER_NORMALIZATION_TIME_NANOS, NANO, System.nanoTime() - startTime); + return normalizedString; + } + + private ColumnMetadata normalizedColumnMetadata(Session session, String catalogName, ColumnMetadata columnMetadata) + { + return ColumnMetadata.builder() + .setName(normalizeIdentifier(session, catalogName, columnMetadata.getName())) + .setType(columnMetadata.getType()) + .setHidden(columnMetadata.isHidden()) + .setNullable(columnMetadata.isNullable()) + .setComment(columnMetadata.getComment().orElse(null)) + .setProperties(columnMetadata.getProperties()) + .setExtraInfo(columnMetadata.getExtraInfo().orElse(null)) + .build(); + } + + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + ConnectorId connectorId = handle.getConnectorId(); + ConnectorMetadata metadata = getMetadata(session, connectorId); + + return metadata.applyTableFunction(session.toConnectorSession(connectorId), handle.getFunctionHandle()) + .map(result -> new TableFunctionApplicationResult<>( + new TableHandle(connectorId, result.getTableHandle(), handle.getTransactionHandle(), Optional.empty()), + result.getColumnHandles())); + } + private ViewDefinition deserializeView(String data) { try { diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdates.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdates.java deleted file mode 100644 index 2e7ea6ec167eb..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUpdates.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.metadata; - -import com.facebook.drift.annotations.ThriftConstructor; -import com.facebook.drift.annotations.ThriftField; -import com.facebook.drift.annotations.ThriftStruct; -import com.facebook.presto.server.thrift.Any; -import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; - -import java.util.List; - -import static java.util.Objects.requireNonNull; - -@ThriftStruct -public class MetadataUpdates -{ - public static final MetadataUpdates DEFAULT_METADATA_UPDATES = new MetadataUpdates(null, ImmutableList.of()); - - private final ConnectorId connectorId; - private List metadataUpdates; - private List metadataUpdatesAny; - private boolean dummy; - - @JsonCreator - public MetadataUpdates( - @JsonProperty("connectorId") @Nullable ConnectorId connectorId, - @JsonProperty("metadataUpdates") List metadataUpdates) - { - this.connectorId = connectorId; - this.metadataUpdates = ImmutableList.copyOf(requireNonNull(metadataUpdates, "metadataUpdates is null")); - } - - /** - * Thrift constructor - * - * @param connectorId id of the connector - * @param metadataUpdatesAny Any representation of ConnectorMetadataUpdateHandle - * @param dummy dummy boolean for disambiguating between the JSON constructor - */ - @ThriftConstructor - public MetadataUpdates(@Nullable ConnectorId connectorId, List metadataUpdatesAny, boolean dummy) - { - this.connectorId = connectorId; - this.metadataUpdatesAny = ImmutableList.copyOf(requireNonNull(metadataUpdatesAny, "metadataUpdatesAny is null")); - this.dummy = dummy; - } - - @JsonProperty - @ThriftField(1) - public ConnectorId getConnectorId() - { - return connectorId; - } - - @JsonProperty - public List getMetadataUpdates() - { - return metadataUpdates; - } - - @ThriftField(2) - public List getMetadataUpdatesAny() - { - return metadataUpdatesAny; - } - - @ThriftField(3) - public boolean getDummy() - { - return dummy; - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java index d33ed25c2cbf8..d970d12216012 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/MetadataUtil.java @@ -21,14 +21,15 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; import com.facebook.presto.spi.ConnectorTableMetadata; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.analyzer.utils.MetadataUtils; import com.facebook.presto.sql.tree.GrantorSpecification; +import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.PrincipalSpecification; import com.facebook.presto.sql.tree.QualifiedName; @@ -36,12 +37,12 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Lists; import java.util.List; import java.util.Optional; +import java.util.function.BiFunction; -import static com.facebook.presto.spi.StandardErrorCode.SYNTAX_ERROR; +import static com.facebook.presto.connector.informationSchema.InformationSchemaMetadata.INFORMATION_SCHEMA; import static com.facebook.presto.spi.security.PrincipalType.ROLE; import static com.facebook.presto.spi.security.PrincipalType.USER; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; @@ -60,34 +61,22 @@ private MetadataUtil() {} public static final String likeTableCatalogError = "LIKE table catalog '%s' does not exist"; public static final String catalogError = "Catalog %s does not exist"; public static final String targetTableCatalogError = "Target catalog '%s' does not exist"; - - public static void checkTableName(String catalogName, Optional schemaName, Optional tableName) - { - checkCatalogName(catalogName); - schemaName.ifPresent(name -> checkLowerCase(name, "schemaName")); - tableName.ifPresent(name -> checkLowerCase(name, "tableName")); - - checkArgument(schemaName.isPresent() || !tableName.isPresent(), "tableName specified but schemaName is missing"); - } - public static String checkCatalogName(String catalogName) { return checkLowerCase(catalogName, "catalogName"); } - public static String checkSchemaName(String schemaName) - { - return checkLowerCase(schemaName, "schemaName"); - } - - public static String checkTableName(String tableName) + public static SchemaTableName toSchemaTableName(QualifiedObjectName qualifiedObjectName) { - return checkLowerCase(tableName, "tableName"); + return new SchemaTableName(qualifiedObjectName.getSchemaName(), qualifiedObjectName.getObjectName()); } - public static SchemaTableName toSchemaTableName(QualifiedObjectName qualifiedObjectName) + public static SchemaTableName toSchemaTableName(String schemaName, String tableName) { - return new SchemaTableName(qualifiedObjectName.getSchemaName(), qualifiedObjectName.getObjectName()); + if (schemaName.equalsIgnoreCase(INFORMATION_SCHEMA)) { + return new SchemaTableName(schemaName.toLowerCase(ENGLISH), tableName.toLowerCase(ENGLISH)); + } + return new SchemaTableName(schemaName, tableName); } public static ConnectorId getConnectorIdOrThrow(Session session, Metadata metadata, String catalogName) @@ -132,20 +121,23 @@ public static String createCatalogName(Session session, Node node) return sessionCatalog.get(); } - public static CatalogSchemaName createCatalogSchemaName(Session session, Node node, Optional schema) + public static CatalogSchemaName createCatalogSchemaName(Session session, Node node, Optional schema, Metadata metadata) { String catalogName = session.getCatalog().orElse(null); String schemaName = session.getSchema().orElse(null); if (schema.isPresent()) { - List parts = schema.get().getParts(); + List parts = schema.get().getOriginalParts(); if (parts.size() > 2) { throw new SemanticException(INVALID_SCHEMA_NAME, node, "Too many parts in schema name: %s", schema.get()); } if (parts.size() == 2) { - catalogName = parts.get(0); + catalogName = parts.get(0).getValue(); + } + if (catalogName == null) { + throw new SemanticException(CATALOG_NOT_SPECIFIED, node, "Catalog must be specified when session catalog is not set"); } - schemaName = schema.get().getSuffix(); + schemaName = metadata.normalizeIdentifier(session, catalogName, schema.get().getOriginalSuffix().getValue()); } if (catalogName == null) { @@ -158,27 +150,10 @@ public static CatalogSchemaName createCatalogSchemaName(Session session, Node no return new CatalogSchemaName(catalogName, schemaName); } - public static QualifiedObjectName createQualifiedObjectName(Session session, Node node, QualifiedName name) - { - requireNonNull(session, "session is null"); - requireNonNull(name, "name is null"); - if (name.getParts().size() > 3) { - throw new PrestoException(SYNTAX_ERROR, format("Too many dots in table name: %s", name)); - } - - List parts = Lists.reverse(name.getParts()); - String objectName = parts.get(0); - String schemaName = (parts.size() > 1) ? parts.get(1) : session.getSchema().orElseThrow(() -> - new SemanticException(SCHEMA_NOT_SPECIFIED, node, "Schema must be specified when session schema is not set")); - String catalogName = (parts.size() > 2) ? parts.get(2) : session.getCatalog().orElseThrow(() -> - new SemanticException(CATALOG_NOT_SPECIFIED, node, "Catalog must be specified when session catalog is not set")); - - return new QualifiedObjectName(catalogName, schemaName, objectName); - } - - public static QualifiedName createQualifiedName(QualifiedObjectName name) + public static QualifiedObjectName createQualifiedObjectName(Session session, Node node, QualifiedName name, Metadata metadata) { - return QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getObjectName()); + BiFunction normalizer = (catalogName, objectName) -> metadata.normalizeIdentifier(session, catalogName, objectName); + return MetadataUtils.createQualifiedObjectName(session.getCatalog(), session.getSchema(), node, name, normalizer); } public static Optional getOptionalCatalogMetadata(Session session, TransactionManager transactionManager, String catalogName) @@ -197,6 +172,7 @@ public static Optional getOptionalTableHandle(Session session, Tran ConnectorMetadata metadata = catalogMetadata.getMetadataFor(connectorId); ConnectorTableHandle tableHandle; + tableHandle = tableVersion .map(expression -> metadata.getTableHandle(session.toConnectorSession(connectorId), toSchemaTableName(table), Optional.of(expression))) .orElseGet(() -> metadata.getTableHandle(session.toConnectorSession(connectorId), toSchemaTableName(table))); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandle.java index 8d3d3d2f6dc87..f88dabb0c243b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandle.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -23,6 +26,7 @@ import static java.util.Objects.requireNonNull; +@ThriftStruct public final class OutputTableHandle { private final ConnectorId connectorId; @@ -30,6 +34,7 @@ public final class OutputTableHandle private final ConnectorOutputTableHandle connectorHandle; @JsonCreator + @ThriftConstructor public OutputTableHandle( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, @@ -41,18 +46,21 @@ public OutputTableHandle( } @JsonProperty + @ThriftField(1) public ConnectorId getConnectorId() { return connectorId; } @JsonProperty + @ThriftField(2) public ConnectorTransactionHandle getTransactionHandle() { return transactionHandle; } @JsonProperty + @ThriftField(3) public ConnectorOutputTableHandle getConnectorHandle() { return connectorHandle; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java index ddd13969cd2bd..c2701ef082f90 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/OutputTableHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class OutputTableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public OutputTableHandleJacksonModule(HandleResolver handleResolver) + public OutputTableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorOutputTableHandle.class, handleResolver::getId, - handleResolver::getOutputTableHandleClass); + handleResolver::getOutputTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorOutputTableHandleCodec)); + } + + public OutputTableHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorOutputTableHandle.class, + handleResolver::getId, + handleResolver::getOutputTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java index 6331b837977f7..f26221b19040c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/PartitioningHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class PartitioningHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public PartitioningHandleJacksonModule(HandleResolver handleResolver) + public PartitioningHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorPartitioningHandle.class, handleResolver::getId, - handleResolver::getPartitioningHandleClass); + handleResolver::getPartitioningHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorPartitioningHandleCodec)); + } + + public PartitioningHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorPartitioningHandle.class, + handleResolver::getId, + handleResolver::getPartitioningHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/QualifiedTablePrefix.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/QualifiedTablePrefix.java index c473aff3ae3eb..3ebd8739be989 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/QualifiedTablePrefix.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/QualifiedTablePrefix.java @@ -20,15 +20,12 @@ import com.facebook.presto.spi.SchemaTablePrefix; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Objects; import java.util.Optional; import static com.facebook.presto.metadata.MetadataUtil.checkCatalogName; -import static com.facebook.presto.metadata.MetadataUtil.checkSchemaName; -import static com.facebook.presto.metadata.MetadataUtil.checkTableName; @Immutable @ThriftStruct @@ -48,15 +45,15 @@ public QualifiedTablePrefix(String catalogName) public QualifiedTablePrefix(String catalogName, String schemaName) { this.catalogName = checkCatalogName(catalogName); - this.schemaName = Optional.of(checkSchemaName(schemaName)); + this.schemaName = Optional.of(schemaName); this.tableName = Optional.empty(); } public QualifiedTablePrefix(String catalogName, String schemaName, String tableName) { this.catalogName = checkCatalogName(catalogName); - this.schemaName = Optional.of(checkSchemaName(schemaName)); - this.tableName = Optional.of(checkTableName(tableName)); + this.schemaName = Optional.of(schemaName); + this.tableName = Optional.of(tableName); } @JsonCreator @@ -66,7 +63,7 @@ public QualifiedTablePrefix( @JsonProperty("schemaName") Optional schemaName, @JsonProperty("tableName") Optional tableName) { - checkTableName(catalogName, schemaName, tableName); + checkCatalogName(catalogName); this.catalogName = catalogName; this.schemaName = schemaName; this.tableName = tableName; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteNodeStats.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteNodeStats.java new file mode 100644 index 0000000000000..a9645cd524491 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteNodeStats.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.NodeStats; + +import java.util.Optional; + +/** + * Interface for retrieving statistics from remote nodes in a Presto cluster. + *

+ * This interface provides a mechanism to asynchronously fetch and cache node statistics + * from remote Presto worker nodes. Implementations handle the communication protocol + * details (HTTP, Thrift, etc.) and provide a unified way to access node health and + * performance metrics. + *

+ * The interface supports lazy loading and caching of node statistics to minimize + * network overhead while ensuring that cluster management components have access + * to current node state information for scheduling and health monitoring decisions. + */ +public interface RemoteNodeStats +{ + /** + * Returns the cached node statistics if available. + *

+ * This method returns the most recently fetched statistics for the remote node. + * If no statistics have been fetched yet or if the last fetch failed, this + * method returns an empty Optional. + * + * @return an Optional containing the node statistics if available, empty otherwise + */ + Optional getNodeStats(); + + /** + * Triggers an asynchronous refresh of the node statistics. + *

+ * This method initiates a background request to fetch the latest statistics + * from the remote node. The operation is non-blocking and the results will + * be available through subsequent calls to {@link #getNodeStats()}. + *

+ * Implementations should handle network failures gracefully and avoid + * overwhelming the remote node with excessive requests. + */ + void asyncRefresh(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteTransactionHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteTransactionHandle.java index 2f89509efa0ee..9f2e60061dcdf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteTransactionHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/RemoteTransactionHandle.java @@ -13,14 +13,18 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +@ThriftStruct public class RemoteTransactionHandle implements ConnectorTransactionHandle { @JsonCreator + @ThriftConstructor public RemoteTransactionHandle() { } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java index 3e9e23e1a6e9a..d013edafb92ac 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.block.BlockBuilder; @@ -44,14 +45,16 @@ import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Parameter; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; -import javax.annotation.Nullable; -import javax.inject.Inject; - +import java.io.File; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -59,14 +62,18 @@ import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; import static com.facebook.presto.common.type.TypeUtils.writeNativeValue; import static com.facebook.presto.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; +import static com.facebook.presto.util.PropertiesUtil.loadProperties; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.io.Files.getNameWithoutExtension; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.HOURS; @@ -74,6 +81,9 @@ public final class SessionPropertyManager { private static final JsonCodecFactory JSON_CODEC_FACTORY = new JsonCodecFactory(); + private static final Logger log = Logger.get(SessionPropertyManager.class); + private static final String SESSION_PROPERTY_PROVIDER_NAME = "session-property-provider.name"; + private final ConcurrentMap> systemSessionProperties = new ConcurrentHashMap<>(); private final ConcurrentMap>> connectorSessionProperties = new ConcurrentHashMap<>(); private final Map workerSessionPropertyProviders; @@ -81,28 +91,33 @@ public final class SessionPropertyManager private final Supplier>> memoizedWorkerSessionProperties; private final Optional nodeManager; private final Optional functionAndTypeManager; + private final File configDir; + private final AtomicBoolean sessionPropertyProvidersLoading = new AtomicBoolean(); @Inject public SessionPropertyManager( SystemSessionProperties systemSessionProperties, Map workerSessionPropertyProviders, FunctionAndTypeManager functionAndTypeManager, - NodeManager nodeManager) + NodeManager nodeManager, + SessionPropertyProviderConfig config) { - this(systemSessionProperties.getSessionProperties(), workerSessionPropertyProviders, Optional.ofNullable(functionAndTypeManager), Optional.ofNullable(nodeManager)); + this(systemSessionProperties.getSessionProperties(), workerSessionPropertyProviders, Optional.ofNullable(functionAndTypeManager), Optional.ofNullable(nodeManager), config); } public SessionPropertyManager( List> sessionProperties, Map workerSessionPropertyProviders, Optional functionAndTypeManager, - Optional nodeManager) + Optional nodeManager, + SessionPropertyProviderConfig config) { this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); this.memoizedWorkerSessionProperties = Suppliers.memoizeWithExpiration(this::getWorkerSessionProperties, 1, HOURS); this.workerSessionPropertyProviders = new ConcurrentHashMap<>(workerSessionPropertyProviders); + this.configDir = requireNonNull(config, "config is null").getSessionPropertyProvidersConfigurationDir(); addSystemSessionProperties(sessionProperties); } @@ -125,34 +140,67 @@ public static SessionPropertyManager createTestingSessionPropertyManager( List> sessionProperties, JavaFeaturesConfig javaFeaturesConfig, NodeSpillConfig nodeSpillConfig) + { + return createTestingSessionPropertyManager(sessionProperties, new FeaturesConfig(), javaFeaturesConfig, nodeSpillConfig); + } + + public static SessionPropertyManager createTestingSessionPropertyManager( + List> sessionProperties, + FeaturesConfig featuresConfig, + JavaFeaturesConfig javaFeaturesConfig, + NodeSpillConfig nodeSpillConfig) { return new SessionPropertyManager( sessionProperties, ImmutableMap.of( "java-worker", new JavaWorkerSessionPropertyProvider( - new FeaturesConfig(), + featuresConfig, javaFeaturesConfig, nodeSpillConfig)), Optional.empty(), - Optional.empty()); + Optional.empty(), + new SessionPropertyProviderConfig()); } - public void loadSessionPropertyProvider(String sessionPropertyProviderName, Optional typeManager, Optional nodeManager) + public void loadSessionPropertyProviders() + throws Exception { + if (!sessionPropertyProvidersLoading.compareAndSet(false, true)) { + return; + } + + for (File file : listFiles(configDir)) { + if (file.isFile() && file.getName().endsWith(".properties")) { + String sessionPropertyProviderName = getNameWithoutExtension(file.getName()); + Map properties = loadProperties(file); + checkState(!isNullOrEmpty(properties.get(SESSION_PROPERTY_PROVIDER_NAME)), + "Session property manager configuration %s does not contain %s", + file.getAbsoluteFile(), + SESSION_PROPERTY_PROVIDER_NAME); + properties = new HashMap<>(properties); + properties.remove(SESSION_PROPERTY_PROVIDER_NAME); + loadSessionPropertyProvider(sessionPropertyProviderName, properties, functionAndTypeManager, nodeManager); + } + } + } + + public void loadSessionPropertyProvider(String sessionPropertyProviderName, Map properties, Optional typeManager, Optional nodeManager) + { + log.info("-- Loading %s session property provider --", sessionPropertyProviderName); WorkerSessionPropertyProviderFactory factory = workerSessionPropertyProviderFactories.get(sessionPropertyProviderName); checkState(factory != null, "No factory for session property provider : " + sessionPropertyProviderName); - WorkerSessionPropertyProvider sessionPropertyProvider = factory.create(new SessionPropertyContext(typeManager, nodeManager)); + WorkerSessionPropertyProvider sessionPropertyProvider = factory.create(new SessionPropertyContext(typeManager, nodeManager), properties); if (workerSessionPropertyProviders.putIfAbsent(sessionPropertyProviderName, sessionPropertyProvider) != null) { throw new IllegalArgumentException("System session property provider is already registered for property provider : " + sessionPropertyProviderName); } + log.info("-- Added session property provider [%s] --", sessionPropertyProviderName); } - public void loadSessionPropertyProviders() + @VisibleForTesting + public Map getWorkerSessionPropertyProviders() { - for (String sessionPropertyProviderName : workerSessionPropertyProviderFactories.keySet()) { - loadSessionPropertyProvider(sessionPropertyProviderName, functionAndTypeManager, nodeManager); - } + return ImmutableMap.copyOf(workerSessionPropertyProviders); } public void addSessionPropertyProviderFactory(WorkerSessionPropertyProviderFactory factory) @@ -224,6 +272,17 @@ private Map> getWorkerSessionProperties() return workerSessionProperties; } + private static List listFiles(File dir) + { + if (dir != null && dir.isDirectory()) { + File[] files = dir.listFiles(); + if (files != null) { + return ImmutableList.copyOf(files); + } + } + return ImmutableList.of(); + } + public List getAllSessionProperties(Session session, Map catalogs) { requireNonNull(session, "session is null"); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyProviderConfig.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyProviderConfig.java new file mode 100644 index 0000000000000..ebd765d9a99f8 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SessionPropertyProviderConfig.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.configuration.Config; +import jakarta.validation.constraints.NotNull; + +import java.io.File; + +public class SessionPropertyProviderConfig +{ + private File sessionPropertyProvidersConfigurationDir = new File("etc/session-property-providers/"); + + @NotNull + public File getSessionPropertyProvidersConfigurationDir() + { + return sessionPropertyProvidersConfigurationDir; + } + + @Config("session-property-provider.config-dir") + public SessionPropertyProviderConfig setSessionPropertyProvidersConfigurationDir(File dir) + { + this.sessionPropertyProvidersConfigurationDir = dir; + return this; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java index a71844c54b7d4..bd383aa6ef87b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SignatureBinder.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.metadata; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.type.FunctionType; import com.facebook.presto.common.type.NamedTypeSignature; import com.facebook.presto.common.type.ParameterKind; @@ -25,6 +24,7 @@ import com.facebook.presto.spi.function.LongVariableConstraint; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.TypeVariableConstraint; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.analyzer.TypeSignatureProvider; import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/Split.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/Split.java index d8581dc1730db..ba93e58a16efc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/Split.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/Split.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.metadata; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSplit; @@ -33,6 +36,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; +@ThriftStruct public final class Split { private final ConnectorId connectorId; @@ -48,6 +52,7 @@ public Split(ConnectorId connectorId, ConnectorTransactionHandle transactionHand } @JsonCreator + @ThriftConstructor public Split( @JsonProperty("connectorId") ConnectorId connectorId, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle, @@ -63,30 +68,35 @@ public Split( } @JsonProperty + @ThriftField(1) public ConnectorId getConnectorId() { return connectorId; } @JsonProperty + @ThriftField(2) public ConnectorTransactionHandle getTransactionHandle() { return transactionHandle; } @JsonProperty + @ThriftField(3) public ConnectorSplit getConnectorSplit() { return connectorSplit; } @JsonProperty + @ThriftField(4) public Lifespan getLifespan() { return lifespan; } @JsonProperty + @ThriftField(5) public SplitContext getSplitContext() { return splitContext; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java index 0caaff86c54b1..5d950534e4f5c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SplitJacksonModule.java @@ -13,18 +13,49 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class SplitJacksonModule extends AbstractTypedJacksonModule { @Inject - public SplitJacksonModule(HandleResolver handleResolver) + public SplitJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorSplit.class, handleResolver::getId, - handleResolver::getSplitClass); + handleResolver::getSplitClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorSplitCodec)); + } + + /** + * Test-friendly constructor that accepts a codec extractor function directly, + * avoiding the need to create a full ConnectorManager with all its dependencies. + */ + public SplitJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorSplit.class, + handleResolver::getId, + handleResolver::getSplitClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/SqlAggregationFunction.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/SqlAggregationFunction.java index 9d2b6c0273632..38f8f0e9a8ef8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/SqlAggregationFunction.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/SqlAggregationFunction.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.operator.aggregation.AggregationFromAnnotationsParser; @@ -46,7 +47,12 @@ public static List createFunctionByAnnotations(Class public static List createFunctionsByAnnotations(Class aggregationDefinition) { - return AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition) + return createFunctionsByAnnotations(aggregationDefinition, JAVA_BUILTIN_NAMESPACE); + } + + public static List createFunctionsByAnnotations(Class aggregationDefinition, CatalogSchemaName functionNamespace) + { + return AggregationFromAnnotationsParser.parseFunctionDefinitions(aggregationDefinition, functionNamespace) .stream() .map(x -> (SqlAggregationFunction) x) .collect(toImmutableList()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStore.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStore.java index db96e2877efed..ca50bb18e6e0b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStore.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStore.java @@ -19,8 +19,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Files; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStoreConfig.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStoreConfig.java index 32d125aa2e214..48cf865297c9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStoreConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticCatalogStoreConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.LegacyConfig; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStore.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStore.java index 85e2d1118734a..a1b7ae89f0033 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStore.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStore.java @@ -16,8 +16,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.NodeManager; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.util.HashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStoreConfig.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStoreConfig.java index 3d7292e854f93..8ee29461b8ca0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStoreConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticFunctionNamespaceStoreConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.metadata; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStore.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStore.java new file mode 100644 index 0000000000000..aefe38c7fea48 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStore.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.log.Logger; +import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; + +import java.io.File; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.facebook.presto.util.PropertiesUtil.loadProperties; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.io.Files.getNameWithoutExtension; + +public class StaticTypeManagerStore +{ + private static final Logger log = Logger.get(StaticTypeManagerStore.class); + private static final String TYPE_MANAGER_NAME = "type-manager.name"; + private final FunctionAndTypeManager functionAndTypeManager; + private final File configDir; + private final AtomicBoolean typeManagersLoading = new AtomicBoolean(); + + @Inject + public StaticTypeManagerStore(FunctionAndTypeManager functionAndTypeManager, StaticTypeManagerStoreConfig config) + { + this.functionAndTypeManager = functionAndTypeManager; + this.configDir = config.getTypeManagerConfigurationDir(); + } + + public void loadTypeManagers() + throws Exception + { + if (!typeManagersLoading.compareAndSet(false, true)) { + return; + } + + for (File file : listFiles(configDir)) { + if (file.isFile() && file.getName().endsWith(".properties")) { + String catalogName = getNameWithoutExtension(file.getName()); + Map properties = loadProperties(file); + checkState(!isNullOrEmpty(properties.get(TYPE_MANAGER_NAME)), + "Type manager configuration %s does not contain %s", + file.getAbsoluteFile(), + TYPE_MANAGER_NAME); + loadTypeManager(catalogName, properties); + } + } + } + + public void loadTypeManagers(Map> catalogProperties) + { + catalogProperties.entrySet().stream() + .forEach(entry -> loadTypeManager(entry.getKey(), entry.getValue())); + } + + private void loadTypeManager(String catalogName, Map properties) + { + log.info("-- Loading %s type manager --", catalogName); + properties = new HashMap<>(properties); + String typeManagerName = properties.remove(TYPE_MANAGER_NAME); + checkState(!isNullOrEmpty(typeManagerName), "%s property must be present", TYPE_MANAGER_NAME); + functionAndTypeManager.loadTypeManager(typeManagerName); + log.info("-- Added type manager [%s] --", catalogName); + } + + private static List listFiles(File dir) + { + if (dir != null && dir.isDirectory()) { + File[] files = dir.listFiles(); + if (files != null) { + return ImmutableList.copyOf(files); + } + } + return ImmutableList.of(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStoreConfig.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStoreConfig.java new file mode 100644 index 0000000000000..c1b039a234c80 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/StaticTypeManagerStoreConfig.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.configuration.Config; +import jakarta.validation.constraints.NotNull; + +import java.io.File; + +public class StaticTypeManagerStoreConfig +{ + private File typeManagerConfigurationDir = new File("etc/type-managers/"); + + @NotNull + public File getTypeManagerConfigurationDir() + { + return typeManagerConfigurationDir; + } + + @Config("type-manager.config-dir") + public StaticTypeManagerStoreConfig setTypeManagerConfigurationDir(File dir) + { + this.typeManagerConfigurationDir = dir; + return this; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java new file mode 100644 index 0000000000000..20e65066d4320 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionHandle.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionHandle +{ + private final ConnectorId connectorId; + private final ConnectorTableFunctionHandle functionHandle; + private final ConnectorTransactionHandle transactionHandle; + + @JsonCreator + public TableFunctionHandle( + @JsonProperty("connectorId") ConnectorId connectorId, + @JsonProperty("functionHandle") ConnectorTableFunctionHandle functionHandle, + @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); + } + + @JsonProperty + public ConnectorId getConnectorId() + { + return connectorId; + } + + @JsonProperty + public ConnectorTableFunctionHandle getFunctionHandle() + { + return functionHandle; + } + + @JsonProperty + public ConnectorTransactionHandle getTransactionHandle() + { + return transactionHandle; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java new file mode 100644 index 0000000000000..9f289f4ac491f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionJacksonHandleModule.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; + +public class TableFunctionJacksonHandleModule + extends AbstractTypedJacksonModule +{ + @Inject + public TableFunctionJacksonHandleModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) + { + super(ConnectorTableFunctionHandle.class, + handleResolver::getId, + handleResolver::getTableFunctionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableFunctionHandleCodec)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java new file mode 100644 index 0000000000000..806215927b736 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionMetadata.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; + +import static java.util.Objects.requireNonNull; + +public class TableFunctionMetadata +{ + private final ConnectorId connectorId; + private final ConnectorTableFunction function; + + public TableFunctionMetadata(ConnectorId connectorId, ConnectorTableFunction function) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.function = requireNonNull(function, "function is null"); + } + + public ConnectorId getConnectorId() + { + return connectorId; + } + + public ConnectorTableFunction getFunction() + { + return function; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java new file mode 100644 index 0000000000000..d624b8364e35b --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableFunctionRegistry.java @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.function.CatalogSchemaFunctionName; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.ThreadSafe; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import static com.facebook.presto.spi.StandardErrorCode.SESSION_CATALOG_NOT_SET; +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +@ThreadSafe +public class TableFunctionRegistry +{ + // catalog name in the original case; schema and function name in lowercase + private final Map> tableFunctions = new ConcurrentHashMap<>(); + + public void addTableFunctions(ConnectorId catalogName, Collection functions) + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(functions, "functions is null"); + checkState(!tableFunctions.containsKey(catalogName), "Table functions already registered for catalog: " + catalogName); + + functions.stream() + .forEach(TableFunctionRegistry::validateTableFunction); + + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (ConnectorTableFunction function : functions) { + builder.put( + new SchemaFunctionName( + function.getSchema().toLowerCase(ENGLISH), + function.getName().toLowerCase(ENGLISH)), + new TableFunctionMetadata(catalogName, function)); + } + tableFunctions.putIfAbsent(catalogName, builder.buildOrThrow()); + } + + public void removeTableFunctions(ConnectorId catalogName) + { + tableFunctions.remove(catalogName); + } + + public static List toPath(Session session, QualifiedName name) + { + List parts = name.getParts(); + if (parts.size() > 3) { + throw new PrestoException(StandardErrorCode.FUNCTION_NOT_FOUND, "Invalid function name: " + name); + } + if (parts.size() == 3) { + return ImmutableList.of(new CatalogSchemaFunctionName(parts.get(0), parts.get(1), parts.get(2))); + } + + if (parts.size() == 2) { + String currentCatalog = session.getCatalog() + .orElseThrow(() -> new PrestoException(SESSION_CATALOG_NOT_SET, "Session default catalog must be set to resolve a partial function name: " + name)); + return ImmutableList.of(new CatalogSchemaFunctionName(currentCatalog, parts.get(0), parts.get(1))); + } + + ImmutableList.Builder names = ImmutableList.builder(); + + String currentCatalog = session.getCatalog() + .orElseThrow(() -> new SemanticException(CATALOG_NOT_SPECIFIED, "Catalog must be specified when session catalog is not set")); + String currentSchema = session.getSchema() + .orElseThrow(() -> new SemanticException(SCHEMA_NOT_SPECIFIED, "Schema must be specified when session schema is not set")); + + // add resolved path items + names.add(new CatalogSchemaFunctionName(currentCatalog, currentSchema, parts.get(0))); + + // add builtin path items + names.add(new CatalogSchemaFunctionName("system", "builtin", parts.get(0))); + return names.build(); + } + + /** + * Resolve table function with given qualified name. + * Table functions are resolved case-insensitive for consistency with existing scalar function resolution. + */ + public Optional resolve(Session session, QualifiedName qualifiedName) + { + for (CatalogSchemaFunctionName name : toPath(session, qualifiedName)) { + ConnectorId connectorId = new ConnectorId(name.getCatalogName()); + Map catalogFunctions = tableFunctions.get(connectorId); + if (catalogFunctions != null) { + String lowercasedSchemaName = name.getSchemaFunctionName().getSchemaName().toLowerCase(ENGLISH); + String lowercasedFunctionName = name.getSchemaFunctionName().getFunctionName().toLowerCase(ENGLISH); + TableFunctionMetadata function = catalogFunctions.get(new SchemaFunctionName(lowercasedSchemaName, lowercasedFunctionName)); + if (function != null) { + return Optional.of(function); + } + } + } + + return Optional.empty(); + } + + private static void validateTableFunction(ConnectorTableFunction tableFunction) + { + requireNonNull(tableFunction, "tableFunction is null"); + requireNonNull(tableFunction.getName(), "table function name is null"); + requireNonNull(tableFunction.getSchema(), "table function schema name is null"); + requireNonNull(tableFunction.getArguments(), "table function arguments is null"); + requireNonNull(tableFunction.getReturnTypeSpecification(), "table function returnTypeSpecification is null"); + + checkArgument(!tableFunction.getName().isEmpty(), "table function name is empty"); + checkArgument(!tableFunction.getSchema().isEmpty(), "table function schema name is empty"); + + Set argumentNames = new HashSet<>(); + int tableArgumentsWithRowSemantics = 0; + for (ArgumentSpecification specification : tableFunction.getArguments()) { + if (!argumentNames.add(specification.getName())) { + throw new IllegalArgumentException("duplicate argument name: " + specification.getName()); + } + + if (specification instanceof TableArgumentSpecification && + ((TableArgumentSpecification) specification).isRowSemantics()) { + tableArgumentsWithRowSemantics++; + } + } + checkArgument(tableArgumentsWithRowSemantics <= 1, "more than one table argument with row semantics"); + // The 'keep when empty' or 'prune when empty' property must not be explicitly specified for a table argument with row semantics. + // Such a table argument is implicitly 'prune when empty'. The TableArgumentSpecification.Builder enforces the 'prune when empty' property + // for a table argument with row semantics. + + if (tableFunction.getReturnTypeSpecification() instanceof DescribedTable) { + DescribedTable describedTable = (DescribedTable) tableFunction.getReturnTypeSpecification(); + checkArgument(describedTable.getDescriptor().isTyped(), "field types missing in returned type specification"); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java index 1a0fb2090ecb6..3dc4c534c2c39 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class TableHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TableHandleJacksonModule(HandleResolver handleResolver) + public TableHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTableHandle.class, handleResolver::getId, - handleResolver::getTableHandleClass); + handleResolver::getTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableHandleCodec)); + } + + public TableHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorTableHandle.class, + handleResolver::getId, + handleResolver::getTableHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayout.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayout.java index 553f5b5f8da1c..19026f586c035 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayout.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayout.java @@ -78,6 +78,11 @@ public List> getLocalProperties() return layout.getLocalProperties(); } + public Optional getUniqueColumn() + { + return layout.getUniqueColumn(); + } + public ConnectorTableLayoutHandle getLayoutHandle() { return layout.getHandle(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java index 4d9e6aeea7d19..505c65e42d4d0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TableLayoutHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class TableLayoutHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TableLayoutHandleJacksonModule(HandleResolver handleResolver) + public TableLayoutHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTableLayoutHandle.class, handleResolver::getId, - handleResolver::getTableLayoutHandleClass); + handleResolver::getTableLayoutHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTableLayoutHandleCodec)); + } + + public TableLayoutHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorTableLayoutHandle.class, + handleResolver::getId, + handleResolver::getTableLayoutHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeState.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeState.java index 011ed41bbde49..20bdf929e723b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeState.java @@ -13,16 +13,15 @@ */ package com.facebook.presto.metadata; +import com.facebook.airlift.units.Duration; import com.facebook.drift.client.DriftClient; import com.facebook.presto.server.thrift.ThriftServerInfoClient; import com.facebook.presto.spi.NodeState; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.Optional; @@ -30,9 +29,9 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.Duration.nanosSince; import static java.util.Objects.requireNonNull; @ThreadSafe diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeStats.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeStats.java new file mode 100644 index 0000000000000..185df3be1c143 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/ThriftRemoteNodeStats.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; +import com.facebook.drift.client.DriftClient; +import com.facebook.presto.server.thrift.ThriftServerInfoClient; +import com.facebook.presto.spi.NodeState; +import com.facebook.presto.spi.NodeStats; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; + +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.airlift.units.Duration.nanosSince; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +@ThreadSafe +public class ThriftRemoteNodeStats + implements RemoteNodeStats +{ + private static final Logger log = Logger.get(ThriftRemoteNodeStats.class); + + private final ThriftServerInfoClient thriftClient; + private final long refreshIntervalMillis; + private final AtomicReference> nodeStats = new AtomicReference<>(Optional.empty()); + private final AtomicBoolean requestInflight = new AtomicBoolean(); + private final AtomicLong lastUpdateNanos = new AtomicLong(); + private final URI stateInfoUri; + private final AtomicLong lastWarningLogged = new AtomicLong(); + + public ThriftRemoteNodeStats(DriftClient thriftClient, URI stateInfoUri, long refreshIntervalMillis) + { + requireNonNull(stateInfoUri, "stateInfoUri is null"); + checkArgument(stateInfoUri.getScheme().equals("thrift"), "unexpected scheme %s", stateInfoUri.getScheme()); + + this.stateInfoUri = stateInfoUri; + this.refreshIntervalMillis = refreshIntervalMillis; + this.thriftClient = requireNonNull(thriftClient, "thriftClient is null").get(Optional.of(stateInfoUri.getAuthority())); + } + + @Override + public Optional getNodeStats() + { + return nodeStats.get(); + } + + @Override + public void asyncRefresh() + { + Duration sinceUpdate = nanosSince(lastUpdateNanos.get()); + if (nanosSince(lastWarningLogged.get()).toMillis() > 1_000 && + sinceUpdate.toMillis() > 10_000 && + requestInflight.get()) { + log.warn("Node state update request to %s has not returned in %s", stateInfoUri, sinceUpdate.toString(SECONDS)); + lastWarningLogged.set(System.nanoTime()); + } + + if (sinceUpdate.toMillis() > refreshIntervalMillis && requestInflight.compareAndSet(false, true)) { + ListenableFuture responseFuture = thriftClient.getServerState(); + + Futures.addCallback(responseFuture, new FutureCallback() + { + @Override + public void onSuccess(@Nullable Integer result) + { + lastUpdateNanos.set(System.nanoTime()); + requestInflight.compareAndSet(true, false); + if (result != null) { + NodeStats nodeStats1 = new NodeStats(NodeState.valueOf(result), null); + nodeStats.set(Optional.of(nodeStats1)); + } + else { + log.warn("Node statistics endpoint %s returned null response, using cached statistics", stateInfoUri); + } + } + + @Override + public void onFailure(Throwable t) + { + log.error("Error fetching node stats from %s: %s", stateInfoUri, t.getMessage()); + lastUpdateNanos.set(System.nanoTime()); + requestInflight.compareAndSet(true, false); + } + }, directExecutor()); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java index 08048a9bb405e..fe650be2ad7e7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/TransactionHandleJacksonModule.java @@ -13,18 +13,45 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import jakarta.inject.Inject; +import jakarta.inject.Provider; -import javax.inject.Inject; +import java.util.Optional; +import java.util.function.Function; public class TransactionHandleJacksonModule extends AbstractTypedJacksonModule { @Inject - public TransactionHandleJacksonModule(HandleResolver handleResolver) + public TransactionHandleJacksonModule( + HandleResolver handleResolver, + Provider connectorManagerProvider, + FeaturesConfig featuresConfig) { super(ConnectorTransactionHandle.class, handleResolver::getId, - handleResolver::getTransactionHandleClass); + handleResolver::getTransactionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> connectorManagerProvider.get() + .getConnectorCodecProvider(connectorId) + .flatMap(ConnectorCodecProvider::getConnectorTransactionHandleCodec)); + } + + public TransactionHandleJacksonModule( + HandleResolver handleResolver, + FeaturesConfig featuresConfig, + Function>> codecExtractor) + { + super(ConnectorTransactionHandle.class, + handleResolver::getId, + handleResolver::getTransactionHandleClass, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + codecExtractor); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/AbstractRowChangeOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/AbstractRowChangeOperator.java index 1fd421543539c..f26ba51630fba 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/AbstractRowChangeOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/AbstractRowChangeOperator.java @@ -48,7 +48,7 @@ protected enum State protected State state = State.RUNNING; protected long rowCount; private boolean closed; - private ListenableFuture> finishFuture; + protected ListenableFuture> finishFuture; private Supplier> pageSource = Optional::empty; private final JsonCodec tableCommitContextCodec; @@ -158,6 +158,7 @@ public void close() } else { pageSource.get().ifPresent(UpdatablePageSource::abort); + abort(); } } } @@ -173,4 +174,6 @@ protected UpdatablePageSource pageSource() // empty source can occur if the source operator doesn't output any rows return source.orElseGet(EmptySplitPageSource::new); } + + protected void abort() {} } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java new file mode 100644 index 0000000000000..71c3919fa6900 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ChangeOnlyUpdatedColumnsMergeProcessor.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.RunLengthEncodedBlock; + +import java.util.ArrayList; +import java.util.List; + +import static com.facebook.presto.common.Utils.nativeValueToBlock; +import static com.facebook.presto.common.block.RowBlock.getRowFieldsFromBlock; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +/** + * The transformPage() method in this class does two things: + *

    + *
  • Transform the input page into an "update" page format
  • + *
  • Removes all rows whose operation number is DEFAULT_CASE_OPERATION_NUMBER
  • + *
+ */ +public class ChangeOnlyUpdatedColumnsMergeProcessor + implements MergeRowChangeProcessor +{ + private static final Block INSERT_FROM_UPDATE_BLOCK = nativeValueToBlock(TINYINT, 0L); + + private final int rowIdChannel; + private final int mergeRowChannel; + private final List dataColumnChannels; + + public ChangeOnlyUpdatedColumnsMergeProcessor( + int rowIdChannel, + int mergeRowChannel, + List targetColumnChannels) + { + this.rowIdChannel = rowIdChannel; + this.mergeRowChannel = mergeRowChannel; + this.dataColumnChannels = requireNonNull(targetColumnChannels, "targetColumnChannels is null"); + } + + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + + int inputChannelCount = inputPage.getChannelCount(); + checkArgument(inputChannelCount >= 2, "inputPage channelCount (%s)", inputChannelCount); + int positionCount = inputPage.getPositionCount(); + checkArgument(positionCount > 0, "positionCount should be > 0, but is %s", positionCount); + + Block mergeRow = inputPage.getBlock(mergeRowChannel).getLoadedBlock(); + if (mergeRow.mayHaveNull()) { + for (int position = 0; position < positionCount; position++) { + checkArgument(!mergeRow.isNull(position), "The mergeRow may not have null rows"); + } + } + + List fields = getRowFieldsFromBlock(mergeRow); + List builder = new ArrayList<>(dataColumnChannels.size() + 3); + for (int channel : dataColumnChannels) { + builder.add(fields.get(channel)); + } + Block operationChannelBlock = fields.get(fields.size() - 2); + builder.add(operationChannelBlock); + builder.add(inputPage.getBlock(rowIdChannel)); + builder.add(new RunLengthEncodedBlock(INSERT_FROM_UPDATE_BLOCK, positionCount)); + + Page result = new Page(builder.toArray(new Block[0])); + + int defaultCaseCount = 0; + for (int position = 0; position < positionCount; position++) { + if (TINYINT.getByte(operationChannelBlock, position) == DEFAULT_CASE_OPERATION_NUMBER) { + defaultCaseCount++; + } + } + if (defaultCaseCount == 0) { + return result; + } + + int usedCases = 0; + int[] positions = new int[positionCount - defaultCaseCount]; + for (int position = 0; position < positionCount; position++) { + if (TINYINT.getByte(operationChannelBlock, position) != DEFAULT_CASE_OPERATION_NUMBER) { + positions[usedCases] = position; + usedCases++; + } + } + + checkArgument(usedCases + defaultCaseCount == positionCount, "usedCases (%s) + defaultCaseCount (%s) != positionCount (%s)", usedCases, defaultCaseCount, positionCount); + + return result.getPositions(positions, 0, usedCases); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/CompletedWork.java b/presto-main-base/src/main/java/com/facebook/presto/operator/CompletedWork.java index 8b636096bd8e9..12b569c69954d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/CompletedWork.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/CompletedWork.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator; -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/DeleteAndInsertMergeProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/DeleteAndInsertMergeProcessor.java new file mode 100644 index 0000000000000..9f55ea6d574c1 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/DeleteAndInsertMergeProcessor.java @@ -0,0 +1,179 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.ColumnarRow; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static com.facebook.presto.common.block.ColumnarRow.toColumnarRow; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.spi.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static java.util.Objects.requireNonNull; + +public class DeleteAndInsertMergeProcessor + implements MergeRowChangeProcessor +{ + private final List targetColumnTypes; + private final Type rowIdType; + private final int targetTableRowIdChannel; + private final int mergeRowChannel; + private final List targetColumnChannels; + + public DeleteAndInsertMergeProcessor( + List targetColumnTypes, + Type rowIdType, + int targetTableRowIdChannel, + int mergeRowChannel, + List targetColumnChannels) + { + this.targetColumnTypes = requireNonNull(targetColumnTypes, "targetColumnTypes is null"); + this.rowIdType = requireNonNull(rowIdType, "rowIdType is null"); + this.targetTableRowIdChannel = targetTableRowIdChannel; + this.mergeRowChannel = mergeRowChannel; + this.targetColumnChannels = requireNonNull(targetColumnChannels, "targetColumnChannels is null"); + } + + @JsonProperty + public Type getRowIdType() + { + return rowIdType; + } + + /** + * Transform UPDATE operations into an INSERT and DELETE operation. + * See {@link MergeRowChangeProcessor#transformPage} for details. + * @param inputPage It has 5 channels/blocks:
+ * 1. Unique ID
+ * 2. Target Table Row ID (_file:varchar, _pos:bigint, partition_spec_id:integer, partition_data:varchar)
+ * 3. Merge Row (source table columns, operation, case number)
+ * 4. Merge case number
+ * 5. Is distinct row: it is 1 if no other row has the same unique id and WHEN clause number, 0 otherwise.
+ */ + @Override + public Page transformPage(Page inputPage) + { + requireNonNull(inputPage, "inputPage is null"); + int inputChannelCount = inputPage.getChannelCount(); + checkArgument(inputChannelCount >= 2, "inputPage channelCount (%s) should be >= 2", inputChannelCount); + + int originalPositionCount = inputPage.getPositionCount(); + checkArgument(originalPositionCount > 0, "originalPositionCount should be > 0, but is %s", originalPositionCount); + + ColumnarRow mergeRow = toColumnarRow(inputPage.getBlock(mergeRowChannel)); + Block operationChannelBlock = mergeRow.getField(mergeRow.getFieldCount() - 2); + + int updatePositions = 0; + int insertPositions = 0; + int deletePositions = 0; + for (int position = 0; position < originalPositionCount; position++) { + byte operation = TINYINT.getByte(operationChannelBlock, position); + switch (operation) { + case DEFAULT_CASE_OPERATION_NUMBER:/* ignored */ + break; + case INSERT_OPERATION_NUMBER: + insertPositions++; + break; + case DELETE_OPERATION_NUMBER: + deletePositions++; + break; + case UPDATE_OPERATION_NUMBER: + updatePositions++; + break; + default: + throw new IllegalArgumentException("Unknown operator number: " + operation); + } + } + + int totalPositions = insertPositions + deletePositions + (2 * updatePositions); + List pageTypes = ImmutableList.builder() + .addAll(targetColumnTypes) + .add(TINYINT) // Operation: INSERT(1), DELETE(2), UPDATE(3). More info: ConnectorMergeSink + .add(rowIdType) + .add(TINYINT) // Insert from update: it is 1 if the cause of the insert is an UPDATE, 0 otherwise. + .build(); + + PageBuilder pageBuilder = new PageBuilder(totalPositions, pageTypes); + for (int position = 0; position < originalPositionCount; position++) { + byte operation = TINYINT.getByte(operationChannelBlock, position); + if (operation != DEFAULT_CASE_OPERATION_NUMBER) { + // Delete and Update because both create a delete row + if (operation == DELETE_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + addDeleteRow(pageBuilder, inputPage, position); + } + // Insert and update because both create an insert row + if (operation == INSERT_OPERATION_NUMBER || operation == UPDATE_OPERATION_NUMBER) { + addInsertRow(pageBuilder, mergeRow, position, operation == UPDATE_OPERATION_NUMBER); + } + } + } + + Page page = pageBuilder.build(); + verify(page.getPositionCount() == totalPositions, "page positions (%s) is not equal to (%s)", page.getPositionCount(), totalPositions); + return page; + } + + private void addDeleteRow(PageBuilder pageBuilder, Page originalPage, int position) + { + // Delete doesn't care about the data columns. + for (int targetChannel : targetColumnChannels) { + BlockBuilder targetBlock = pageBuilder.getBlockBuilder(targetChannel); + targetBlock.appendNull(); + } + + // Add the operation column == deleted + TINYINT.writeLong(pageBuilder.getBlockBuilder(targetColumnChannels.size()), DELETE_OPERATION_NUMBER); + + // Copy target table row ID column + rowIdType.appendTo(originalPage.getBlock(targetTableRowIdChannel), position, pageBuilder.getBlockBuilder(targetColumnChannels.size() + 1)); + + // Write 0, meaning this row is not an insert derived from an update + TINYINT.writeLong(pageBuilder.getBlockBuilder(targetColumnChannels.size() + 2), 0); + + pageBuilder.declarePosition(); + } + + private void addInsertRow(PageBuilder pageBuilder, ColumnarRow mergeCaseBlock, int position, boolean causedByUpdate) + { + // Copy the values from the merge block + for (int targetChannel : targetColumnChannels) { + Type columnType = targetColumnTypes.get(targetChannel); + BlockBuilder targetBlock = pageBuilder.getBlockBuilder(targetChannel); + // The value comes from that column of the page + columnType.appendTo(mergeCaseBlock.getField(targetChannel), position, targetBlock); + } + + // Add the operation column == insert + TINYINT.writeLong(pageBuilder.getBlockBuilder(targetColumnChannels.size()), INSERT_OPERATION_NUMBER); + + // Add null target table row ID column + pageBuilder.getBlockBuilder(targetColumnChannels.size() + 1).appendNull(); + + // Write 1 if this row is an insert derived from an update, 0 otherwise + TINYINT.writeLong(pageBuilder.getBlockBuilder(targetColumnChannels.size() + 2), causedByUpdate ? 1 : 0); + + pageBuilder.declarePosition(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/Driver.java b/presto-main-base/src/main/java/com/facebook/presto/operator/Driver.java index 7dbc4e7485687..9cf41c0f563ca 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/Driver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/Driver.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.execution.FragmentResultCacheContext; import com.facebook.presto.execution.ScheduledSplit; @@ -30,9 +31,7 @@ import com.google.common.collect.Sets; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.Closeable; import java.util.ArrayList; @@ -82,6 +81,7 @@ public class Driver private final Optional sourceOperator; private final Optional deleteOperator; private final Optional updateOperator; + private final Optional mergeOperator; // This variable acts as a staging area. When new splits (encapsulated in TaskSource) are // provided to a Driver, the Driver will not process them right away. Instead, the splits are @@ -142,6 +142,7 @@ private Driver(DriverContext driverContext, List operators) Optional sourceOperator = Optional.empty(); Optional deleteOperator = Optional.empty(); Optional updateOperator = Optional.empty(); + Optional mergeOperator = Optional.empty(); for (Operator operator : operators) { if (operator instanceof SourceOperator) { checkArgument(!sourceOperator.isPresent(), "There must be at most one SourceOperator"); @@ -155,10 +156,15 @@ else if (operator instanceof UpdateOperator) { checkArgument(!updateOperator.isPresent(), "There must be at most one UpdateOperator"); updateOperator = Optional.of((UpdateOperator) operator); } + else if (operator instanceof MergeWriterOperator) { + checkArgument(!mergeOperator.isPresent(), "There must be at most one MergeWriterOperator"); + mergeOperator = Optional.of((MergeWriterOperator) operator); + } } this.sourceOperator = sourceOperator; this.deleteOperator = deleteOperator; this.updateOperator = updateOperator; + this.mergeOperator = mergeOperator; currentTaskSource = sourceOperator.map(operator -> new TaskSource(operator.getSourceId(), ImmutableSet.of(), false)).orElse(null); // initially the driverBlockedFuture is not blocked (it is completed) @@ -290,6 +296,7 @@ private void processNewSources() Supplier> pageSource = sourceOperator.addSplit(newSplit); deleteOperator.ifPresent(deleteOperator -> deleteOperator.setPageSource(pageSource)); updateOperator.ifPresent(updateOperator -> updateOperator.setPageSource(pageSource)); + mergeOperator.ifPresent(mergeOperator -> mergeOperator.setPageSource(pageSource)); } // set no more splits diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverContext.java b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverContext.java index 95bac0519e9ef..f6165fcf4449f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.execution.FragmentResultCacheContext; import com.facebook.presto.execution.Lifespan; @@ -26,7 +27,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.util.List; import java.util.Optional; @@ -37,11 +37,11 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.getFirst; import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.Iterables.transform; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Math.max; import static java.lang.System.currentTimeMillis; import static java.lang.System.nanoTime; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverStats.java b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverStats.java index 473ae5971e0de..675cef78b4624 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverStats.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -21,10 +22,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Set; @@ -156,12 +155,9 @@ public DriverStats( this.queuedTime = requireNonNull(queuedTime, "queuedTime is null"); this.elapsedTime = requireNonNull(elapsedTime, "elapsedTime is null"); - checkArgument(userMemoryReservationInBytes >= 0, "userMemoryReservationInBytes is negative"); - this.userMemoryReservationInBytes = userMemoryReservationInBytes; - checkArgument(revocableMemoryReservationInBytes >= 0, "revocableMemoryReservationInBytes is negative"); - this.revocableMemoryReservationInBytes = revocableMemoryReservationInBytes; - checkArgument(systemMemoryReservationInBytes >= 0, "systemMemoryReservationInBytes is negative"); - this.systemMemoryReservationInBytes = systemMemoryReservationInBytes; + this.userMemoryReservationInBytes = (userMemoryReservationInBytes >= 0) ? userMemoryReservationInBytes : Long.MAX_VALUE; + this.revocableMemoryReservationInBytes = (revocableMemoryReservationInBytes >= 0) ? revocableMemoryReservationInBytes : Long.MAX_VALUE; + this.systemMemoryReservationInBytes = (systemMemoryReservationInBytes >= 0) ? systemMemoryReservationInBytes : Long.MAX_VALUE; this.totalScheduledTime = requireNonNull(totalScheduledTime, "totalScheduledTime is null"); this.totalCpuTime = requireNonNull(totalCpuTime, "totalCpuTime is null"); @@ -169,30 +165,22 @@ public DriverStats( this.fullyBlocked = fullyBlocked; this.blockedReasons = ImmutableSet.copyOf(requireNonNull(blockedReasons, "blockedReasons is null")); - checkArgument(totalAllocationInBytes >= 0, "totalAllocationInBytes is negative"); - this.totalAllocationInBytes = totalAllocationInBytes; + this.totalAllocationInBytes = (totalAllocationInBytes >= 0) ? totalAllocationInBytes : Long.MAX_VALUE; - checkArgument(rawInputDataSizeInBytes >= 0, "rawInputDataSizeInBytes is negative"); - this.rawInputDataSizeInBytes = rawInputDataSizeInBytes; + this.rawInputDataSizeInBytes = (rawInputDataSizeInBytes >= 0) ? rawInputDataSizeInBytes : Long.MAX_VALUE; - checkArgument(rawInputPositions >= 0, "rawInputPositions is negative"); - this.rawInputPositions = rawInputPositions; + this.rawInputPositions = (rawInputPositions >= 0) ? rawInputPositions : Long.MAX_VALUE; this.rawInputReadTime = requireNonNull(rawInputReadTime, "rawInputReadTime is null"); - checkArgument(processedInputDataSizeInBytes >= 0, "processedInputDataSizeInBytes is negative"); - this.processedInputDataSizeInBytes = processedInputDataSizeInBytes; + this.processedInputDataSizeInBytes = (processedInputDataSizeInBytes >= 0) ? processedInputDataSizeInBytes : Long.MAX_VALUE; - checkArgument(processedInputPositions >= 0, "processedInputPositions is negative"); - this.processedInputPositions = processedInputPositions; + this.processedInputPositions = (processedInputPositions >= 0) ? processedInputPositions : Long.MAX_VALUE; - // An overflow could have occurred on this stat - handle this gracefully. this.outputDataSizeInBytes = (outputDataSizeInBytes >= 0) ? outputDataSizeInBytes : Long.MAX_VALUE; - checkArgument(outputPositions >= 0, "outputPositions is negative"); - this.outputPositions = outputPositions; + this.outputPositions = (outputPositions >= 0) ? outputPositions : Long.MAX_VALUE; - checkArgument(physicalWrittenDataSizeInBytes >= 0, "writtenDataSizeInBytes is negative"); - this.physicalWrittenDataSizeInBytes = physicalWrittenDataSizeInBytes; + this.physicalWrittenDataSizeInBytes = (physicalWrittenDataSizeInBytes >= 0) ? physicalWrittenDataSizeInBytes : Long.MAX_VALUE; this.operatorStats = ImmutableList.copyOf(requireNonNull(operatorStats, "operatorStats is null")); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverYieldSignal.java b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverYieldSignal.java index e5a1bbadff17a..37b9c7d19e14d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/DriverYieldSignal.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/DriverYieldSignal.java @@ -14,9 +14,8 @@ package com.facebook.presto.operator; import com.google.common.annotations.VisibleForTesting; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java index 1f0fd5e9dd5db..d4221e6b1eb9c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/DynamicFilterSourceOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; @@ -24,9 +25,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java new file mode 100644 index 0000000000000..bda83ae6319d4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/EmptyTableFunctionPartition.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; + +import java.util.List; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +/** + * This is a class representing empty input to a table function. An EmptyTableFunctionPartition is created + * when the table function has KEEP WHEN EMPTY property, which means that the function should be executed + * even if the input is empty, and all the table arguments are empty relations. + *

+ * An EmptyTableFunctionPartition is created and processed once per node. To avoid duplicated execution, + * a table function having KEEP WHEN EMPTY property must have single distribution. + */ +public class EmptyTableFunctionPartition + implements TableFunctionPartition +{ + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + private final Type[] passThroughTypes; + + public EmptyTableFunctionPartition(TableFunctionDataProcessor tableFunction, int properChannelsCount, int passThroughSourcesCount, List passThroughTypes) + { + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.passThroughTypes = passThroughTypes.toArray(new Type[] {}); + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(() -> { + TableFunctionProcessorState state = tableFunction.process(null); + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendNullsForPassThroughColumns(processed.getResult())); + } + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + }); + } + + private Page appendNullsForPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + + Block[] resultBlocks = new Block[properChannelsCount + passThroughTypes.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + // because no input was processed, all pass-through indexes in the result page must be null (there are no input rows they could refer to). + // for performance reasons this is not checked. All pass-through columns are filled with nulls. + int channel = properChannelsCount; + for (Type type : passThroughTypes) { + resultBlocks[channel] = RunLengthEncodedBlock.create(type, null, page.getPositionCount()); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClient.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClient.java index 6da865456582b..77540227a7c16 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClient.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClient.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.TaskId; import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.operator.PageBufferClient.ClientCallback; @@ -22,12 +24,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.io.Closeable; import java.net.URI; @@ -44,6 +43,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -51,7 +51,6 @@ import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.airlift.slice.Slices.EMPTY_SLICE; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.max; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientConfig.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientConfig.java index c80c7bbe3c4db..33ed5473598dc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientConfig.java @@ -15,18 +15,17 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.DefunctConfig; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDataSize; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDataSize; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctDataSize; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctDataSize; @DefunctConfig("exchange.async-page-transport-enabled") public class ExchangeClientConfig diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientFactory.java index 764ebd727fe05..abd07ed290194 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ExchangeClientFactory.java @@ -14,22 +14,21 @@ package com.facebook.presto.operator; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.memory.context.LocalMemoryContext; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newFixedThreadPool; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheConfig.java b/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheConfig.java index c900ba362f2b9..ebdf0baca862c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheConfig.java @@ -15,18 +15,17 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDataSize; +import com.facebook.airlift.units.MinDuration; import com.facebook.presto.CompressionCodec; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MinDataSize; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; +import jakarta.validation.constraints.Min; import java.net.URI; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.DAYS; public class FileFragmentResultCacheConfig diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheManager.java b/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheManager.java index 18a4e7d026d62..e04529f84fbb7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/FileFragmentResultCacheManager.java @@ -28,10 +28,9 @@ import io.airlift.slice.InputStreamSliceInput; import io.airlift.slice.OutputStreamSliceOutput; import io.airlift.slice.SliceOutput; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; -import javax.inject.Inject; - import java.io.Closeable; import java.io.File; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/FilterAndProjectOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/FilterAndProjectOperator.java index b06428a276639..1c2cb20363034 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/FilterAndProjectOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/FilterAndProjectOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.type.Type; @@ -21,7 +22,6 @@ import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.spi.plan.PlanNodeId; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; import java.util.function.Supplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ForExchange.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ForExchange.java index f49adfe1fd4ea..2adbe320d52f6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ForExchange.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ForExchange.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ForScheduler.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ForScheduler.java index 7d33c10082c03..e04fe43abab9f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ForScheduler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ForScheduler.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java index 1d4e8cdb19568..92d6f37de0b53 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/HashAggregationOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -35,7 +36,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; @@ -43,6 +43,7 @@ import java.util.OptionalLong; import java.util.stream.Collectors; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.operator.aggregation.builder.InMemoryHashAggregationBuilder.toTypes; import static com.facebook.presto.sql.planner.PlannerUtils.INITIAL_HASH_VALUE; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; @@ -50,7 +51,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; public class HashAggregationOperator diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java index 18c4767dc8fbf..6e7917a53cb7b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/HashBuilderOperator.java @@ -26,9 +26,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.io.Closer; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import java.io.IOException; import java.util.ArrayDeque; @@ -41,6 +40,7 @@ import java.util.Queue; import static com.facebook.airlift.concurrent.MoreFutures.getDone; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.ExceededMemoryLimitException.exceededLocalUserMemoryLimit; import static com.facebook.presto.SystemSessionProperties.getQueryMaxMemoryPerNode; import static com.facebook.presto.operator.SpillingUtils.checkSpillSucceeded; @@ -48,7 +48,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.Futures.immediateFuture; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java index bd14aba4548dc..8efe28ce44c0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/InterpretedHashGenerator.java @@ -18,8 +18,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.scalar.CombineHashFunction; import com.facebook.presto.type.TypeUtils; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Arrays; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java index c41a5ebd4e171..d4b6995ccb401 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinFilterFunction.java @@ -13,10 +13,9 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; -import javax.annotation.concurrent.NotThreadSafe; - @NotThreadSafe public interface JoinFilterFunction { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinHash.java b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinHash.java index 057dba50b07b2..d2eca91d09d88 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinHash.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinHash.java @@ -15,10 +15,9 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Optional; import static java.lang.Math.toIntExact; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinProbe.java b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinProbe.java index 245176d8988e7..872e9a877a8f2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/JoinProbe.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/JoinProbe.java @@ -16,8 +16,7 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.google.common.primitives.Ints; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.OptionalInt; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java new file mode 100644 index 0000000000000..3eb272cef09ec --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LeafTableFunctionOperator.java @@ -0,0 +1,205 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.execution.ScheduledSplit; +import com.facebook.presto.metadata.Split; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.UpdatablePageSource; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Blocked; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.util.concurrent.ListenableFuture; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class LeafTableFunctionOperator + implements SourceOperator +{ + public static class LeafTableFunctionOperatorFactory + implements SourceOperatorFactory + { + private final int operatorId; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + private boolean closed; + + public LeafTableFunctionOperatorFactory(int operatorId, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorId = operatorId; + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public SourceOperator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, sourceId, LeafTableFunctionOperator.class.getSimpleName()); + return new LeafTableFunctionOperator(operatorContext, sourceId, tableFunctionProvider, functionHandle); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + } + + private final OperatorContext operatorContext; + private final PlanNodeId sourceId; + private final TableFunctionProcessorProvider tableFunctionProvider; + private final ConnectorTableFunctionHandle functionHandle; + + private ConnectorSplit currentSplit; + private final List pendingSplits = new ArrayList<>(); + private boolean noMoreSplits; + + private TableFunctionSplitProcessor processor; + private boolean processorUsedData; + private boolean processorFinishedSplit = true; + private ListenableFuture processorBlocked = NOT_BLOCKED; + + public LeafTableFunctionOperator(OperatorContext operatorContext, PlanNodeId sourceId, TableFunctionProcessorProvider tableFunctionProvider, ConnectorTableFunctionHandle functionHandle) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.tableFunctionProvider = requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); + } + + private void resetProcessor() + { + this.processor = tableFunctionProvider.getSplitProcessor(functionHandle); + this.processorUsedData = false; + this.processorFinishedSplit = false; + this.processorBlocked = NOT_BLOCKED; + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public PlanNodeId getSourceId() + { + return sourceId; + } + + @Override + public boolean needsInput() + { + return false; + } + + @Override + public void addInput(Page page) + { + throw new UnsupportedOperationException(getClass().getName() + " does not take input"); + } + + @Override + public Supplier> addSplit(ScheduledSplit split) + { + Split curSplit = requireNonNull(split, "split is null").getSplit(); + checkState(!noMoreSplits, "no more splits expected"); + ConnectorSplit curConnectorSplit = curSplit.getConnectorSplit(); + pendingSplits.add(curConnectorSplit); + return Optional::empty; + } + + @Override + public void noMoreSplits() + { + noMoreSplits = true; + } + + @Override + public Page getOutput() + { + if (processorFinishedSplit) { + // start processing a new split + if (pendingSplits.isEmpty()) { + // no more splits to process at the moment + return null; + } + currentSplit = pendingSplits.remove(0); + resetProcessor(); + } + else { + // a split is being processed + requireNonNull(currentSplit, "currentSplit is null"); + } + + TableFunctionProcessorState state = processor.process(processorUsedData ? null : currentSplit); + if (state == FINISHED) { + processorFinishedSplit = true; + } + if (state instanceof Blocked) { + Blocked blocked = (Blocked) state; + processorBlocked = toListenableFuture(blocked.getFuture()); + } + if (state instanceof Processed) { + Processed processed = (Processed) state; + if (processed.isUsedInput()) { + processorUsedData = true; + } + if (processed.getResult() != null) { + return processed.getResult(); + } + } + return null; + } + + @Override + public ListenableFuture isBlocked() + { + return processorBlocked; + } + + @Override + public void finish() + { + // this method is redundant. the operator takes no input at all. noMoreSplits() should be called instead. + } + + @Override + public boolean isFinished() + { + return processorFinishedSplit && pendingSplits.isEmpty() && noMoreSplits; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java index 276a0b428f760..9a9d8f86e13b9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperator.java @@ -26,8 +26,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.io.Closer; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.IOException; import java.util.HashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperators.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperators.java index e3918b42be0fc..634bace39c159 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperators.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupJoinOperators.java @@ -19,8 +19,7 @@ import com.facebook.presto.operator.JoinProbe.JoinProbeFactory; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spiller.PartitioningSpillerFactory; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupSource.java b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupSource.java index b69281505fc05..68b50a826c56f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/LookupSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/LookupSource.java @@ -13,11 +13,10 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.Closeable; @NotThreadSafe diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/MergeHashSort.java b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeHashSort.java index 36c3c15738b14..31242de7480be 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/MergeHashSort.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeHashSort.java @@ -18,8 +18,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.AggregatedMemoryContext; import com.facebook.presto.util.MergeSortedPages.PageWithPosition; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.function.BiPredicate; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/MergeProcessorOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeProcessorOperator.java new file mode 100644 index 0000000000000..cd1cd35716db2 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeProcessorOperator.java @@ -0,0 +1,172 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.MergeParadigmAndTypes; + +import java.util.List; + +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * This operator is used by operations like SQL MERGE. It is used + * for all {@link com.facebook.presto.spi.connector.RowChangeParadigm}s. This operator + * creates the {@link MergeRowChangeProcessor}. + */ +public class MergeProcessorOperator + implements Operator +{ + public static class MergeProcessorOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final MergeRowChangeProcessor rowChangeProcessor; + private boolean closed; + + private MergeProcessorOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + MergeRowChangeProcessor rowChangeProcessor) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + public MergeProcessorOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + MergeParadigmAndTypes merge, + int rowIdChannel, + int mergeRowChannel, + List targetColumnChannels) + { + MergeRowChangeProcessor rowChangeProcessor = createRowChangeProcessor(merge, rowIdChannel, mergeRowChannel, targetColumnChannels); + + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + private static MergeRowChangeProcessor createRowChangeProcessor( + MergeParadigmAndTypes merge, + int rowIdChannel, + int mergeRowChannel, + List targetColumnChannels) + { + switch (merge.getParadigm()) { + case DELETE_ROW_AND_INSERT_ROW: + return new DeleteAndInsertMergeProcessor( + merge.getColumnTypes(), + merge.getTargetTableRowIdColumnType(), + rowIdChannel, + mergeRowChannel, + targetColumnChannels); + case CHANGE_ONLY_UPDATED_COLUMNS: + return new ChangeOnlyUpdatedColumnsMergeProcessor( + rowIdChannel, + mergeRowChannel, + targetColumnChannels); + default: + throw new PrestoException(NOT_SUPPORTED, "Merge paradigm not supported: " + merge.getParadigm()); + } + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, MergeProcessorOperator.class.getSimpleName()); + return new MergeProcessorOperator(context, rowChangeProcessor); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new MergeProcessorOperatorFactory(operatorId, planNodeId, rowChangeProcessor); + } + } + + private final OperatorContext operatorContext; + private final MergeRowChangeProcessor rowChangeProcessor; + + private Page currentPage; + private boolean finishing; + + public MergeProcessorOperator( + OperatorContext operatorContext, + MergeRowChangeProcessor rowChangeProcessor) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.rowChangeProcessor = requireNonNull(rowChangeProcessor, "rowChangeProcessor is null"); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public void finish() + { + finishing = true; + } + + @Override + public boolean isFinished() + { + return finishing && currentPage == null; + } + + @Override + public boolean needsInput() + { + return !finishing && currentPage == null; + } + + @Override + public void addInput(Page page) + { + checkState(!finishing, "Operator is already finishing"); + checkState(currentPage == null, "currentPage must be null to add a new page"); + + currentPage = requireNonNull(page, "page is null"); + } + + @Override + public Page getOutput() + { + if (currentPage == null) { + return null; + } + + Page transformedPage = rowChangeProcessor.transformPage(currentPage); + currentPage = null; + + return transformedPage; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/MergeRowChangeProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeRowChangeProcessor.java new file mode 100644 index 0000000000000..dc7208211a53f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeRowChangeProcessor.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.spi.ConnectorMergeSink; + +public interface MergeRowChangeProcessor +{ + int DEFAULT_CASE_OPERATION_NUMBER = -1; + + /** + * Transform a page generated by an SQL MERGE operation into page of data columns and + * operations. The SQL MERGE input page consists of the following: + *

    + *
  • The write redistribution columns, if any
  • + *
  • For partitioned or bucketed tables, a hash value column
  • + *
  • The rowId column for the row from the target table if matched, or null if not matched
  • + *
  • The merge case row block
  • + *
+ * The output page consists of the following: + *
    + *
  • All data columns, in table column order
  • + *
  • {@link ConnectorMergeSink#storeMergedRows The operation block}
  • + *
  • The rowId block
  • + *
  • The last column in the resulting page is 1 if the row is an insert + * derived from an update, and zero otherwise.
  • + *
+ *

+ * The {@link DeleteAndInsertMergeProcessor} implementation will transform each UPDATE + * row into multiple rows: an INSERT row and a DELETE row. + * + * @param inputPage It has 5 channels/blocks:
+ * 1. Unique ID
+ * 2. Target Table Row ID (_file:varchar, _pos:bigint, partition_spec_id:integer, partition_data:varchar)
+ * 3. Merge Row (source table columns, operation, case number)
+ * 4. Merge case number
+ * 5. Is distinct row: it is 1 if no other row has the same unique id and WHEN clause number, 0 otherwise.
+ */ + Page transformPage(Page inputPage); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/MergeWriterOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeWriterOperator.java new file mode 100644 index 0000000000000..559535bf353a3 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/MergeWriterOperator.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.Session; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.spi.ConnectorMergeSink; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.MergeTarget; +import com.facebook.presto.split.PageSinkManager; + +import java.util.stream.IntStream; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class MergeWriterOperator + extends AbstractRowChangeOperator +{ + public static class MergeWriterOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + private final PageSinkManager pageSinkManager; + private final MergeTarget target; + private final Session session; + private final JsonCodec tableCommitContextCodec; + private boolean closed; + + public MergeWriterOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + PageSinkManager pageSinkManager, + MergeTarget target, + Session session, + JsonCodec tableCommitContextCodec) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); + this.target = requireNonNull(target, "target is null"); + this.session = requireNonNull(session, "session is null"); + this.tableCommitContextCodec = requireNonNull(tableCommitContextCodec, "tableCommitContextCodec is null"); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext context = driverContext.addOperatorContext(operatorId, planNodeId, MergeWriterOperator.class.getSimpleName()); + ConnectorMergeSink mergeSink = pageSinkManager.createMergeSink(session, target.getMergeHandle().get()); + return new MergeWriterOperator(context, mergeSink, tableCommitContextCodec); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new MergeWriterOperatorFactory(operatorId, planNodeId, pageSinkManager, target, session, tableCommitContextCodec); + } + } + + private final ConnectorMergeSink mergeSink; + + public MergeWriterOperator(OperatorContext operatorContext, ConnectorMergeSink mergeSink, JsonCodec tableCommitContextCodec) + { + super(operatorContext, tableCommitContextCodec); + this.mergeSink = requireNonNull(mergeSink, "mergeSink is null"); + } + + /** + * @param page It has N + 3 channels/blocks, where N is the number of columns in the source table.
+ * 1: Source table column 1.
+ * 2: Source table column 2.
+ * N: Source table column N.
+ * N + 1: Operation: INSERT(1), DELETE(2), UPDATE(3). More info: {@link ConnectorMergeSink}
+ * N + 2: Target Table Row ID (_file:varchar, _pos:bigint, partition_spec_id:integer, partition_data:varchar).
+ * N + 3: Insert from update: it is 1 if the cause of the insert is an UPDATE, 0 otherwise.
+ */ + @Override + public void addInput(Page page) + { + requireNonNull(page, "page is null"); + checkState(state == State.RUNNING, "Operator is %s", state); + + // Copy all but the last block to a new page. + // The last block exists only to get the rowCount right. + int outputChannelCount = page.getChannelCount() - 1; + + int[] columns = IntStream.range(0, outputChannelCount).toArray(); + Page newPage = page.extractChannels(columns); + + // Store the page + mergeSink.storeMergedRows(newPage); + + // Calculate the amount to increment the rowCount + Block insertFromUpdateColumn = page.getBlock(page.getChannelCount() - 1); + long insertsFromUpdates = 0; + int positionCount = page.getPositionCount(); + for (int position = 0; position < positionCount; position++) { + insertsFromUpdates += TINYINT.getByte(insertFromUpdateColumn, position); + } + rowCount += positionCount - insertsFromUpdates; + } + + @Override + public void finish() + { + if (state == State.RUNNING) { + state = State.FINISHING; + finishFuture = toListenableFuture(mergeSink.finish()); + } + } + + @Override + protected void abort() + { + mergeSink.abort(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPages.java b/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPages.java index 35546b68b7068..e9f5fcb74c1ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPages.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPages.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPagesBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPagesBuilder.java index e807c6e84f931..a20f4cf10e3ed 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPagesBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/NestedLoopJoinPagesBuilder.java @@ -13,17 +13,17 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.operator.project.PageProcessor; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.ArrayList; import java.util.List; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.addExact; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OperationTimer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OperationTimer.java index 45d76b26b3d97..d1bedb9a06b9b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OperationTimer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OperationTimer.java @@ -13,11 +13,10 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import com.sun.management.ThreadMXBean; -import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; - import java.lang.management.ManagementFactory; import java.util.concurrent.atomic.AtomicLong; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorContext.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorContext.java index df29b4f5e6dd5..bb60267c0f0c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorContext.java @@ -28,10 +28,9 @@ import com.google.common.base.Suppliers; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.HashSet; import java.util.Optional; @@ -40,13 +39,13 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.operator.BlockedReason.WAITING_FOR_MEMORY; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Math.max; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorInfoUnion.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorInfoUnion.java index 887eca18bec98..abd3c8370f31c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorInfoUnion.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorInfoUnion.java @@ -45,6 +45,12 @@ public class OperatorInfoUnion private short id; + @ThriftConstructor + public OperatorInfoUnion() + { + this.id = 0; + } + @ThriftConstructor public OperatorInfoUnion(ExchangeClientStatus exchangeClientStatus) { @@ -251,7 +257,7 @@ else if (infoUnion.getTableWriterMergeInfo() != null) { return infoUnion.getTableWriterMergeInfo(); } else { - throw new IllegalArgumentException("OperatorInfoUnion is of an unknown type: " + infoUnion.getClass().getName()); + return null; } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorMemoryReservationSummary.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorMemoryReservationSummary.java index eb8557f563074..a0687e56f9561 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorMemoryReservationSummary.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorMemoryReservationSummary.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.spi.plan.PlanNodeId; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorStats.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorStats.java index 0208a2b10d49d..71e85bc6db33f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OperatorStats.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -21,17 +22,15 @@ import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.HashSet; import java.util.List; import java.util.Optional; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Math.max; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -244,7 +243,7 @@ public OperatorStats( this.blockedReason = blockedReason; this.info = info; - this.infoUnion = null; + this.infoUnion = (info != null) ? OperatorInfoUnion.convertToOperatorInfoUnion(info) : null; this.nullJoinBuildKeyCount = nullJoinBuildKeyCount; this.joinBuildKeyCount = joinBuildKeyCount; this.nullJoinProbeKeyCount = nullJoinProbeKeyCount; @@ -391,7 +390,7 @@ public OperatorStats( this.blockedReason = blockedReason; this.infoUnion = infoUnion; - this.info = null; + this.info = (infoUnion != null) ? OperatorInfoUnion.convertToOperatorInfo(infoUnion) : null; this.nullJoinBuildKeyCount = nullJoinBuildKeyCount; this.joinBuildKeyCount = joinBuildKeyCount; this.nullJoinProbeKeyCount = nullJoinProbeKeyCount; @@ -669,6 +668,9 @@ public Optional getBlockedReason() @JsonProperty public OperatorInfo getInfo() { + if (info == null && infoUnion != null) { + return OperatorInfoUnion.convertToOperatorInfo(infoUnion); + } return info; } @@ -676,6 +678,9 @@ public OperatorInfo getInfo() @ThriftField(39) public OperatorInfoUnion getInfoUnion() { + if (infoUnion == null && info != null) { + return OperatorInfoUnion.convertToOperatorInfoUnion(info); + } return infoUnion; } @@ -886,7 +891,9 @@ else if (info != null && info.getClass() == base.getClass()) { nullJoinProbeKeyCount += operator.getNullJoinProbeKeyCount(); joinProbeKeyCount += operator.getJoinProbeKeyCount(); } - + if (finishCpu < 0) { + finishCpu = Long.MAX_VALUE; + } return Optional.of(new OperatorStats( stageId, stageExecutionId, diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/OuterLookupSource.java b/presto-main-base/src/main/java/com/facebook/presto/operator/OuterLookupSource.java index 884a8ff9548a0..00542ead8c7f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/OuterLookupSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/OuterLookupSource.java @@ -13,12 +13,11 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.function.Supplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java new file mode 100644 index 0000000000000..ed68df6f8d0c9 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBuffer.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import jakarta.annotation.Nullable; + +import static com.facebook.presto.operator.WorkProcessor.ProcessState.finished; +import static com.facebook.presto.operator.WorkProcessor.ProcessState.ofResult; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class PageBuffer +{ + @Nullable + private Page page; + private boolean finished; + + public WorkProcessor pages() + { + return WorkProcessor.create(() -> { + if (isFinished() && isEmpty()) { + return finished(); + } + + if (!isEmpty()) { + Page result = page; + page = null; + return ofResult(result); + } + + return WorkProcessor.ProcessState.yield(); + }); + } + + public boolean isEmpty() + { + return page == null; + } + + public boolean isFinished() + { + return finished; + } + + public void add(Page page) + { + checkState(isEmpty(), "page buffer is not empty"); + checkState(!isFinished(), "page buffer is finished"); + requireNonNull(page, "page is null"); + + if (page.getPositionCount() == 0) { + return; + } + + this.page = page; + } + + public void finish() + { + finished = true; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBufferClient.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBufferClient.java index b7b524608d843..fd3f0839db4a9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PageBufferClient.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PageBufferClient.java @@ -14,6 +14,8 @@ package com.facebook.presto.operator; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.server.remotetask.Backoff; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.PrestoException; @@ -23,14 +25,11 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import org.apache.http.client.utils.URIBuilder; -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - import java.io.Closeable; import java.net.URI; import java.net.URISyntaxException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHash.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHash.java index 713c99e79047a..c301db4fecb23 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHash.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesHash.java @@ -13,20 +13,20 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.array.AdaptiveLongBigArray; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.HashCommon; import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.facebook.presto.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static io.airlift.slice.SizeOf.sizeOf; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java index bd2f5ffbc817a..b0b82f340b5b3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesIndex.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; @@ -36,13 +37,11 @@ import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.Swapper; import it.unimi.dsi.fastutil.objects.ObjectArrayList; +import jakarta.inject.Inject; import org.openjdk.jol.info.ClassLayout; -import javax.inject.Inject; - import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -54,6 +53,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.facebook.presto.operator.SyntheticAddress.encodeSyntheticAddress; @@ -61,7 +61,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.SizeOf.sizeOf; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; /** @@ -271,9 +270,12 @@ public void swap(int a, int b) valueAddresses.swap(a, b); } - public int buildPage(int position, int[] outputChannels, PageBuilder pageBuilder) + public int buildPage(int position, int endPosition, int[] outputChannels, PageBuilder pageBuilder) { - while (!pageBuilder.isFull() && position < positionCount) { + // Check both endPosition (for range-based iteration) and positionCount (to handle concurrent clear()). + // If clear() is called while an iterator is consuming pages, positionCount becomes 0, + // allowing the loop to exit gracefully instead of accessing cleared internal arrays. + while (!pageBuilder.isFull() && position < endPosition && position < positionCount) { long pageAddress = valueAddresses.get(position); int blockIndex = decodeSliceIndex(pageAddress); int blockPosition = decodePosition(pageAddress); @@ -563,10 +565,29 @@ protected Page computeNext() } public Iterator getSortedPages() + { + return getSortedPagesFromRange(0, positionCount); + } + + /** + * Get sorted pages from the specified section of the PagesIndex. + * + * @param start start position of the section, inclusive + * @param end end position of the section, exclusive + * @return iterator of pages + */ + public Iterator getSortedPages(int start, int end) + { + checkArgument(start >= 0 && end <= positionCount, "position range out of bounds"); + checkArgument(start <= end, "invalid position range"); + return getSortedPagesFromRange(start, end); + } + + private Iterator getSortedPagesFromRange(int start, int end) { return new AbstractIterator() { - private int currentPosition; + private int currentPosition = start; private final PageBuilder pageBuilder = new PageBuilder(types); private final int[] outputChannels = new int[types.size()]; @@ -577,7 +598,7 @@ public Iterator getSortedPages() @Override public Page computeNext() { - currentPosition = buildPage(currentPosition, outputChannels, pageBuilder); + currentPosition = buildPage(currentPosition, end, outputChannels, pageBuilder); if (pageBuilder.isEmpty()) { return endOfData(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexFactory.java index 26efb15fae6c8..a48b3b2e3af88 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexFactory.java @@ -17,10 +17,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexSupplier.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexSupplier.java index 8ec0b2661c321..c57c79705a216 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexSupplier.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PagesSpatialIndexSupplier.java @@ -16,6 +16,7 @@ import com.esri.core.geometry.Operator; import com.esri.core.geometry.OperatorFactoryLocal; import com.esri.core.geometry.ogc.OGCGeometry; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.array.AdaptiveLongBigArray; import com.facebook.presto.common.block.Block; @@ -27,7 +28,6 @@ import com.facebook.presto.operator.SpatialIndexBuilderOperator.SpatialPredicate; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.objects.ObjectArrayList; import org.openjdk.jol.info.ClassLayout; @@ -36,6 +36,7 @@ import java.util.Optional; import java.util.function.Supplier; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.geospatial.GeometryUtils.accelerateGeometry; @@ -44,7 +45,6 @@ import static com.facebook.presto.operator.SyntheticAddress.decodePosition; import static com.facebook.presto.operator.SyntheticAddress.decodeSliceIndex; import static com.google.common.base.Verify.verify; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedConsumption.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedConsumption.java index 1461f74717fc6..cea9e48b17d39 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedConsumption.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedConsumption.java @@ -18,10 +18,9 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.ArrayDeque; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java index 847460adbe8e0..a58f19abebc71 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSource.java @@ -13,15 +13,14 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.exchange.LocalPartitionGenerator; import com.google.common.io.Closer; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.NotThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.io.IOException; import java.io.UncheckedIOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java index 2143c3558ee4e..b8270e9028c90 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PartitionedLookupSourceFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.type.Type; @@ -23,10 +24,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.Immutable; -import javax.annotation.concurrent.NotThreadSafe; +import com.google.errorprone.annotations.Immutable; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Arrays; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineContext.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineContext.java index 99851036dc02d..e7102570ad9d7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineContext.java @@ -26,8 +26,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.ArrayList; import java.util.Arrays; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStats.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStats.java index ad5f7212116f6..4ddd5fd1e19ec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStats.java @@ -21,9 +21,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Set; @@ -178,16 +177,13 @@ public PipelineStats( this.totalAllocationInBytes = totalAllocationInBytes; this.rawInputDataSizeInBytes = rawInputDataSizeInBytes; - checkArgument(rawInputPositions >= 0, "rawInputPositions is negative"); - this.rawInputPositions = rawInputPositions; + this.rawInputPositions = (rawInputPositions >= 0) ? rawInputPositions : Long.MAX_VALUE; this.processedInputDataSizeInBytes = processedInputDataSizeInBytes; - checkArgument(processedInputPositions >= 0, "processedInputPositions is negative"); - this.processedInputPositions = processedInputPositions; + this.processedInputPositions = (processedInputPositions >= 0) ? processedInputPositions : Long.MAX_VALUE; this.outputDataSizeInBytes = outputDataSizeInBytes; - checkArgument(outputPositions >= 0, "outputPositions is negative"); - this.outputPositions = outputPositions; + this.outputPositions = (outputPositions >= 0) ? outputPositions : Long.MAX_VALUE; this.physicalWrittenDataSizeInBytes = physicalWrittenDataSizeInBytes; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStatus.java b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStatus.java index 4f399b6f92ccb..9b2f68fb0fe47 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStatus.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/PipelineStatus.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; @Immutable public final class PipelineStatus diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ReferenceCount.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ReferenceCount.java index d0168daf345fd..194e292672d1e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ReferenceCount.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ReferenceCount.java @@ -15,9 +15,8 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java new file mode 100644 index 0000000000000..5d0376f3be7ee --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/RegularTableFunctionPartition.java @@ -0,0 +1,438 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Ints; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.presto.common.Utils.checkState; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class RegularTableFunctionPartition + implements TableFunctionPartition +{ + private final PagesIndex pagesIndex; + private final int partitionStart; + private final int partitionEnd; + private final Iterator sortedPages; + + private final TableFunctionDataProcessor tableFunction; + private final int properChannelsCount; + private final int passThroughSourcesCount; + + // channels required by the table function, listed by source in order of argument declarations + private final int[][] requiredChannels; + + // for each input channel, the end position of actual data in that channel (exclusive) relative to partition. The remaining rows are "filler" rows, and should not be passed to table function or passed-through + private final int[] endOfData; + + // a builder for each pass-through column, in order of argument declarations + private final PassThroughColumnProvider[] passThroughProviders; + + // number of processed input positions from partition start. all sources have been processed up to this position, except the sources whose partitions ended earlier. + private int processedPositions; + + public RegularTableFunctionPartition( + PagesIndex pagesIndex, + int partitionStart, + int partitionEnd, + TableFunctionDataProcessor tableFunction, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications) + + { + checkArgument(pagesIndex.getPositionCount() != 0, "PagesIndex is empty for regular table function partition"); + this.pagesIndex = pagesIndex; + this.partitionStart = partitionStart; + this.partitionEnd = partitionEnd; + this.sortedPages = pagesIndex.getSortedPages(partitionStart, partitionEnd); + this.tableFunction = requireNonNull(tableFunction, "tableFunction is null"); + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(Ints::toArray) + .toArray(int[][]::new); + this.endOfData = findEndOfData(markerChannels, requiredChannels, passThroughSpecifications); + for (List channels : requiredChannels) { + checkState( + channels.stream() + .mapToInt(channel -> endOfData[channel]) + .distinct() + .count() <= 1, + "end-of-data position is inconsistent within a table function source"); + } + this.passThroughProviders = new PassThroughColumnProvider[passThroughSpecifications.size()]; + for (int i = 0; i < passThroughSpecifications.size(); i++) { + passThroughProviders[i] = createColumnProvider(passThroughSpecifications.get(i)); + } + } + + @Override + public WorkProcessor toOutputPages() + { + return WorkProcessor.create(new WorkProcessor.Process() + { + List> inputPages = prepareInputPages(); + + @Override + public WorkProcessor.ProcessState process() + { + TableFunctionProcessorState state = tableFunction.process(inputPages); + boolean functionGotNoData = inputPages == null; + if (state == FINISHED) { + return WorkProcessor.ProcessState.finished(); + } + if (state instanceof TableFunctionProcessorState.Blocked) { + return WorkProcessor.ProcessState.blocked(toListenableFuture(((TableFunctionProcessorState.Blocked) state).getFuture())); + } + TableFunctionProcessorState.Processed processed = (TableFunctionProcessorState.Processed) state; + if (processed.isUsedInput()) { + inputPages = prepareInputPages(); + } + if (processed.getResult() != null) { + return WorkProcessor.ProcessState.ofResult(appendPassThroughColumns(processed.getResult())); + } + if (functionGotNoData) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "When function got no input, it should either produce output or return Blocked state"); + } + return WorkProcessor.ProcessState.blocked(immediateFuture(null)); + } + }); + } + + /** + * Iterate over the partition by page and extract pages for each table function source from the input page. + * For each source, project the columns required by the table function. + * If for some source all data in the partition has been consumed, Optional.empty() is returned for that source. + * It happens when the partition of this source is shorter than the partition of some other source. + * The overall length of the table function partition is equal to the length of the longest source partition. + * When all sources are fully consumed, this method returns null. + *

+ * NOTE: There are two types of table function's source semantics: set and row. The two types of sources should be handled + * by the TableFunctionDataProcessor in different ways. For a source with set semantics, the whole partition can be used for computations, + * while for a source with row semantics, each row should be processed independently from all other rows. + * To enforce that behavior, we could pass to the TableFunctionDataProcessor only one row from a table with row semantics. + * However, for performance reasons, we handle sources with row and set semantics in the same way: the TableFunctionDataProcessor + * gets a page of data from each source. The TableFunctionDataProcessor is responsible for using the provided data accordingly + * to the declared source semantics (set or rows). + * + * @return A List containing: + * - Optional Page for every source that is not fully consumed + * - Optional.empty() for every source that is fully consumed + * or null if all sources are fully consumed. + */ + private List> prepareInputPages() + { + if (!sortedPages.hasNext()) { + return null; + } + + Page inputPage = sortedPages.next(); + ImmutableList.Builder> sourcePages = ImmutableList.builder(); + + for (int[] channelsForSource : requiredChannels) { + if (channelsForSource.length == 0) { + sourcePages.add(Optional.of(new Page(inputPage.getPositionCount()))); + } + else { + int endOfDataForSource = endOfData[channelsForSource[0]]; // end-of-data position is validated to be consistent for all channels from source + if (endOfDataForSource <= processedPositions) { + // all data for this source was already processed + sourcePages.add(Optional.empty()); + } + else { + Block[] sourceBlocks = new Block[channelsForSource.length]; + if (endOfDataForSource < processedPositions + inputPage.getPositionCount()) { + // data for this source ends within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel).getRegion(0, endOfDataForSource - processedPositions); + } + } + else { + // data for this source does not end within the current page + for (int i = 0; i < channelsForSource.length; i++) { + int inputChannel = channelsForSource[i]; + sourceBlocks[i] = inputPage.getBlock(inputChannel); + } + } + sourcePages.add(Optional.of(new Page(sourceBlocks))); + } + } + } + + processedPositions += inputPage.getPositionCount(); + + return sourcePages.build(); + } + + /** + * There are two types of table function's source semantics: set and row. + *

+ * For a source with row semantics, the table function result depends on the whole partition, + * so it is not always possible to associate an output row with a specific input row. + * The TableFunctionDataProcessor can return null as the pass-through index to indicate that + * the output row is not associated with any row from the given source. + *

+ * For a source with row semantics, the output is determined on a row-by-row basis, so every + * output row is associated with a specific input row. In such case, the pass-through index + * should never be null. + *

+ * In our implementation, we handle sources with row and set semantics in the same way. + * For performance reasons, we do not validate the null pass-through indexes. + * The TableFunctionDataProcessor is responsible for using the pass-through capability + * accordingly to the declared source semantics (set or rows). + */ + private Page appendPassThroughColumns(Page page) + { + if (page.getChannelCount() != properChannelsCount + passThroughSourcesCount) { + throw new PrestoException( + FUNCTION_IMPLEMENTATION_ERROR, + format( + "Table function returned a page containing %s channels. Expected channel number: %s (%s proper columns, %s pass-through index columns)", + page.getChannelCount(), + properChannelsCount + passThroughSourcesCount, + properChannelsCount, + passThroughSourcesCount)); + } + // TODO is it possible to verify types of columns returned by TF? + + Block[] resultBlocks = new Block[properChannelsCount + passThroughProviders.length]; + + // proper outputs first + for (int channel = 0; channel < properChannelsCount; channel++) { + resultBlocks[channel] = page.getBlock(channel); + } + + // pass-through columns next + int channel = properChannelsCount; + for (PassThroughColumnProvider provider : passThroughProviders) { + resultBlocks[channel] = provider.getPassThroughColumn(page); + channel++; + } + + // pass the position count so that the Page can be successfully created in the case when there are no output channels (resultBlocks is empty) + return new Page(page.getPositionCount(), resultBlocks); + } + + private int[] findEndOfData(Optional> markerChannels, List> requiredChannels, List passThroughSpecifications) + { + Set referencedChannels = ImmutableSet.builder() + .addAll(requiredChannels.stream() + .flatMap(Collection::stream) + .collect(toImmutableList())) + .addAll(passThroughSpecifications.stream() + .map(PassThroughColumnSpecification::getInputChannel) + .collect(toImmutableList())) + .build(); + + if (referencedChannels.isEmpty()) { + // no required or pass-through channels + return null; + } + + int maxInputChannel = referencedChannels.stream() + .mapToInt(Integer::intValue) + .max() + .orElseThrow(NoSuchElementException::new); + + int[] result = new int[maxInputChannel + 1]; + Arrays.fill(result, -1); + + // if table function had one source, adding a marker channel was not necessary. + // end-of-data position is equal to partition end for each input channel + if (!markerChannels.isPresent()) { + referencedChannels.stream() + .forEach(channel -> result[channel] = partitionEnd - partitionStart); + return result; + } + + // if table function had more than one source, the markers map shall be present, and it shall contain mapping for each input channel + ImmutableMap.Builder endOfDataPerMarkerBuilder = ImmutableMap.builder(); + for (int markerChannel : ImmutableSet.copyOf(markerChannels.orElseThrow(NoSuchElementException::new).values())) { + endOfDataPerMarkerBuilder.put(markerChannel, findFirstNullPosition(markerChannel)); + } + Map endOfDataPerMarker = endOfDataPerMarkerBuilder.buildOrThrow(); + referencedChannels.stream() + .forEach(channel -> result[channel] = endOfDataPerMarker.get(markerChannels.orElseThrow(NoSuchElementException::new).get(channel)) - partitionStart); + + return result; + } + + private int findFirstNullPosition(int markerChannel) + { + if (pagesIndex.isNull(markerChannel, partitionStart)) { + return partitionStart; + } + if (!pagesIndex.isNull(markerChannel, partitionEnd - 1)) { + return partitionEnd; + } + + int start = partitionStart; + int end = partitionEnd; + // value at start is not null, value at end is null + while (end - start > 1) { + int mid = (start + end) >>> 1; + if (pagesIndex.isNull(markerChannel, mid)) { + end = mid; + } + else { + start = mid; + } + } + return end; + } + + public static class PassThroughColumnSpecification + { + private final boolean isPartitioningColumn; + private final int inputChannel; + private final int indexChannel; + + public PassThroughColumnSpecification(boolean isPartitioningColumn, int inputChannel, int indexChannel) + { + this.isPartitioningColumn = isPartitioningColumn; + this.inputChannel = inputChannel; + this.indexChannel = indexChannel; + } + + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + + public int getInputChannel() + { + return inputChannel; + } + + public int getIndexChannel() + { + return indexChannel; + } + } + + private PassThroughColumnProvider createColumnProvider(PassThroughColumnSpecification specification) + { + if (specification.isPartitioningColumn()) { + return new PartitioningColumnProvider(pagesIndex.getSingleValueBlock(specification.getInputChannel(), partitionStart)); + } + return new NonPartitioningColumnProvider(specification.getInputChannel(), specification.getIndexChannel()); + } + + private interface PassThroughColumnProvider + { + Block getPassThroughColumn(Page page); + } + + private static class PartitioningColumnProvider + implements PassThroughColumnProvider + { + private final Block partitioningValue; + + private PartitioningColumnProvider(Block partitioningValue) + { + this.partitioningValue = requireNonNull(partitioningValue, "partitioningValue is null"); + } + + @Override + public Block getPassThroughColumn(Page page) + { + return new RunLengthEncodedBlock(partitioningValue, page.getPositionCount()); + } + + public Block getPartitioningValue() + { + return partitioningValue; + } + } + + private final class NonPartitioningColumnProvider + implements PassThroughColumnProvider + { + private final int inputChannel; + private final Type type; + private final int indexChannel; + + public NonPartitioningColumnProvider(int inputChannel, int indexChannel) + { + this.inputChannel = inputChannel; + this.type = pagesIndex.getType(inputChannel); + this.indexChannel = indexChannel; + } + + @Override + public Block getPassThroughColumn(Page page) + { + Block indexes = page.getBlock(indexChannel); + BlockBuilder builder = type.createBlockBuilder(null, page.getPositionCount()); + for (int position = 0; position < page.getPositionCount(); position++) { + if (indexes.isNull(position)) { + builder.appendNull(); + } + else { + // table function returns index from partition start + long index = BIGINT.getLong(indexes, position); + // validate index + if (index < 0 || index >= endOfData[inputChannel] || index >= processedPositions) { + int end = min(endOfData[inputChannel], processedPositions) - 1; + if (end >= 0) { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, format("Index of a pass-through row: %s out of processed portion of partition [0, %s]", index, end)); + } + else { + throw new PrestoException(FUNCTION_IMPLEMENTATION_ERROR, "Index of a pass-through row must be null when no input data from the partition was processed. Actual: " + index); + } + } + // index in PagesIndex + long absoluteIndex = partitionStart + index; + pagesIndex.appendTo(inputChannel, toIntExact(absoluteIndex), builder); + } + } + + return builder.build(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/RpcShuffleClient.java b/presto-main-base/src/main/java/com/facebook/presto/operator/RpcShuffleClient.java index b03ea89e042bf..6eadc430f3653 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/RpcShuffleClient.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/RpcShuffleClient.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.operator.PageBufferClient.PagesResponse; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; /** * All methods in this class should be async diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java index 56d24cdb00292..f45d8faccafd4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/ScanFilterAndProjectOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.RuntimeStats; @@ -43,7 +44,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; import java.io.Closeable; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/SetBuilderOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/SetBuilderOperator.java index f6afd147f3748..4f6a75851e8f8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/SetBuilderOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/SetBuilderOperator.java @@ -22,9 +22,8 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/SimpleArrayAllocator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/SimpleArrayAllocator.java index be53c2f7b953a..4c917044cf3ec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/SimpleArrayAllocator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/SimpleArrayAllocator.java @@ -14,10 +14,9 @@ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.block.ArrayAllocator; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.ArrayDeque; import java.util.Deque; import java.util.IdentityHashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/SpatialJoinOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/SpatialJoinOperator.java index e8970ffd90105..1440ebee07027 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/SpatialJoinOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/SpatialJoinOperator.java @@ -22,15 +22,14 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; import static com.facebook.airlift.concurrent.MoreFutures.getDone; -import static com.facebook.presto.spi.plan.SpatialJoinNode.Type.INNER; -import static com.facebook.presto.spi.plan.SpatialJoinNode.Type.LEFT; +import static com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType.INNER; +import static com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType.LEFT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -45,7 +44,7 @@ public static final class SpatialJoinOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; - private final SpatialJoinNode.Type joinType; + private final SpatialJoinNode.SpatialJoinType joinType; private final List probeTypes; private final List probeOutputChannels; private final int probeGeometryChannel; @@ -57,7 +56,7 @@ public static final class SpatialJoinOperatorFactory public SpatialJoinOperatorFactory( int operatorId, PlanNodeId planNodeId, - SpatialJoinNode.Type joinType, + SpatialJoinNode.SpatialJoinType joinType, List probeTypes, List probeOutputChannels, int probeGeometryChannel, @@ -115,7 +114,7 @@ public OperatorFactory duplicate() private final OperatorContext operatorContext; private final LocalMemoryContext localUserMemoryContext; - private final SpatialJoinNode.Type joinType; + private final SpatialJoinNode.SpatialJoinType joinType; private final List probeTypes; private final List probeOutputChannels; private final int probeGeometryChannel; @@ -140,7 +139,7 @@ public OperatorFactory duplicate() public SpatialJoinOperator( OperatorContext operatorContext, - SpatialJoinNode.Type joinType, + SpatialJoinNode.SpatialJoinType joinType, List probeTypes, List probeOutputChannels, int probeGeometryChannel, diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/SpilledLookupSourceHandle.java b/presto-main-base/src/main/java/com/facebook/presto/operator/SpilledLookupSourceHandle.java index ea73e82b595a1..65dd66e037f72 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/SpilledLookupSourceHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/SpilledLookupSourceHandle.java @@ -16,10 +16,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.function.Supplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishInfo.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishInfo.java index 0362d366676d2..c4038f20146d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishInfo.java @@ -14,19 +14,19 @@ package com.facebook.presto.operator; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.Optional; import static com.facebook.airlift.json.JsonCodec.jsonCodec; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishOperator.java index 80d989c9bdeb4..ddf2c0e102a42 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFinishOperator.java @@ -46,6 +46,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.SystemSessionProperties.isStatisticsCpuTimerEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; @@ -61,7 +62,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.whenAllSucceed; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.Duration.succinctNanos; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java new file mode 100644 index 0000000000000..947389f342ed5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionOperator.java @@ -0,0 +1,635 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.memory.context.LocalMemoryContext; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFuture; +import jakarta.annotation.Nullable; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkPositionIndex; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.concat; +import static java.util.Collections.nCopies; +import static java.util.Objects.requireNonNull; + +public class TableFunctionOperator + implements Operator +{ + public static class TableFunctionOperatorFactory + implements OperatorFactory + { + private final int operatorId; + private final PlanNodeId planNodeId; + + // a provider of table function processor to be called once per partition + private final TableFunctionProcessorProvider tableFunctionProvider; + + // all information necessary to execute the table function collected during analysis + private final ConnectorTableFunctionHandle functionHandle; + + // number of proper columns produced by the table function + private final int properChannelsCount; + + // number of input tables declared as pass-through + private final int passThroughSourcesCount; + + // columns required by the table function, in order of input tables + private final List> requiredChannels; + + // map from input channel to marker channel + // for each input table, there is a channel that marks which rows contain original data, and which are "filler" rows. + // the "filler" rows are part of the algorithm, and they should not be processed by the table function, or passed-through. + // In this map, every original column from the input table is associated with the appropriate marker. + private final Optional> markerChannels; + + // necessary information to build a pass-through column, for all pass-through columns, ordered as expected on the output + // it includes columns from sources declared as pass-through as well as partitioning columns from other sources + private final List passThroughSpecifications; + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // partitioning channels from all sources + private final List partitionChannels; + + // subset of partition channels that are already grouped + private final List prePartitionedChannels; + + // channels necessary to sort all sources: + // - for a single source, these are the source's sort channels + // - for multiple sources, this is a single synthesized row number channel + private final List sortChannels; + private final List sortOrders; + + // number of leading sort channels that are already sorted + private final int preSortedPrefix; + + private final List sourceTypes; + private final int expectedPositions; + private final PagesIndex.Factory pagesIndexFactory; + + private boolean closed; + + public TableFunctionOperatorFactory( + int operatorId, + PlanNodeId planNodeId, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(planNodeId, "planNodeId is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorId = operatorId; + this.planNodeId = planNodeId; + this.tableFunctionProvider = tableFunctionProvider; + this.functionHandle = functionHandle; + this.properChannelsCount = properChannelsCount; + this.passThroughSourcesCount = passThroughSourcesCount; + this.requiredChannels = requiredChannels.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerChannels = markerChannels.map(ImmutableMap::copyOf); + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.pruneWhenEmpty = pruneWhenEmpty; + this.partitionChannels = ImmutableList.copyOf(partitionChannels); + this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels); + this.sortChannels = ImmutableList.copyOf(sortChannels); + this.sortOrders = ImmutableList.copyOf(sortOrders); + this.preSortedPrefix = preSortedPrefix; + this.sourceTypes = ImmutableList.copyOf(sourceTypes); + this.expectedPositions = expectedPositions; + this.pagesIndexFactory = pagesIndexFactory; + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, TableFunctionOperator.class.getSimpleName()); + return new TableFunctionOperator( + operatorContext, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new TableFunctionOperatorFactory( + operatorId, + planNodeId, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + pruneWhenEmpty, + partitionChannels, + prePartitionedChannels, + sortChannels, + sortOrders, + preSortedPrefix, + sourceTypes, + expectedPositions, + pagesIndexFactory); + } + } + + private final OperatorContext operatorContext; + + private final PageBuffer pageBuffer = new PageBuffer(); + private final WorkProcessor outputPages; + private final boolean processEmptyInput; + + @Nullable + private Page pendingInput; + private boolean operatorFinishing; + + public TableFunctionOperator( + OperatorContext operatorContext, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean pruneWhenEmpty, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix, + List sourceTypes, + int expectedPositions, + PagesIndex.Factory pagesIndexFactory) + { + requireNonNull(operatorContext, "operatorContext is null"); + requireNonNull(tableFunctionProvider, "tableFunctionProvider is null"); + requireNonNull(functionHandle, "functionHandle is null"); + requireNonNull(requiredChannels, "requiredChannels is null"); + requireNonNull(markerChannels, "markerChannels is null"); + requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"); + requireNonNull(partitionChannels, "partitionChannels is null"); + requireNonNull(prePartitionedChannels, "prePartitionedChannels is null"); + checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels"); + requireNonNull(sortChannels, "sortChannels is null"); + requireNonNull(sortOrders, "sortOrders is null"); + checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders"); + checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels"); + checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped"); + requireNonNull(sourceTypes, "sourceTypes is null"); + requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + + this.operatorContext = operatorContext; + + this.processEmptyInput = !pruneWhenEmpty; + + PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions); + HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix); + + this.outputPages = pageBuffer.pages() + .transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput)) + .flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions( + groupPagesIndex, + hashStrategies, + tableFunctionProvider, + functionHandle, + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications, + processEmptyInput)) + .flatMap(TableFunctionPartition::toOutputPages); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public void finish() + { + pageBuffer.finish(); + } + + @Override + public boolean isFinished() + { + return outputPages.isFinished(); + } + + @Override + public ListenableFuture isBlocked() + { + if (outputPages.isBlocked()) { + return outputPages.getBlockedFuture(); + } + + return NOT_BLOCKED; + } + + @Override + public boolean needsInput() + { + return pageBuffer.isEmpty() && !pageBuffer.isFinished(); + } + + @Override + public void addInput(Page page) + { + pageBuffer.add(page); + } + + @Override + public Page getOutput() + { + if (!outputPages.process()) { + return null; + } + + if (outputPages.isFinished()) { + return null; + } + + return outputPages.getResult(); + } + + private static class HashStrategies + { + final PagesHashStrategy prePartitionedStrategy; + final PagesHashStrategy remainingPartitionStrategy; + final PagesHashStrategy preSortedStrategy; + final List remainingPartitionAndSortChannels; + final List remainingSortOrders; + final int[] prePartitionedChannelsArray; + + public HashStrategies( + PagesIndex pagesIndex, + List partitionChannels, + List prePartitionedChannels, + List sortChannels, + List sortOrders, + int preSortedPrefix) + { + this.prePartitionedStrategy = pagesIndex.createPagesHashStrategy(prePartitionedChannels, OptionalInt.empty()); + + List remainingPartitionChannels = partitionChannels.stream() + .filter(channel -> !prePartitionedChannels.contains(channel)) + .collect(toImmutableList()); + this.remainingPartitionStrategy = pagesIndex.createPagesHashStrategy(remainingPartitionChannels, OptionalInt.empty()); + + List preSortedChannels = sortChannels.stream() + .limit(preSortedPrefix) + .collect(toImmutableList()); + this.preSortedStrategy = pagesIndex.createPagesHashStrategy(preSortedChannels, OptionalInt.empty()); + + if (preSortedPrefix > 0) { + // preSortedPrefix > 0 implies that all partition channels are already pre-partitioned (enforced by check in the constructor), so we only need to do the remaining sort + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(Iterables.skip(sortChannels, preSortedPrefix)); + this.remainingSortOrders = ImmutableList.copyOf(Iterables.skip(sortOrders, preSortedPrefix)); + } + else { + // we need to sort by the remaining partition channels so that the input is fully partitioned, + // and then need to we sort by all the sort channels so that the input is fully sorted + this.remainingPartitionAndSortChannels = ImmutableList.copyOf(concat(remainingPartitionChannels, sortChannels)); + this.remainingSortOrders = ImmutableList.copyOf(concat(nCopies(remainingPartitionChannels.size(), ASC_NULLS_LAST), sortOrders)); + } + + this.prePartitionedChannelsArray = Ints.toArray(prePartitionedChannels); + } + } + + private class PartitionAndSort + implements WorkProcessor.Transformation + { + private final PagesIndex pagesIndex; + private final HashStrategies hashStrategies; + private final LocalMemoryContext memoryContext; + + private boolean resetPagesIndex; + private int inputPosition; + private boolean processEmptyInput; + + public PartitionAndSort(PagesIndex pagesIndex, HashStrategies hashStrategies, boolean processEmptyInput) + { + this.pagesIndex = pagesIndex; + this.hashStrategies = hashStrategies; + this.memoryContext = operatorContext.aggregateUserMemoryContext().newLocalMemoryContext(PartitionAndSort.class.getSimpleName()); + this.processEmptyInput = processEmptyInput; + } + + @Override + public WorkProcessor.TransformationState process(Optional input) + { + if (resetPagesIndex) { + pagesIndex.clear(); + updateMemoryUsage(); + resetPagesIndex = false; + } + + if (!input.isPresent() && pagesIndex.getPositionCount() == 0) { + if (processEmptyInput) { + // it can only happen at the first call to process(), which implies that there is no input. Empty PagesIndex can be passed on only once. + processEmptyInput = false; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + else { + memoryContext.close(); + return WorkProcessor.TransformationState.finished(); + } + } + + // there is input, so we are not interested in processing empty input + processEmptyInput = false; + + if (input.isPresent()) { + // append rows from input which belong to the current group wrt pre-partitioned columns + // it might be one or more partitions + inputPosition = appendCurrentGroup(pagesIndex, hashStrategies, input.get(), inputPosition); + updateMemoryUsage(); + + if (inputPosition >= input.get().getPositionCount()) { + inputPosition = 0; + return WorkProcessor.TransformationState.needsMoreData(); + } + } + + // we have unused input or the input is finished. we have buffered a full group + // the group contains one or more partitions, as it was determined by the pre-partitioned columns + // sorting serves two purposes: + // - sort by the remaining partition channels so that the input is fully partitioned, + // - sort by all the sort channels so that the input is fully sorted + sortCurrentGroup(pagesIndex, hashStrategies); + resetPagesIndex = true; + return WorkProcessor.TransformationState.ofResult(pagesIndex, false); + } + + void updateMemoryUsage() + { + memoryContext.setBytes(pagesIndex.getEstimatedSize().toBytes()); + } + } + + private static int appendCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies, Page page, int startPosition) + { + checkArgument(page.getPositionCount() > startPosition); + + PagesHashStrategy prePartitionedStrategy = hashStrategies.prePartitionedStrategy; + Page prePartitionedPage = page.extractChannels(hashStrategies.prePartitionedChannelsArray); + + if (pagesIndex.getPositionCount() == 0 || pagesIndex.positionEqualsRow(prePartitionedStrategy, 0, startPosition, prePartitionedPage)) { + // we are within the current group. find the position where the pre-grouped columns change + int groupEnd = findGroupEnd(prePartitionedPage, prePartitionedStrategy, startPosition); + + // add the section of the page that contains values for the current group + pagesIndex.addPage(page.getRegion(startPosition, groupEnd - startPosition)); + + if (page.getPositionCount() - groupEnd > 0) { + // the remaining prt of the page contains the next group + return groupEnd; + } + // page fully consumed: it contains the current group only + return page.getPositionCount(); + } + + // we had previous results buffered, but the remaining page starts with new group values + return startPosition; + } + + private static void sortCurrentGroup(PagesIndex pagesIndex, HashStrategies hashStrategies) + { + PagesHashStrategy preSortedStrategy = hashStrategies.preSortedStrategy; + List remainingPartitionAndSortChannels = hashStrategies.remainingPartitionAndSortChannels; + List remainingSortOrders = hashStrategies.remainingSortOrders; + + if (pagesIndex.getPositionCount() > 1 && !remainingPartitionAndSortChannels.isEmpty()) { + int startPosition = 0; + while (startPosition < pagesIndex.getPositionCount()) { + int endPosition = findGroupEnd(pagesIndex, preSortedStrategy, startPosition); + pagesIndex.sort(remainingPartitionAndSortChannels, remainingSortOrders, startPosition, endPosition); + startPosition = endPosition; + } + } + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(Page page, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(page.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, page.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, page.getPositionCount(), (firstPosition, secondPosition) -> pagesHashStrategy.rowEqualsRow(firstPosition, page, secondPosition, page)); + } + + // Assumes input grouped on relevant pagesHashStrategy columns + private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHashStrategy, int startPosition) + { + checkArgument(pagesIndex.getPositionCount() > 0, "Must have at least one position"); + checkPositionIndex(startPosition, pagesIndex.getPositionCount(), "startPosition out of bounds"); + + return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionEqualsPosition(pagesHashStrategy, firstPosition, secondPosition)); + } + + /** + * @param startPosition - inclusive + * @param endPosition - exclusive + * @param comparator - returns true if positions given as parameters are equal + * @return the end of the group position exclusive (the position the very next group starts) + */ + @VisibleForTesting + static int findEndPosition(int startPosition, int endPosition, PositionComparator comparator) + { + checkArgument(startPosition >= 0, "startPosition must be greater or equal than zero: %s", startPosition); + checkArgument(startPosition < endPosition, "startPosition (%s) must be less than endPosition (%s)", startPosition, endPosition); + + int left = startPosition; + int right = endPosition; + + while (right - left > 1) { + int middle = (left + right) >>> 1; + + if (comparator.test(startPosition, middle)) { + left = middle; + } + else { + right = middle; + } + } + + return right; + } + + private interface PositionComparator + { + boolean test(int first, int second); + } + + private WorkProcessor pagesIndexToTableFunctionPartitions( + PagesIndex pagesIndex, + HashStrategies hashStrategies, + TableFunctionProcessorProvider tableFunctionProvider, + ConnectorTableFunctionHandle functionHandle, + int properChannelsCount, + int passThroughSourcesCount, + List> requiredChannels, + Optional> markerChannels, + List passThroughSpecifications, + boolean processEmptyInput) + { + // pagesIndex contains the full grouped and sorted data for one or more partitions + + PagesHashStrategy remainingPartitionStrategy = hashStrategies.remainingPartitionStrategy; + + return WorkProcessor.create(new WorkProcessor.Process() + { + private int partitionStart; + private boolean processEmpty = processEmptyInput; + + @Override + public WorkProcessor.ProcessState process() + { + if (partitionStart == pagesIndex.getPositionCount()) { + if (processEmpty && pagesIndex.getPositionCount() == 0) { + // empty PagesIndex can only be passed once as the result of PartitionAndSort. Neither this nor any future instance of Process will ever get an empty PagesIndex again. + processEmpty = false; + return WorkProcessor.ProcessState.ofResult(new EmptyTableFunctionPartition( + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + passThroughSpecifications.stream() + .map(RegularTableFunctionPartition.PassThroughColumnSpecification::getInputChannel) + .map(pagesIndex::getType) + .collect(toImmutableList()))); + } + return WorkProcessor.ProcessState.finished(); + } + + // there is input, so we are not interested in processing empty input + processEmpty = false; + + int partitionEnd = findGroupEnd(pagesIndex, remainingPartitionStrategy, partitionStart); + + RegularTableFunctionPartition partition = new RegularTableFunctionPartition( + pagesIndex, + partitionStart, + partitionEnd, + tableFunctionProvider.getDataProcessor(functionHandle), + properChannelsCount, + passThroughSourcesCount, + requiredChannels, + markerChannels, + passThroughSpecifications); + + partitionStart = partitionEnd; + return WorkProcessor.ProcessState.ofResult(partition); + } + }); + } + + private class PagesSource + implements WorkProcessor.Process + { + @Override + public WorkProcessor.ProcessState process() + { + if (operatorFinishing && pendingInput == null) { + return WorkProcessor.ProcessState.finished(); + } + + if (pendingInput != null) { + Page result = pendingInput; + pendingInput = null; + return WorkProcessor.ProcessState.ofResult(result); + } + + return WorkProcessor.ProcessState.yield(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java new file mode 100644 index 0000000000000..1876b352bd251 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableFunctionPartition.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.presto.common.Page; + +public interface TableFunctionPartition +{ + WorkProcessor toOutputPages(); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeInfo.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeInfo.java index a09362c5494c8..9f5ef5ea0ead6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeInfo.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeOperator.java index 174392dc56ae5..a34db48c9381c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterMergeOperator.java @@ -32,6 +32,7 @@ import java.util.Queue; import java.util.function.Supplier; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.SystemSessionProperties.isStatisticsCpuTimerEnabled; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.operator.TableWriterUtils.CONTEXT_CHANNEL; @@ -43,7 +44,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.airlift.units.Duration.succinctNanos; import static java.util.Objects.requireNonNull; public class TableWriterMergeOperator diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java index 527785d184d98..781b822dab73d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TableWriterOperator.java @@ -14,29 +14,29 @@ package com.facebook.presto.operator; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.Session; import com.facebook.presto.common.Page; +import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.RunLengthEncodedBlock; import com.facebook.presto.common.type.Type; import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskMetadataContext; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.InsertHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.RefreshMaterializedViewHandle; import com.facebook.presto.memory.context.LocalMemoryContext; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.operator.OperationTimer.OperationTiming; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.connector.ConnectorMetadataUpdater; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.split.PageSinkManager; import com.facebook.presto.util.AutoCloseableCloser; @@ -47,17 +47,16 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import java.util.Collection; import java.util.List; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.concurrent.MoreFutures.toListenableFuture; +import static com.facebook.airlift.units.Duration.succinctNanos; import static com.facebook.presto.SystemSessionProperties.isStatisticsCpuTimerEnabled; import static com.facebook.presto.common.RuntimeMetricName.WRITTEN_FILES_COUNT; import static com.facebook.presto.common.RuntimeUnit.NONE; @@ -71,7 +70,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.allAsList; import static io.airlift.slice.Slices.wrappedBuffer; -import static io.airlift.units.Duration.succinctNanos; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -79,14 +77,13 @@ public class TableWriterOperator implements Operator { public static final String OPERATOR_TYPE = "TableWriterOperator"; + public static class TableWriterOperatorFactory implements OperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; private final PageSinkManager pageSinkManager; - private final ConnectorMetadataUpdaterManager metadataUpdaterManager; - private final TaskMetadataContext taskMetadataContext; private final ExecutionWriterTarget target; private final List columnChannels; private final List notNullChannelColumnNames; @@ -101,8 +98,6 @@ public TableWriterOperatorFactory( int operatorId, PlanNodeId planNodeId, PageSinkManager pageSinkManager, - ConnectorMetadataUpdaterManager metadataUpdaterManager, - TaskMetadataContext taskMetadataContext, ExecutionWriterTarget writerTarget, List columnChannels, List notNullChannelColumnNames, @@ -117,11 +112,12 @@ public TableWriterOperatorFactory( this.columnChannels = requireNonNull(columnChannels, "columnChannels is null"); this.notNullChannelColumnNames = requireNonNull(notNullChannelColumnNames, "notNullChannelColumnNames is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); - this.metadataUpdaterManager = requireNonNull(metadataUpdaterManager, "metadataUpdaterManager is null"); - this.taskMetadataContext = requireNonNull(taskMetadataContext, "taskMetadataContext is null"); checkArgument( - writerTarget instanceof CreateHandle || writerTarget instanceof InsertHandle || writerTarget instanceof RefreshMaterializedViewHandle, - "writerTarget must be CreateHandle or InsertHandle or RefreshMaterializedViewHandle"); + writerTarget instanceof CreateHandle || + writerTarget instanceof InsertHandle || + writerTarget instanceof RefreshMaterializedViewHandle || + writerTarget instanceof ExecuteProcedureHandle, + "writerTarget must be CreateHandle or InsertHandle or RefreshMaterializedViewHandle or TableExecuteHandle"); this.target = requireNonNull(writerTarget, "writerTarget is null"); this.session = session; this.statisticsAggregationOperatorFactory = requireNonNull(statisticsAggregationOperatorFactory, "statisticsAggregationOperatorFactory is null"); @@ -139,7 +135,7 @@ public Operator createOperator(DriverContext driverContext) boolean statisticsCpuTimerEnabled = !(statisticsAggregationOperator instanceof DevNullOperator) && isStatisticsCpuTimerEnabled(session); return new TableWriterOperator( context, - createPageSink(), + createPageSink(context), columnChannels, notNullChannelColumnNames, statisticsAggregationOperator, @@ -149,27 +145,24 @@ public Operator createOperator(DriverContext driverContext) pageSinkCommitStrategy); } - private ConnectorPageSink createPageSink() + private ConnectorPageSink createPageSink(OperatorContext operatorContext) { - ConnectorId connectorId = getConnectorId(target); - Optional metadataUpdater = metadataUpdaterManager.getMetadataUpdater(connectorId); - if (metadataUpdater.isPresent()) { - taskMetadataContext.setConnectorId(connectorId); - taskMetadataContext.addMetadataUpdater(metadataUpdater.get()); - } - PageSinkContext.Builder pageSinkContextBuilder = PageSinkContext.builder() .setCommitRequired(pageSinkCommitStrategy.isCommitRequired()); - metadataUpdater.ifPresent(pageSinkContextBuilder::setConnectorMetadataUpdater); + + RuntimeStats runtimeStats = operatorContext.getRuntimeStats(); if (target instanceof CreateHandle) { - return pageSinkManager.createPageSink(session, ((CreateHandle) target).getHandle(), pageSinkContextBuilder.build()); + return pageSinkManager.createPageSink(session, ((CreateHandle) target).getHandle(), pageSinkContextBuilder.build(), runtimeStats); } if (target instanceof InsertHandle) { - return pageSinkManager.createPageSink(session, ((InsertHandle) target).getHandle(), pageSinkContextBuilder.build()); + return pageSinkManager.createPageSink(session, ((InsertHandle) target).getHandle(), pageSinkContextBuilder.build(), runtimeStats); } if (target instanceof RefreshMaterializedViewHandle) { - return pageSinkManager.createPageSink(session, ((RefreshMaterializedViewHandle) target).getHandle(), pageSinkContextBuilder.build()); + return pageSinkManager.createPageSink(session, ((RefreshMaterializedViewHandle) target).getHandle(), pageSinkContextBuilder.build(), runtimeStats); + } + if (target instanceof ExecuteProcedureHandle) { + return pageSinkManager.createPageSink(session, ((ExecuteProcedureHandle) target).getHandle(), pageSinkContextBuilder.build()); } throw new UnsupportedOperationException("Unhandled target type: " + target.getClass().getName()); } @@ -188,6 +181,9 @@ private static ConnectorId getConnectorId(ExecutionWriterTarget handle) return ((RefreshMaterializedViewHandle) handle).getHandle().getConnectorId(); } + if (handle instanceof ExecuteProcedureHandle) { + return ((ExecuteProcedureHandle) handle).getHandle().getConnectorId(); + } throw new UnsupportedOperationException("Unhandled target type: " + handle.getClass().getName()); } @@ -204,8 +200,6 @@ public OperatorFactory duplicate() operatorId, planNodeId, pageSinkManager, - metadataUpdaterManager, - taskMetadataContext, target, columnChannels, notNullChannelColumnNames, diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskContext.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskContext.java index b75c5ead5bede..03b22aa178655 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskContext.java @@ -15,11 +15,12 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.GcMonitor; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskMetadataContext; import com.facebook.presto.execution.TaskState; import com.facebook.presto.execution.TaskStateMachine; import com.facebook.presto.execution.buffer.LazyOutputBuffer; @@ -39,11 +40,8 @@ import com.google.common.collect.ListMultimap; import com.google.common.util.concurrent.AtomicDouble; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.Collection; @@ -55,6 +53,9 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicLong; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; +import static com.facebook.presto.common.RuntimeUnit.NONE; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -62,8 +63,6 @@ import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.Iterables.transform; import static com.google.common.collect.Sets.newConcurrentHashSet; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Math.max; import static java.lang.Math.toIntExact; import static java.lang.System.currentTimeMillis; @@ -128,8 +127,6 @@ public class TaskContext private final MemoryTrackingContext taskMemoryContext; - private final TaskMetadataContext taskMetadataContext; - private final Optional taskPlan; // Only contains metrics exposed in this task. Doesn't contain the metrics exposed in the operators. @@ -199,7 +196,6 @@ private TaskContext( this.perOperatorAllocationTrackingEnabled = perOperatorAllocationTrackingEnabled; this.allocationTrackingEnabled = allocationTrackingEnabled; this.legacyLifespanCompletionCondition = legacyLifespanCompletionCondition; - this.taskMetadataContext = new TaskMetadataContext(); } // the state change listener is added here in a separate initialize() method @@ -287,11 +283,6 @@ public TaskState getState() return taskStateMachine.getState(); } - public TaskMetadataContext getTaskMetadataContext() - { - return taskMetadataContext; - } - public DataSize getMemoryReservation() { return new DataSize(taskMemoryContext.getUserMemory(), BYTE); @@ -572,12 +563,20 @@ public TaskStats getTaskStats() boolean fullyBlocked = hasRunningPipelines && runningPipelinesFullyBlocked; + // Add createTime and endTime metrics to RuntimeStats to match native execution behavior + long createTimeInMillis = taskStateMachine.getCreatedTimeInMillis(); + long endTimeInMillis = executionEndTime.get(); + mergedRuntimeStats.addMetricValue("createTime", NONE, createTimeInMillis); + if (endTimeInMillis > 0) { + mergedRuntimeStats.addMetricValue("endTime", NONE, endTimeInMillis); + } + return new TaskStats( - taskStateMachine.getCreatedTimeInMillis(), + createTimeInMillis, executionStartTime.get(), lastExecutionStartTime.get(), lastExecutionEndTime, - executionEndTime.get(), + endTimeInMillis, elapsedTimeInNanos, queuedTimeInNanos, totalDrivers, @@ -589,6 +588,16 @@ public TaskStats getTaskStats() runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, + // Report driver and split stats separately. Since there's a 1:1 mapping between drivers and splits + // in the Java worker, we can safely reuse the driver stats to represent the split stats. + totalDrivers, + queuedDrivers, + runningDrivers, + completedDrivers, + totalDrivers, + queuedDrivers, + runningDrivers, + completedDrivers, cumulativeUserMemory.get(), cumulativeTotalMemory.get(), userMemory, diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskExchangeClientManager.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskExchangeClientManager.java index fb98c9d67890b..59608bcfdd94a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskExchangeClientManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskExchangeClientManager.java @@ -16,8 +16,7 @@ import com.facebook.presto.memory.context.LocalMemoryContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskMemoryReservationSummary.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskMemoryReservationSummary.java index b6ce6dcce05d7..94e8c2a76ff10 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskMemoryReservationSummary.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskMemoryReservationSummary.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskStats.java b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskStats.java index 71704074c1ed2..1595e860dc15d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/TaskStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/TaskStats.java @@ -52,6 +52,16 @@ public class TaskStats private final int blockedDrivers; private final int completedDrivers; + private final int totalNewDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + + private final int totalSplits; + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; + private final double cumulativeUserMemory; private final double cumulativeTotalMemory; private final long userMemoryReservationInBytes; @@ -107,6 +117,14 @@ public TaskStats(long createTimeInMillis, long endTimeInMillis) 0L, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0.0, 0.0, 0L, @@ -152,6 +170,14 @@ public TaskStats(DateTime createTime, DateTime endTime) 0L, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0.0, 0.0, 0L, @@ -200,6 +226,16 @@ public TaskStats( @JsonProperty("blockedDrivers") int blockedDrivers, @JsonProperty("completedDrivers") int completedDrivers, + @JsonProperty("totalNewDrivers") int totalNewDrivers, + @JsonProperty("queuedNewDrivers") int queuedNewDrivers, + @JsonProperty("runningNewDrivers") int runningNewDrivers, + @JsonProperty("completedNewDrivers") int completedNewDrivers, + + @JsonProperty("totalSplits") int totalSplits, + @JsonProperty("queuedSplits") int queuedSplits, + @JsonProperty("runningSplits") int runningSplits, + @JsonProperty("completedSplits") int completedSplits, + @JsonProperty("cumulativeUserMemory") double cumulativeUserMemory, @JsonProperty("cumulativeTotalMemory") double cumulativeTotalMemory, @JsonProperty("userMemoryReservationInBytes") long userMemoryReservationInBytes, @@ -267,6 +303,30 @@ public TaskStats( checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; + checkArgument(totalNewDrivers >= 0, "totalNewDrivers is negative"); + this.totalNewDrivers = totalNewDrivers; + + checkArgument(queuedNewDrivers >= 0, "queuedNewDrivers is negative"); + this.queuedNewDrivers = queuedNewDrivers; + + checkArgument(runningNewDrivers >= 0, "runningNewDrivers is negative"); + this.runningNewDrivers = runningNewDrivers; + + checkArgument(completedNewDrivers >= 0, "completedNewDrivers is negative"); + this.completedNewDrivers = completedNewDrivers; + + checkArgument(totalSplits >= 0, "totalSplits is negative"); + this.totalSplits = totalSplits; + + checkArgument(queuedSplits >= 0, "queuedSplits is negative"); + this.queuedSplits = queuedSplits; + + checkArgument(runningSplits >= 0, "runningSplits is negative"); + this.runningSplits = runningSplits; + + checkArgument(completedSplits >= 0, "completedSplits is negative"); + this.completedSplits = completedSplits; + this.cumulativeUserMemory = cumulativeUserMemory; this.cumulativeTotalMemory = cumulativeTotalMemory; this.userMemoryReservationInBytes = userMemoryReservationInBytes; @@ -286,16 +346,13 @@ public TaskStats( this.totalAllocationInBytes = totalAllocationInBytes; this.rawInputDataSizeInBytes = rawInputDataSizeInBytes; - checkArgument(rawInputPositions >= 0, "rawInputPositions is negative"); - this.rawInputPositions = rawInputPositions; + this.rawInputPositions = (rawInputPositions >= 0) ? rawInputPositions : Long.MAX_VALUE; this.processedInputDataSizeInBytes = processedInputDataSizeInBytes; - checkArgument(processedInputPositions >= 0, "processedInputPositions is negative"); - this.processedInputPositions = processedInputPositions; + this.processedInputPositions = (processedInputPositions >= 0) ? processedInputPositions : Long.MAX_VALUE; this.outputDataSizeInBytes = outputDataSizeInBytes; - checkArgument(outputPositions >= 0, "outputPositions is negative"); - this.outputPositions = outputPositions; + this.outputPositions = (outputPositions >= 0) ? outputPositions : Long.MAX_VALUE; this.physicalWrittenDataSizeInBytes = physicalWrittenDataSizeInBytes; @@ -619,6 +676,62 @@ public RuntimeStats getRuntimeStats() return runtimeStats; } + @JsonProperty + @ThriftField(42) + public int getTotalSplits() + { + return totalSplits; + } + + @JsonProperty + @ThriftField(43) + public int getQueuedSplits() + { + return queuedSplits; + } + + @JsonProperty + @ThriftField(44) + public int getRunningSplits() + { + return runningSplits; + } + + @JsonProperty + @ThriftField(45) + public int getCompletedSplits() + { + return completedSplits; + } + + @JsonProperty + @ThriftField(46) + public int getTotalNewDrivers() + { + return totalNewDrivers; + } + + @JsonProperty + @ThriftField(47) + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + @JsonProperty + @ThriftField(48) + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + @JsonProperty + @ThriftField(49) + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } + public TaskStats summarize() { return new TaskStats( @@ -638,6 +751,14 @@ public TaskStats summarize() runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + totalSplits, + queuedSplits, + runningSplits, + completedSplits, cumulativeUserMemory, cumulativeTotalMemory, userMemoryReservationInBytes, @@ -684,6 +805,14 @@ public TaskStats summarizeFinal() runningPartitionedSplitsWeight, blockedDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + totalSplits, + queuedSplits, + runningSplits, + completedSplits, cumulativeUserMemory, cumulativeTotalMemory, userMemoryReservationInBytes, diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/UncheckedStackArrayAllocator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/UncheckedStackArrayAllocator.java index fa3a79a0e898d..ea965e784f917 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/UncheckedStackArrayAllocator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/UncheckedStackArrayAllocator.java @@ -14,10 +14,9 @@ package com.facebook.presto.operator; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.block.ArrayAllocator; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.Arrays; import static com.google.common.base.MoreObjects.toStringHelper; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/WindowInfo.java b/presto-main-base/src/main/java/com/facebook/presto/operator/WindowInfo.java index e1eaba05dc5b5..06e651dc96829 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/WindowInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/WindowInfo.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/WindowOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/WindowOperator.java index 6688002d64dfe..1a63ac5d89c7b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/WindowOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/WindowOperator.java @@ -36,8 +36,7 @@ import com.google.common.collect.PeekingIterator; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessor.java index e7c40697dd442..5c1ab5431961f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessor.java @@ -15,8 +15,7 @@ import com.google.common.collect.Iterators; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Comparator; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessorUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessorUtils.java index 749472513db2a..ec4eed5c4cbc4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessorUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/WorkProcessorUtils.java @@ -18,8 +18,7 @@ import com.facebook.presto.operator.WorkProcessor.TransformationState; import com.google.common.collect.AbstractIterator; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Comparator; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationFromAnnotationsParser.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationFromAnnotationsParser.java index dfb1899359d12..696303dcb014c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationFromAnnotationsParser.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationFromAnnotationsParser.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; @@ -25,8 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; @@ -35,6 +35,7 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.operator.aggregation.AggregationImplementation.Parser.parseImplementation; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription; import static com.google.common.base.Preconditions.checkArgument; @@ -55,7 +56,7 @@ public static ParametricAggregation parseFunctionDefinitionWithTypesConstraint(C { requireNonNull(returnType, "returnType is null"); requireNonNull(argumentTypes, "argumentTypes is null"); - for (ParametricAggregation aggregation : parseFunctionDefinitions(clazz)) { + for (ParametricAggregation aggregation : parseFunctionDefinitions(clazz, JAVA_BUILTIN_NAMESPACE)) { if (aggregation.getSignature().getReturnType().equals(returnType) && aggregation.getSignature().getArgumentTypes().equals(argumentTypes)) { return aggregation; @@ -65,6 +66,11 @@ public static ParametricAggregation parseFunctionDefinitionWithTypesConstraint(C } public static List parseFunctionDefinitions(Class aggregationDefinition) + { + return parseFunctionDefinitions(aggregationDefinition, JAVA_BUILTIN_NAMESPACE); + } + + public static List parseFunctionDefinitions(Class aggregationDefinition, CatalogSchemaName functionNamespace) { AggregationFunction aggregationAnnotation = aggregationDefinition.getAnnotation(AggregationFunction.class); requireNonNull(aggregationAnnotation, "aggregationAnnotation is null"); @@ -77,7 +83,7 @@ public static List parseFunctionDefinitions(Class aggr for (Method outputFunction : getOutputFunctions(aggregationDefinition, stateClass)) { for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { for (AggregationHeader header : parseHeaders(aggregationDefinition, outputFunction)) { - AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); + AggregationImplementation onlyImplementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory, functionNamespace); ParametricImplementationsGroup implementations = ParametricImplementationsGroup.of(onlyImplementation); builder.add(new ParametricAggregation(implementations.getSignature(), header, implementations)); } @@ -98,7 +104,7 @@ public static ParametricAggregation parseFunctionDefinition(Class aggregation Optional aggregationStateSerializerFactory = getAggregationStateSerializerFactory(aggregationDefinition, stateClass); Method outputFunction = getOnlyElement(getOutputFunctions(aggregationDefinition, stateClass)); for (Method inputFunction : getInputFunctions(aggregationDefinition, stateClass)) { - AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory); + AggregationImplementation implementation = parseImplementation(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, aggregationStateSerializerFactory, JAVA_BUILTIN_NAMESPACE); implementationsBuilder.addImplementation(implementation); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java index 8cb42304cd67f..5ffbd15f4d3de 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationImplementation.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.function.SqlFunctionProperties; @@ -46,7 +47,6 @@ import java.util.stream.Stream; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.containsAnnotation; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.createTypeVariableConstraints; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseLiteralParameters; @@ -246,6 +246,7 @@ public static final class Parser private final AggregationHeader header; private final Set literalParameters; private final List typeParameters; + private final CatalogSchemaName functionNamespace; private Parser( Class aggregationDefinition, @@ -254,7 +255,8 @@ private Parser( Method inputFunction, Method outputFunction, Method combineFunction, - Optional stateSerializerFactoryFunction) + Optional stateSerializerFactoryFunction, + CatalogSchemaName functionNamespace) { // rewrite data passed directly this.aggregationDefinition = aggregationDefinition; @@ -301,12 +303,13 @@ private Parser( inputHandle = methodHandle(inputFunction); combineHandle = methodHandle(combineFunction); outputHandle = methodHandle(outputFunction); + this.functionNamespace = requireNonNull(functionNamespace, "functionNamespace is null"); } private AggregationImplementation get() { Signature signature = new Signature( - QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, header.getName()), + QualifiedObjectName.valueOf(functionNamespace, header.getName()), FunctionKind.AGGREGATE, typeVariableConstraints, longVariableConstraints, @@ -336,9 +339,10 @@ public static AggregationImplementation parseImplementation( Method inputFunction, Method outputFunction, Method combineFunction, - Optional stateSerializerFactoryFunction) + Optional stateSerializerFactoryFunction, + CatalogSchemaName functionNamespace) { - return new Parser(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, stateSerializerFactoryFunction).get(); + return new Parser(aggregationDefinition, header, stateClass, inputFunction, outputFunction, combineFunction, stateSerializerFactoryFunction, functionNamespace).get(); } private static List parseParameterMetadataTypes(Method method) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java index 578e782bc8d91..c78186ebb2417 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/AggregationUtils.java @@ -22,6 +22,7 @@ import com.facebook.presto.operator.aggregation.state.CentralMomentsState; import com.facebook.presto.operator.aggregation.state.CorrelationState; import com.facebook.presto.operator.aggregation.state.CovarianceState; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; import com.facebook.presto.operator.aggregation.state.RegressionState; import com.facebook.presto.operator.aggregation.state.VarianceState; import com.facebook.presto.spi.function.AggregationFunctionImplementation; @@ -145,9 +146,14 @@ public static double getCorrelation(CorrelationState state) public static void updateRegressionState(RegressionState state, double x, double y) { double oldMeanX = state.getMeanX(); - double oldMeanY = state.getMeanY(); updateCovarianceState(state, x, y); state.setM2X(state.getM2X() + (x - oldMeanX) * (x - state.getMeanX())); + } + + public static void updateExtendedRegressionState(ExtendedRegressionState state, double x, double y) + { + double oldMeanY = state.getMeanY(); + updateRegressionState(state, x, y); state.setM2Y(state.getM2Y() + (y - oldMeanY) * (y - state.getMeanY())); } @@ -189,12 +195,12 @@ public static double getRegressionSxy(RegressionState state) return state.getC2(); } - public static double getRegressionSyy(RegressionState state) + public static double getRegressionSyy(ExtendedRegressionState state) { return state.getM2Y(); } - public static double getRegressionR2(RegressionState state) + public static double getRegressionR2(ExtendedRegressionState state) { if (state.getM2X() != 0 && state.getM2Y() == 0) { return 1.0; @@ -311,10 +317,21 @@ public static void mergeRegressionState(RegressionState state, RegressionState o long na = state.getCount(); long nb = otherState.getCount(); state.setM2X(state.getM2X() + otherState.getM2X() + na * nb * Math.pow(state.getMeanX() - otherState.getMeanX(), 2) / (double) (na + nb)); - state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); updateCovarianceState(state, otherState); } + public static void mergeExtendedRegressionState(ExtendedRegressionState state, ExtendedRegressionState otherState) + { + if (otherState.getCount() == 0) { + return; + } + + long na = state.getCount(); + long nb = otherState.getCount(); + state.setM2Y(state.getM2Y() + otherState.getM2Y() + na * nb * Math.pow(state.getMeanY() - otherState.getMeanY(), 2) / (double) (na + nb)); + mergeRegressionState(state, otherState); + } + public static String generateAggregationName(String baseName, TypeSignature outputType, List inputTypes) { StringBuilder sb = new StringBuilder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java index 24d1c6e61fcf5..db3ad26ec5d6d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeRegressionState; import static com.facebook.presto.operator.aggregation.AggregationUtils.updateRegressionState; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.DOUBLE) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - DOUBLE.writeDouble(out, result); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..3550cd0936949 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/DoubleRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.mergeExtendedRegressionState; +import static com.facebook.presto.operator.aggregation.AggregationUtils.updateExtendedRegressionState; + +@AggregationFunction +public class DoubleRegressionExtendedAggregation +{ + private DoubleRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.DOUBLE) double dependentValue, @SqlType(StandardTypes.DOUBLE) double independentValue) + { + updateExtendedRegressionState(state, independentValue, dependentValue); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + mergeExtendedRegressionState(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.DOUBLE) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + DOUBLE.writeDouble(out, result); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java index d83db731c3588..4f3ac2e2e15cb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/GenericAccumulatorFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; @@ -54,11 +55,9 @@ import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/NumericHistogram.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/NumericHistogram.java index 9e5f17426faa0..b89d9c4cbe918 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/NumericHistogram.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/NumericHistogram.java @@ -184,7 +184,7 @@ private static PriorityQueue mergeBuckets(double[] values, double[] weigh Entry right = current.getRight(); - // right is guaranteed to exist because we set the penalty of the last bucket to infinity + // right is guaranteed to exist because we set the penalty of the last bucket to NaN // so the first current in the queue can never be the last bucket checkState(right != null, "Expected right to be != null"); checkState(right.isValid(), "Expected right to be valid"); @@ -250,7 +250,7 @@ private static int mergeSameBuckets(double[] values, double[] weights, int nextI int current = 0; for (int i = 1; i < nextIndex; i++) { - if (values[current] == values[i]) { + if (values[current] == values[i] || (Double.isNaN(values[current]) && Double.isNaN(values[i]))) { weights[current] += weights[i]; } else { @@ -311,11 +311,11 @@ public void swap(int a, int b) }); } - private static double computePenalty(double value1, double value2, double weight1, double weight2) + private static double computePenalty(double value1, double weight1, double value2, double weight2) { - double weight = value2 + weight2; - double squaredDifference = (value1 - weight1) * (value1 - weight1); - double proportionsProduct = (value2 * weight2) / ((value2 + weight2) * (value2 + weight2)); + double weight = weight1 + weight2; + double squaredDifference = (value1 - value2) * (value1 - value2); + double proportionsProduct = (weight1 * weight2) / ((weight1 + weight2) * (weight1 + weight2)); return weight * squaredDifference * proportionsProduct; } @@ -350,7 +350,7 @@ private Entry(int id, double value, double weight, Entry left, Entry right) penalty = computePenalty(value, weight, right.value, right.weight); } else { - penalty = Double.POSITIVE_INFINITY; + penalty = Double.NaN; } if (left != null) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java index 1fe5d006da1a9..a75222bfa93c4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionAggregation.java @@ -24,15 +24,8 @@ import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.common.type.RealType.REAL; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionIntercept; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSlope; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; -import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; import static java.lang.Float.floatToRawIntBits; import static java.lang.Float.intBitsToFloat; @@ -78,100 +71,4 @@ public static void regrIntercept(@AggregationState RegressionState state, BlockB out.appendNull(); } } - - @AggregationFunction("regr_sxy") - @OutputFunction(StandardTypes.REAL) - public static void regrSxy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_sxx") - @OutputFunction(StandardTypes.REAL) - public static void regrSxx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSxx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_syy") - @OutputFunction(StandardTypes.REAL) - public static void regrSyy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionSyy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_r2") - @OutputFunction(StandardTypes.REAL) - public static void regrR2(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionR2(state); - if (Double.isFinite(result)) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_count") - @OutputFunction(StandardTypes.REAL) - public static void regrCount(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionCount(state); - if (Double.isFinite(result) && result > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgy") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgy(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgy(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } - - @AggregationFunction("regr_avgx") - @OutputFunction(StandardTypes.REAL) - public static void regrAvgx(@AggregationState RegressionState state, BlockBuilder out) - { - double result = getRegressionAvgx(state); - double count = getRegressionCount(state); - if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { - REAL.writeLong(out, floatToRawIntBits((float) result)); - } - else { - out.appendNull(); - } - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java new file mode 100644 index 0000000000000..2d0335ae9aca6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/RealRegressionExtendedAggregation.java @@ -0,0 +1,149 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation; + +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.operator.aggregation.state.ExtendedRegressionState; +import com.facebook.presto.spi.function.AggregationFunction; +import com.facebook.presto.spi.function.AggregationState; +import com.facebook.presto.spi.function.CombineFunction; +import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.OutputFunction; +import com.facebook.presto.spi.function.SqlType; + +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionAvgy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionCount; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionR2; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxx; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSxy; +import static com.facebook.presto.operator.aggregation.AggregationUtils.getRegressionSyy; +import static java.lang.Float.floatToRawIntBits; +import static java.lang.Float.intBitsToFloat; + +@AggregationFunction +public class RealRegressionExtendedAggregation +{ + private RealRegressionExtendedAggregation() {} + + @InputFunction + public static void input(@AggregationState ExtendedRegressionState state, @SqlType(StandardTypes.REAL) long dependentValue, @SqlType(StandardTypes.REAL) long independentValue) + { + DoubleRegressionExtendedAggregation.input(state, intBitsToFloat((int) dependentValue), intBitsToFloat((int) independentValue)); + } + + @CombineFunction + public static void combine(@AggregationState ExtendedRegressionState state, @AggregationState ExtendedRegressionState otherState) + { + DoubleRegressionExtendedAggregation.combine(state, otherState); + } + + @AggregationFunction("regr_sxy") + @OutputFunction(StandardTypes.REAL) + public static void regrSxy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_sxx") + @OutputFunction(StandardTypes.REAL) + public static void regrSxx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSxx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_syy") + @OutputFunction(StandardTypes.REAL) + public static void regrSyy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionSyy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_r2") + @OutputFunction(StandardTypes.REAL) + public static void regrR2(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionR2(state); + if (Double.isFinite(result)) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_count") + @OutputFunction(StandardTypes.REAL) + public static void regrCount(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionCount(state); + if (Double.isFinite(result) && result > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgy") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgy(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgy(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } + + @AggregationFunction("regr_avgx") + @OutputFunction(StandardTypes.REAL) + public static void regrAvgx(@AggregationState ExtendedRegressionState state, BlockBuilder out) + { + double result = getRegressionAvgx(state); + double count = getRegressionCount(state); + if (Double.isFinite(result) && Double.isFinite(count) && count > 0) { + REAL.writeLong(out, floatToRawIntBits((float) result)); + } + else { + out.appendNull(); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java index 0ddd89e74779d..0dca29f836107 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/TypedSet.java @@ -13,18 +13,19 @@ */ package com.facebook.presto.operator.aggregation; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.PrestoException; import com.google.common.annotations.VisibleForTesting; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.ints.IntArrayList; import org.openjdk.jol.info.ClassLayout; import java.lang.invoke.MethodHandle; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; @@ -33,7 +34,6 @@ import static com.facebook.presto.util.Failures.internalError; import static com.google.common.base.Defaults.defaultValue; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java index b66343ae40bdd..127eac1555b03 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/InMemoryHashAggregationBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation.builder; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.BlockBuilder; @@ -37,7 +38,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntIterators; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java index 302692c0a9525..85bb894a556f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/MergingHashAggregationBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation.builder; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; @@ -25,7 +26,6 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.sql.gen.JoinCompiler; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.io.Closeable; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java index ba258cff0d8e3..c3d0bf74fc066 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/builder/SpillableHashAggregationBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation.builder; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.LocalMemoryContext; @@ -29,7 +30,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.io.Closer; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import java.io.IOException; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java index 891d538c9544f..2a82520a9a19a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/noisyaggregation/sketch/SfmSketch.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.aggregation.noisyaggregation.sketch; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.airlift.slice.BasicSliceInput; @@ -22,8 +23,6 @@ import io.airlift.slice.Slice; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.concurrent.NotThreadSafe; - import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java index 917afe5dac25f..a4d159ef9b3e3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/PartialAggregationController.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.operator.aggregation.partial; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; import java.util.OptionalLong; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java index 28ef06478e558..6532b6d0df87e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/partial/SkipAggregationBuilder.java @@ -28,8 +28,7 @@ import com.facebook.presto.spi.function.aggregation.GroupByIdBlock; import com.facebook.presto.spi.function.aggregation.GroupedAccumulator; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java index 8728f01f6b084..d2c7c63088f80 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/reservoirsample/ReservoirSample.java @@ -19,10 +19,9 @@ import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.Type; import io.airlift.slice.SizeOf; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.ArrayList; import java.util.Collections; import java.util.concurrent.ThreadLocalRandom; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java index 0aa8c20c7780e..dc1b0d2b3f3dc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/sketch/kll/KllSketchAggregationState.java @@ -24,6 +24,7 @@ import com.facebook.presto.spi.function.AccumulatorStateMetadata; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slice; +import jakarta.annotation.Nullable; import org.apache.datasketches.common.ArrayOfBooleansSerDe; import org.apache.datasketches.common.ArrayOfDoublesSerDe; import org.apache.datasketches.common.ArrayOfItemsSerDe; @@ -32,8 +33,6 @@ import org.apache.datasketches.kll.KllItemsSketch; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.util.Comparator; import java.util.Map; import java.util.function.Function; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java new file mode 100644 index 0000000000000..64a9883174158 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/ExtendedRegressionState.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.aggregation.state; + +public interface ExtendedRegressionState + extends RegressionState +{ + double getM2Y(); + + void setM2Y(double value); +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java index c6ac99e042149..1d73576e4bffd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/LongDecimalWithOverflowStateFactory.java @@ -17,10 +17,9 @@ import com.facebook.presto.common.array.ObjectBigArray; import com.facebook.presto.spi.function.AccumulatorStateFactory; import io.airlift.slice.Slice; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java index 79837f90c0c11..ae3af6f46dc43 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/aggregation/state/RegressionState.java @@ -19,8 +19,4 @@ public interface RegressionState double getM2X(); void setM2X(double value); - - double getM2Y(); - - void setM2Y(double value); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/FunctionsParserHelper.java b/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/FunctionsParserHelper.java index acb63e7abe084..60784f765b3ec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/FunctionsParserHelper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/FunctionsParserHelper.java @@ -15,7 +15,9 @@ import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.FunctionDescriptor; import com.facebook.presto.spi.function.IsNull; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.LongVariableConstraint; @@ -29,8 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; @@ -58,12 +59,16 @@ import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.common.type.StandardTypes.PARAMETRIC_TYPES; import static com.facebook.presto.operator.annotations.ImplementationDependency.isImplementationDependencyAnnotation; +import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.reflect.Modifier.isPublic; import static java.lang.reflect.Modifier.isStatic; import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; public class FunctionsParserHelper { @@ -254,6 +259,30 @@ public static Optional parseDescription(AnnotatedElement base) return (description == null) ? Optional.empty() : Optional.of(description.value()); } + public static ComplexTypeFunctionDescriptor parseFunctionDescriptor(AnnotatedElement base) + { + FunctionDescriptor descriptor = base.getAnnotation(FunctionDescriptor.class); + if (descriptor == null) { + return ComplexTypeFunctionDescriptor.DEFAULT; + } + + int pushdownSubfieldArgIndex = descriptor.pushdownSubfieldArgIndex(); + Optional descriptorPushdownIndex; + if (pushdownSubfieldArgIndex < 0) { + descriptorPushdownIndex = Optional.empty(); + } + else { + descriptorPushdownIndex = Optional.of(pushdownSubfieldArgIndex); + } + + return new ComplexTypeFunctionDescriptor( + true, + emptyList(), + Optional.of(emptySet()), + Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired), + descriptorPushdownIndex); + } + public static List parseLongVariableConstraints(Method inputFunction) { return Stream.of(inputFunction.getAnnotationsByType(Constraint.class)) @@ -277,4 +306,25 @@ public static Map> getDeclaredSpecializedTypeParameters(Method } return specializedTypeParameters; } + + public static void checkPushdownSubfieldArgIndex(Method method, Signature signature, Optional pushdownSubfieldArgIndex) + { + if (pushdownSubfieldArgIndex.isPresent()) { + Map typeConstraintMapping = new HashMap<>(); + for (TypeVariableConstraint constraint : signature.getTypeVariableConstraints()) { + typeConstraintMapping.put(constraint.getName(), constraint); + } + checkCondition(signature.getArgumentTypes().size() > pushdownSubfieldArgIndex.get(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has out of range pushdown subfield arg index", method); + String typeVariableName = signature.getArgumentTypes().get(pushdownSubfieldArgIndex.get()).toString(); + + // The type variable must be directly a ROW type + // or (it is a type alias that is not bounded by a type) + // or (it is a type alias that maps to a row type) + boolean meetsTypeConstraint = (!typeConstraintMapping.containsKey(typeVariableName) && typeVariableName.equals(com.facebook.presto.common.type.StandardTypes.ROW)) || + (typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound() == null && !typeConstraintMapping.get(typeVariableName).isNonDecimalNumericRequired()) || + (typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound().equals(com.facebook.presto.common.type.StandardTypes.ROW)); + + checkCondition(meetsTypeConstraint, FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] does not have a struct or row type as pushdown subfield arg", method); + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/ImplementationDependency.java b/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/ImplementationDependency.java index 4f30469473589..06e897574a1d4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/ImplementationDependency.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/annotations/ImplementationDependency.java @@ -20,9 +20,9 @@ import com.facebook.presto.spi.function.Convention; import com.facebook.presto.spi.function.FunctionDependency; import com.facebook.presto.spi.function.InvocationConvention; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.TypeParameter; -import com.facebook.presto.type.LiteralParameter; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java index e1b40098777b2..37ab3ee42bdac 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.exchange; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; import com.facebook.presto.execution.Lifespan; @@ -28,10 +29,8 @@ import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.SystemPartitioningHandle; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.Closeable; import java.util.ArrayList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeMemoryManager.java b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeMemoryManager.java index 5775dcac4d52a..66dcc97ed8946 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeMemoryManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeMemoryManager.java @@ -15,10 +15,9 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.concurrent.atomic.AtomicLong; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeSource.java b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeSource.java index a66bef9f3e8de..3eb8535f2bb27 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/LocalExchangeSource.java @@ -18,10 +18,9 @@ import com.facebook.presto.operator.WorkProcessor.ProcessState; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/PageReference.java b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/PageReference.java index a4698ef93818f..90c2a888d8ca5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/PageReference.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/exchange/PageReference.java @@ -14,8 +14,7 @@ package com.facebook.presto.operator.exchange; import com.facebook.presto.common.Page; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/DynamicTupleFilterFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/DynamicTupleFilterFactory.java index 942eebedf0427..246eaad3f8c63 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/DynamicTupleFilterFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/DynamicTupleFilterFactory.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.type.Type; @@ -27,7 +28,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import java.util.List; import java.util.Map; @@ -36,10 +36,10 @@ import java.util.function.Supplier; import java.util.stream.IntStream; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.operator.FilterAndProjectOperator.FilterAndProjectOperatorFactory; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; public class DynamicTupleFilterFactory diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexJoinLookupStats.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexJoinLookupStats.java index 449dbbd2b6577..ac9d13c8d63fa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexJoinLookupStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexJoinLookupStats.java @@ -14,11 +14,10 @@ package com.facebook.presto.operator.index; import com.facebook.airlift.stats.CounterStat; +import com.google.errorprone.annotations.ThreadSafe; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.concurrent.ThreadSafe; - @ThreadSafe public class IndexJoinLookupStats { diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLoader.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLoader.java index 498fe4a21029d..9d29195940009 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLoader.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLoader.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.concurrent.NotThreadSafe; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.type.Type; @@ -34,12 +37,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSource.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSource.java index 746de5d912578..6c8ae8f70ecd1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSource.java @@ -13,12 +13,11 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.operator.LookupSource; -import javax.annotation.concurrent.NotThreadSafe; - import static com.facebook.presto.operator.index.IndexSnapshot.UNLOADED_INDEX_KEY; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java index 55c3f5a2fe1f4..5e7d1dbf3f2e6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexLookupSourceFactory.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.Type; import com.facebook.presto.operator.LookupSourceFactory; import com.facebook.presto.operator.LookupSourceProvider; @@ -27,8 +29,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshot.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshot.java index a8c302f708f2d..11c1c5474b093 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshot.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshot.java @@ -16,8 +16,7 @@ import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.operator.LookupSource; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java index f37f524117b7d..bb014869c60a2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/IndexSnapshotBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; @@ -23,7 +24,6 @@ import com.facebook.presto.operator.PagesIndex; import com.facebook.presto.operator.index.UnloadedIndexKeyRecordSet.UnloadedIndexKeyRecordCursor; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/PageBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/PageBuffer.java index 1481689d2e85a..94be981ec5064 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/PageBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/PageBuffer.java @@ -17,8 +17,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.ArrayDeque; import java.util.Queue; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/PagesIndexBuilderOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/PagesIndexBuilderOperator.java index b06852b3be1f7..4837053199e64 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/PagesIndexBuilderOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/PagesIndexBuilderOperator.java @@ -19,8 +19,7 @@ import com.facebook.presto.operator.OperatorContext; import com.facebook.presto.operator.OperatorFactory; import com.facebook.presto.spi.plan.PlanNodeId; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/StreamingIndexedData.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/StreamingIndexedData.java index 6456a7a8e3474..cd173d38c5a76 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/StreamingIndexedData.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/StreamingIndexedData.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.index; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -21,8 +22,6 @@ import com.facebook.presto.operator.Driver; import com.google.common.collect.ImmutableList; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.List; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/index/UpdateRequest.java b/presto-main-base/src/main/java/com/facebook/presto/operator/index/UpdateRequest.java index 3df8c51ba63bf..b1797681b0455 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/index/UpdateRequest.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/index/UpdateRequest.java @@ -16,8 +16,7 @@ import com.facebook.airlift.concurrent.MoreFutures; import com.facebook.presto.common.Page; import com.google.common.util.concurrent.SettableFuture; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java b/presto-main-base/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java index 4b01efcba8d21..6a8f311b1714f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/project/DictionaryAwarePageProjection.java @@ -22,8 +22,7 @@ import com.facebook.presto.operator.CompletedWork; import com.facebook.presto.operator.DriverYieldSignal; import com.facebook.presto.operator.Work; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/project/MergingPageOutput.java b/presto-main-base/src/main/java/com/facebook/presto/operator/project/MergingPageOutput.java index 8d8e89deb1f75..17f3a6758c42f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/project/MergingPageOutput.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/project/MergingPageOutput.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.project; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -20,11 +21,9 @@ import com.facebook.presto.common.type.Type; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.Iterator; import java.util.LinkedList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java index 929d9b0db8a85..736662d471eaf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/project/PageProcessor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.project; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.array.ReferenceCountMap; import com.facebook.presto.common.block.Block; @@ -29,8 +30,6 @@ import com.google.common.annotations.VisibleForTesting; import io.airlift.slice.SizeOf; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/AbstractBlockEncodingBuffer.java b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/AbstractBlockEncodingBuffer.java index 068be8bfaf36a..f89e1093fcde4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/AbstractBlockEncodingBuffer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/AbstractBlockEncodingBuffer.java @@ -42,8 +42,7 @@ import com.facebook.presto.common.block.VariableWidthBlock; import com.google.common.annotations.VisibleForTesting; import io.airlift.slice.SliceOutput; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Arrays; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/OptimizedPartitionedOutputOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/OptimizedPartitionedOutputOperator.java index 77605767e610f..b125f573bfc96 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/OptimizedPartitionedOutputOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/OptimizedPartitionedOutputOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.repartition; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.ArrayAllocator; import com.facebook.presto.common.block.ArrayBlock; @@ -51,11 +52,9 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.SliceOutput; -import io.airlift.units.DataSize; +import jakarta.annotation.Nullable; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.Nullable; - import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/PartitionedOutputOperator.java b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/PartitionedOutputOperator.java index b249f76b8961b..a77a6829b5d08 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/PartitionedOutputOperator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/repartition/PartitionedOutputOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.repartition; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.common.block.Block; @@ -35,9 +36,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Arrays; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/AbstractArraySortByKeyFunction.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/AbstractArraySortByKeyFunction.java new file mode 100644 index 0000000000000..30313bca7d438 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/AbstractArraySortByKeyFunction.java @@ -0,0 +1,420 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.bytecode.BytecodeBlock; +import com.facebook.presto.bytecode.CallSiteBinder; +import com.facebook.presto.bytecode.ClassDefinition; +import com.facebook.presto.bytecode.MethodDefinition; +import com.facebook.presto.bytecode.Parameter; +import com.facebook.presto.bytecode.Scope; +import com.facebook.presto.bytecode.Variable; +import com.facebook.presto.bytecode.control.IfStatement; +import com.facebook.presto.common.NotSupportedException; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.function.SqlFunctionProperties; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.BoundVariables; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.LambdaArgumentDescriptor; +import com.facebook.presto.spi.function.LambdaDescriptor; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunctionVisibility; +import com.facebook.presto.sql.gen.lambda.UnaryFunctionInterface; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.primitives.Primitives; +import it.unimi.dsi.fastutil.ints.IntComparator; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.util.Optional; + +import static com.facebook.presto.bytecode.Access.FINAL; +import static com.facebook.presto.bytecode.Access.PUBLIC; +import static com.facebook.presto.bytecode.Access.a; +import static com.facebook.presto.bytecode.Parameter.arg; +import static com.facebook.presto.bytecode.ParameterizedType.type; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull; +import static com.facebook.presto.bytecode.expression.BytecodeExpressions.equal; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.functionTypeArgumentProperty; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty; +import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.function.Signature.typeVariable; +import static com.facebook.presto.sql.gen.SqlTypeBytecodeExpression.constantType; +import static com.facebook.presto.util.CompilerUtils.defineClass; +import static com.facebook.presto.util.CompilerUtils.makeClassName; +import static com.facebook.presto.util.Reflection.methodHandle; +import static it.unimi.dsi.fastutil.ints.IntArrays.quickSort; + +public abstract class AbstractArraySortByKeyFunction + extends SqlScalarFunction +{ + private final ComplexTypeFunctionDescriptor descriptor; + + protected AbstractArraySortByKeyFunction(String functionName) + { + super(new Signature( + QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, functionName), + FunctionKind.SCALAR, + ImmutableList.of(typeVariable("T"), typeVariable("K")), + ImmutableList.of(), + parseTypeSignature("array(T)"), + ImmutableList.of(parseTypeSignature("array(T)"), parseTypeSignature("function(T,K)")), + false)); + descriptor = new ComplexTypeFunctionDescriptor( + true, + ImmutableList.of(new LambdaDescriptor(1, ImmutableMap.of(0, new LambdaArgumentDescriptor(0, ComplexTypeFunctionDescriptor::prependAllSubscripts)))), + Optional.of(ImmutableSet.of(0)), + Optional.of(ComplexTypeFunctionDescriptor::clearRequiredSubfields), + getSignature()); + } + + @Override + public SqlFunctionVisibility getVisibility() + { + return SqlFunctionVisibility.PUBLIC; + } + + @Override + public boolean isDeterministic() + { + return true; + } + + @Override + public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager) + { + Type elementType = boundVariables.getTypeVariable("T"); + Type keyType = boundVariables.getTypeVariable("K"); + + // Generate the specialized key extractor instance once + KeyExtractor keyExtractor = generateKeyExtractor(elementType, keyType); + + MethodHandle raw = methodHandle( + AbstractArraySortByKeyFunction.class, + "sortByKey", + AbstractArraySortByKeyFunction.class, + Type.class, + Type.class, + KeyExtractor.class, + SqlFunctionProperties.class, + Block.class, + UnaryFunctionInterface.class); + + MethodHandle bound = MethodHandles.insertArguments(raw, 0, this, elementType, keyType, keyExtractor); + + return new BuiltInScalarFunctionImplementation( + false, + ImmutableList.of( + valueTypeArgumentProperty(RETURN_NULL_ON_NULL), // array parameter + functionTypeArgumentProperty(UnaryFunctionInterface.class)), // keyFunction parameter + bound); + } + + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } + + public static Block sortByKey( + AbstractArraySortByKeyFunction function, + Type elementType, + Type keyType, + KeyExtractor keyExtractor, + SqlFunctionProperties properties, + Block array, + UnaryFunctionInterface keyFunction) + { + int arrayLength = array.getPositionCount(); + if (arrayLength < 2) { + return array; + } + + // Create array of indices and extracted keys + int[] indices = new int[arrayLength]; + BlockBuilder keyBlockBuilder = keyType.createBlockBuilder(null, arrayLength); + + // Extract keys for all elements + for (int i = 0; i < arrayLength; i++) { + indices[i] = i; + if (array.isNull(i)) { + keyBlockBuilder.appendNull(); + } + else { + try { + // Use the generated KeyExtractor implementation (direct virtual call) + keyExtractor.extract(properties, array, i, keyFunction, keyBlockBuilder); + } + catch (Throwable t) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, String.format("Error applying key function to element at position %d", i), t); + } + } + } + + Block keysBlock = keyBlockBuilder.build(); + + // Sort indices based on extracted keys using Type's compareTo + try { + if (array.mayHaveNull() || keysBlock.mayHaveNull()) { + quickSort(indices, new NullableComparator(array, keysBlock, keyType, function)); + } + else { + quickSort(indices, new NonNullableComparator(keysBlock, keyType, function)); + } + } + catch (NotSupportedException | UnsupportedOperationException e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Key type does not support comparison", e); + } + catch (PrestoException e) { + if (e.getErrorCode() == NOT_SUPPORTED.toErrorCode()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Key type does not support comparison", e); + } + throw e; + } + + // Build result block with sorted elements + BlockBuilder resultBuilder = elementType.createBlockBuilder(null, arrayLength); + for (int i = 0; i < arrayLength; i++) { + elementType.appendTo(array, indices[i], resultBuilder); + } + + return resultBuilder.build(); + } + + /** + * KeyExtractor is a simple interface implemented by generated classes. + * Implementations must write the extracted key into the provided BlockBuilder + * (or appendNull) for the given position. + */ + public interface KeyExtractor + { + void extract(SqlFunctionProperties properties, Block array, int position, UnaryFunctionInterface keyFunction, BlockBuilder keyBlockBuilder) throws Throwable; + } + + // Generate just the key extraction logic + public static KeyExtractor generateKeyExtractor(Type elementType, Type keyType) + { + CallSiteBinder binder = new CallSiteBinder(); + Class elementJavaType = Primitives.wrap(elementType.getJavaType()); + Class keyJavaType = Primitives.wrap(keyType.getJavaType()); + + String className = "ArraySortKeyExtractorImpl_" + elementType.getTypeSignature() + "_" + keyType.getTypeSignature(); + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName(className), + type(Object.class), + type(KeyExtractor.class)); + definition.declareDefaultConstructor(a(PUBLIC)); + + Parameter properties = arg("properties", SqlFunctionProperties.class); + Parameter array = arg("array", Block.class); + Parameter position = arg("position", int.class); + Parameter keyFunction = arg("keyFunction", UnaryFunctionInterface.class); + Parameter keyBlockBuilder = arg("keyBlockBuilder", BlockBuilder.class); + + MethodDefinition method = definition.declareMethod( + a(PUBLIC), + "extract", + type(void.class), + ImmutableList.of(properties, array, position, keyFunction, keyBlockBuilder)); + + BytecodeBlock body = method.getBody(); + Scope scope = method.getScope(); + Variable element = scope.declareVariable(elementJavaType, "element"); + Variable key = scope.declareVariable(keyJavaType, "key"); + + // Load element with correct primitive handling + if (!elementType.equals(UNKNOWN)) { + // generates the correct getLong/getDouble/getBoolean/getSlice/getObject call + body.append(element.set(constantType(binder, elementType).getValue(array, position).cast(elementJavaType))); + } + else { + body.append(element.set(constantNull(elementJavaType))); + } + + body.append(key.set(keyFunction.invoke("apply", Object.class, element.cast(Object.class)).cast(keyJavaType))); + + // Write the key to the block builder + if (!keyType.equals(UNKNOWN)) { + body.append(new IfStatement() + .condition(equal(key, constantNull(keyJavaType))) + .ifTrue(keyBlockBuilder.invoke("appendNull", BlockBuilder.class).pop()) + .ifFalse(constantType(binder, keyType).writeValue(keyBlockBuilder, key.cast(keyType.getJavaType())))); + } + else { + body.append(keyBlockBuilder.invoke("appendNull", BlockBuilder.class).pop()); + } + + body.ret(); + + Class generatedClass = defineClass(definition, Object.class, binder.getBindings(), AbstractArraySortByKeyFunction.class.getClassLoader()); + + try { + // instantiate generated class and cast to KeyExtractor for direct virtual call + return (KeyExtractor) generatedClass.getConstructor().newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate generated key extractor", e); + } + } + + // Abstract method to be implemented by subclasses to define comparison direction + protected abstract int compareKeys(Type keyType, Block keysBlock, int leftIndex, int rightIndex); + + private static class NullableComparator + implements IntComparator + { + private final Block array; + private final Block keysBlock; + private final Type keyType; + private final AbstractArraySortByKeyFunction function; + + public NullableComparator(Block array, Block keysBlock, Type keyType, AbstractArraySortByKeyFunction function) + { + this.array = array; + this.keysBlock = keysBlock; + this.keyType = keyType; + this.function = function; + } + + @Override + public int compare(int leftIndex, int rightIndex) + { + boolean leftArrayNull = array.isNull(leftIndex); + boolean rightArrayNull = array.isNull(rightIndex); + + if (leftArrayNull && rightArrayNull) { + return 0; + } + if (leftArrayNull) { + return 1; + } + if (rightArrayNull) { + return -1; + } + + boolean leftKeyNull = keysBlock.isNull(leftIndex); + boolean rightKeyNull = keysBlock.isNull(rightIndex); + + if (leftKeyNull && rightKeyNull) { + return 0; + } + if (leftKeyNull) { + return 1; + } + if (rightKeyNull) { + return -1; + } + + int result = function.compareKeys(keyType, keysBlock, leftIndex, rightIndex); + + // If keys are equal, maintain original order + if (result == 0) { + return Integer.compare(leftIndex, rightIndex); + } + + return result; + } + } + + private static class NonNullableComparator + implements IntComparator + { + private final Block keysBlock; + private final Type keyType; + private final AbstractArraySortByKeyFunction function; + + public NonNullableComparator(Block keysBlock, Type keyType, AbstractArraySortByKeyFunction function) + { + this.keysBlock = keysBlock; + this.keyType = keyType; + this.function = function; + } + + @Override + public int compare(int leftIndex, int rightIndex) + { + int result = function.compareKeys(keyType, keysBlock, leftIndex, rightIndex); + + // If keys are equal, maintain original order + if (result == 0) { + return Integer.compare(leftIndex, rightIndex); + } + + return result; + } + } + + public static class ArraySortByKeyFunction + extends AbstractArraySortByKeyFunction + { + public static final ArraySortByKeyFunction ARRAY_SORT_BY_KEY_FUNCTION = new ArraySortByKeyFunction(); + + private ArraySortByKeyFunction() + { + super("array_sort"); + } + + @Override + public String getDescription() + { + return "Sorts the given array using a lambda function to extract sorting keys. " + + "Null array elements and null keys are placed at the end. " + + "Example: array_sort(ARRAY['apple', 'banana', 'cherry'], x -> length(x))"; + } + + @Override + protected int compareKeys(Type keyType, Block keysBlock, int leftIndex, int rightIndex) + { + return keyType.compareTo(keysBlock, leftIndex, keysBlock, rightIndex); + } + } + + public static class ArraySortDescByKeyFunction + extends AbstractArraySortByKeyFunction + { + public static final ArraySortDescByKeyFunction ARRAY_SORT_DESC_BY_KEY_FUNCTION = new ArraySortDescByKeyFunction(); + + private ArraySortDescByKeyFunction() + { + super("array_sort_desc"); + } + + @Override + public String getDescription() + { + return "Sorts the given array in descending order using a lambda function to extract sorting keys"; + } + + @Override + protected int compareKeys(Type keyType, Block keysBlock, int leftIndex, int rightIndex) + { + return keyType.compareTo(keysBlock, rightIndex, keysBlock, leftIndex); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java index d977420d28339..4ee510a86c467 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ArrayIntersectFunction.java @@ -19,8 +19,6 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.OperatorDependency; import com.facebook.presto.spi.function.ScalarFunction; -import com.facebook.presto.spi.function.SqlInvokedScalarFunction; -import com.facebook.presto.spi.function.SqlParameter; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; @@ -60,14 +58,4 @@ public static Block intersect( return typedSet.getBlock(); } - - @SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false) - @Description("Intersects elements of all arrays in the given array") - @TypeParameter("T") - @SqlParameter(name = "input", type = "array>") - @SqlType("array") - public static String arrayIntersectArray() - { - return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)"; - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java index 0ecf2917fd5cb..b5198a36699ae 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/CharacterStringCasts.java @@ -14,10 +14,10 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.type.LiteralParameter; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.SliceUtf8; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java index 4d5cab1a36648..8db043aab9da2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/DateTimeFunctions.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar; import com.facebook.airlift.concurrent.ThreadLocalCache; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.NotSupportedException; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.type.StandardTypes; @@ -26,7 +27,6 @@ import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.TimestampOperators; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.joda.time.DateTimeField; import org.joda.time.DateTimeZone; @@ -38,8 +38,11 @@ import org.joda.time.format.DateTimeFormatterBuilder; import org.joda.time.format.ISODateTimeFormat; +import java.math.BigDecimal; import java.util.Locale; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static com.facebook.presto.common.type.DateTimeEncoding.packDateTimeWithZone; import static com.facebook.presto.common.type.DateTimeEncoding.unpackMillisUtc; @@ -83,6 +86,7 @@ public final class DateTimeFunctions private static final DateTimeField MONTH_OF_YEAR = UTC_CHRONOLOGY.monthOfYear(); private static final DateTimeField QUARTER = QUARTER_OF_YEAR.getField(UTC_CHRONOLOGY); private static final DateTimeField YEAR = UTC_CHRONOLOGY.year(); + private static final Pattern PATTERN = Pattern.compile("^\\s*(\\d+(?:\\.\\d+)?)\\s*([a-zA-Z]+)\\s*$"); private static final int MILLISECONDS_IN_SECOND = 1000; private static final int MILLISECONDS_IN_MINUTE = 60 * MILLISECONDS_IN_SECOND; private static final int MILLISECONDS_IN_HOUR = 60 * MILLISECONDS_IN_MINUTE; @@ -113,13 +117,12 @@ public static long currentTime(SqlFunctionProperties properties) // and we need to have UTC millis for packDateTimeWithZone long millis = UTC_CHRONOLOGY.millisOfDay().get(properties.getSessionStartTime()); - if (!properties.isLegacyTimestamp()) { - // However, those UTC millis are pointing to the correct UTC timestamp - // Our TIME WITH TIME ZONE representation does use UTC 1970-01-01 representation - // So we have to hack here in order to get valid representation - // of TIME WITH TIME ZONE - millis -= valueToSessionTimeZoneOffsetDiff(properties.getSessionStartTime(), getDateTimeZone(properties.getTimeZoneKey())); - } + // However, those UTC millis are pointing to the correct UTC timestamp + // Our TIME WITH TIME ZONE representation does use UTC 1970-01-01 representation + // So we have to hack here in order to get valid representation + // of TIME WITH TIME ZONE + millis -= valueToSessionTimeZoneOffsetDiff(properties.getSessionStartTime(), getDateTimeZone(properties.getTimeZoneKey())); + try { return packDateTimeWithZone(millis, properties.getTimeZoneKey()); } @@ -140,7 +143,8 @@ public static long currentTime(SqlFunctionProperties properties) public static long localTime(SqlFunctionProperties properties) { if (properties.isLegacyTimestamp()) { - return UTC_CHRONOLOGY.millisOfDay().get(properties.getSessionStartTime()); + long millis = UTC_CHRONOLOGY.millisOfDay().get(properties.getSessionStartTime()); + return millis - valueToSessionTimeZoneOffsetDiff(properties.getSessionStartTime(), getDateTimeZone(properties.getTimeZoneKey())); } ISOChronology localChronology = getChronology(properties.getTimeZoneKey()); return localChronology.millisOfDay().get(properties.getSessionStartTime()); @@ -1437,14 +1441,53 @@ else if (character == '%') { @SqlType(StandardTypes.INTERVAL_DAY_TO_SECOND) public static long parseDuration(@SqlType("varchar(x)") Slice duration) { + String durationStr = duration.toStringUtf8(); + + if (durationStr.isEmpty()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "duration is empty"); + } + try { - return Duration.valueOf(duration.toStringUtf8()).toMillis(); + Matcher matcher = PATTERN.matcher(durationStr); + + if (!matcher.matches()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, + "duration is not a valid data duration string: " + durationStr); + } + + BigDecimal value = new BigDecimal(matcher.group(1)); + TimeUnit timeUnit = Duration.valueOfTimeUnit(matcher.group(2)); + + return value.multiply(millisPerTimeUnit(timeUnit)) + .add(BigDecimal.valueOf(0.5)).longValue(); } - catch (IllegalArgumentException e) { + catch (IllegalArgumentException | ArithmeticException e) { throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e); } } + private static BigDecimal millisPerTimeUnit(TimeUnit timeUnit) + { + switch (timeUnit) { + case NANOSECONDS: + return new BigDecimal("0.000001"); + case MICROSECONDS: + return new BigDecimal("0.001"); + case MILLISECONDS: + return BigDecimal.ONE; + case SECONDS: + return BigDecimal.valueOf(1000); + case MINUTES: + return BigDecimal.valueOf(60_000); + case HOURS: + return BigDecimal.valueOf(3_600_000); + case DAYS: + return BigDecimal.valueOf(86_400_000); + default: + throw new AssertionError("Unknown TimeUnit: " + timeUnit); + } + } + private static long timeAtTimeZone(SqlFunctionProperties properties, long timeWithTimeZone, TimeZoneKey timeZoneKey) { DateTimeZone sourceTimeZone = getDateTimeZone(unpackZoneKey(timeWithTimeZone)); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JoniRegexpCasts.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JoniRegexpCasts.java index c0f551099672c..eb6264e1a0a1d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JoniRegexpCasts.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JoniRegexpCasts.java @@ -15,11 +15,11 @@ import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.JoniRegexpType; -import com.facebook.presto.type.LiteralParameter; import io.airlift.jcodings.specific.NonStrictUTF8Encoding; import io.airlift.joni.Option; import io.airlift.joni.Regex; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JsonFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JsonFunctions.java index 339ff5c10dba9..41242a9fcbf50 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JsonFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/JsonFunctions.java @@ -21,13 +21,13 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.JsonPathType; -import com.facebook.presto.type.LiteralParameter; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index bb8f445713c85..9d6048f484b83 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -21,13 +21,13 @@ import com.facebook.presto.operator.aggregation.TypedSet; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.Constraint; -import com.facebook.presto.type.LiteralParameter; import com.facebook.presto.util.SecureRandomGeneration; import com.google.common.primitives.Doubles; import io.airlift.slice.Slice; @@ -39,6 +39,7 @@ import org.apache.commons.math3.distribution.GammaDistribution; import org.apache.commons.math3.distribution.LaplaceDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.commons.math3.distribution.TDistribution; import org.apache.commons.math3.distribution.WeibullDistribution; import org.apache.commons.math3.special.Erf; @@ -49,6 +50,7 @@ import static com.facebook.presto.common.type.Decimals.longTenToNth; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.add; import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.isNegative; import static com.facebook.presto.common.type.UnscaledDecimal128Arithmetic.isZero; @@ -980,6 +982,31 @@ public static double poissonCdf( return distribution.cumulativeProbability((int) value); } + @Description("inverse of Student's t cdf given degrees of freedom and probability") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double inverseTCdf( + @SqlType(StandardTypes.DOUBLE) double df, + @SqlType(StandardTypes.DOUBLE) double p) + { + checkCondition(df > 0, INVALID_FUNCTION_ARGUMENT, "df must be greater than 0"); + checkCondition(p >= 0.0 && p <= 1.0, INVALID_FUNCTION_ARGUMENT, "p must be in the interval [0, 1]"); + TDistribution distribution = new TDistribution(null, df, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.inverseCumulativeProbability(p); + } + + @Description("Student's t cdf given degrees of freedom and value") + @ScalarFunction + @SqlType(StandardTypes.DOUBLE) + public static double tCdf( + @SqlType(StandardTypes.DOUBLE) double df, + @SqlType(StandardTypes.DOUBLE) double value) + { + checkCondition(df > 0, INVALID_FUNCTION_ARGUMENT, "df must be greater than 0"); + TDistribution distribution = new TDistribution(null, df, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + return distribution.cumulativeProbability(value); + } + @Description("Inverse of Weibull cdf given a, b parameters and probability") @ScalarFunction @SqlType(StandardTypes.DOUBLE) @@ -1631,6 +1658,11 @@ public static Double arrayCosineSimilarity(@SqlType("array(double)") Block leftA INVALID_FUNCTION_ARGUMENT, "Both array arguments need to have identical size"); + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both arrays must not have nulls"); + Double normLeftArray = array2Norm(leftArray); Double normRightArray = array2Norm(rightArray); @@ -1643,6 +1675,113 @@ public static Double arrayCosineSimilarity(@SqlType("array(double)") Block leftA return dotProduct / (normLeftArray * normRightArray); } + @Description("squared Euclidean distance between the given identical sized vectors represented as arrays") + @ScalarFunction("l2_squared") + @SqlType(StandardTypes.REAL) + public static long arrayL2Squared(@SqlType("array(real)") Block leftArray, @SqlType("array(real)") Block rightArray) + { + checkCondition( + leftArray.getPositionCount() == rightArray.getPositionCount(), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments need to have identical size"); + + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both arrays must not have nulls"); + + float sum = 0.0f; + for (int i = 0; i < leftArray.getPositionCount(); i++) { + float left = intBitsToFloat((int) leftArray.getInt(i)); + float right = intBitsToFloat((int) rightArray.getInt(i)); + float diff = left - right; + sum += diff * diff; + } + + return floatToRawIntBits(sum); + } + + @Description("squared Euclidean distance between the given identical sized vectors represented as arrays") + @ScalarFunction("l2_squared") + @SqlType(StandardTypes.DOUBLE) + public static double arrayL2SquaredDouble( + @SqlType("array(double)") Block leftArray, + @SqlType("array(double)") Block rightArray) + { + checkCondition( + leftArray.getPositionCount() == rightArray.getPositionCount(), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments need to have identical size"); + + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both arrays must not have nulls"); + + double sum = 0.0; + for (int i = 0; i < leftArray.getPositionCount(); i++) { + double left = DOUBLE.getDouble(leftArray, i); + double right = DOUBLE.getDouble(rightArray, i); + double diff = left - right; + sum += diff * diff; + } + return sum; + } + + @Description("Dot Product distance between the given identical sized vectors represented as DOUBLE arrays") + @ScalarFunction("dot_product") + @SqlNullable + @SqlType(StandardTypes.DOUBLE) + public static Double arrayDotProduct( + @SqlType("array(double)") Block leftArray, + @SqlType("array(double)") Block rightArray) + { + checkCondition( + leftArray.getPositionCount() == rightArray.getPositionCount(), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments must have identical sizes"); + + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments must not have nulls"); + + double result = 0.0; + + for (int i = 0; i < leftArray.getPositionCount(); i++) { + result += DOUBLE.getDouble(leftArray, i) * DOUBLE.getDouble(rightArray, i); + } + + return result; + } + + @Description("Dot Product distance between the given identical sized vectors represented as REAL arrays") + @ScalarFunction("dot_product") + @SqlNullable + @SqlType(StandardTypes.REAL) + public static Long arrayDotProductReal( + @SqlType("array(real)") Block leftArray, + @SqlType("array(real)") Block rightArray) + { + checkCondition( + leftArray.getPositionCount() == rightArray.getPositionCount(), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments must have identical sizes"); + + checkCondition( + !(leftArray.mayHaveNull() || rightArray.mayHaveNull()), + INVALID_FUNCTION_ARGUMENT, + "Both array arguments must not have nulls"); + + float dotProduct = 0.0F; + + for (int i = 0; i < leftArray.getPositionCount(); i++) { + dotProduct += intBitsToFloat((int) REAL.getLong(leftArray, i)) * Float.intBitsToFloat((int) REAL.getLong(rightArray, i)); + } + + return ((long) floatToRawIntBits(dotProduct)); + } + private static double mapDotProduct(Block leftMap, Block rightMap) { TypedSet rightMapKeys = new TypedSet(VARCHAR, rightMap.getPositionCount(), "cosine_similarity"); @@ -1665,17 +1804,6 @@ private static double mapDotProduct(Block leftMap, Block rightMap) return result; } - private static double arrayDotProduct(Block leftArray, Block rightArray) - { - double result = 0.0; - - for (int i = 0; i < leftArray.getPositionCount(); i++) { - result += DOUBLE.getDouble(leftArray, i) * DOUBLE.getDouble(rightArray, i); - } - - return result; - } - private static Double mapL2Norm(Block map) { double norm = 0.0; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index 3950c075e9fe4..12a80dbb9d164 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -19,6 +19,7 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; import com.google.common.annotations.VisibleForTesting; @@ -28,7 +29,7 @@ import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_IMPLEMENTATION; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; -import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.util.Failures.checkCondition; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -67,6 +68,12 @@ public boolean isCalledOnNullInput() return details.isCalledOnNullInput(); } + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return details.getComplexTypeFunctionDescriptor(); + } + @Override public String getDescription() { @@ -112,6 +119,6 @@ public BuiltInScalarFunctionImplementation specialize(BoundVariables boundVariab return selectedImplementation; } - throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundVariables, getSignature())); + throw new PrestoException(NOT_SUPPORTED, format("Unsupported type parameters (%s) for %s", boundVariables, getSignature())); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/QuantileDigestFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/QuantileDigestFunctions.java index e5a6ad4de1c20..098489bb18a2b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/QuantileDigestFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/QuantileDigestFunctions.java @@ -107,8 +107,16 @@ public static Block valuesAtQuantilesDouble(@SqlType("qdigest(double)") Slice in { QuantileDigest digest = new QuantileDigest(input); BlockBuilder output = DOUBLE.createBlockBuilder(null, percentilesArrayBlock.getPositionCount()); - for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { - DOUBLE.writeDouble(output, sortableLongToDouble(digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i)))); + if (percentilesArrayBlock.mayHaveNull()) { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + checkCondition(!percentilesArrayBlock.isNull(i), INVALID_FUNCTION_ARGUMENT, "All quantiles should be non-null."); + DOUBLE.writeDouble(output, sortableLongToDouble(digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i)))); + } + } + else { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + DOUBLE.writeDouble(output, sortableLongToDouble(digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i)))); + } } return output.build(); } @@ -120,8 +128,16 @@ public static Block valuesAtQuantilesReal(@SqlType("qdigest(real)") Slice input, { QuantileDigest digest = new QuantileDigest(input); BlockBuilder output = REAL.createBlockBuilder(null, percentilesArrayBlock.getPositionCount()); - for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { - REAL.writeLong(output, floatToRawIntBits(sortableIntToFloat((int) digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))))); + if (percentilesArrayBlock.mayHaveNull()) { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + checkCondition(!percentilesArrayBlock.isNull(i), INVALID_FUNCTION_ARGUMENT, "All quantiles should be non-null."); + REAL.writeLong(output, floatToRawIntBits(sortableIntToFloat((int) digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))))); + } + } + else { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + REAL.writeLong(output, floatToRawIntBits(sortableIntToFloat((int) digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))))); + } } return output.build(); } @@ -133,8 +149,16 @@ public static Block valuesAtQuantilesBigint(@SqlType("qdigest(bigint)") Slice in { QuantileDigest digest = new QuantileDigest(input); BlockBuilder output = BIGINT.createBlockBuilder(null, percentilesArrayBlock.getPositionCount()); - for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { - BIGINT.writeLong(output, digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))); + if (percentilesArrayBlock.mayHaveNull()) { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + checkCondition(!percentilesArrayBlock.isNull(i), INVALID_FUNCTION_ARGUMENT, "All quantiles should be non-null."); + BIGINT.writeLong(output, digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))); + } + } + else { + for (int i = 0; i < percentilesArrayBlock.getPositionCount(); i++) { + BIGINT.writeLong(output, digest.getQuantile(DOUBLE.getDouble(percentilesArrayBlock, i))); + } } return output.build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java index b6b5e33fa6c00..bcecb71cb43f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.SqlFunctionVisibility; import java.util.Optional; @@ -23,13 +24,15 @@ public class ScalarHeader private final SqlFunctionVisibility visibility; private final boolean deterministic; private final boolean calledOnNullInput; + private final ComplexTypeFunctionDescriptor complexTypeFunctionDescriptor; - public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput) + public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput, ComplexTypeFunctionDescriptor complexTypeFunctionDescriptor) { this.description = description; this.visibility = visibility; this.deterministic = deterministic; this.calledOnNullInput = calledOnNullInput; + this.complexTypeFunctionDescriptor = complexTypeFunctionDescriptor; } public Optional getDescription() @@ -51,4 +54,9 @@ public boolean isCalledOnNullInput() { return calledOnNullInput; } + + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return complexTypeFunctionDescriptor; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java index 2a55fe9485f29..ba7722722f609 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; @@ -26,7 +27,6 @@ import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.CodePointsType; import com.facebook.presto.type.Constraint; -import com.facebook.presto.type.LiteralParameter; import com.google.common.primitives.Ints; import io.airlift.slice.InvalidCodePointException; import io.airlift.slice.InvalidUtf8Exception; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/TryFunction.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/TryFunction.java index 3fe0d689e9536..9aebc92850fd3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/TryFunction.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/TryFunction.java @@ -26,10 +26,6 @@ import java.util.function.Supplier; -import static com.facebook.presto.spi.StandardErrorCode.DIVISION_BY_ZERO; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; -import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN; @Description("internal try function for desugaring TRY") @@ -162,11 +158,7 @@ public static T evaluate(Supplier supplier, T defaultValue) private static void propagateIfUnhandled(PrestoException e) throws PrestoException { - int errorCode = e.getErrorCode().getCode(); - if (errorCode == DIVISION_BY_ZERO.toErrorCode().getCode() - || errorCode == INVALID_CAST_ARGUMENT.toErrorCode().getCode() - || errorCode == INVALID_FUNCTION_ARGUMENT.toErrorCode().getCode() - || errorCode == NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode().getCode()) { + if (e.getErrorCode().isCatchableByTry()) { return; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/UrlFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/UrlFunctions.java index ecdfb3a21cbf3..5f1d386b3b86c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/UrlFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/UrlFunctions.java @@ -26,8 +26,7 @@ import com.google.common.net.UrlEscapers; import io.airlift.slice.Slice; import io.airlift.slice.Slices; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.io.UnsupportedEncodingException; import java.net.URI; diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java index 6f977b34e2f8b..c01d55bcd5aab 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/VarbinaryFunctions.java @@ -355,6 +355,18 @@ public static Slice xxhash64(@SqlType(StandardTypes.VARBINARY) Slice slice) return hash; } + @Description("compute xxhash64 hash with a seed") + @ScalarFunction + @SqlType(StandardTypes.VARBINARY) + public static Slice xxhash64( + @SqlType(StandardTypes.VARBINARY) Slice slice, + @SqlType(StandardTypes.BIGINT) long seed) + { + Slice hash = Slices.allocate(Long.BYTES); + hash.setLong(0, Long.reverseBytes(XxHash64.hash(seed, slice))); + return hash; + } + @Description("compute SpookyHashV2 32-bit hash") @ScalarFunction @SqlType(StandardTypes.VARBINARY) diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java index ae83d270b8308..ae9464001d024 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/CodegenScalarFromAnnotationsParser.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.scalar.annotations; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; @@ -25,6 +26,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.CodegenScalarFunction; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.IsNull; @@ -51,7 +53,9 @@ import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; +import static com.facebook.presto.operator.annotations.FunctionsParserHelper.checkPushdownSubfieldArgIndex; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.findPublicStaticMethods; +import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseFunctionDescriptor; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.functionTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty; import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL; @@ -67,9 +71,14 @@ public class CodegenScalarFromAnnotationsParser private CodegenScalarFromAnnotationsParser() {} public static List parseFunctionDefinitions(Class clazz) + { + return parseFunctionDefinitions(clazz, JAVA_BUILTIN_NAMESPACE); + } + + public static List parseFunctionDefinitions(Class clazz, CatalogSchemaName functionNamespace) { return findScalarsInFunctionDefinitionClass(clazz).stream() - .map(method -> createSqlScalarFunction(method)) + .map(method -> createSqlScalarFunction(method, functionNamespace)) .collect(toImmutableList()); } @@ -109,12 +118,12 @@ private static List getArgumentProperties(Method method) .collect(toImmutableList()); } - private static SqlScalarFunction createSqlScalarFunction(Method method) + private static SqlScalarFunction createSqlScalarFunction(Method method, CatalogSchemaName functionNamespace) { CodegenScalarFunction codegenScalarFunction = method.getAnnotation(CodegenScalarFunction.class); Signature signature = new Signature( - QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, codegenScalarFunction.value()), + QualifiedObjectName.valueOf(functionNamespace, codegenScalarFunction.value()), FunctionKind.SCALAR, Arrays.stream(method.getAnnotationsByType(TypeParameter.class)).map(t -> withVariadicBound(t.value(), t.boundedBy().isEmpty() ? null : t.boundedBy())).collect(toImmutableList()), ImmutableList.of(), @@ -122,6 +131,8 @@ private static SqlScalarFunction createSqlScalarFunction(Method method) Arrays.stream(method.getParameters()).map(p -> parseTypeSignature(p.getAnnotation(SqlType.class).value())).collect(toImmutableList()), false); + ComplexTypeFunctionDescriptor descriptor = parseAndCheckFunctionDescriptor(method, signature); + return new SqlScalarFunction(signature) { @Override @@ -166,6 +177,19 @@ public boolean isCalledOnNullInput() { return codegenScalarFunction.calledOnNullInput(); } + + @Override + public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor() + { + return descriptor; + } }; } + + private static ComplexTypeFunctionDescriptor parseAndCheckFunctionDescriptor(Method method, Signature signature) + { + ComplexTypeFunctionDescriptor descriptor = parseFunctionDescriptor(method); + checkPushdownSubfieldArgIndex(method, signature, descriptor.getPushdownSubfieldArgIndex()); + return descriptor; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index eb2f5a3ae346f..a0b04fd003cd5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.operator.scalar.annotations; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.metadata.SqlScalarFunction; import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.CodegenScalarFunction; +import com.facebook.presto.spi.function.FunctionDescriptor; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.Signature; @@ -35,10 +37,13 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; +import static com.facebook.presto.operator.annotations.FunctionsParserHelper.checkPushdownSubfieldArgIndex; import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.util.Failures.checkCondition; import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public final class ScalarFromAnnotationsParser @@ -46,28 +51,38 @@ public final class ScalarFromAnnotationsParser private ScalarFromAnnotationsParser() {} public static List parseFunctionDefinition(Class clazz) + { + return parseFunctionDefinition(clazz, JAVA_BUILTIN_NAMESPACE); + } + + public static List parseFunctionDefinition(Class clazz, CatalogSchemaName functionNamespace) { ImmutableList.Builder builder = ImmutableList.builder(); - for (ScalarHeaderAndMethods scalar : findScalarsInFunctionDefinitionClass(clazz)) { + for (ScalarHeaderAndMethods scalar : findScalarsInFunctionDefinitionClass(clazz, functionNamespace)) { builder.add(parseParametricScalar(scalar, FunctionsParserHelper.findConstructor(clazz))); } return builder.build(); } public static List parseFunctionDefinitions(Class clazz) + { + return parseFunctionDefinitions(clazz, JAVA_BUILTIN_NAMESPACE); + } + + public static List parseFunctionDefinitions(Class clazz, CatalogSchemaName functionNamespace) { ImmutableList.Builder builder = ImmutableList.builder(); - for (ScalarHeaderAndMethods methods : findScalarsInFunctionSetClass(clazz)) { + for (ScalarHeaderAndMethods methods : findScalarsInFunctionSetClass(clazz, functionNamespace)) { // Non-static function only makes sense in classes annotated @ScalarFunction. builder.add(parseParametricScalar(methods, Optional.empty())); } return builder.build(); } - private static List findScalarsInFunctionDefinitionClass(Class annotated) + private static List findScalarsInFunctionDefinitionClass(Class annotated, CatalogSchemaName functionNamespace) { ImmutableList.Builder builder = ImmutableList.builder(); - List classHeaders = ScalarImplementationHeader.fromAnnotatedElement(annotated); + List classHeaders = ScalarImplementationHeader.fromAnnotatedElement(annotated, functionNamespace); checkArgument(!classHeaders.isEmpty(), "Class [%s] that defines function must be annotated with @ScalarFunction or @ScalarOperator", annotated.getName()); for (ScalarImplementationHeader header : classHeaders) { @@ -83,16 +98,19 @@ private static List findScalarsInFunctionDefinitionClass return builder.build(); } - private static List findScalarsInFunctionSetClass(Class annotated) + private static List findScalarsInFunctionSetClass(Class annotated, CatalogSchemaName functionNamespace) { ImmutableList.Builder builder = ImmutableList.builder(); for (Method method : FunctionsParserHelper.findPublicMethods( annotated, - ImmutableSet.of(SqlType.class, ScalarFunction.class, ScalarOperator.class), + ImmutableSet.of(SqlType.class, ScalarFunction.class, ScalarOperator.class, FunctionDescriptor.class), ImmutableSet.of(SqlInvokedScalarFunction.class, CodegenScalarFunction.class))) { checkCondition((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method); - for (ScalarImplementationHeader header : ScalarImplementationHeader.fromAnnotatedElement(method)) { + if (method.getAnnotation(ScalarOperator.class) != null) { + checkArgument(functionNamespace.equals(JAVA_BUILTIN_NAMESPACE), format("Connector specific Scalar operator functions are not supported: Class [%s], Namespace [%s]", annotated.getName(), functionNamespace)); + } + for (ScalarImplementationHeader header : ScalarImplementationHeader.fromAnnotatedElement(method, functionNamespace)) { builder.add(new ScalarHeaderAndMethods(header, ImmutableSet.of(method))); } } @@ -106,6 +124,7 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc Map signatures = new HashMap<>(); for (Method method : scalar.getMethods()) { ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header, method, constructor); + checkPushdownSubfieldArgIndex(method, implementation.getSignature(), header.getHeader().getComplexTypeFunctionDescriptor().getPushdownSubfieldArgIndex()); if (!signatures.containsKey(implementation.getSpecializedSignature())) { ParametricScalarImplementation.Builder builder = new ParametricScalarImplementation.Builder( implementation.getSignature(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java index b79b0b7f05219..101225c5dc1af 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarImplementationHeader.java @@ -13,9 +13,11 @@ */ package com.facebook.presto.operator.scalar.annotations; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.operator.scalar.ScalarHeader; +import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlFunctionVisibility; @@ -28,6 +30,7 @@ import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription; +import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseFunctionDescriptor; import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN; import static com.google.common.base.CaseFormat.LOWER_CAMEL; import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE; @@ -42,7 +45,12 @@ public class ScalarImplementationHeader private ScalarImplementationHeader(String name, ScalarHeader header) { - this.name = QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, requireNonNull(name)); + this(name, header, JAVA_BUILTIN_NAMESPACE); + } + + private ScalarImplementationHeader(String name, ScalarHeader header, CatalogSchemaName functionNamespace) + { + this.name = QualifiedObjectName.valueOf(requireNonNull(functionNamespace), requireNonNull(name)); this.operatorType = Optional.empty(); this.header = requireNonNull(header); } @@ -71,25 +79,26 @@ private static String camelToSnake(String name) return LOWER_CAMEL.to(LOWER_UNDERSCORE, name); } - public static List fromAnnotatedElement(AnnotatedElement annotated) + public static List fromAnnotatedElement(AnnotatedElement annotated, CatalogSchemaName functionNamespace) { ScalarFunction scalarFunction = annotated.getAnnotation(ScalarFunction.class); ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class); Optional description = parseDescription(annotated); + ComplexTypeFunctionDescriptor descriptor = parseFunctionDescriptor(annotated); ImmutableList.Builder builder = ImmutableList.builder(); if (scalarFunction != null) { String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value(); - builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput()))); + builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput(), descriptor), functionNamespace)); for (String alias : scalarFunction.alias()) { - builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput()))); + builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput(), descriptor), functionNamespace)); } } if (scalarOperator != null) { - builder.add(new ScalarImplementationHeader(scalarOperator.value(), new ScalarHeader(description, HIDDEN, true, scalarOperator.value().isCalledOnNullInput()))); + builder.add(new ScalarImplementationHeader(scalarOperator.value(), new ScalarHeader(description, HIDDEN, true, scalarOperator.value().isCalledOnNullInput(), descriptor))); } List result = builder.build(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java index e34ede50a57f3..1c5f621cb8d05 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar.annotations; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.spi.PrestoException; @@ -39,7 +40,6 @@ import java.util.stream.Stream; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.findPublicStaticMethods; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; @@ -60,7 +60,7 @@ public final class SqlInvokedScalarFromAnnotationsParser { private SqlInvokedScalarFromAnnotationsParser() {} - public static List parseFunctionDefinition(Class clazz) + public static List parseFunctionDefinition(Class clazz, CatalogSchemaName defaultNamespace) { checkArgument(clazz.isAnnotationPresent(SqlInvokedScalarFunction.class), "Class is not annotated with SqlInvokedScalarFunction: %s", clazz.getName()); @@ -68,15 +68,15 @@ public static List parseFunctionDefinition(Class clazz) Optional description = Optional.ofNullable(clazz.getAnnotation(Description.class)).map(Description::value); return findScalarsInFunctionDefinitionClass(clazz).stream() - .map(method -> createSqlInvokedFunctions(method, Optional.of(header), description)) + .map(method -> createSqlInvokedFunctions(method, Optional.of(header), description, defaultNamespace)) .flatMap(List::stream) .collect(toImmutableList()); } - public static List parseFunctionDefinitions(Class clazz) + public static List parseFunctionDefinitions(Class clazz, CatalogSchemaName defaultNamespace) { return findScalarsInFunctionSetClass(clazz).stream() - .map(method -> createSqlInvokedFunctions(method, Optional.empty(), Optional.empty())) + .map(method -> createSqlInvokedFunctions(method, Optional.empty(), Optional.empty(), defaultNamespace)) .flatMap(List::stream) .collect(toImmutableList()); } @@ -121,7 +121,7 @@ private static List findScalarsInFunctionSetClass(Class clazz) return ImmutableList.copyOf(methods); } - private static List createSqlInvokedFunctions(Method method, Optional header, Optional description) + private static List createSqlInvokedFunctions(Method method, Optional header, Optional description, CatalogSchemaName defaultNamespace) { SqlInvokedScalarFunction functionHeader = header.orElseGet(() -> method.getAnnotation(SqlInvokedScalarFunction.class)); String functionDescription = description.orElseGet(() -> method.isAnnotationPresent(Description.class) ? method.getAnnotation(Description.class).value() : ""); @@ -167,7 +167,7 @@ else if (method.isAnnotationPresent(SqlParameters.class)) { return Stream.concat(Stream.of(functionHeader.value()), stream(functionHeader.alias())) .map(name -> new SqlInvokedFunction( - QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, name), + QualifiedObjectName.valueOf(defaultNamespace, name), parameters, typeVariableConstraints, emptyList(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java new file mode 100644 index 0000000000000..a098602b07bac --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/ExcludeColumns.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; + +import javax.inject.Provider; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.stream.Collectors.joining; + +public class ExcludeColumns + implements Provider +{ + public static final String NAME = "exclude_columns"; + + @Override + public ConnectorTableFunction get() + { + return new ExcludeColumnsFunction(); + } + + public static class ExcludeColumnsFunction + extends AbstractConnectorTableFunction + { + private static final String TABLE_ARGUMENT_NAME = "INPUT"; + private static final String DESCRIPTOR_ARGUMENT_NAME = "COLUMNS"; + + public ExcludeColumnsFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name(TABLE_ARGUMENT_NAME) + .rowSemantics() + .build(), + DescriptorArgumentSpecification.builder() + .name(DESCRIPTOR_ARGUMENT_NAME) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + DescriptorArgument excludedColumns = (DescriptorArgument) arguments.get(DESCRIPTOR_ARGUMENT_NAME); + if (excludedColumns.equals(NULL_DESCRIPTOR)) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor is null"); + } + Descriptor excludedColumnsDescriptor = excludedColumns.getDescriptor().orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing exclude columns descriptor")); + if (excludedColumnsDescriptor.getFields().stream().anyMatch(field -> field.getType().isPresent())) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "COLUMNS descriptor contains types"); + } + + // column names in DescriptorArgument are canonical wrt SQL identifier semantics. + // column names in TableArgument are not canonical wrt SQL identifier semantics, as they are taken from the corresponding RelationType. + // because of that, we match the excluded columns names case-insensitive + // TODO: apply proper identifier semantics + Set excludedNames = excludedColumnsDescriptor.getFields().stream() + .map(Descriptor.Field::getName) + .map(name -> name.orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing Descriptor field name")).toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + List inputSchema = ((TableArgument) arguments.get(TABLE_ARGUMENT_NAME)).getRowType().getFields(); + Set inputNames = inputSchema.stream() + .map(RowType.Field::getName) + .filter(Optional::isPresent) + .map(Optional::get) + .map(name -> name.toLowerCase(ENGLISH)) + .collect(toImmutableSet()); + + if (!inputNames.containsAll(excludedNames)) { + String missingColumns = Sets.difference(excludedNames, inputNames).stream() + .collect(joining(", ", "[", "]")); + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Excluded columns: %s not present in the table", missingColumns)); + } + + ImmutableList.Builder requiredColumns = ImmutableList.builder(); + ImmutableList.Builder returnedColumns = ImmutableList.builder(); + + for (int i = 0; i < inputSchema.size(); i++) { + Optional name = inputSchema.get(i).getName(); + if (!name.isPresent() || !excludedNames.contains(name.orElseThrow(() -> new PrestoException(INVALID_FUNCTION_ARGUMENT, "Missing schema name")).toLowerCase(ENGLISH))) { + requiredColumns.add(i); + // per SQL standard, all columns produced by a table function must be named. We allow anonymous columns. + returnedColumns.add(new Descriptor.Field(name, Optional.of(inputSchema.get(i).getType()))); + } + } + + List returnedType = returnedColumns.build(); + if (returnedType.isEmpty()) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "All columns are excluded"); + } + + return TableFunctionAnalysis.builder() + .requiredColumns(TABLE_ARGUMENT_NAME, requiredColumns.build()) + .returnedType(new Descriptor(returnedType)) + .handle(new ExcludeColumnsFunctionHandle()) + .build(); + } + } + + public static TableFunctionProcessorProvider getExcludeColumnsFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(getOnlyElement(input).orElseThrow(() -> new PrestoException(INVALID_ARGUMENTS, "Missing data processor input"))); + }; + } + }; + } + + public static class ExcludeColumnsFunctionHandle + implements ConnectorTableFunctionHandle + { + // there's no information to remember. All logic is effectively delegated to the engine via `requiredColumns`. + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java new file mode 100644 index 0000000000000..f32f850e1632a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/table/Sequence.java @@ -0,0 +1,325 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.table; + +import com.facebook.presto.common.PageBuilder; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import javax.inject.Provider; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.operator.table.Sequence.SequenceFunctionSplit.MAX_SPLIT_SIZE; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.table.Descriptor.descriptor; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; + +public class Sequence + implements Provider +{ + public static final String NAME = "sequence"; + + @Override + public ConnectorTableFunction get() + { + return new SequenceFunction(); + } + + public static class SequenceFunction + extends AbstractConnectorTableFunction + { + private static final String START_ARGUMENT_NAME = "START"; + private static final String STOP_ARGUMENT_NAME = "STOP"; + private static final String STEP_ARGUMENT_NAME = "STEP"; + + public SequenceFunction() + { + super( + "builtin", + NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name(START_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(0L) + .build(), + ScalarArgumentSpecification.builder() + .name(STOP_ARGUMENT_NAME) + .type(BIGINT) + .build(), + ScalarArgumentSpecification.builder() + .name(STEP_ARGUMENT_NAME) + .type(BIGINT) + .defaultValue(1L) + .build()), + new DescribedTable(descriptor(ImmutableList.of("sequential_number"), ImmutableList.of(BIGINT)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + Object startValue = ((ScalarArgument) arguments.get(START_ARGUMENT_NAME)).getValue(); + if (startValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Start is null"); + } + + Object stopValue = ((ScalarArgument) arguments.get(STOP_ARGUMENT_NAME)).getValue(); + if (stopValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Stop is null"); + } + + Object stepValue = ((ScalarArgument) arguments.get(STEP_ARGUMENT_NAME)).getValue(); + if (stepValue == null) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Step is null"); + } + + long start = (long) startValue; + long stop = (long) stopValue; + long step = (long) stepValue; + + if (start < stop && step <= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be positive for sequence [%s, %s]", start, stop)); + } + + if (start > stop && step >= 0) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Step must be negative for sequence [%s, %s]", start, stop)); + } + + return TableFunctionAnalysis.builder() + .handle(new SequenceFunctionHandle(start, stop, start == stop ? 0 : step)) + .build(); + } + } + + public static class SequenceFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long start; + private final long stop; + private final long step; + + @JsonCreator + public SequenceFunctionHandle(@JsonProperty("start") long start, @JsonProperty("stop") long stop, @JsonProperty("step") long step) + { + this.start = start; + this.stop = stop; + this.step = step; + } + + @JsonProperty + public long start() + { + return start; + } + + @JsonProperty + public long stop() + { + return stop; + } + + @JsonProperty + public long step() + { + return step; + } + } + + public static ConnectorSplitSource getSequenceFunctionSplitSource(SequenceFunctionHandle handle) + { + // using BigInteger to avoid long overflow since it's not in the main data processing loop + BigInteger start = BigInteger.valueOf(handle.start()); + BigInteger stop = BigInteger.valueOf(handle.stop()); + BigInteger step = BigInteger.valueOf(handle.step()); + + if (step.equals(BigInteger.ZERO)) { + checkArgument(start.equals(stop), "start is not equal to stop for step = 0"); + return new FixedSplitSource(ImmutableList.of(new SequenceFunctionSplit(start.longValueExact(), stop.longValueExact()))); + } + + ImmutableList.Builder splits = ImmutableList.builder(); + + BigInteger totalSteps = stop.subtract(start).divide(step).add(BigInteger.ONE); + BigInteger totalSplits = totalSteps.divide(BigInteger.valueOf(MAX_SPLIT_SIZE)).add(BigInteger.ONE); + BigInteger[] stepsPerSplit = totalSteps.divideAndRemainder(totalSplits); + BigInteger splitJump = stepsPerSplit[0].subtract(BigInteger.ONE).multiply(step); + + BigInteger splitStart = start; + for (BigInteger i = BigInteger.ZERO; i.compareTo(totalSplits) < 0; i = i.add(BigInteger.ONE)) { + BigInteger splitStop = splitStart.add(splitJump); + // distribute the remaining steps between the initial splits, one step per split + if (i.compareTo(stepsPerSplit[1]) < 0) { + splitStop = splitStop.add(step); + } + splits.add(new SequenceFunctionSplit(splitStart.longValueExact(), splitStop.longValueExact())); + splitStart = splitStop.add(step); + } + + return new FixedSplitSource(splits.build()); + } + + public static class SequenceFunctionSplit + implements ConnectorSplit + { + public static final int DEFAULT_SPLIT_SIZE = 1000000; + public static final int MAX_SPLIT_SIZE = 1000000; + + // the first value of sub-sequence + private final long start; + + // the last value of sub-sequence. this value is aligned so that it belongs to the sequence. + private final long stop; + + @JsonCreator + public SequenceFunctionSplit(@JsonProperty("start") long start, @JsonProperty("stop") long stop) + { + this.start = start; + this.stop = stop; + } + + @JsonProperty + public long getStart() + { + return start; + } + + @JsonProperty + public long getStop() + { + return stop; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return ImmutableMap.builder() + .put("start", start) + .put("stop", stop) + .buildOrThrow(); + } + } + + public static TableFunctionProcessorProvider getSequenceFunctionProcessorProvider() + { + return new TableFunctionProcessorProvider() { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new SequenceFunctionProcessor(((SequenceFunctionHandle) handle).step()); + } + }; + } + + public static class SequenceFunctionProcessor + implements TableFunctionSplitProcessor + { + private final PageBuilder page = new PageBuilder(ImmutableList.of(BIGINT)); + private final long step; + private long start; + private long stop; + private boolean finished; + + public SequenceFunctionProcessor(long step) + { + this.step = step; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split != null) { + SequenceFunctionSplit sequenceSplit = (SequenceFunctionSplit) split; + start = sequenceSplit.getStart(); + stop = sequenceSplit.getStop(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return usedInputAndProduced(page.build()); + } + return usedInputAndProduced(page.build()); + } + + if (finished) { + return FINISHED; + } + + page.reset(); + BlockBuilder block = page.getBlockBuilder(0); + while (start != stop && !page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + start += step; + } + if (!page.isFull()) { + page.declarePosition(); + BIGINT.writeLong(block, start); + finished = true; + return produced(page.build()); + } + return produced(page.build()); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/unnest/UnnestBlockBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/operator/unnest/UnnestBlockBuilder.java index 416b244882048..575dc823f8077 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/unnest/UnnestBlockBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/unnest/UnnestBlockBuilder.java @@ -15,8 +15,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.DictionaryBlock; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import static com.facebook.presto.operator.unnest.UnnestBlockBuilder.NullElementFinder.NULL_NOT_FOUND; import static com.google.common.base.Verify.verify; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/AggregatedResourceGroupInfoBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/AggregatedResourceGroupInfoBuilder.java index 1d035af361de9..5463780328727 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/AggregatedResourceGroupInfoBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/AggregatedResourceGroupInfoBuilder.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.server.QueryStateInfo; import com.facebook.presto.server.ResourceGroupInfo; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; @@ -20,7 +21,6 @@ import com.facebook.presto.spi.resourceGroups.SchedulingPolicy; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterMemoryManagerService.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterMemoryManagerService.java index efa189e8d7afc..51eb8a89941b5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterMemoryManagerService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterMemoryManagerService.java @@ -20,10 +20,9 @@ import com.facebook.presto.spi.memory.MemoryPoolInfo; import com.facebook.presto.util.PeriodicTaskExecutor; import com.google.common.collect.ImmutableMap; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterQueryTrackerService.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterQueryTrackerService.java index 69650289b06dd..8e8fa9d4396b5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterQueryTrackerService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ClusterQueryTrackerService.java @@ -15,10 +15,9 @@ import com.facebook.drift.client.DriftClient; import com.facebook.presto.util.PeriodicTaskExecutor; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicInteger; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ForResourceManager.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ForResourceManager.java index 8cb358309e76f..c5ced1a63d381 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ForResourceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ForResourceManager.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.resourcemanager; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RandomResourceManagerAddressSelector.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RandomResourceManagerAddressSelector.java index d924d16e74115..6a250e75d3ffa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RandomResourceManagerAddressSelector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RandomResourceManagerAddressSelector.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.HostAddress; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HostAndPort; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RatisServer.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RatisServer.java index 317482435717b..4380990db5928 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RatisServer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/RatisServer.java @@ -15,6 +15,8 @@ import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; +import jakarta.annotation.PostConstruct; +import jakarta.inject.Inject; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.grpc.GrpcConfigKeys; import org.apache.ratis.protocol.RaftGroup; @@ -25,9 +27,6 @@ import org.apache.ratis.server.RaftServerConfigKeys; import org.apache.ratis.statemachine.impl.BaseStateMachine; -import javax.annotation.PostConstruct; -import javax.inject.Inject; - import java.io.File; import java.util.Collections; import java.util.Set; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStateProvider.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStateProvider.java index 6ee9dac448107..828b1a2f71b05 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStateProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStateProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.resourceGroups.ResourceGroupRuntimeInfo; import com.facebook.presto.memory.ClusterMemoryPool; import com.facebook.presto.memory.MemoryInfo; @@ -31,10 +32,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.inject.Inject; import java.net.URI; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerConfig.java index 0a013130cff27..b8e5793873a95 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerConfig.java @@ -15,10 +15,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.Min; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerResourceGroupService.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerResourceGroupService.java index 37c6810e0bf06..cb50153fb958f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerResourceGroupService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerResourceGroupService.java @@ -13,15 +13,14 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.Duration; import com.facebook.drift.client.DriftClient; import com.facebook.presto.execution.resourceGroups.ResourceGroupRuntimeInfo; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.Executor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerServer.java b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerServer.java index 21b70e4a2bb9d..e08702ad1ca69 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerServer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerServer.java @@ -23,8 +23,7 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java index 987e8cb26cc9c..6bc0a9b1a197d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlManager.java @@ -42,11 +42,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.io.File; import java.security.Principal; import java.security.cert.X509Certificate; @@ -57,7 +56,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK; @@ -81,13 +79,14 @@ public class AccessControlManager private final Map systemAccessControlFactories = new ConcurrentHashMap<>(); private final Map connectorAccessControl = new ConcurrentHashMap<>(); - private final AtomicReference systemAccessControl = new AtomicReference<>(new InitializingSystemAccessControl()); + private final StatsRecordingSystemAccessControl systemAccessControl = new StatsRecordingSystemAccessControl(new InitializingSystemAccessControl()); private final AtomicBoolean systemAccessControlLoading = new AtomicBoolean(); private final CounterStat authenticationSuccess = new CounterStat(); private final CounterStat authenticationFail = new CounterStat(); private final CounterStat authorizationSuccess = new CounterStat(); private final CounterStat authorizationFail = new CounterStat(); + private StatsRecordingSystemAccessControl.Stats detailedStats = new StatsRecordingSystemAccessControl.Stats(); @Inject public AccessControlManager(TransactionManager transactionManager) @@ -96,6 +95,7 @@ public AccessControlManager(TransactionManager transactionManager) addSystemAccessControlFactory(new AllowAllSystemAccessControl.Factory()); addSystemAccessControlFactory(new ReadOnlySystemAccessControl.Factory()); addSystemAccessControlFactory(new FileBasedSystemAccessControl.Factory()); + addSystemAccessControlFactory(new DenyQueryIntegrityCheckSystemAccessControl.Factory()); } public void addSystemAccessControlFactory(SystemAccessControlFactory accessControlFactory) @@ -159,8 +159,7 @@ protected void setSystemAccessControl(String name, Map propertie SystemAccessControlFactory systemAccessControlFactory = systemAccessControlFactories.get(name); checkState(systemAccessControlFactory != null, "Access control %s is not registered", name); - SystemAccessControl systemAccessControl = systemAccessControlFactory.create(ImmutableMap.copyOf(properties)); - this.systemAccessControl.set(systemAccessControl); + systemAccessControl.updateDelegate(systemAccessControlFactory.create(ImmutableMap.copyOf(properties))); log.info("-- Loaded system access control %s --", name); } @@ -171,7 +170,7 @@ public void checkCanSetUser(Identity identity, AccessControlContext context, Opt requireNonNull(principal, "principal is null"); requireNonNull(userName, "userName is null"); - authenticationCheck(() -> systemAccessControl.get().checkCanSetUser(identity, context, principal, userName)); + authenticationCheck(() -> systemAccessControl.checkCanSetUser(identity, context, principal, userName)); } @Override @@ -180,16 +179,16 @@ public AuthorizedIdentity selectAuthorizedIdentity(Identity identity, AccessCont requireNonNull(userName, "userName is null"); requireNonNull(certificates, "certificates is null"); - return systemAccessControl.get().selectAuthorizedIdentity(identity, context, userName, certificates); + return systemAccessControl.selectAuthorizedIdentity(identity, context, userName, certificates); } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { requireNonNull(identity, "identity is null"); requireNonNull(query, "query is null"); - authenticationCheck(() -> systemAccessControl.get().checkQueryIntegrity(identity, context, query, viewDefinitions, materializedViewDefinitions)); + authenticationCheck(() -> systemAccessControl.checkQueryIntegrity(identity, context, query, preparedStatements, viewDefinitions, materializedViewDefinitions)); } @Override @@ -198,7 +197,7 @@ public Set filterCatalogs(Identity identity, AccessControlContext contex requireNonNull(identity, "identity is null"); requireNonNull(catalogs, "catalogs is null"); - return systemAccessControl.get().filterCatalogs(identity, context, catalogs); + return systemAccessControl.filterCatalogs(identity, context, catalogs); } @Override @@ -207,7 +206,7 @@ public void checkCanAccessCatalog(Identity identity, AccessControlContext contex requireNonNull(identity, "identity is null"); requireNonNull(catalogName, "catalog is null"); - authenticationCheck(() -> systemAccessControl.get().checkCanAccessCatalog(identity, context, catalogName)); + authenticationCheck(() -> systemAccessControl.checkCanAccessCatalog(identity, context, catalogName)); } @Override @@ -218,7 +217,7 @@ public void checkCanCreateSchema(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, schemaName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanCreateSchema(identity, context, schemaName)); + authorizationCheck(() -> systemAccessControl.checkCanCreateSchema(identity, context, schemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); if (entry != null) { @@ -234,7 +233,7 @@ public void checkCanDropSchema(TransactionId transactionId, Identity identity, A authenticationCheck(() -> checkCanAccessCatalog(identity, context, schemaName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDropSchema(identity, context, schemaName)); + authorizationCheck(() -> systemAccessControl.checkCanDropSchema(identity, context, schemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); if (entry != null) { @@ -250,7 +249,7 @@ public void checkCanRenameSchema(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, schemaName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanRenameSchema(identity, context, schemaName, newSchemaName)); + authorizationCheck(() -> systemAccessControl.checkCanRenameSchema(identity, context, schemaName, newSchemaName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schemaName.getCatalogName()); if (entry != null) { @@ -266,7 +265,7 @@ public void checkCanShowSchemas(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, catalogName)); - authorizationCheck(() -> systemAccessControl.get().checkCanShowSchemas(identity, context, catalogName)); + authorizationCheck(() -> systemAccessControl.checkCanShowSchemas(identity, context, catalogName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); if (entry != null) { @@ -285,7 +284,7 @@ public Set filterSchemas(TransactionId transactionId, Identity identity, return ImmutableSet.of(); } - schemaNames = systemAccessControl.get().filterSchemas(identity, context, catalogName, schemaNames); + schemaNames = systemAccessControl.filterSchemas(identity, context, catalogName, schemaNames); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); if (entry != null) { @@ -294,6 +293,22 @@ public Set filterSchemas(TransactionId transactionId, Identity identity, return schemaNames; } + @Override + public void checkCanShowCreateTable(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "tableName is null"); + + authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.checkCanShowCreateTable(identity, context, toCatalogSchemaTableName(tableName))); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanShowCreateTable(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName))); + } + } + @Override public void checkCanCreateTable(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) { @@ -302,7 +317,7 @@ public void checkCanCreateTable(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanCreateTable(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanCreateTable(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -318,7 +333,7 @@ public void checkCanDropTable(TransactionId transactionId, Identity identity, Ac authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDropTable(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanDropTable(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -335,7 +350,7 @@ public void checkCanRenameTable(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanRenameTable(identity, context, toCatalogSchemaTableName(tableName), toCatalogSchemaTableName(newTableName))); + authorizationCheck(() -> systemAccessControl.checkCanRenameTable(identity, context, toCatalogSchemaTableName(tableName), toCatalogSchemaTableName(newTableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -349,7 +364,7 @@ public void checkCanSetTableProperties(TransactionId transactionId, Identity ide requireNonNull(identity, "identity is null"); requireNonNull(tableName, "tableName is null"); authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanSetTableProperties(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanSetTableProperties(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { authorizationCheck(() -> entry.getAccessControl().checkCanSetTableProperties(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName), properties)); @@ -364,7 +379,7 @@ public void checkCanShowTablesMetadata(TransactionId transactionId, Identity ide authenticationCheck(() -> checkCanAccessCatalog(identity, context, schema.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanShowTablesMetadata(identity, context, schema)); + authorizationCheck(() -> systemAccessControl.checkCanShowTablesMetadata(identity, context, schema)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, schema.getCatalogName()); if (entry != null) { @@ -383,7 +398,7 @@ public Set filterTables(TransactionId transactionId, Identity i return ImmutableSet.of(); } - tableNames = systemAccessControl.get().filterTables(identity, context, catalogName, tableNames); + tableNames = systemAccessControl.filterTables(identity, context, catalogName, tableNames); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); if (entry != null) { @@ -392,6 +407,46 @@ public Set filterTables(TransactionId transactionId, Identity i return tableNames; } + @Override + public void checkCanShowColumnsMetadata(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "table is null"); + + CatalogSchemaTableName catalogSchemaTableName = toCatalogSchemaTableName(tableName); + + authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.checkCanShowColumnsMetadata(identity, context, catalogSchemaTableName)); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanShowColumnsMetadata(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(), context, catalogSchemaTableName.getSchemaTableName())); + } + } + + @Override + public List filterColumns(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName, List columns) + { + requireNonNull(transactionId, "transaction is null"); + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "tableName is null"); + + SchemaTableName schemaTableName = new SchemaTableName(tableName.getSchemaName(), tableName.getObjectName()); + + if (filterTables(transactionId, identity, context, tableName.getCatalogName(), ImmutableSet.of(schemaTableName)).isEmpty()) { + return ImmutableList.of(); + } + + columns = systemAccessControl.filterColumns(identity, context, toCatalogSchemaTableName(tableName), columns); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + columns = entry.getAccessControl().filterColumns(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(), context, schemaTableName, columns); + } + return columns; + } + @Override public void checkCanAddColumns(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) { @@ -400,7 +455,7 @@ public void checkCanAddColumns(TransactionId transactionId, Identity identity, A authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanAddColumn(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanAddColumn(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -416,7 +471,7 @@ public void checkCanDropColumn(TransactionId transactionId, Identity identity, A authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDropColumn(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanDropColumn(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -432,7 +487,7 @@ public void checkCanRenameColumn(TransactionId transactionId, Identity identity, authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanRenameColumn(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanRenameColumn(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -448,7 +503,7 @@ public void checkCanInsertIntoTable(TransactionId transactionId, Identity identi authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanInsertIntoTable(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanInsertIntoTable(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -464,7 +519,7 @@ public void checkCanDeleteFromTable(TransactionId transactionId, Identity identi authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDeleteFromTable(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanDeleteFromTable(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -480,7 +535,7 @@ public void checkCanTruncateTable(TransactionId transactionId, Identity identity authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanTruncateTable(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanTruncateTable(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -496,7 +551,7 @@ public void checkCanUpdateTableColumns(TransactionId transactionId, Identity ide authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanUpdateTableColumns(identity, context, toCatalogSchemaTableName(tableName), updatedColumnNames)); + authorizationCheck(() -> systemAccessControl.checkCanUpdateTableColumns(identity, context, toCatalogSchemaTableName(tableName), updatedColumnNames)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -512,7 +567,7 @@ public void checkCanCreateView(TransactionId transactionId, Identity identity, A authenticationCheck(() -> checkCanAccessCatalog(identity, context, viewName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanCreateView(identity, context, toCatalogSchemaTableName(viewName))); + authorizationCheck(() -> systemAccessControl.checkCanCreateView(identity, context, toCatalogSchemaTableName(viewName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); if (entry != null) { @@ -529,7 +584,7 @@ public void checkCanRenameView(TransactionId transactionId, Identity identity, A authenticationCheck(() -> checkCanAccessCatalog(identity, context, viewName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanRenameView(identity, context, toCatalogSchemaTableName(viewName), toCatalogSchemaTableName(newViewName))); + authorizationCheck(() -> systemAccessControl.checkCanRenameView(identity, context, toCatalogSchemaTableName(viewName), toCatalogSchemaTableName(newViewName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); if (entry != null) { @@ -545,7 +600,7 @@ public void checkCanDropView(TransactionId transactionId, Identity identity, Acc authenticationCheck(() -> checkCanAccessCatalog(identity, context, viewName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDropView(identity, context, toCatalogSchemaTableName(viewName))); + authorizationCheck(() -> systemAccessControl.checkCanDropView(identity, context, toCatalogSchemaTableName(viewName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, viewName.getCatalogName()); if (entry != null) { @@ -561,7 +616,7 @@ public void checkCanCreateViewWithSelectFromColumns(TransactionId transactionId, authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanCreateViewWithSelectFromColumns(identity, context, toCatalogSchemaTableName(tableName), columnNames)); + authorizationCheck(() -> systemAccessControl.checkCanCreateViewWithSelectFromColumns(identity, context, toCatalogSchemaTableName(tableName), columnNames)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -578,7 +633,7 @@ public void checkCanGrantTablePrivilege(TransactionId transactionId, Identity id authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanGrantTablePrivilege(identity, context, privilege, toCatalogSchemaTableName(tableName), grantee, withGrantOption)); + authorizationCheck(() -> systemAccessControl.checkCanGrantTablePrivilege(identity, context, privilege, toCatalogSchemaTableName(tableName), grantee, withGrantOption)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -595,7 +650,7 @@ public void checkCanRevokeTablePrivilege(TransactionId transactionId, Identity i authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanRevokeTablePrivilege(identity, context, privilege, toCatalogSchemaTableName(tableName), revokee, grantOptionFor)); + authorizationCheck(() -> systemAccessControl.checkCanRevokeTablePrivilege(identity, context, privilege, toCatalogSchemaTableName(tableName), revokee, grantOptionFor)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -609,7 +664,7 @@ public void checkCanSetSystemSessionProperty(Identity identity, AccessControlCon requireNonNull(identity, "identity is null"); requireNonNull(propertyName, "propertyName is null"); - authorizationCheck(() -> systemAccessControl.get().checkCanSetSystemSessionProperty(identity, context, propertyName)); + authorizationCheck(() -> systemAccessControl.checkCanSetSystemSessionProperty(identity, context, propertyName)); } @Override @@ -621,7 +676,7 @@ public void checkCanSetCatalogSessionProperty(TransactionId transactionId, Ident authenticationCheck(() -> checkCanAccessCatalog(identity, context, catalogName)); - authorizationCheck(() -> systemAccessControl.get().checkCanSetCatalogSessionProperty(identity, context, catalogName, propertyName)); + authorizationCheck(() -> systemAccessControl.checkCanSetCatalogSessionProperty(identity, context, catalogName, propertyName)); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, catalogName); if (entry != null) { @@ -638,7 +693,7 @@ public void checkCanSelectFromColumns(TransactionId transactionId, Identity iden authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanSelectFromColumns( + authorizationCheck(() -> systemAccessControl.checkCanSelectFromColumns( identity, context, toCatalogSchemaTableName(tableName), @@ -650,6 +705,26 @@ public void checkCanSelectFromColumns(TransactionId transactionId, Identity iden } } + @Override + public void checkCanCallProcedure(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName procedureName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(procedureName, "procedureName is null"); + + authenticationCheck(() -> checkCanAccessCatalog(identity, context, procedureName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.checkCanCallProcedure( + identity, + context, + toCatalogSchemaTableName(procedureName))); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, procedureName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanCallProcedure(entry.getTransactionHandle(transactionId), + identity.toConnectorIdentity(procedureName.getCatalogName()), context, toSchemaTableName(procedureName))); + } + } + @Override public void checkCanCreateRole(TransactionId transactionId, Identity identity, AccessControlContext context, String role, Optional grantor, String catalogName) { @@ -772,6 +847,38 @@ public void checkCanShowRoleGrants(TransactionId transactionId, Identity identit } } + @Override + public void checkCanDropBranch(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "tableName is null"); + + authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.checkCanDropBranch(identity, context, toCatalogSchemaTableName(tableName))); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanDropBranch(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName))); + } + } + + @Override + public void checkCanDropTag(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + requireNonNull(identity, "identity is null"); + requireNonNull(tableName, "tableName is null"); + + authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); + + authorizationCheck(() -> systemAccessControl.checkCanDropTag(identity, context, toCatalogSchemaTableName(tableName))); + + CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); + if (entry != null) { + authorizationCheck(() -> entry.getAccessControl().checkCanDropTag(entry.getTransactionHandle(transactionId), identity.toConnectorIdentity(tableName.getCatalogName()), context, toSchemaTableName(tableName))); + } + } + @Override public void checkCanDropConstraint(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) { @@ -780,7 +887,7 @@ public void checkCanDropConstraint(TransactionId transactionId, Identity identit authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanDropConstraint(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanDropConstraint(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -796,7 +903,7 @@ public void checkCanAddConstraints(TransactionId transactionId, Identity identit authenticationCheck(() -> checkCanAccessCatalog(identity, context, tableName.getCatalogName())); - authorizationCheck(() -> systemAccessControl.get().checkCanAddConstraint(identity, context, toCatalogSchemaTableName(tableName))); + authorizationCheck(() -> systemAccessControl.checkCanAddConstraint(identity, context, toCatalogSchemaTableName(tableName))); CatalogAccessControlEntry entry = getConnectorAccessControl(transactionId, tableName.getCatalogName()); if (entry != null) { @@ -818,7 +925,7 @@ public List getRowFilters(TransactionId transactionId, Identity .forEach(filters::add); } - systemAccessControl.get().getRowFilters(identity, context, toCatalogSchemaTableName(tableName)) + systemAccessControl.getRowFilters(identity, context, toCatalogSchemaTableName(tableName)) .forEach(filters::add); return filters.build(); @@ -841,7 +948,7 @@ public Map getColumnMasks(TransactionId transact columnMasksBuilder.putAll(connectorMasks); } - Map systemMasks = systemAccessControl.get().getColumnMasks(identity, context, toCatalogSchemaTableName(tableName), columns); + Map systemMasks = systemAccessControl.getColumnMasks(identity, context, toCatalogSchemaTableName(tableName), columns); columnMasksBuilder.putAll(systemMasks); try { @@ -887,6 +994,13 @@ public CounterStat getAuthorizationFail() return authorizationFail; } + @Managed + @Nested + public StatsRecordingSystemAccessControl.Stats getDetailedStats() + { + return systemAccessControl.getStats(); + } + private void authenticationCheck(Runnable runnable) { try { @@ -947,7 +1061,7 @@ private static class InitializingSystemAccessControl implements SystemAccessControl { @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { throw new PrestoException(SERVER_STARTING_UP, "Presto server is still initializing"); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlUtils.java b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlUtils.java index a22344d988a17..0800e5f5115c3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/AccessControlUtils.java @@ -22,7 +22,9 @@ import com.facebook.presto.spi.security.AccessDeniedException; import com.facebook.presto.spi.security.AuthorizedIdentity; import com.facebook.presto.spi.security.Identity; +import com.google.common.collect.ImmutableMap; +import java.util.Map; import java.util.Optional; public class AccessControlUtils @@ -49,7 +51,9 @@ public static void checkPermissions(AccessControl accessControl, SecurityConfig sessionContext.getRuntimeStats(), Optional.empty(), Optional.ofNullable(sessionContext.getCatalog()), - Optional.ofNullable(sessionContext.getSchema())), + Optional.ofNullable(sessionContext.getSchema()), + getSqlText(sessionContext, securityConfig), + getPreparedStatements(sessionContext, securityConfig)), identity.getPrincipal(), identity.getUser()); } @@ -77,11 +81,29 @@ public static Optional getAuthorizedIdentity(AccessControl a sessionContext.getRuntimeStats(), Optional.empty(), Optional.ofNullable(sessionContext.getCatalog()), - Optional.ofNullable(sessionContext.getSchema())), + Optional.ofNullable(sessionContext.getSchema()), + getSqlText(sessionContext, securityConfig), + getPreparedStatements(sessionContext, securityConfig)), identity.getUser(), sessionContext.getCertificates()); return Optional.of(authorizedIdentity); } return Optional.empty(); } + + private static Optional getSqlText(SessionContext sessionContext, SecurityConfig securityConfig) + { + if (securityConfig.isEnableSqlQueryTextContextField()) { + return Optional.of(sessionContext.getSqlText()); + } + return Optional.empty(); + } + + private static Map getPreparedStatements(SessionContext sessionContext, SecurityConfig securityConfig) + { + if (securityConfig.isEnableSqlQueryTextContextField()) { + return sessionContext.getPreparedStatements(); + } + return ImmutableMap.of(); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java index 84160517c6d97..2b7459ef6b1c6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/AllowAllSystemAccessControl.java @@ -67,7 +67,7 @@ public SystemAccessControl create(Map config) } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { } @@ -124,6 +124,11 @@ public Set filterSchemas(Identity identity, AccessControlContext context return schemaNames; } + @Override + public void checkCanShowCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + @Override public void checkCanCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -154,6 +159,17 @@ public Set filterTables(Identity identity, AccessControlContext return tableNames; } + @Override + public void checkCanShowColumnsMetadata(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + + @Override + public List filterColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, List columns) + { + return columns; + } + @Override public void checkCanAddColumn(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -174,6 +190,11 @@ public void checkCanSelectFromColumns(Identity identity, AccessControlContext co { } + @Override + public void checkCanCallProcedure(Identity identity, AccessControlContext context, CatalogSchemaTableName procedure) + { + } + @Override public void checkCanInsertIntoTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -229,6 +250,16 @@ public void checkCanRevokeTablePrivilege(Identity identity, AccessControlContext { } + @Override + public void checkCanDropBranch(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanDropTag(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + @Override public void checkCanDropConstraint(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/DenyQueryIntegrityCheckSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/DenyQueryIntegrityCheckSystemAccessControl.java new file mode 100644 index 0000000000000..07412f651b902 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/security/DenyQueryIntegrityCheckSystemAccessControl.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.SystemAccessControl; +import com.facebook.presto.spi.security.SystemAccessControlFactory; + +import java.security.Principal; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.security.AccessDeniedException.denyQueryIntegrityCheck; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class DenyQueryIntegrityCheckSystemAccessControl + implements SystemAccessControl +{ + public static final String NAME = "deny-query-integrity-check"; + + private static final DenyQueryIntegrityCheckSystemAccessControl INSTANCE = new DenyQueryIntegrityCheckSystemAccessControl(); + + public static class Factory + implements SystemAccessControlFactory + { + @Override + public String getName() + { + return NAME; + } + + @Override + public SystemAccessControl create(Map config) + { + requireNonNull(config, "config is null"); + checkArgument(config.isEmpty(), "This access controller does not support any configuration properties"); + return INSTANCE; + } + } + + @Override + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) + { + denyQueryIntegrityCheck(); + } + + @Override + public void checkCanSetUser(Identity identity, AccessControlContext context, Optional principal, String userName) + { + } + + @Override + public void checkCanSetSystemSessionProperty(Identity identity, AccessControlContext context, String propertyName) + { + } + + @Override + public void checkCanAccessCatalog(Identity identity, AccessControlContext context, String catalogName) + { + } + + @Override + public Set filterCatalogs(Identity identity, AccessControlContext context, Set catalogs) + { + return catalogs; + } + + @Override + public Set filterSchemas(Identity identity, AccessControlContext context, String catalogName, Set schemaNames) + { + return schemaNames; + } + + @Override + public void checkCanCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanCreateView(Identity identity, AccessControlContext context, CatalogSchemaTableName view) + { + } + + @Override + public void checkCanShowCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + + @Override + public void checkCanSelectFromColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, Set columns) + { + } + + @Override + public void checkCanCallProcedure(Identity identity, AccessControlContext context, CatalogSchemaTableName procedure) + { + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java index d0ae4c0451f93..d248e838ff8ad 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/FileBasedSystemAccessControl.java @@ -14,6 +14,7 @@ package com.facebook.presto.security; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.plugin.base.security.ForwardingSystemAccessControl; @@ -36,7 +37,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import java.nio.file.Paths; import java.security.Principal; @@ -55,16 +55,19 @@ import static com.facebook.presto.spi.StandardErrorCode.CONFIGURATION_INVALID; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddConstraint; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCallProcedure; import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateViewWithSelect; import static com.facebook.presto.spi.security.AccessDeniedException.denyDeleteTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropBranch; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropConstraint; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTable; +import static com.facebook.presto.spi.security.AccessDeniedException.denyDropTag; import static com.facebook.presto.spi.security.AccessDeniedException.denyDropView; import static com.facebook.presto.spi.security.AccessDeniedException.denyGrantTablePrivilege; import static com.facebook.presto.spi.security.AccessDeniedException.denyInsertTable; @@ -75,6 +78,8 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyRevokeTablePrivilege; import static com.facebook.presto.spi.security.AccessDeniedException.denySetTableProperties; import static com.facebook.presto.spi.security.AccessDeniedException.denySetUser; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowColumnsMetadata; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyTruncateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyUpdateTableColumns; import static com.google.common.base.Preconditions.checkState; @@ -201,7 +206,7 @@ public AuthorizedIdentity selectAuthorizedIdentity(Identity identity, AccessCont } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { } @@ -280,6 +285,14 @@ public Set filterSchemas(Identity identity, AccessControlContext context return schemaNames; } + @Override + public void checkCanShowCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + if (!canAccessCatalog(identity, table.getCatalogName(), READ_ONLY)) { + denyShowCreateTable(table.toString()); + } + } + @Override public void checkCanCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -334,6 +347,24 @@ public Set filterTables(Identity identity, AccessControlContext return tableNames; } + @Override + public void checkCanShowColumnsMetadata(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + if (!canAccessCatalog(identity, table.getCatalogName(), READ_ONLY)) { + denyShowColumnsMetadata(table.toString()); + } + } + + @Override + public List filterColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, List columns) + { + if (!canAccessCatalog(identity, table.getCatalogName(), READ_ONLY)) { + return ImmutableList.of(); + } + + return columns; + } + @Override public void checkCanAddColumn(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -363,6 +394,14 @@ public void checkCanSelectFromColumns(Identity identity, AccessControlContext co { } + @Override + public void checkCanCallProcedure(Identity identity, AccessControlContext context, CatalogSchemaTableName procedure) + { + if (!canAccessCatalog(identity, procedure.getCatalogName(), ALL)) { + denyCallProcedure(procedure.toString()); + } + } + @Override public void checkCanInsertIntoTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { @@ -440,6 +479,22 @@ public void checkCanRevokeTablePrivilege(Identity identity, AccessControlContext } } + @Override + public void checkCanDropBranch(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + if (!canAccessCatalog(identity, table.getCatalogName(), ALL)) { + denyDropBranch(table.toString()); + } + } + + @Override + public void checkCanDropTag(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + if (!canAccessCatalog(identity, table.getCatalogName(), ALL)) { + denyDropTag(table.toString()); + } + } + @Override public void checkCanDropConstraint(Identity identity, AccessControlContext context, CatalogSchemaTableName table) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java index 477fb4a581a57..32452d9b39d70 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/security/ReadOnlySystemAccessControl.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.analyzer.ViewDefinition; @@ -25,6 +26,7 @@ import com.facebook.presto.spi.security.SystemAccessControlFactory; import java.security.Principal; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -63,7 +65,7 @@ public void checkCanSetUser(Identity identity, AccessControlContext context, Opt } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { } @@ -82,6 +84,11 @@ public void checkCanSelectFromColumns(Identity identity, AccessControlContext co { } + @Override + public void checkCanCallProcedure(Identity identity, AccessControlContext context, CatalogSchemaTableName procedure) + { + } + @Override public void checkCanSetCatalogSessionProperty(Identity identity, AccessControlContext context, String catalogName, String propertyName) { @@ -104,12 +111,28 @@ public Set filterSchemas(Identity identity, AccessControlContext context return schemaNames; } + @Override + public void checkCanShowCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + @Override public Set filterTables(Identity identity, AccessControlContext context, String catalogName, Set tableNames) { return tableNames; } + @Override + public void checkCanShowColumnsMetadata(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + } + + @Override + public List filterColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, List columns) + { + return columns; + } + @Override public void checkCanShowSchemas(Identity identity, AccessControlContext context, String catalogName) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/StatsRecordingSystemAccessControl.java b/presto-main-base/src/main/java/com/facebook/presto/security/StatsRecordingSystemAccessControl.java new file mode 100644 index 0000000000000..6c57e4204f063 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/security/StatsRecordingSystemAccessControl.java @@ -0,0 +1,1146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.RuntimeUnit; +import com.facebook.presto.spi.CatalogSchemaTableName; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.AuthorizedIdentity; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.PrestoPrincipal; +import com.facebook.presto.spi.security.Privilege; +import com.facebook.presto.spi.security.SystemAccessControl; +import com.facebook.presto.spi.security.ViewExpression; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import java.security.Principal; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Objects.requireNonNull; + +public final class StatsRecordingSystemAccessControl + implements SystemAccessControl +{ + private final Stats stats = new Stats(); + private final AtomicReference delegate = new AtomicReference<>(); + + public StatsRecordingSystemAccessControl(SystemAccessControl delegate) + { + updateDelegate(delegate); + } + + public void updateDelegate(SystemAccessControl delegate) + { + this.delegate.set(requireNonNull(delegate, "delegate is null")); + } + + public Stats getStats() + { + return stats; + } + + @Override + public void checkCanSetUser(Identity identity, AccessControlContext context, Optional principal, String userName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanSetUser(identity, context, principal, userName); + } + catch (RuntimeException e) { + stats.checkCanSetUser.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanSetUser", RuntimeUnit.NANO, duration); + stats.checkCanSetUser.record(duration); + } + } + + @Override + public AuthorizedIdentity selectAuthorizedIdentity(Identity identity, AccessControlContext context, String userName, List certificates) + { + long start = System.nanoTime(); + try { + return delegate.get().selectAuthorizedIdentity(identity, context, userName, certificates); + } + catch (RuntimeException e) { + stats.selectAuthorizedIdentity.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.selectAuthorizedIdentity", RuntimeUnit.NANO, duration); + stats.selectAuthorizedIdentity.record(duration); + } + } + + @Override + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) + { + long start = System.nanoTime(); + try { + delegate.get().checkQueryIntegrity(identity, context, query, preparedStatements, viewDefinitions, materializedViewDefinitions); + } + catch (RuntimeException e) { + stats.checkQueryIntegrity.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkQueryIntegrity", RuntimeUnit.NANO, duration); + stats.checkQueryIntegrity.record(duration); + } + } + + @Override + public void checkCanSetSystemSessionProperty(Identity identity, AccessControlContext context, String propertyName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanSetSystemSessionProperty(identity, context, propertyName); + } + catch (RuntimeException e) { + stats.checkCanSetSystemSessionProperty.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanSetSystemSessionProperty", RuntimeUnit.NANO, duration); + stats.checkCanSetSystemSessionProperty.record(duration); + } + } + + @Override + public void checkCanAccessCatalog(Identity identity, AccessControlContext context, String catalogName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanAccessCatalog(identity, context, catalogName); + } + catch (RuntimeException e) { + stats.checkCanAccessCatalog.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanAccessCatalog", RuntimeUnit.NANO, duration); + stats.checkCanAccessCatalog.record(duration); + } + } + + @Override + public Set filterCatalogs(Identity identity, AccessControlContext context, Set catalogs) + { + long start = System.nanoTime(); + try { + return delegate.get().filterCatalogs(identity, context, catalogs); + } + catch (RuntimeException e) { + stats.filterCatalogs.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.filterCatalogs", RuntimeUnit.NANO, duration); + stats.filterCatalogs.record(duration); + } + } + + @Override + public void checkCanCreateSchema(Identity identity, AccessControlContext context, CatalogSchemaName schema) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanCreateSchema(identity, context, schema); + } + catch (RuntimeException e) { + stats.checkCanCreateSchema.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanCreateSchema", RuntimeUnit.NANO, duration); + stats.checkCanCreateSchema.record(duration); + } + } + + @Override + public void checkCanDropSchema(Identity identity, AccessControlContext context, CatalogSchemaName schema) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropSchema(identity, context, schema); + } + catch (RuntimeException e) { + stats.checkCanDropSchema.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropSchema", RuntimeUnit.NANO, duration); + stats.checkCanDropSchema.record(duration); + } + } + + @Override + public void checkCanRenameSchema(Identity identity, AccessControlContext context, CatalogSchemaName schema, String newSchemaName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanRenameSchema(identity, context, schema, newSchemaName); + } + catch (RuntimeException e) { + stats.checkCanRenameSchema.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanRenameSchema", RuntimeUnit.NANO, duration); + stats.checkCanRenameSchema.record(duration); + } + } + + @Override + public void checkCanShowSchemas(Identity identity, AccessControlContext context, String catalogName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanShowSchemas(identity, context, catalogName); + } + catch (RuntimeException e) { + stats.checkCanShowSchemas.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanShowSchemas", RuntimeUnit.NANO, duration); + stats.checkCanShowSchemas.record(duration); + } + } + + @Override + public Set filterSchemas(Identity identity, AccessControlContext context, String catalogName, Set schemaNames) + { + long start = System.nanoTime(); + try { + return delegate.get().filterSchemas(identity, context, catalogName, schemaNames); + } + catch (RuntimeException e) { + stats.filterSchemas.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.filterSchemas", RuntimeUnit.NANO, duration); + stats.filterSchemas.record(duration); + } + } + + @Override + public void checkCanShowCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanShowCreateTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanShowCreateTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanShowCreateTable", RuntimeUnit.NANO, duration); + stats.checkCanShowCreateTable.record(duration); + } + } + + @Override + public void checkCanCreateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanCreateTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanCreateTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanCreateTable", RuntimeUnit.NANO, duration); + stats.checkCanCreateTable.record(duration); + } + } + + @Override + public void checkCanSetTableProperties(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanSetTableProperties(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanSetTableProperties.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanSetTableProperties", RuntimeUnit.NANO, duration); + stats.checkCanSetTableProperties.record(duration); + } + } + + @Override + public void checkCanDropTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDropTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropTable", RuntimeUnit.NANO, duration); + stats.checkCanDropTable.record(duration); + } + } + + @Override + public void checkCanRenameTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table, CatalogSchemaTableName newTable) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanRenameTable(identity, context, table, newTable); + } + catch (RuntimeException e) { + stats.checkCanRenameTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanRenameTable", RuntimeUnit.NANO, duration); + stats.checkCanRenameTable.record(duration); + } + } + + @Override + public void checkCanShowTablesMetadata(Identity identity, AccessControlContext context, CatalogSchemaName schema) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanShowTablesMetadata(identity, context, schema); + } + catch (RuntimeException e) { + stats.checkCanShowTablesMetadata.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanShowTablesMetadata", RuntimeUnit.NANO, duration); + stats.checkCanShowTablesMetadata.record(duration); + } + } + + @Override + public Set filterTables(Identity identity, AccessControlContext context, String catalogName, Set tableNames) + { + long start = System.nanoTime(); + try { + return delegate.get().filterTables(identity, context, catalogName, tableNames); + } + catch (RuntimeException e) { + stats.filterTables.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.filterTables", RuntimeUnit.NANO, duration); + stats.filterTables.record(duration); + } + } + + @Override + public void checkCanShowColumnsMetadata(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanShowColumnsMetadata(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanShowColumnsMetadata.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanShowColumnsMetadata", RuntimeUnit.NANO, duration); + stats.checkCanShowColumnsMetadata.record(duration); + } + } + + @Override + public List filterColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, List columns) + { + long start = System.nanoTime(); + try { + return delegate.get().filterColumns(identity, context, table, columns); + } + catch (RuntimeException e) { + stats.filterColumns.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.filterColumns", RuntimeUnit.NANO, duration); + stats.filterColumns.record(duration); + } + } + + @Override + public void checkCanAddColumn(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanAddColumn(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanAddColumn.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanAddColumn", RuntimeUnit.NANO, duration); + stats.checkCanAddColumn.record(duration); + } + } + + @Override + public void checkCanDropColumn(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropColumn(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDropColumn.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropColumn", RuntimeUnit.NANO, duration); + stats.checkCanDropColumn.record(duration); + } + } + + @Override + public void checkCanRenameColumn(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanRenameColumn(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanRenameColumn.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanRenameColumn", RuntimeUnit.NANO, duration); + stats.checkCanRenameColumn.record(duration); + } + } + + @Override + public void checkCanSelectFromColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, Set columns) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanSelectFromColumns(identity, context, table, columns); + } + catch (RuntimeException e) { + stats.checkCanSelectFromColumns.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanSelectFromColumns", RuntimeUnit.NANO, duration); + stats.checkCanSelectFromColumns.record(duration); + } + } + + @Override + public void checkCanCallProcedure(Identity identity, AccessControlContext context, CatalogSchemaTableName procedure) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanCallProcedure(identity, context, procedure); + } + catch (RuntimeException e) { + stats.checkCanCallProcedure.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanCallProcedure", RuntimeUnit.NANO, duration); + stats.checkCanCallProcedure.record(duration); + } + } + + @Override + public void checkCanInsertIntoTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanInsertIntoTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanInsertIntoTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanInsertIntoTable", RuntimeUnit.NANO, duration); + stats.checkCanInsertIntoTable.record(duration); + } + } + + @Override + public void checkCanDeleteFromTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDeleteFromTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDeleteFromTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDeleteFromTable", RuntimeUnit.NANO, duration); + stats.checkCanDeleteFromTable.record(duration); + } + } + + @Override + public void checkCanTruncateTable(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanTruncateTable(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanTruncateTable.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanTruncateTable", RuntimeUnit.NANO, duration); + stats.checkCanTruncateTable.record(duration); + } + } + + @Override + public void checkCanUpdateTableColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, Set updatedColumnNames) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanUpdateTableColumns(identity, context, table, updatedColumnNames); + } + catch (RuntimeException e) { + stats.checkCanUpdateTableColumns.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanUpdateTableColumns", RuntimeUnit.NANO, duration); + stats.checkCanUpdateTableColumns.record(duration); + } + } + + @Override + public void checkCanCreateView(Identity identity, AccessControlContext context, CatalogSchemaTableName view) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanCreateView(identity, context, view); + } + catch (RuntimeException e) { + stats.checkCanCreateView.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanCreateView", RuntimeUnit.NANO, duration); + stats.checkCanCreateView.record(duration); + } + } + + @Override + public void checkCanRenameView(Identity identity, AccessControlContext context, CatalogSchemaTableName view, CatalogSchemaTableName newView) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanRenameView(identity, context, view, newView); + } + catch (RuntimeException e) { + stats.checkCanRenameView.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanRenameView", RuntimeUnit.NANO, duration); + stats.checkCanRenameView.record(duration); + } + } + + @Override + public void checkCanDropView(Identity identity, AccessControlContext context, CatalogSchemaTableName view) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropView(identity, context, view); + } + catch (RuntimeException e) { + stats.checkCanDropView.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropView", RuntimeUnit.NANO, duration); + stats.checkCanDropView.record(duration); + } + } + + @Override + public void checkCanCreateViewWithSelectFromColumns(Identity identity, AccessControlContext context, CatalogSchemaTableName table, Set columns) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanCreateViewWithSelectFromColumns(identity, context, table, columns); + } + catch (RuntimeException e) { + stats.checkCanCreateViewWithSelectFromColumns.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanCreateViewWithSelectFromColumns", RuntimeUnit.NANO, duration); + stats.checkCanCreateViewWithSelectFromColumns.record(duration); + } + } + + @Override + public void checkCanSetCatalogSessionProperty(Identity identity, AccessControlContext context, String catalogName, String propertyName) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanSetCatalogSessionProperty(identity, context, catalogName, propertyName); + } + catch (RuntimeException e) { + stats.checkCanSetCatalogSessionProperty.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanSetCatalogSessionProperty", RuntimeUnit.NANO, duration); + stats.checkCanSetCatalogSessionProperty.record(duration); + } + } + + @Override + public void checkCanGrantTablePrivilege(Identity identity, AccessControlContext context, Privilege privilege, CatalogSchemaTableName table, PrestoPrincipal grantee, boolean withGrantOption) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanGrantTablePrivilege(identity, context, privilege, table, grantee, withGrantOption); + } + catch (RuntimeException e) { + stats.checkCanGrantTablePrivilege.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanGrantTablePrivilege", RuntimeUnit.NANO, duration); + stats.checkCanGrantTablePrivilege.record(duration); + } + } + + @Override + public void checkCanRevokeTablePrivilege(Identity identity, AccessControlContext context, Privilege privilege, CatalogSchemaTableName table, PrestoPrincipal revokee, boolean grantOptionFor) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanRevokeTablePrivilege(identity, context, privilege, table, revokee, grantOptionFor); + } + catch (RuntimeException e) { + stats.checkCanRevokeTablePrivilege.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanRevokeTablePrivilege", RuntimeUnit.NANO, duration); + stats.checkCanRevokeTablePrivilege.record(duration); + } + } + + @Override + public void checkCanDropBranch(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropBranch(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDropBranch.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropBranch", RuntimeUnit.NANO, duration); + stats.checkCanDropBranch.record(duration); + } + } + + @Override + public void checkCanDropTag(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropTag(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDropTag.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropTag", RuntimeUnit.NANO, duration); + stats.checkCanDropTag.record(duration); + } + } + + @Override + public void checkCanDropConstraint(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanDropConstraint(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanDropConstraint.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanDropConstraint", RuntimeUnit.NANO, duration); + stats.checkCanDropConstraint.record(duration); + } + } + + @Override + public void checkCanAddConstraint(Identity identity, AccessControlContext context, CatalogSchemaTableName table) + { + long start = System.nanoTime(); + try { + delegate.get().checkCanAddConstraint(identity, context, table); + } + catch (RuntimeException e) { + stats.checkCanAddConstraint.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.checkCanAddConstraint", RuntimeUnit.NANO, duration); + stats.checkCanAddConstraint.record(duration); + } + } + + @Override + public List getRowFilters(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName) + { + long start = System.nanoTime(); + try { + return delegate.get().getRowFilters(identity, context, tableName); + } + catch (RuntimeException e) { + stats.getRowFilters.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.getRowFilters", RuntimeUnit.NANO, duration); + stats.getRowFilters.record(duration); + } + } + + @Override + public Map getColumnMasks(Identity identity, AccessControlContext context, CatalogSchemaTableName tableName, List columns) + { + long start = System.nanoTime(); + try { + return delegate.get().getColumnMasks(identity, context, tableName, columns); + } + catch (RuntimeException e) { + stats.getColumnMasks.recordFailure(); + throw e; + } + finally { + long duration = System.nanoTime() - start; + context.getRuntimeStats().addMetricValue("systemAccessControl.getColumnMasks", RuntimeUnit.NANO, duration); + stats.getColumnMasks.record(duration); + } + } + + public static class Stats + { + final SystemAccessControlStats checkCanSetUser = new SystemAccessControlStats(); + final SystemAccessControlStats selectAuthorizedIdentity = new SystemAccessControlStats(); + final SystemAccessControlStats checkQueryIntegrity = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanSetSystemSessionProperty = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanAccessCatalog = new SystemAccessControlStats(); + final SystemAccessControlStats filterCatalogs = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanCreateSchema = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropSchema = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanRenameSchema = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanShowSchemas = new SystemAccessControlStats(); + final SystemAccessControlStats filterSchemas = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanShowCreateTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanCreateTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanSetTableProperties = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanRenameTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanShowTablesMetadata = new SystemAccessControlStats(); + final SystemAccessControlStats filterTables = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanShowColumnsMetadata = new SystemAccessControlStats(); + final SystemAccessControlStats filterColumns = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanAddColumn = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropColumn = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanRenameColumn = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanSelectFromColumns = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanCallProcedure = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanInsertIntoTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDeleteFromTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanTruncateTable = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanUpdateTableColumns = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanCreateView = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanRenameView = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropView = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanCreateViewWithSelectFromColumns = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanSetCatalogSessionProperty = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanGrantTablePrivilege = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanRevokeTablePrivilege = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropBranch = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropTag = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanDropConstraint = new SystemAccessControlStats(); + final SystemAccessControlStats checkCanAddConstraint = new SystemAccessControlStats(); + final SystemAccessControlStats getRowFilters = new SystemAccessControlStats(); + final SystemAccessControlStats getColumnMasks = new SystemAccessControlStats(); + + @Managed + @Nested + public SystemAccessControlStats getCheckCanSetUser() + { + return checkCanSetUser; + } + + @Managed + @Nested + public SystemAccessControlStats getSelectAuthorizedIdentity() + { + return selectAuthorizedIdentity; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckQueryIntegrity() + { + return checkQueryIntegrity; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanSetSystemSessionProperty() + { + return checkCanSetSystemSessionProperty; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanAccessCatalog() + { + return checkCanAccessCatalog; + } + + @Managed + @Nested + public SystemAccessControlStats getFilterCatalogs() + { + return filterCatalogs; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanCreateSchema() + { + return checkCanCreateSchema; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDropSchema() + { + return checkCanDropSchema; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanRenameSchema() + { + return checkCanRenameSchema; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanShowSchemas() + { + return checkCanShowSchemas; + } + + @Managed + @Nested + public SystemAccessControlStats getFilterSchemas() + { + return filterSchemas; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanShowCreateTable() + { + return checkCanShowCreateTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanCreateTable() + { + return checkCanCreateTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanSetTableProperties() + { + return checkCanSetTableProperties; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDropTable() + { + return checkCanDropTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanRenameTable() + { + return checkCanRenameTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanShowTablesMetadata() + { + return checkCanShowTablesMetadata; + } + + @Managed + @Nested + public SystemAccessControlStats getFilterTables() + { + return filterTables; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanShowColumnsMetadata() + { + return checkCanShowColumnsMetadata; + } + + @Managed + @Nested + public SystemAccessControlStats getFilterColumns() + { + return filterColumns; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanAddColumn() + { + return checkCanAddColumn; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDropColumn() + { + return checkCanDropColumn; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanRenameColumn() + { + return checkCanRenameColumn; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanSelectFromColumns() + { + return checkCanSelectFromColumns; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanCallProcedure() + { + return checkCanCallProcedure; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanInsertIntoTable() + { + return checkCanInsertIntoTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDeleteFromTable() + { + return checkCanDeleteFromTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanTruncateTable() + { + return checkCanTruncateTable; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanUpdateTableColumns() + { + return checkCanUpdateTableColumns; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanCreateView() + { + return checkCanCreateView; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanRenameView() + { + return checkCanRenameView; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDropView() + { + return checkCanDropView; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanCreateViewWithSelectFromColumns() + { + return checkCanCreateViewWithSelectFromColumns; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanSetCatalogSessionProperty() + { + return checkCanSetCatalogSessionProperty; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanGrantTablePrivilege() + { + return checkCanGrantTablePrivilege; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanRevokeTablePrivilege() + { + return checkCanRevokeTablePrivilege; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanDropConstraint() + { + return checkCanDropConstraint; + } + + @Managed + @Nested + public SystemAccessControlStats getCheckCanAddConstraint() + { + return checkCanAddConstraint; + } + + @Managed + @Nested + public SystemAccessControlStats getGetRowFilters() + { + return getRowFilters; + } + + @Managed + @Nested + public SystemAccessControlStats getGetColumnMasks() + { + return getColumnMasks; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/security/SystemAccessControlStats.java b/presto-main-base/src/main/java/com/facebook/presto/security/SystemAccessControlStats.java new file mode 100644 index 0000000000000..95babfbb9e80d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/security/SystemAccessControlStats.java @@ -0,0 +1,52 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.stats.TimeStat; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +public class SystemAccessControlStats +{ + private final CounterStat failures = new CounterStat(); + private final TimeStat time = new TimeStat(MICROSECONDS); + + public void record(long nanos) + { + time.add(nanos, NANOSECONDS); + } + + public void recordFailure() + { + failures.update(1); + } + + @Managed + @Nested + public TimeStat getTime() + { + return time; + } + + @Managed + @Nested + public CounterStat getFailures() + { + return failures; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncHttpExecutionMBean.java b/presto-main-base/src/main/java/com/facebook/presto/server/AsyncHttpExecutionMBean.java index fda3e1bbfd55f..e27d93edb97f6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncHttpExecutionMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/AsyncHttpExecutionMBean.java @@ -14,11 +14,10 @@ package com.facebook.presto.server; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryInfo.java b/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryInfo.java index 5615e99c8bbb6..e77f380aa880b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryInfo.java @@ -31,9 +31,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryStats.java b/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryStats.java index 3da71c71af74d..f12a7851233c7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/BasicQueryStats.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -21,18 +23,15 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.Immutable; import org.joda.time.DateTime; -import javax.annotation.concurrent.Immutable; - import java.util.OptionalDouble; import java.util.Set; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.util.DateTimeUtils.toTimeStampInMillis; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.System.currentTimeMillis; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -62,6 +61,16 @@ public class BasicQueryStats private final int runningDrivers; private final int completedDrivers; + private final int totalNewDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + + private final int totalSplits; + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; + private final DataSize rawInputDataSize; private final long rawInputPositions; @@ -97,6 +106,14 @@ public BasicQueryStats( int queuedDrivers, int runningDrivers, int completedDrivers, + int totalNewDrivers, + int queuedNewDrivers, + int runningNewDrivers, + int completedNewDrivers, + int totalSplits, + int queuedSplits, + int runningSplits, + int completedSplits, DataSize rawInputDataSize, long rawInputPositions, double cumulativeUserMemory, @@ -133,6 +150,22 @@ public BasicQueryStats( this.runningDrivers = runningDrivers; checkArgument(completedDrivers >= 0, "completedDrivers is negative"); this.completedDrivers = completedDrivers; + checkArgument(totalNewDrivers >= 0, "totalNewDrivers is negative"); + this.totalNewDrivers = totalNewDrivers; + checkArgument(queuedNewDrivers >= 0, "queuedNewDrivers is negative"); + this.queuedNewDrivers = queuedNewDrivers; + checkArgument(runningNewDrivers >= 0, "runningNewDrivers is negative"); + this.runningNewDrivers = runningNewDrivers; + checkArgument(completedNewDrivers >= 0, "completedNewDrivers is negative"); + this.completedNewDrivers = completedNewDrivers; + checkArgument(totalSplits >= 0, "totalSplits is negative"); + this.totalSplits = totalSplits; + checkArgument(queuedSplits >= 0, "queuedSplits is negative"); + this.queuedSplits = queuedSplits; + checkArgument(runningSplits >= 0, "runningSplits is negative"); + this.runningSplits = runningSplits; + checkArgument(completedSplits >= 0, "completedSplits is negative"); + this.completedSplits = completedSplits; this.rawInputDataSize = requireNonNull(rawInputDataSize); this.rawInputPositions = rawInputPositions; @@ -172,6 +205,14 @@ public BasicQueryStats( @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("runningDrivers") int runningDrivers, @JsonProperty("completedDrivers") int completedDrivers, + @JsonProperty("totalNewDrivers") int totalNewDrivers, + @JsonProperty("queuedNewDrivers") int queuedNewDrivers, + @JsonProperty("runningNewDrivers") int runningNewDrivers, + @JsonProperty("completedNewDrivers") int completedNewDrivers, + @JsonProperty("totalSplits") int totalSplits, + @JsonProperty("queuedSplits") int queuedSplits, + @JsonProperty("runningSplits") int runningSplits, + @JsonProperty("completedSplits") int completedSplits, @JsonProperty("rawInputDataSize") DataSize rawInputDataSize, @JsonProperty("rawInputPositions") long rawInputPositions, @JsonProperty("cumulativeUserMemory") double cumulativeUserMemory, @@ -202,6 +243,14 @@ public BasicQueryStats( queuedDrivers, runningDrivers, completedDrivers, + totalNewDrivers, + queuedNewDrivers, + runningNewDrivers, + completedNewDrivers, + totalSplits, + queuedSplits, + runningSplits, + completedSplits, rawInputDataSize, rawInputPositions, cumulativeUserMemory, @@ -235,6 +284,14 @@ public BasicQueryStats(QueryStats queryStats) queryStats.getQueuedDrivers(), queryStats.getRunningDrivers(), queryStats.getCompletedDrivers(), + queryStats.getTotalNewDrivers(), + queryStats.getQueuedNewDrivers(), + queryStats.getRunningNewDrivers(), + queryStats.getCompletedNewDrivers(), + queryStats.getTotalSplits(), + queryStats.getQueuedSplits(), + queryStats.getRunningSplits(), + queryStats.getCompletedSplits(), queryStats.getRawInputDataSize(), queryStats.getRawInputPositions(), queryStats.getCumulativeUserMemory(), @@ -270,6 +327,14 @@ public static BasicQueryStats immediateFailureQueryStats() 0, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, new DataSize(0, BYTE), 0, 0, @@ -499,4 +564,60 @@ public Duration getAnalysisTime() { return analysisTime; } + + @ThriftField(30) + @JsonProperty + public int getTotalSplits() + { + return totalSplits; + } + + @ThriftField(31) + @JsonProperty + public int getQueuedSplits() + { + return queuedSplits; + } + + @ThriftField(32) + @JsonProperty + public int getRunningSplits() + { + return runningSplits; + } + + @ThriftField(33) + @JsonProperty + public int getCompletedSplits() + { + return completedSplits; + } + + @ThriftField(34) + @JsonProperty + public int getTotalNewDrivers() + { + return totalNewDrivers; + } + + @ThriftField(35) + @JsonProperty + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + @ThriftField(36) + @JsonProperty + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + @ThriftField(37) + @JsonProperty + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ConnectorMetadataUpdateHandleJsonSerde.java b/presto-main-base/src/main/java/com/facebook/presto/server/ConnectorMetadataUpdateHandleJsonSerde.java deleted file mode 100644 index 9afc3d23e1a07..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ConnectorMetadataUpdateHandleJsonSerde.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import io.airlift.slice.Slices; - -import static com.facebook.airlift.json.JsonCodec.jsonCodec; - -/** - * Json based connector handle serde - */ -public class ConnectorMetadataUpdateHandleJsonSerde - implements ConnectorTypeSerde -{ - private LoadingCache jsonCodecCache = CacheBuilder.newBuilder() - .recordStats() - .maximumSize(10_000) - .build(CacheLoader.from(cacheKey -> jsonCodec(cacheKey))); - - @Override - public byte[] serialize(ConnectorMetadataUpdateHandle value) - { - JsonCodec jsonCodec = jsonCodecCache.getUnchecked(value.getClass()); - return jsonCodec.toBytes(value); - } - - @Override - public ConnectorMetadataUpdateHandle deserialize(Class connectorTypeClass, byte[] bytes) - { - JsonCodec jsonCodec = jsonCodec(connectorTypeClass); - return jsonCodec.readBytes(Slices.wrappedBuffer(bytes).getInput()); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/DataSizeSerializer.java b/presto-main-base/src/main/java/com/facebook/presto/server/DataSizeSerializer.java index 906c3111c148b..d826e872938e1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/DataSizeSerializer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/DataSizeSerializer.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; -import io.airlift.units.DataSize; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/DurationSerializer.java b/presto-main-base/src/main/java/com/facebook/presto/server/DurationSerializer.java index a5f77d4cdfbac..bc8118caec2f0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/DurationSerializer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/DurationSerializer.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.Duration; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.SerializerProvider; -import io.airlift.units.Duration; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ExchangeExecutionMBean.java b/presto-main-base/src/main/java/com/facebook/presto/server/ExchangeExecutionMBean.java index a6e13e247d9bd..f85dc7ac47f44 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ExchangeExecutionMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ExchangeExecutionMBean.java @@ -15,11 +15,10 @@ import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; import com.facebook.presto.operator.ForExchange; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ForAsyncRpc.java b/presto-main-base/src/main/java/com/facebook/presto/server/ForAsyncRpc.java index ccc896f05ed46..cd4c4840ec67a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ForAsyncRpc.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ForAsyncRpc.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.server; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ForStatementResource.java b/presto-main-base/src/main/java/com/facebook/presto/server/ForStatementResource.java index eaac0cbd8f82a..cf39ad05b7132 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ForStatementResource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ForStatementResource.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.server; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ForWorkerInfo.java b/presto-main-base/src/main/java/com/facebook/presto/server/ForWorkerInfo.java index 1ce219599df7d..58d71f0067943 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ForWorkerInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ForWorkerInfo.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.server; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/GracefulShutdownHandler.java b/presto-main-base/src/main/java/com/facebook/presto/server/GracefulShutdownHandler.java index e73b745d32510..187cb7d39f600 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/GracefulShutdownHandler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/GracefulShutdownHandler.java @@ -15,14 +15,13 @@ import com.facebook.airlift.bootstrap.LifeCycleManager; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskManager; -import io.airlift.units.Duration; - -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.CountDownLatch; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java b/presto-main-base/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java index d76ff77dd1a03..47812dc99f384 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/InternalCommunicationConfig.java @@ -16,15 +16,14 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.transport.netty.codec.Protocol; -import io.airlift.units.DataSize; - -import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.NotNull; import java.util.Optional; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class InternalCommunicationConfig { @@ -50,6 +49,8 @@ public class InternalCommunicationConfig private CommunicationProtocol serverInfoCommunicationProtocol = CommunicationProtocol.HTTP; private boolean memoizeDeadNodesEnabled; private String sharedSecret; + private long nodeStatsRefreshIntervalMillis = 1_000; + private long nodeDiscoveryPollingIntervalMillis = 5_000; private boolean internalJwtEnabled; @@ -313,6 +314,32 @@ public InternalCommunicationConfig setSharedSecret(String sharedSecret) return this; } + public long getNodeStatsRefreshIntervalMillis() + { + return nodeStatsRefreshIntervalMillis; + } + + @Config("internal-communication.node-stats-refresh-interval-millis") + @ConfigDescription("Interval in milliseconds for refreshing node statistics") + public InternalCommunicationConfig setNodeStatsRefreshIntervalMillis(long nodeStatsRefreshIntervalMillis) + { + this.nodeStatsRefreshIntervalMillis = nodeStatsRefreshIntervalMillis; + return this; + } + + public long getNodeDiscoveryPollingIntervalMillis() + { + return nodeDiscoveryPollingIntervalMillis; + } + + @Config("internal-communication.node-discovery-polling-interval-millis") + @ConfigDescription("Interval in milliseconds for polling node discovery and refreshing node states") + public InternalCommunicationConfig setNodeDiscoveryPollingIntervalMillis(long nodeDiscoveryPollingIntervalMillis) + { + this.nodeDiscoveryPollingIntervalMillis = nodeDiscoveryPollingIntervalMillis; + return this; + } + public boolean isInternalJwtEnabled() { return internalJwtEnabled; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/NodeResourceStatus.java b/presto-main-base/src/main/java/com/facebook/presto/server/NodeResourceStatus.java index c53e79ce1d601..868b7df74137b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/NodeResourceStatus.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/NodeResourceStatus.java @@ -14,8 +14,7 @@ package com.facebook.presto.server; import com.facebook.presto.execution.ClusterSizeMonitor; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/NodeStatus.java b/presto-main-base/src/main/java/com/facebook/presto/server/NodeStatus.java index 14262e205f8c3..0845a70bacdea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/NodeStatus.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/NodeStatus.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -20,7 +21,6 @@ import com.facebook.presto.memory.MemoryInfo; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java index 0693d90bb9cc3..d072762d14dbc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java @@ -63,10 +63,9 @@ import com.facebook.presto.ttl.nodettlfetchermanagers.NodeTtlFetcherManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.resolver.ArtifactResolver; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.util.List; @@ -307,6 +306,12 @@ public void installPlugin(Plugin plugin) log.info("Registering client request filter factory"); clientRequestFilterManager.registerClientRequestFilterFactory(clientRequestFilterFactory); } + + for (Class functionClass : plugin.getSqlInvokedFunctions()) { + log.info("Registering functions from %s", functionClass.getName()); + metadata.getFunctionAndTypeManager().registerPluginFunctions( + extractFunctions(functionClass, metadata.getFunctionAndTypeManager().getDefaultNamespace())); + } } public void installCoordinatorPlugin(CoordinatorPlugin plugin) diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerConfig.java index e58ac4a0e6634..a985a4c3365ba 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerConfig.java @@ -19,8 +19,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.resolver.ArtifactResolver; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerUtil.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerUtil.java index 9be5916ff4692..20cddf1264fb0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManagerUtil.java @@ -60,7 +60,7 @@ public class PluginManagerUtil .add("com.fasterxml.jackson.annotation.") .add("com.fasterxml.jackson.module.afterburner.") .add("io.airlift.slice.") - .add("io.airlift.units.") + .add("com.facebook.airlift.units.") .add("org.openjdk.jol.") .add("com.facebook.presto.common") .add("com.facebook.drift.annotations.") diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PrestoSystemRequirements.java b/presto-main-base/src/main/java/com/facebook/presto/server/PrestoSystemRequirements.java index daae41fe39933..b2870934a0bc0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PrestoSystemRequirements.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PrestoSystemRequirements.java @@ -98,15 +98,11 @@ private static void verifyJavaVersion() } JavaVersion version = JavaVersion.parse(javaVersion); - if (version.getMajor() == 8 && version.getUpdate().isPresent() && version.getUpdate().getAsInt() >= 151) { + if (version.getMajor() >= 17) { return; } - if (version.getMajor() >= 9) { - return; - } - - failRequirement("Presto requires Java 8u151+ (found %s)", javaVersion); + failRequirement("Presto requires Java 17 (found %s)", javaVersion); } private static void verifyUsingG1Gc() diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/QueryProgressStats.java b/presto-main-base/src/main/java/com/facebook/presto/server/QueryProgressStats.java index c897888febc70..f833073fcf954 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/QueryProgressStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/QueryProgressStats.java @@ -51,6 +51,14 @@ public class QueryProgressStats private final int runningDrivers; private final int completedDrivers; + private final int queuedNewDrivers; + private final int runningNewDrivers; + private final int completedNewDrivers; + + private final int queuedSplits; + private final int runningSplits; + private final int completedSplits; + @JsonCreator @ThriftConstructor public QueryProgressStats( @@ -72,7 +80,14 @@ public QueryProgressStats( @JsonProperty("progressPercentage") OptionalDouble progressPercentage, @JsonProperty("queuedDrivers") int queuedDrivers, @JsonProperty("runningDrivers") int runningDrivers, - @JsonProperty("completedDrivers") int completedDrivers) + @JsonProperty("completedDrivers") int completedDrivers, + @JsonProperty("queuedNewDrivers") int queuedNewDrivers, + @JsonProperty("runningNewDrivers") int runningNewDrivers, + @JsonProperty("completedNewDrivers") int completedNewDrivers, + @JsonProperty("queuedSplits") int queuedSplits, + @JsonProperty("runningSplits") int runningSplits, + @JsonProperty("completedSplits") int completedSplits) + { this.elapsedTimeMillis = elapsedTimeMillis; this.queuedTimeMillis = queuedTimeMillis; @@ -93,6 +108,12 @@ public QueryProgressStats( this.queuedDrivers = queuedDrivers; this.runningDrivers = runningDrivers; this.completedDrivers = completedDrivers; + this.queuedNewDrivers = queuedNewDrivers; + this.runningNewDrivers = runningNewDrivers; + this.completedNewDrivers = completedNewDrivers; + this.queuedSplits = queuedSplits; + this.runningSplits = runningSplits; + this.completedSplits = completedSplits; } public static QueryProgressStats createQueryProgressStats(BasicQueryStats queryStats) @@ -116,7 +137,13 @@ public static QueryProgressStats createQueryProgressStats(BasicQueryStats queryS queryStats.getProgressPercentage(), queryStats.getQueuedDrivers(), queryStats.getRunningDrivers(), - queryStats.getCompletedDrivers()); + queryStats.getCompletedDrivers(), + queryStats.getQueuedNewDrivers(), + queryStats.getRunningNewDrivers(), + queryStats.getCompletedNewDrivers(), + queryStats.getQueuedSplits(), + queryStats.getRunningSplits(), + queryStats.getCompletedSplits()); } @ThriftField(1) @@ -251,4 +278,46 @@ public int getCompletedDrivers() { return completedDrivers; } + + @ThriftField(20) + @JsonProperty + public int getQueuedSplits() + { + return queuedSplits; + } + + @ThriftField(21) + @JsonProperty + public int getRunningSplits() + { + return runningSplits; + } + + @ThriftField(22) + @JsonProperty + public int getCompletedSplits() + { + return completedSplits; + } + + @ThriftField(23) + @JsonProperty + public int getQueuedNewDrivers() + { + return queuedNewDrivers; + } + + @ThriftField(24) + @JsonProperty + public int getRunningNewDrivers() + { + return runningNewDrivers; + } + + @ThriftField(25) + @JsonProperty + public int getCompletedNewDrivers() + { + return completedNewDrivers; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java b/presto-main-base/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java index 9a9c23503c8e1..1da085a50cef2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/QuerySessionSupplier.java @@ -29,9 +29,8 @@ import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.SqlEnvironmentConfig; import com.facebook.presto.transaction.TransactionManager; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.inject.Inject; import java.util.Locale; import java.util.Map; @@ -144,6 +143,7 @@ private Identity authenticateIdentity(QueryId queryId, SessionContext context) context.getIdentity().getExtraCredentials(), context.getIdentity().getExtraAuthenticators(), Optional.of(identity.getUserName()), - identity.getReasonForSelect())).orElseGet(context::getIdentity); + identity.getReasonForSelect(), + context.getCertificates())).orElseGet(context::getIdentity); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ResourceGroupInfo.java b/presto-main-base/src/main/java/com/facebook/presto/server/ResourceGroupInfo.java index 76a68eb88bd91..ab5c3cb706b72 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ResourceGroupInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ResourceGroupInfo.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -21,9 +22,7 @@ import com.facebook.presto.spi.resourceGroups.SchedulingPolicy; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.DataSize; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/RetryConfig.java b/presto-main-base/src/main/java/com/facebook/presto/server/RetryConfig.java new file mode 100644 index 0000000000000..813a4b44ba60e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/RetryConfig.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.presto.common.ErrorCode; +import com.facebook.presto.spi.StandardErrorCode; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableSet; +import jakarta.validation.constraints.NotNull; + +import java.util.Set; + +import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class RetryConfig +{ + private boolean retryEnabled = true; + private Set allowedRetryDomains = ImmutableSet.of(); + private boolean requireHttps; + private Set crossClusterRetryErrorCodes = ImmutableSet.of( + REMOTE_TASK_ERROR.toErrorCode().getCode()); + + public boolean isRetryEnabled() + { + return retryEnabled; + } + + @Config("retry.enabled") + @ConfigDescription("Enable cross-cluster retry functionality") + public RetryConfig setRetryEnabled(boolean retryEnabled) + { + this.retryEnabled = retryEnabled; + return this; + } + + @NotNull + public Set getAllowedRetryDomains() + { + return allowedRetryDomains; + } + + @Config("retry.allowed-domains") + @ConfigDescription("Comma-separated list of allowed domains for retry URLs " + + "(supports wildcards like *.example.com)") + public RetryConfig setAllowedRetryDomains(String domains) + { + if (domains == null || domains.trim().isEmpty()) { + this.allowedRetryDomains = ImmutableSet.of(); + } + else { + this.allowedRetryDomains = Splitter.on(',') + .trimResults() + .omitEmptyStrings() + .splitToList(domains) + .stream() + .map(String::toLowerCase) + .collect(toImmutableSet()); + } + return this; + } + + public boolean isRequireHttps() + { + return requireHttps; + } + + @Config("retry.require-https") + @ConfigDescription("Require HTTPS for retry URLs") + public RetryConfig setRequireHttps(boolean requireHttps) + { + this.requireHttps = requireHttps; + return this; + } + + @NotNull + public Set getCrossClusterRetryErrorCodes() + { + return crossClusterRetryErrorCodes; + } + + @Config("retry.cross-cluster-error-codes") + @ConfigDescription("Comma-separated list of error codes that allow cross-cluster retry") + public RetryConfig setCrossClusterRetryErrorCodes(String errorCodes) + { + if (errorCodes == null || errorCodes.trim().isEmpty()) { + // Keep the default error codes + return this; + } + else { + this.crossClusterRetryErrorCodes = Splitter.on(',') + .trimResults() + .omitEmptyStrings() + .splitToList(errorCodes) + .stream() + .map(StandardErrorCode::valueOf) + .map(StandardErrorCode::toErrorCode) + .map(ErrorCode::getCode) + .collect(toImmutableSet()); + } + return this; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ServerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/server/ServerConfig.java index 30e9f0b506bce..7e7bcbbffd1d4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ServerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/ServerConfig.java @@ -14,10 +14,9 @@ package com.facebook.presto.server; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.NodePoolType; -import io.airlift.units.Duration; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import static com.facebook.presto.spi.NodePoolType.DEFAULT; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -43,6 +42,7 @@ public class ServerConfig private Duration clusterStatsExpirationDuration = new Duration(0, MILLISECONDS); private boolean nestedDataSerializationEnabled = true; private Duration clusterResourceGroupStateInfoExpirationDuration = new Duration(0, MILLISECONDS); + private String clusterTag; public boolean isResourceManager() { @@ -241,4 +241,16 @@ public ServerConfig setClusterResourceGroupStateInfoExpirationDuration(Duration this.clusterResourceGroupStateInfoExpirationDuration = clusterResourceGroupStateInfoExpirationDuration; return this; } + + public String getClusterTag() + { + return clusterTag; + } + + @Config("cluster-tag") + public ServerConfig setClusterTag(String clusterTag) + { + this.clusterTag = clusterTag; + return this; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/SessionContext.java b/presto-main-base/src/main/java/com/facebook/presto/server/SessionContext.java index d40bb15482377..2bb02da49b350 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/SessionContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/SessionContext.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.session.ResourceEstimates; import com.facebook.presto.spi.tracing.Tracer; import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.security.cert.X509Certificate; import java.util.List; @@ -51,6 +50,8 @@ default List getCertificates() @Nullable String getSchema(); + String getSqlText(); + @Nullable String getSource(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/SessionPropertyDefaults.java b/presto-main-base/src/main/java/com/facebook/presto/server/SessionPropertyDefaults.java index aa6491b04e775..4725b173b031e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/SessionPropertyDefaults.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/SessionPropertyDefaults.java @@ -23,8 +23,7 @@ import com.facebook.presto.spi.session.SessionPropertyConfigurationManager.SystemSessionPropertyConfiguration; import com.facebook.presto.spi.session.SessionPropertyConfigurationManagerFactory; import com.google.common.annotations.VisibleForTesting; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/SimpleHttpResponseHandlerStats.java b/presto-main-base/src/main/java/com/facebook/presto/server/SimpleHttpResponseHandlerStats.java index 17f0ceeb0debc..77d39ca51ec56 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/SimpleHttpResponseHandlerStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/SimpleHttpResponseHandlerStats.java @@ -13,11 +13,10 @@ */ package com.facebook.presto.server; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - import java.util.concurrent.atomic.AtomicLong; public class SimpleHttpResponseHandlerStats diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/SqlInvokedFunctionCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/SqlInvokedFunctionCodec.java index 73262b328fed0..0e3660d58da6b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/SqlInvokedFunctionCodec.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/SqlInvokedFunctionCodec.java @@ -19,8 +19,7 @@ import com.facebook.drift.protocol.TProtocolReader; import com.facebook.drift.protocol.TProtocolWriter; import com.facebook.presto.spi.function.SqlInvokedFunction; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/StatementHttpExecutionMBean.java b/presto-main-base/src/main/java/com/facebook/presto/server/StatementHttpExecutionMBean.java index acd7acb1022a2..3e58651b08c24 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/StatementHttpExecutionMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/StatementHttpExecutionMBean.java @@ -14,11 +14,10 @@ package com.facebook.presto.server; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/TaskResourceUtils.java b/presto-main-base/src/main/java/com/facebook/presto/server/TaskResourceUtils.java deleted file mode 100644 index ebf53381f606f..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/TaskResourceUtils.java +++ /dev/null @@ -1,526 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server; - -import com.facebook.presto.connector.ConnectorTypeSerdeManager; -import com.facebook.presto.execution.TaskInfo; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.metadata.MetadataUpdates; -import com.facebook.presto.operator.DriverStats; -import com.facebook.presto.operator.OperatorStats; -import com.facebook.presto.operator.PipelineStats; -import com.facebook.presto.operator.TaskStats; -import com.facebook.presto.server.thrift.Any; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import com.google.common.collect.ImmutableList; - -import javax.ws.rs.core.HttpHeaders; - -import java.util.List; - -import static com.facebook.presto.operator.OperatorInfoUnion.convertToOperatorInfo; -import static com.facebook.presto.operator.OperatorInfoUnion.convertToOperatorInfoUnion; -import static java.util.stream.Collectors.toList; - -public class TaskResourceUtils -{ - private TaskResourceUtils() - { - } - - public static boolean isThriftAcceptable(HttpHeaders httpHeaders) - { - return httpHeaders.getAcceptableMediaTypes().stream() - .anyMatch(mediaType -> mediaType.toString().contains("application/x-thrift")); - } - - public static TaskInfo convertToThriftTaskInfo( - TaskInfo taskInfo, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - HandleResolver handleResolver) - { - return new TaskInfo( - taskInfo.getTaskId(), - taskInfo.getTaskStatus(), - taskInfo.getLastHeartbeatInMillis(), - taskInfo.getOutputBuffers(), - taskInfo.getNoMoreSplits(), - convertToThriftTaskStats(taskInfo.getStats()), - taskInfo.isNeedsPlan(), - convertToThriftMetadataUpdates(taskInfo.getMetadataUpdates(), connectorTypeSerdeManager, handleResolver), - taskInfo.getNodeId()); - } - - private static TaskStats convertToThriftTaskStats(TaskStats taskStats) - { - if (taskStats.getPipelines().isEmpty()) { - return taskStats; - } - - return new TaskStats( - taskStats.getCreateTimeInMillis(), - taskStats.getFirstStartTimeInMillis(), - taskStats.getLastStartTimeInMillis(), - taskStats.getLastEndTimeInMillis(), - taskStats.getEndTimeInMillis(), - taskStats.getElapsedTimeInNanos(), - taskStats.getQueuedTimeInNanos(), - taskStats.getTotalDrivers(), - taskStats.getQueuedDrivers(), - taskStats.getQueuedPartitionedDrivers(), - taskStats.getQueuedPartitionedSplitsWeight(), - taskStats.getRunningDrivers(), - taskStats.getRunningPartitionedDrivers(), - taskStats.getRunningPartitionedSplitsWeight(), - taskStats.getBlockedDrivers(), - taskStats.getCompletedDrivers(), - taskStats.getCumulativeUserMemory(), - taskStats.getCumulativeTotalMemory(), - taskStats.getUserMemoryReservationInBytes(), - taskStats.getRevocableMemoryReservationInBytes(), - taskStats.getSystemMemoryReservationInBytes(), - taskStats.getPeakUserMemoryInBytes(), - taskStats.getPeakTotalMemoryInBytes(), - taskStats.getPeakNodeTotalMemoryInBytes(), - taskStats.getTotalScheduledTimeInNanos(), - taskStats.getTotalCpuTimeInNanos(), - taskStats.getTotalBlockedTimeInNanos(), - taskStats.isFullyBlocked(), - taskStats.getBlockedReasons(), - taskStats.getTotalAllocationInBytes(), - taskStats.getRawInputDataSizeInBytes(), - taskStats.getRawInputPositions(), - taskStats.getProcessedInputDataSizeInBytes(), - taskStats.getProcessedInputPositions(), - taskStats.getOutputDataSizeInBytes(), - taskStats.getOutputPositions(), - taskStats.getPhysicalWrittenDataSizeInBytes(), - taskStats.getFullGcCount(), - taskStats.getFullGcTimeInMillis(), - convertToThriftPipeLineStatsList(taskStats.getPipelines()), - taskStats.getRuntimeStats()); - } - - private static List convertToThriftPipeLineStatsList(List pipelines) - { - return pipelines.stream() - .map(TaskResourceUtils::convertToThriftPipelineStats) - .collect(toList()); - } - - private static PipelineStats convertToThriftPipelineStats(PipelineStats pipelineStats) - { - if (pipelineStats.getDrivers().isEmpty() && pipelineStats.getOperatorSummaries().isEmpty()) { - return pipelineStats; - } - - return new PipelineStats( - pipelineStats.getPipelineId(), - pipelineStats.getFirstStartTimeInMillis(), - pipelineStats.getLastStartTimeInMillis(), - pipelineStats.getLastEndTimeInMillis(), - pipelineStats.isInputPipeline(), - pipelineStats.isOutputPipeline(), - pipelineStats.getTotalDrivers(), - pipelineStats.getQueuedDrivers(), - pipelineStats.getQueuedPartitionedDrivers(), - pipelineStats.getQueuedPartitionedSplitsWeight(), - pipelineStats.getRunningDrivers(), - pipelineStats.getRunningPartitionedDrivers(), - pipelineStats.getRunningPartitionedSplitsWeight(), - pipelineStats.getBlockedDrivers(), - pipelineStats.getCompletedDrivers(), - pipelineStats.getUserMemoryReservationInBytes(), - pipelineStats.getRevocableMemoryReservationInBytes(), - pipelineStats.getSystemMemoryReservationInBytes(), - pipelineStats.getQueuedTime(), - pipelineStats.getElapsedTime(), - pipelineStats.getTotalScheduledTimeInNanos(), - pipelineStats.getTotalCpuTimeInNanos(), - pipelineStats.getTotalBlockedTimeInNanos(), - pipelineStats.isFullyBlocked(), - pipelineStats.getBlockedReasons(), - pipelineStats.getTotalAllocationInBytes(), - pipelineStats.getRawInputDataSizeInBytes(), - pipelineStats.getRawInputPositions(), - pipelineStats.getProcessedInputDataSizeInBytes(), - pipelineStats.getProcessedInputPositions(), - pipelineStats.getOutputDataSizeInBytes(), - pipelineStats.getOutputPositions(), - pipelineStats.getPhysicalWrittenDataSizeInBytes(), - convertToThriftOperatorStatsList(pipelineStats.getOperatorSummaries()), - convertToThriftDriverStatsList(pipelineStats.getDrivers())); - } - - private static List convertToThriftDriverStatsList(List drivers) - { - return drivers.stream() - .map(d -> d.getOperatorStats().isEmpty() ? d : convertToThriftDriverStats(d)) - .collect(toList()); - } - - private static DriverStats convertToThriftDriverStats(DriverStats driverStats) - { - return new DriverStats( - driverStats.getLifespan(), - driverStats.getCreateTimeInMillis(), - driverStats.getStartTimeInMillis(), - driverStats.getEndTimeInMillis(), - driverStats.getQueuedTime(), - driverStats.getElapsedTime(), - driverStats.getUserMemoryReservationInBytes(), - driverStats.getRevocableMemoryReservationInBytes(), - driverStats.getSystemMemoryReservationInBytes(), - driverStats.getTotalScheduledTime(), - driverStats.getTotalCpuTime(), - driverStats.getTotalBlockedTime(), - driverStats.isFullyBlocked(), - driverStats.getBlockedReasons(), - driverStats.getTotalAllocationInBytes(), - driverStats.getRawInputDataSizeInBytes(), - driverStats.getRawInputPositions(), - driverStats.getRawInputReadTime(), - driverStats.getProcessedInputDataSizeInBytes(), - driverStats.getProcessedInputPositions(), - driverStats.getOutputDataSizeInBytes(), - driverStats.getOutputPositions(), - driverStats.getPhysicalWrittenDataSizeInBytes(), - convertToThriftOperatorStatsList(driverStats.getOperatorStats())); - } - - private static List convertToThriftOperatorStatsList(List operatorSummaries) - { - return operatorSummaries.stream() - .map(operatorStats -> operatorStats.getInfo() != null ? convertToThriftOperatorStats(operatorStats) : operatorStats) - .collect(toList()); - } - - private static OperatorStats convertToThriftOperatorStats(OperatorStats operatorStats) - { - return new OperatorStats( - operatorStats.getStageId(), - operatorStats.getStageExecutionId(), - operatorStats.getPipelineId(), - operatorStats.getOperatorId(), - operatorStats.getPlanNodeId(), - operatorStats.getOperatorType(), - operatorStats.getTotalDrivers(), - operatorStats.getIsBlockedCalls(), - operatorStats.getIsBlockedWall(), - operatorStats.getIsBlockedCpu(), - operatorStats.getIsBlockedAllocationInBytes(), - operatorStats.getAddInputCalls(), - operatorStats.getAddInputWall(), - operatorStats.getAddInputCpu(), - operatorStats.getAddInputAllocationInBytes(), - operatorStats.getRawInputDataSizeInBytes(), - operatorStats.getRawInputPositions(), - operatorStats.getInputDataSizeInBytes(), - operatorStats.getInputPositions(), - operatorStats.getSumSquaredInputPositions(), - operatorStats.getGetOutputCalls(), - operatorStats.getGetOutputWall(), - operatorStats.getGetOutputCpu(), - operatorStats.getGetOutputAllocationInBytes(), - operatorStats.getOutputDataSizeInBytes(), - operatorStats.getOutputPositions(), - operatorStats.getPhysicalWrittenDataSizeInBytes(), - operatorStats.getAdditionalCpu(), - operatorStats.getBlockedWall(), - operatorStats.getFinishCalls(), - operatorStats.getFinishWall(), - operatorStats.getFinishCpu(), - operatorStats.getFinishAllocationInBytes(), - operatorStats.getUserMemoryReservationInBytes(), - operatorStats.getRevocableMemoryReservationInBytes(), - operatorStats.getSystemMemoryReservationInBytes(), - operatorStats.getPeakUserMemoryReservationInBytes(), - operatorStats.getPeakSystemMemoryReservationInBytes(), - operatorStats.getPeakTotalMemoryReservationInBytes(), - operatorStats.getSpilledDataSizeInBytes(), - operatorStats.getBlockedReason(), - operatorStats.getRuntimeStats(), - operatorStats.getDynamicFilterStats(), - convertToOperatorInfoUnion(operatorStats.getInfo()), - operatorStats.getNullJoinBuildKeyCount(), - operatorStats.getJoinBuildKeyCount(), - operatorStats.getNullJoinProbeKeyCount(), - operatorStats.getJoinProbeKeyCount()); - } - - private static MetadataUpdates convertToThriftMetadataUpdates( - MetadataUpdates metadataUpdates, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - HandleResolver handleResolver) - { - List metadataUpdateHandles = metadataUpdates.getMetadataUpdates(); - if (metadataUpdateHandles.isEmpty()) { - return new MetadataUpdates(metadataUpdates.getConnectorId(), ImmutableList.of(), true); - } - ConnectorTypeSerde connectorTypeSerde = - connectorTypeSerdeManager.getMetadataUpdateHandleSerde(metadataUpdates.getConnectorId()); - List anyMetadataHandles = convertToAny(metadataUpdateHandles, connectorTypeSerde, handleResolver); - return new MetadataUpdates(metadataUpdates.getConnectorId(), anyMetadataHandles, true); - } - - private static List convertToAny( - List connectorMetadataUpdateHandles, - ConnectorTypeSerde connectorTypeSerde, - HandleResolver handleResolver) - { - return connectorMetadataUpdateHandles.stream() - .map(e -> new Any(handleResolver.getId(e), connectorTypeSerde.serialize(e))) - .collect(toList()); - } - - public static TaskInfo convertFromThriftTaskInfo( - TaskInfo thriftTaskInfo, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - HandleResolver handleResolver) - { - return new TaskInfo( - thriftTaskInfo.getTaskId(), - thriftTaskInfo.getTaskStatus(), - thriftTaskInfo.getLastHeartbeatInMillis(), - thriftTaskInfo.getOutputBuffers(), - thriftTaskInfo.getNoMoreSplits(), - convertFromThriftTaskStats(thriftTaskInfo.getStats()), - thriftTaskInfo.isNeedsPlan(), - convertFromThriftMetadataUpdates(thriftTaskInfo.getMetadataUpdates(), connectorTypeSerdeManager, handleResolver), - thriftTaskInfo.getNodeId()); - } - - private static TaskStats convertFromThriftTaskStats(TaskStats thriftTaskStats) - { - if (thriftTaskStats.getPipelines().isEmpty()) { - return thriftTaskStats; - } - - return new TaskStats( - thriftTaskStats.getCreateTimeInMillis(), - thriftTaskStats.getFirstStartTimeInMillis(), - thriftTaskStats.getLastStartTimeInMillis(), - thriftTaskStats.getLastEndTimeInMillis(), - thriftTaskStats.getEndTimeInMillis(), - thriftTaskStats.getElapsedTimeInNanos(), - thriftTaskStats.getQueuedTimeInNanos(), - thriftTaskStats.getTotalDrivers(), - thriftTaskStats.getQueuedDrivers(), - thriftTaskStats.getQueuedPartitionedDrivers(), - thriftTaskStats.getQueuedPartitionedSplitsWeight(), - thriftTaskStats.getRunningDrivers(), - thriftTaskStats.getRunningPartitionedDrivers(), - thriftTaskStats.getRunningPartitionedSplitsWeight(), - thriftTaskStats.getBlockedDrivers(), - thriftTaskStats.getCompletedDrivers(), - thriftTaskStats.getCumulativeUserMemory(), - thriftTaskStats.getCumulativeTotalMemory(), - thriftTaskStats.getUserMemoryReservationInBytes(), - thriftTaskStats.getRevocableMemoryReservationInBytes(), - thriftTaskStats.getSystemMemoryReservationInBytes(), - thriftTaskStats.getPeakUserMemoryInBytes(), - thriftTaskStats.getPeakTotalMemoryInBytes(), - thriftTaskStats.getPeakNodeTotalMemoryInBytes(), - thriftTaskStats.getTotalScheduledTimeInNanos(), - thriftTaskStats.getTotalCpuTimeInNanos(), - thriftTaskStats.getTotalBlockedTimeInNanos(), - thriftTaskStats.isFullyBlocked(), - thriftTaskStats.getBlockedReasons(), - thriftTaskStats.getTotalAllocationInBytes(), - thriftTaskStats.getRawInputDataSizeInBytes(), - thriftTaskStats.getRawInputPositions(), - thriftTaskStats.getProcessedInputDataSizeInBytes(), - thriftTaskStats.getProcessedInputPositions(), - thriftTaskStats.getOutputDataSizeInBytes(), - thriftTaskStats.getOutputPositions(), - thriftTaskStats.getPhysicalWrittenDataSizeInBytes(), - thriftTaskStats.getFullGcCount(), - thriftTaskStats.getFullGcTimeInMillis(), - convertFromThriftPipeLineStatsList(thriftTaskStats.getPipelines()), - thriftTaskStats.getRuntimeStats()); - } - - private static List convertFromThriftPipeLineStatsList(List pipelines) - { - return pipelines.stream() - .map(TaskResourceUtils::convertFromThriftPipelineStats) - .collect(toList()); - } - - private static PipelineStats convertFromThriftPipelineStats(PipelineStats thriftPipelineStats) - { - if (thriftPipelineStats.getDrivers().isEmpty() && thriftPipelineStats.getOperatorSummaries().isEmpty()) { - return thriftPipelineStats; - } - - return new PipelineStats( - thriftPipelineStats.getPipelineId(), - thriftPipelineStats.getFirstStartTimeInMillis(), - thriftPipelineStats.getLastStartTimeInMillis(), - thriftPipelineStats.getLastEndTimeInMillis(), - thriftPipelineStats.isInputPipeline(), - thriftPipelineStats.isOutputPipeline(), - thriftPipelineStats.getTotalDrivers(), - thriftPipelineStats.getQueuedDrivers(), - thriftPipelineStats.getQueuedPartitionedDrivers(), - thriftPipelineStats.getQueuedPartitionedSplitsWeight(), - thriftPipelineStats.getRunningDrivers(), - thriftPipelineStats.getRunningPartitionedDrivers(), - thriftPipelineStats.getRunningPartitionedSplitsWeight(), - thriftPipelineStats.getBlockedDrivers(), - thriftPipelineStats.getCompletedDrivers(), - thriftPipelineStats.getUserMemoryReservationInBytes(), - thriftPipelineStats.getRevocableMemoryReservationInBytes(), - thriftPipelineStats.getSystemMemoryReservationInBytes(), - thriftPipelineStats.getQueuedTime(), - thriftPipelineStats.getElapsedTime(), - thriftPipelineStats.getTotalScheduledTimeInNanos(), - thriftPipelineStats.getTotalCpuTimeInNanos(), - thriftPipelineStats.getTotalBlockedTimeInNanos(), - thriftPipelineStats.isFullyBlocked(), - thriftPipelineStats.getBlockedReasons(), - thriftPipelineStats.getTotalAllocationInBytes(), - thriftPipelineStats.getRawInputDataSizeInBytes(), - thriftPipelineStats.getRawInputPositions(), - thriftPipelineStats.getProcessedInputDataSizeInBytes(), - thriftPipelineStats.getProcessedInputPositions(), - thriftPipelineStats.getOutputDataSizeInBytes(), - thriftPipelineStats.getOutputPositions(), - thriftPipelineStats.getPhysicalWrittenDataSizeInBytes(), - convertFromThriftOperatorStatsList(thriftPipelineStats.getOperatorSummaries()), - convertFromThriftDriverStatsList(thriftPipelineStats.getDrivers())); - } - - private static List convertFromThriftDriverStatsList(List thriftDrivers) - { - return thriftDrivers.stream() - .map(driverStats -> driverStats.getOperatorStats().isEmpty() ? driverStats : convertFromThriftDriverStats(driverStats)) - .collect(toList()); - } - - private static DriverStats convertFromThriftDriverStats(DriverStats thriftDriverStats) - { - return new DriverStats( - thriftDriverStats.getLifespan(), - thriftDriverStats.getCreateTimeInMillis(), - thriftDriverStats.getStartTimeInMillis(), - thriftDriverStats.getEndTimeInMillis(), - thriftDriverStats.getQueuedTime(), - thriftDriverStats.getElapsedTime(), - thriftDriverStats.getUserMemoryReservationInBytes(), - thriftDriverStats.getRevocableMemoryReservationInBytes(), - thriftDriverStats.getSystemMemoryReservationInBytes(), - thriftDriverStats.getTotalScheduledTime(), - thriftDriverStats.getTotalCpuTime(), - thriftDriverStats.getTotalBlockedTime(), - thriftDriverStats.isFullyBlocked(), - thriftDriverStats.getBlockedReasons(), - thriftDriverStats.getTotalAllocationInBytes(), - thriftDriverStats.getRawInputDataSizeInBytes(), - thriftDriverStats.getRawInputPositions(), - thriftDriverStats.getRawInputReadTime(), - thriftDriverStats.getProcessedInputDataSizeInBytes(), - thriftDriverStats.getProcessedInputPositions(), - thriftDriverStats.getOutputDataSizeInBytes(), - thriftDriverStats.getOutputPositions(), - thriftDriverStats.getPhysicalWrittenDataSizeInBytes(), - convertFromThriftOperatorStatsList(thriftDriverStats.getOperatorStats())); - } - - private static List convertFromThriftOperatorStatsList(List thriftOperatorSummaries) - { - return thriftOperatorSummaries.stream() - .map(operatorStats -> operatorStats.getInfoUnion() != null ? convertFromThriftOperatorStats(operatorStats) : operatorStats) - .collect(toList()); - } - - private static OperatorStats convertFromThriftOperatorStats(OperatorStats thriftOperatorStats) - { - return new OperatorStats( - thriftOperatorStats.getStageId(), - thriftOperatorStats.getStageExecutionId(), - thriftOperatorStats.getPipelineId(), - thriftOperatorStats.getOperatorId(), - thriftOperatorStats.getPlanNodeId(), - thriftOperatorStats.getOperatorType(), - thriftOperatorStats.getTotalDrivers(), - thriftOperatorStats.getIsBlockedCalls(), - thriftOperatorStats.getIsBlockedWall(), - thriftOperatorStats.getIsBlockedCpu(), - thriftOperatorStats.getIsBlockedAllocationInBytes(), - thriftOperatorStats.getAddInputCalls(), - thriftOperatorStats.getAddInputWall(), - thriftOperatorStats.getAddInputCpu(), - thriftOperatorStats.getAddInputAllocationInBytes(), - thriftOperatorStats.getRawInputDataSizeInBytes(), - thriftOperatorStats.getRawInputPositions(), - thriftOperatorStats.getInputDataSizeInBytes(), - thriftOperatorStats.getInputPositions(), - thriftOperatorStats.getSumSquaredInputPositions(), - thriftOperatorStats.getGetOutputCalls(), - thriftOperatorStats.getGetOutputWall(), - thriftOperatorStats.getGetOutputCpu(), - thriftOperatorStats.getGetOutputAllocationInBytes(), - thriftOperatorStats.getOutputDataSizeInBytes(), - thriftOperatorStats.getOutputPositions(), - thriftOperatorStats.getPhysicalWrittenDataSizeInBytes(), - thriftOperatorStats.getAdditionalCpu(), - thriftOperatorStats.getBlockedWall(), - thriftOperatorStats.getFinishCalls(), - thriftOperatorStats.getFinishWall(), - thriftOperatorStats.getFinishCpu(), - thriftOperatorStats.getFinishAllocationInBytes(), - thriftOperatorStats.getUserMemoryReservationInBytes(), - thriftOperatorStats.getRevocableMemoryReservationInBytes(), - thriftOperatorStats.getSystemMemoryReservationInBytes(), - thriftOperatorStats.getPeakUserMemoryReservationInBytes(), - thriftOperatorStats.getPeakSystemMemoryReservationInBytes(), - thriftOperatorStats.getPeakTotalMemoryReservationInBytes(), - thriftOperatorStats.getSpilledDataSizeInBytes(), - thriftOperatorStats.getBlockedReason(), - convertToOperatorInfo(thriftOperatorStats.getInfoUnion()), - thriftOperatorStats.getRuntimeStats(), - thriftOperatorStats.getDynamicFilterStats(), - thriftOperatorStats.getNullJoinBuildKeyCount(), - thriftOperatorStats.getJoinBuildKeyCount(), - thriftOperatorStats.getNullJoinProbeKeyCount(), - thriftOperatorStats.getJoinProbeKeyCount()); - } - - private static MetadataUpdates convertFromThriftMetadataUpdates( - MetadataUpdates metadataUpdates, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - HandleResolver handleResolver) - { - List metadataUpdateHandles = metadataUpdates.getMetadataUpdatesAny(); - if (metadataUpdateHandles.isEmpty()) { - return new MetadataUpdates(metadataUpdates.getConnectorId(), ImmutableList.of()); - } - ConnectorTypeSerde connectorTypeSerde = - connectorTypeSerdeManager.getMetadataUpdateHandleSerde(metadataUpdates.getConnectorId()); - List connectorMetadataUpdateHandles = convertToConnector(metadataUpdateHandles, connectorTypeSerde, handleResolver); - return new MetadataUpdates(metadataUpdates.getConnectorId(), connectorMetadataUpdateHandles); - } - - private static List convertToConnector( - List metadataUpdateHandles, - ConnectorTypeSerde connectorTypeSerde, - HandleResolver handleResolver) - { - return metadataUpdateHandles.stream() - .map(e -> connectorTypeSerde.deserialize(handleResolver.getMetadataUpdateHandleClass(e.getId()), e.getBytes())) - .collect(toList()); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/WebUiResource.java b/presto-main-base/src/main/java/com/facebook/presto/server/WebUiResource.java deleted file mode 100644 index c45774edc31e7..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/WebUiResource.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server; - -import javax.annotation.security.RolesAllowed; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; - -import static com.facebook.presto.server.security.RoleType.ADMIN; -import static com.google.common.base.Strings.isNullOrEmpty; -import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; -import static javax.ws.rs.core.Response.Status.MOVED_PERMANENTLY; - -@Path("/") -@RolesAllowed(ADMIN) -public class WebUiResource -{ - @GET - public Response redirectIndexHtml( - @HeaderParam(X_FORWARDED_PROTO) String proto, - @Context UriInfo uriInfo) - { - if (isNullOrEmpty(proto)) { - proto = uriInfo.getRequestUri().getScheme(); - } - - return Response.status(MOVED_PERMANENTLY) - .location(uriInfo.getRequestUriBuilder().scheme(proto).path("/ui/").build()) - .build(); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryBlockingRateLimiter.java b/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryBlockingRateLimiter.java index 7cfc2e0947e6f..130a2732b1349 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryBlockingRateLimiter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryBlockingRateLimiter.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.QueryManagerConfig; import com.facebook.presto.spi.QueryId; import com.google.common.cache.CacheBuilder; @@ -24,12 +25,10 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.RateLimiter; import com.google.inject.Inject; -import io.airlift.units.Duration; +import jakarta.annotation.PreDestroy; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; - import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/RetryCircuitBreaker.java b/presto-main-base/src/main/java/com/facebook/presto/server/protocol/RetryCircuitBreaker.java index 8201518284837..6ae4480230fc7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/RetryCircuitBreaker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/protocol/RetryCircuitBreaker.java @@ -14,9 +14,9 @@ package com.facebook.presto.server.protocol; import com.facebook.airlift.stats.DecayCounter; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.QueryManagerConfig; import com.google.inject.Inject; -import io.airlift.units.Duration; import org.weakref.jmx.Managed; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/remotetask/Backoff.java b/presto-main-base/src/main/java/com/facebook/presto/server/remotetask/Backoff.java index e8041bc14a7ba..4947017880bd5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/remotetask/Backoff.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/remotetask/Backoff.java @@ -13,12 +13,11 @@ */ package com.facebook.presto.server.remotetask; +import com.facebook.airlift.units.Duration; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/security/PrestoAuthenticatorManager.java b/presto-main-base/src/main/java/com/facebook/presto/server/security/PrestoAuthenticatorManager.java index dcd60ae76c4d5..33f0644666c9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/security/PrestoAuthenticatorManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/security/PrestoAuthenticatorManager.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.security.PrestoAuthenticatorFactory; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.util.HashMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/security/SecurityConfig.java b/presto-main-base/src/main/java/com/facebook/presto/server/security/SecurityConfig.java index 9d2e74e939a0a..49a5de32345e5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/security/SecurityConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/security/SecurityConfig.java @@ -18,8 +18,7 @@ import com.facebook.airlift.configuration.DefunctConfig; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.util.List; @@ -34,6 +33,7 @@ public class SecurityConfig private List authenticationTypes = ImmutableList.of(); private boolean allowForwardedHttps; private boolean authorizedIdentitySelectionEnabled; + private boolean enableSqlQueryTextContextField; public enum AuthenticationType { @@ -41,7 +41,9 @@ public enum AuthenticationType KERBEROS, PASSWORD, JWT, - CUSTOM + CUSTOM, + TEST_EXTERNAL, + OAUTH2 } @NotNull @@ -57,7 +59,7 @@ public SecurityConfig setAuthenticationTypes(List authentica } @Config("http-server.authentication.type") - @ConfigDescription("Authentication types (supported types: CERTIFICATE, KERBEROS, PASSWORD, JWT, CUSTOM)") + @ConfigDescription("Authentication types (supported types: CERTIFICATE, KERBEROS, PASSWORD, JWT, CUSTOM, OAUTH2, TEST_EXTERNAL)") public SecurityConfig setAuthenticationTypes(String types) { if (types == null) { @@ -96,4 +98,17 @@ public boolean isAuthorizedIdentitySelectionEnabled() { return authorizedIdentitySelectionEnabled; } + + @Config("permissions.enable-sql-query-text-context-field") + @ConfigDescription("Allow sql query text to be stored inside access control context") + public SecurityConfig setEnableSqlQueryTextContextField(boolean enableSqlQueryTextContextField) + { + this.enableSqlQueryTextContextField = enableSqlQueryTextContextField; + return this; + } + + public boolean isEnableSqlQueryTextContextField() + { + return enableSqlQueryTextContextField; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/AbstractTypedThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/AbstractTypedThriftCodec.java new file mode 100644 index 0000000000000..f5d394b6cda7d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/AbstractTypedThriftCodec.java @@ -0,0 +1,222 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.drift.annotations.ThriftField.Requiredness; +import com.facebook.drift.codec.ThriftCodec; +import com.facebook.drift.codec.metadata.DefaultThriftTypeReference; +import com.facebook.drift.codec.metadata.FieldKind; +import com.facebook.drift.codec.metadata.ThriftFieldMetadata; +import com.facebook.drift.codec.metadata.ThriftMethodInjection; +import com.facebook.drift.codec.metadata.ThriftStructMetadata; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TField; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.drift.protocol.TStruct; +import com.facebook.drift.protocol.TType; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +import static java.util.Objects.requireNonNull; + +public abstract class AbstractTypedThriftCodec + implements ThriftCodec +{ + private static final Set NON_THRIFT_CONNECTOR = new HashSet<>(); + private static final Logger log = Logger.get(AbstractTypedThriftCodec.class); + private static final String TYPE_VALUE = "connectorId"; + private static final String CUSTOM_SERIALIZED_VALUE = "customSerializedValue"; + private static final String JSON_VALUE = "jsonValue"; + private static final short TYPE_FIELD_ID = 1; + private static final short CUSTOM_FIELD_ID = 2; + private static final short JSON_FIELD_ID = 3; + + private final Class baseClass; + private final JsonCodec jsonCodec; + private final Function nameResolver; + private final Function> classResolver; + + protected AbstractTypedThriftCodec(Class baseClass, + JsonCodec jsonCodec, + Function nameResolver, + Function> classResolver) + { + this.baseClass = requireNonNull(baseClass, "baseClass is null"); + this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); + this.nameResolver = requireNonNull(nameResolver, "nameResolver is null"); + this.classResolver = requireNonNull(classResolver, "classResolver is null"); + } + + @Override + public abstract ThriftType getType(); + + protected static ThriftType createThriftType(Class baseClass) + { + List fields = new ArrayList<>(); + try { + fields.add(new ThriftFieldMetadata( + TYPE_FIELD_ID, + false, false, Requiredness.OPTIONAL, ImmutableMap.of(), + new DefaultThriftTypeReference(ThriftType.STRING), + TYPE_VALUE, + FieldKind.THRIFT_FIELD, + ImmutableList.of(), + Optional.empty(), + // Drift requires at least one of the three arguments below, so we provide a dummy method here as a workaround. + // https://github.com/airlift/drift/blob/master/drift-codec/src/main/java/io/airlift/drift/codec/metadata/ThriftFieldMetadata.java#L99 + Optional.of(new ThriftMethodInjection(AbstractTypedThriftCodec.class.getDeclaredMethod("getTypeField"), ImmutableList.of())), + Optional.empty(), + Optional.empty())); + fields.add(new ThriftFieldMetadata( + CUSTOM_FIELD_ID, + false, false, Requiredness.OPTIONAL, ImmutableMap.of(), + new DefaultThriftTypeReference(ThriftType.BINARY), + CUSTOM_SERIALIZED_VALUE, + FieldKind.THRIFT_FIELD, + ImmutableList.of(), + Optional.empty(), + // Drift requires at least one of the three arguments below, so we provide a dummy method here as a workaround. + // https://github.com/airlift/drift/blob/master/drift-codec/src/main/java/io/airlift/drift/codec/metadata/ThriftFieldMetadata.java#L99 + Optional.of(new ThriftMethodInjection(AbstractTypedThriftCodec.class.getDeclaredMethod("getCustomField"), ImmutableList.of())), + Optional.empty(), + Optional.empty())); + // TODO: This field will be cleaned up: https://github.com/prestodb/presto/issues/25671 + fields.add(new ThriftFieldMetadata( + JSON_FIELD_ID, + false, false, Requiredness.OPTIONAL, ImmutableMap.of(), + new DefaultThriftTypeReference(ThriftType.STRING), + JSON_VALUE, + FieldKind.THRIFT_FIELD, + ImmutableList.of(), + Optional.empty(), + // Drift requires at least one of the three arguments below, so we provide a dummy method here as a workaround. + // https://github.com/airlift/drift/blob/master/drift-codec/src/main/java/io/airlift/drift/codec/metadata/ThriftFieldMetadata.java#L99 + Optional.of(new ThriftMethodInjection(AbstractTypedThriftCodec.class.getDeclaredMethod("getJsonField"), ImmutableList.of())), + Optional.empty(), + Optional.empty())); + } + catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Failed to create ThriftFieldMetadata", e); + } + + return ThriftType.struct(new ThriftStructMetadata( + baseClass.getSimpleName(), + ImmutableMap.of(), baseClass, null, ThriftStructMetadata.MetadataType.STRUCT, + Optional.empty(), ImmutableList.of(), fields, Optional.empty(), ImmutableList.of())); + } + + @Override + public T read(TProtocolReader reader) + throws Exception + { + String connectorId = null; + T value = null; + String jsonValue = null; + + reader.readStructBegin(); + while (true) { + TField field = reader.readFieldBegin(); + if (field.getType() == TType.STOP) { + break; + } + switch (field.getId()) { + case JSON_FIELD_ID: + jsonValue = reader.readString(); + break; + case TYPE_FIELD_ID: + connectorId = reader.readString(); + break; + case CUSTOM_FIELD_ID: + requireNonNull(connectorId, "connectorId is null"); + Class concreteClass = classResolver.apply(connectorId); + requireNonNull(concreteClass, "concreteClass is null"); + value = readConcreteValue(connectorId, reader); + break; + } + reader.readFieldEnd(); + } + reader.readStructEnd(); + + if (jsonValue != null) { + return jsonCodec.fromJson(jsonValue); + } + if (value != null) { + return value; + } + throw new IllegalStateException("Neither thrift nor json value was present"); + } + + public abstract T readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception; + + public abstract void writeConcreteValue(String connectorId, T value, TProtocolWriter writer) + throws Exception; + + public abstract boolean isThriftCodecAvailable(String connectorId); + + @Override + public void write(T value, TProtocolWriter writer) + throws Exception + { + if (value == null) { + return; + } + String connectorId = nameResolver.apply(value); + requireNonNull(connectorId, "connectorId is null"); + + writer.writeStructBegin(new TStruct(baseClass.getSimpleName())); + if (isThriftCodecAvailable(connectorId)) { + writer.writeFieldBegin(new TField(TYPE_VALUE, TType.STRING, TYPE_FIELD_ID)); + writer.writeString(connectorId); + writer.writeFieldEnd(); + + writer.writeFieldBegin(new TField(CUSTOM_SERIALIZED_VALUE, TType.STRING, CUSTOM_FIELD_ID)); + writeConcreteValue(connectorId, value, writer); + writer.writeFieldEnd(); + } + else { + // If thrift codec is not available for this connector, fall back to its json + writer.writeFieldBegin(new TField(JSON_VALUE, TType.STRING, JSON_FIELD_ID)); + writer.writeString(jsonCodec.toJson(value)); + writer.writeFieldEnd(); + } + writer.writeFieldStop(); + writer.writeStructEnd(); + } + + private String getTypeField() + { + return "getTypeField"; + } + + private String getCustomField() + { + return "getCustomField"; + } + + private String getJsonField() + { + return "getJsonField"; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ConnectorSplitThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ConnectorSplitThriftCodec.java new file mode 100644 index 0000000000000..234a53611f421 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ConnectorSplitThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorSplit; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class ConnectorSplitThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorSplit.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public ConnectorSplitThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorSplit.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getSplitClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorSplit readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getConnectorSplitCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorSplit value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getConnectorSplitCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Cannot serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getConnectorSplitCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/CustomCodecUtils.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/CustomCodecUtils.java deleted file mode 100644 index 483c0c3def17d..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/CustomCodecUtils.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.thrift; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.drift.TException; -import com.facebook.drift.codec.metadata.DefaultThriftTypeReference; -import com.facebook.drift.codec.metadata.FieldKind; -import com.facebook.drift.codec.metadata.ThriftFieldExtractor; -import com.facebook.drift.codec.metadata.ThriftFieldMetadata; -import com.facebook.drift.codec.metadata.ThriftStructMetadata; -import com.facebook.drift.codec.metadata.ThriftType; -import com.facebook.drift.protocol.TField; -import com.facebook.drift.protocol.TProtocolException; -import com.facebook.drift.protocol.TProtocolReader; -import com.facebook.drift.protocol.TProtocolWriter; -import com.facebook.drift.protocol.TStruct; -import com.facebook.drift.protocol.TType; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; - -import java.util.Optional; - -import static com.facebook.drift.annotations.ThriftField.Requiredness.NONE; -import static java.lang.String.format; - -/*** - * When we need a custom codec for a primitive type, we need a wrapper to pass the needsCodec check within ThriftCodecByteCodeGenerator.java - */ -public class CustomCodecUtils -{ - private CustomCodecUtils() {} - - public static ThriftStructMetadata createSyntheticMetadata(short fieldId, String fieldName, Class originalType, Class referencedType, ThriftType thriftType) - { - ThriftFieldMetadata fieldMetaData = new ThriftFieldMetadata( - fieldId, - false, false, NONE, ImmutableMap.of(), - new DefaultThriftTypeReference(thriftType), - fieldName, - FieldKind.THRIFT_FIELD, - ImmutableList.of(), - Optional.empty(), - Optional.empty(), - Optional.of(new ThriftFieldExtractor( - fieldId, - fieldName, - FieldKind.THRIFT_FIELD, - originalType.getDeclaredFields()[0], // Any field should work since we are handing extraction in codec on our own - referencedType)), - Optional.empty()); - return new ThriftStructMetadata( - originalType.getSimpleName() + "Wrapper", - ImmutableMap.of(), - originalType, null, - ThriftStructMetadata.MetadataType.STRUCT, - Optional.empty(), ImmutableList.of(), ImmutableList.of(fieldMetaData), Optional.empty(), ImmutableList.of()); - } - - public static T readSingleJsonField(TProtocolReader protocol, JsonCodec jsonCodec, short fieldId, String fieldName) - throws TException - { - protocol.readStructBegin(); - String jsonValue = null; - TField field = protocol.readFieldBegin(); - while (field.getType() != TType.STOP) { - if (field.getId() == fieldId) { - if (field.getType() == TType.STRING) { - jsonValue = protocol.readString(); - } - else { - throw new TProtocolException(format("Unexpected field type: %s for field %s", field.getType(), fieldName)); - } - } - protocol.readFieldEnd(); - field = protocol.readFieldBegin(); - } - protocol.readStructEnd(); - - if (jsonValue == null) { - throw new TProtocolException(format("Required field '%s' was not found", fieldName)); - } - return jsonCodec.fromJson(jsonValue); - } - - public static void writeSingleJsonField(T value, TProtocolWriter protocol, JsonCodec jsonCodec, short fieldId, String fieldName, String structName) - throws TException - { - protocol.writeStructBegin(new TStruct(structName)); - - protocol.writeFieldBegin(new TField(fieldName, TType.STRING, fieldId)); - protocol.writeString(jsonCodec.toJson(value)); - protocol.writeFieldEnd(); - - protocol.writeFieldStop(); - protocol.writeStructEnd(); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/DeleteTableHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/DeleteTableHandleThriftCodec.java new file mode 100644 index 0000000000000..37232f59e1fba --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/DeleteTableHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorDeleteTableHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class DeleteTableHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorDeleteTableHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public DeleteTableHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorDeleteTableHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getDeleteTableHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorDeleteTableHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getDeleteTableHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorDeleteTableHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getDeleteTableHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getDeleteTableHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/HandleThriftModule.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/HandleThriftModule.java new file mode 100644 index 0000000000000..71112a1fd1371 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/HandleThriftModule.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.drift.codec.guice.ThriftCodecBinder.thriftCodecBinder; + +public class HandleThriftModule + implements Module +{ + @Override + public void configure(Binder binder) + { + thriftCodecBinder(binder).bindCustomThriftCodec(ConnectorSplitThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TransactionHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(OutputTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(InsertTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(DeleteTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(MergeTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableLayoutHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableHandleThriftCodec.class); + + jsonCodecBinder(binder).bindJsonCodec(ConnectorSplit.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTransactionHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorOutputTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorInsertTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorDeleteTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorMergeTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableLayoutHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableHandle.class); + + binder.bind(HandleResolver.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/InsertTableHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/InsertTableHandleThriftCodec.java new file mode 100644 index 0000000000000..c01480a5c535e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/InsertTableHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorInsertTableHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class InsertTableHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorInsertTableHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public InsertTableHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorInsertTableHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getInsertTableHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorInsertTableHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getInsertTableHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorInsertTableHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getInsertTableHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getInsertTableHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MergeTableHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MergeTableHandleThriftCodec.java new file mode 100644 index 0000000000000..694c6e49eaf7a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MergeTableHandleThriftCodec.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorMergeTableHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class MergeTableHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorMergeTableHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public MergeTableHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorMergeTableHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getMergeTableHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorMergeTableHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + checkArgument(byteBuffer.position() == 0, "Buffer position should be 0, but is %s", byteBuffer.position()); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getMergeTableHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorMergeTableHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getMergeTableHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getMergeTableHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MetadataUpdatesCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MetadataUpdatesCodec.java deleted file mode 100644 index f14ddf4295513..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/MetadataUpdatesCodec.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.thrift; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.drift.codec.CodecThriftType; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.metadata.ThriftCatalog; -import com.facebook.drift.codec.metadata.ThriftType; -import com.facebook.drift.protocol.TProtocolReader; -import com.facebook.drift.protocol.TProtocolWriter; -import com.facebook.presto.metadata.MetadataUpdates; - -import javax.inject.Inject; - -import static com.facebook.presto.server.thrift.CustomCodecUtils.createSyntheticMetadata; -import static com.facebook.presto.server.thrift.CustomCodecUtils.readSingleJsonField; -import static com.facebook.presto.server.thrift.CustomCodecUtils.writeSingleJsonField; -import static java.util.Objects.requireNonNull; - -public class MetadataUpdatesCodec - implements ThriftCodec -{ - private static final short METADATA_UPDATES_DATA_FIELD_ID = 1; - private static final String METADATA_UPDATES_DATA_FIELD_NAME = "metadataUpdates"; - private static final String METADATA_UPDATES_STRUCT_NAME = "MetadataUpdates"; - private static final ThriftType SYNTHETIC_STRUCT_TYPE = ThriftType.struct(createSyntheticMetadata(METADATA_UPDATES_DATA_FIELD_ID, METADATA_UPDATES_DATA_FIELD_NAME, MetadataUpdates.class, String.class, ThriftType.STRING)); - - private final JsonCodec jsonCodec; - - @Inject - public MetadataUpdatesCodec(JsonCodec jsonCodec, ThriftCatalog thriftCatalog) - { - this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); - thriftCatalog.addThriftType(SYNTHETIC_STRUCT_TYPE); - } - - @CodecThriftType - public static ThriftType getThriftType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public ThriftType getType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public MetadataUpdates read(TProtocolReader protocol) - throws Exception - { - return readSingleJsonField(protocol, jsonCodec, METADATA_UPDATES_DATA_FIELD_ID, METADATA_UPDATES_DATA_FIELD_NAME); - } - - @Override - public void write(MetadataUpdates value, TProtocolWriter protocol) - throws Exception - { - writeSingleJsonField(value, protocol, jsonCodec, METADATA_UPDATES_DATA_FIELD_ID, METADATA_UPDATES_DATA_FIELD_NAME, METADATA_UPDATES_STRUCT_NAME); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/OutputTableHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/OutputTableHandleThriftCodec.java new file mode 100644 index 0000000000000..b364dc694f8b7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/OutputTableHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorOutputTableHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class OutputTableHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorOutputTableHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public OutputTableHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorOutputTableHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getOutputTableHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorOutputTableHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getOutputTableHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorOutputTableHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getOutputTableHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getOutputTableHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/SplitCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/SplitCodec.java deleted file mode 100644 index e47fdcf9b3e19..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/SplitCodec.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.thrift; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.drift.codec.CodecThriftType; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.metadata.ThriftCatalog; -import com.facebook.drift.codec.metadata.ThriftType; -import com.facebook.drift.protocol.TProtocolReader; -import com.facebook.drift.protocol.TProtocolWriter; -import com.facebook.presto.metadata.Split; - -import javax.inject.Inject; - -import static com.facebook.presto.server.thrift.CustomCodecUtils.createSyntheticMetadata; -import static com.facebook.presto.server.thrift.CustomCodecUtils.readSingleJsonField; -import static com.facebook.presto.server.thrift.CustomCodecUtils.writeSingleJsonField; -import static java.util.Objects.requireNonNull; - -public class SplitCodec - implements ThriftCodec -{ - private static final short SPLIT_DATA_FIELD_ID = 1; - private static final String SPLIT_DATA_FIELD_NAME = "split"; - private static final String SPLIT_DATA_STRUCT_NAME = "Split"; - private static final ThriftType SYNTHETIC_STRUCT_TYPE = ThriftType.struct(createSyntheticMetadata(SPLIT_DATA_FIELD_ID, SPLIT_DATA_FIELD_NAME, Split.class, String.class, ThriftType.STRING)); - - private final JsonCodec jsonCodec; - - @Inject - public SplitCodec(JsonCodec jsonCodec, ThriftCatalog thriftCatalog) - { - this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); - thriftCatalog.addThriftType(SYNTHETIC_STRUCT_TYPE); - } - - @CodecThriftType - public static ThriftType getThriftType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public ThriftType getType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public Split read(TProtocolReader protocol) - throws Exception - { - return readSingleJsonField(protocol, jsonCodec, SPLIT_DATA_FIELD_ID, SPLIT_DATA_FIELD_NAME); - } - - @Override - public void write(Split value, TProtocolWriter protocol) - throws Exception - { - writeSingleJsonField(value, protocol, jsonCodec, SPLIT_DATA_FIELD_ID, SPLIT_DATA_FIELD_NAME, SPLIT_DATA_STRUCT_NAME); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableHandleThriftCodec.java new file mode 100644 index 0000000000000..95b209879c762 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorTableHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class TableHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorTableHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public TableHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorTableHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getTableHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorTableHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getTableHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorTableHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getTableHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getTableHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableLayoutHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableLayoutHandleThriftCodec.java new file mode 100644 index 0000000000000..9387a6c0058b5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableLayoutHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class TableLayoutHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorTableLayoutHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public TableLayoutHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorTableLayoutHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getTableLayoutHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorTableLayoutHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getTableLayoutHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorTableLayoutHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getTableLayoutHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getTableLayoutHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableWriteInfoCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableWriteInfoCodec.java deleted file mode 100644 index 50754a0593bae..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TableWriteInfoCodec.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.thrift; - -import com.facebook.airlift.json.JsonCodec; -import com.facebook.drift.codec.CodecThriftType; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.metadata.ThriftCatalog; -import com.facebook.drift.codec.metadata.ThriftType; -import com.facebook.drift.protocol.TProtocolReader; -import com.facebook.drift.protocol.TProtocolWriter; -import com.facebook.presto.execution.scheduler.TableWriteInfo; - -import javax.inject.Inject; - -import static com.facebook.presto.server.thrift.CustomCodecUtils.createSyntheticMetadata; -import static com.facebook.presto.server.thrift.CustomCodecUtils.readSingleJsonField; -import static com.facebook.presto.server.thrift.CustomCodecUtils.writeSingleJsonField; -import static java.util.Objects.requireNonNull; - -public class TableWriteInfoCodec - implements ThriftCodec -{ - private static final short TABLE_WRITE_INFO_DATA_FIELD_ID = 1; - private static final String TABLE_WRITE_INFO_DATA_FIELD_NAME = "tableWriteInfo"; - private static final String TABLE_WRITE_INFO_STRUCT_NAME = "TableWriteInfo"; - private static final ThriftType SYNTHETIC_STRUCT_TYPE = ThriftType.struct(createSyntheticMetadata(TABLE_WRITE_INFO_DATA_FIELD_ID, TABLE_WRITE_INFO_DATA_FIELD_NAME, TableWriteInfo.class, String.class, ThriftType.STRING)); - - private final JsonCodec jsonCodec; - - @Inject - public TableWriteInfoCodec(JsonCodec jsonCodec, ThriftCatalog thriftCatalog) - { - this.jsonCodec = requireNonNull(jsonCodec, "jsonCodec is null"); - thriftCatalog.addThriftType(SYNTHETIC_STRUCT_TYPE); - } - - @CodecThriftType - public static ThriftType getThriftType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public ThriftType getType() - { - return SYNTHETIC_STRUCT_TYPE; - } - - @Override - public TableWriteInfo read(TProtocolReader protocol) - throws Exception - { - return readSingleJsonField(protocol, jsonCodec, TABLE_WRITE_INFO_DATA_FIELD_ID, TABLE_WRITE_INFO_DATA_FIELD_NAME); - } - - @Override - public void write(TableWriteInfo value, TProtocolWriter protocol) - throws Exception - { - writeSingleJsonField(value, protocol, jsonCodec, TABLE_WRITE_INFO_DATA_FIELD_ID, TABLE_WRITE_INFO_DATA_FIELD_NAME, TABLE_WRITE_INFO_STRUCT_NAME); - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftCodecUtils.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftCodecUtils.java new file mode 100644 index 0000000000000..4819ac2a00a40 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftCodecUtils.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.drift.codec.ThriftCodec; +import com.facebook.drift.protocol.TBinaryProtocol; +import com.facebook.drift.protocol.TMemoryBuffer; +import com.facebook.drift.protocol.TMemoryBufferWriteOnly; +import com.facebook.drift.protocol.TProtocolException; + +public class ThriftCodecUtils +{ + private ThriftCodecUtils() {} + + public static T fromThrift(byte[] bytes, ThriftCodec thriftCodec) + throws TProtocolException + { + try { + TMemoryBuffer transport = new TMemoryBuffer(bytes.length); + transport.write(bytes); + TBinaryProtocol protocol = new TBinaryProtocol(transport); + return thriftCodec.read(protocol); + } + catch (Exception e) { + throw new TProtocolException("Can not deserialize the data", e); + } + } + + public static byte[] toThrift(T value, ThriftCodec thriftCodec) + throws TProtocolException + { + TMemoryBufferWriteOnly transport = new TMemoryBufferWriteOnly(1024); + TBinaryProtocol protocol = new TBinaryProtocol(transport); + try { + thriftCodec.write(value, protocol); + return transport.getBytes(); + } + catch (Exception e) { + throw new TProtocolException("Can not serialize the data", e); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftServerInfoService.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftServerInfoService.java index a9abba86bc295..115473cea3d08 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftServerInfoService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftServerInfoService.java @@ -18,8 +18,7 @@ import com.facebook.presto.server.GracefulShutdownHandler; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.presto.spi.NodeState.ACTIVE; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftTaskService.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftTaskService.java index 528f3795c5470..aa389b56b9a88 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftTaskService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/ThriftTaskService.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.server.thrift; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftMethod; import com.facebook.drift.annotations.ThriftService; import com.facebook.presto.execution.TaskId; @@ -25,9 +26,7 @@ import com.facebook.presto.server.ForAsyncRpc; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TransactionHandleThriftCodec.java b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TransactionHandleThriftCodec.java new file mode 100644 index 0000000000000..65a96ec5f863d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/server/thrift/TransactionHandleThriftCodec.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.thrift; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.drift.codec.CodecThriftType; +import com.facebook.drift.codec.metadata.ThriftType; +import com.facebook.drift.protocol.TProtocolReader; +import com.facebook.drift.protocol.TProtocolWriter; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +import javax.inject.Inject; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +public class TransactionHandleThriftCodec + extends AbstractTypedThriftCodec +{ + private static final ThriftType THRIFT_TYPE = createThriftType(ConnectorTransactionHandle.class); + private final ConnectorCodecManager connectorCodecManager; + + @Inject + public TransactionHandleThriftCodec(HandleResolver handleResolver, ConnectorCodecManager connectorCodecManager, JsonCodec jsonCodec) + { + super(ConnectorTransactionHandle.class, + requireNonNull(jsonCodec, "jsonCodec is null"), + requireNonNull(handleResolver, "handleResolver is null")::getId, + handleResolver::getTransactionHandleClass); + this.connectorCodecManager = requireNonNull(connectorCodecManager, "connectorThriftCodecManager is null"); + } + + @CodecThriftType + public static ThriftType getThriftType() + { + return THRIFT_TYPE; + } + + @Override + public ThriftType getType() + { + return THRIFT_TYPE; + } + + @Override + public ConnectorTransactionHandle readConcreteValue(String connectorId, TProtocolReader reader) + throws Exception + { + ByteBuffer byteBuffer = reader.readBinary(); + assert (byteBuffer.position() == 0); + byte[] bytes = byteBuffer.array(); + return connectorCodecManager.getTransactionHandleCodec(connectorId).map(codec -> codec.deserialize(bytes)).orElse(null); + } + + @Override + public void writeConcreteValue(String connectorId, ConnectorTransactionHandle value, TProtocolWriter writer) + throws Exception + { + requireNonNull(value, "value is null"); + writer.writeBinary(ByteBuffer.wrap(connectorCodecManager.getTransactionHandleCodec(connectorId).map(codec -> codec.serialize(value)).orElseThrow(() -> new IllegalArgumentException("Can not serialize " + value)))); + } + + @Override + public boolean isThriftCodecAvailable(String connectorId) + { + return connectorCodecManager.getTransactionHandleCodec(connectorId).isPresent(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/JavaWorkerSessionPropertyProvider.java b/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/JavaWorkerSessionPropertyProvider.java index d53339750ed33..0b304c37bdb46 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/JavaWorkerSessionPropertyProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/JavaWorkerSessionPropertyProvider.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sessionpropertyproviders; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; @@ -21,7 +22,6 @@ import com.facebook.presto.sql.analyzer.JavaFeaturesConfig; import com.google.common.collect.ImmutableList; import com.google.inject.Inject; -import io.airlift.units.DataSize; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/NativeWorkerSessionPropertyProvider.java b/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/NativeWorkerSessionPropertyProvider.java index 82e35430b5957..30d6818998cf5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/NativeWorkerSessionPropertyProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sessionpropertyproviders/NativeWorkerSessionPropertyProvider.java @@ -28,7 +28,6 @@ import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; import static java.util.Objects.requireNonNull; -@Deprecated public class NativeWorkerSessionPropertyProvider implements WorkerSessionPropertyProvider { @@ -53,7 +52,12 @@ public class NativeWorkerSessionPropertyProvider public static final String NATIVE_DEBUG_DISABLE_EXPRESSION_WITH_MEMOIZATION = "native_debug_disable_expression_with_memoization"; public static final String NATIVE_DEBUG_DISABLE_EXPRESSION_WITH_LAZY_INPUTS = "native_debug_disable_expression_with_lazy_inputs"; public static final String NATIVE_DEBUG_MEMORY_POOL_NAME_REGEX = "native_debug_memory_pool_name_regex"; + public static final String NATIVE_DEBUG_MEMORY_POOL_WARN_THRESHOLD_BYTES = "native_debug_memory_pool_warn_threshold_bytes"; public static final String NATIVE_SELECTIVE_NIMBLE_READER_ENABLED = "native_selective_nimble_reader_enabled"; + public static final String NATIVE_ROW_SIZE_TRACKING_ENABLED = "row_size_tracking_enabled"; + public static final String NATIVE_PREFERRED_OUTPUT_BATCH_BYTES = "preferred_output_batch_bytes"; + public static final String NATIVE_PREFERRED_OUTPUT_BATCH_ROWS = "preferred_output_batch_rows"; + public static final String NATIVE_MAX_OUTPUT_BATCH_ROWS = "max_output_batch_rows"; public static final String NATIVE_MAX_PARTIAL_AGGREGATION_MEMORY = "native_max_partial_aggregation_memory"; public static final String NATIVE_MAX_EXTENDED_PARTIAL_AGGREGATION_MEMORY = "native_max_extended_partial_aggregation_memory"; public static final String NATIVE_MAX_SPILL_BYTES = "native_max_spill_bytes"; @@ -61,7 +65,7 @@ public class NativeWorkerSessionPropertyProvider public static final String NATIVE_MAX_OUTPUT_BUFFER_SIZE = "native_max_output_buffer_size"; public static final String NATIVE_QUERY_TRACE_ENABLED = "native_query_trace_enabled"; public static final String NATIVE_QUERY_TRACE_DIR = "native_query_trace_dir"; - public static final String NATIVE_QUERY_TRACE_NODE_IDS = "native_query_trace_node_ids"; + public static final String NATIVE_QUERY_TRACE_NODE_ID = "native_query_trace_node_id"; public static final String NATIVE_QUERY_TRACE_MAX_BYTES = "native_query_trace_max_bytes"; public static final String NATIVE_QUERY_TRACE_FRAGMENT_ID = "native_query_trace_fragment_id"; public static final String NATIVE_QUERY_TRACE_SHARD_ID = "native_query_trace_shard_id"; @@ -78,6 +82,15 @@ public class NativeWorkerSessionPropertyProvider public static final String NATIVE_TABLE_SCAN_SCALE_UP_MEMORY_USAGE_RATIO = "native_table_scan_scale_up_memory_usage_ratio"; public static final String NATIVE_STREAMING_AGGREGATION_MIN_OUTPUT_BATCH_ROWS = "native_streaming_aggregation_min_output_batch_rows"; public static final String NATIVE_REQUEST_DATA_SIZES_MAX_WAIT_SEC = "native_request_data_sizes_max_wait_sec"; + public static final String NATIVE_QUERY_MEMORY_RECLAIMER_PRIORITY = "native_query_memory_reclaimer_priority"; + public static final String NATIVE_MAX_NUM_SPLITS_LISTENED_TO = "native_max_num_splits_listened_to"; + public static final String NATIVE_INDEX_LOOKUP_JOIN_MAX_PREFETCH_BATCHES = "native_index_lookup_join_max_prefetch_batches"; + public static final String NATIVE_INDEX_LOOKUP_JOIN_SPLIT_OUTPUT = "native_index_lookup_join_split_output"; + public static final String NATIVE_UNNEST_SPLIT_OUTPUT = "native_unnest_split_output"; + public static final String NATIVE_USE_VELOX_GEOSPATIAL_JOIN = "native_use_velox_geospatial_join"; + public static final String NATIVE_AGGREGATION_COMPACTION_BYTES_THRESHOLD = "native_aggregation_compaction_bytes_threshold"; + public static final String NATIVE_AGGREGATION_COMPACTION_UNUSED_MEMORY_RATIO = "native_aggregation_compaction_unused_memory_ratio"; + private final List> sessionProperties; @Inject @@ -149,7 +162,7 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) longProperty( NATIVE_WRITER_FLUSH_THRESHOLD_BYTES, "Native Execution only. Minimum memory footprint size required to reclaim memory from a file " + - "writer by flushing its buffered data to disk.", + "writer by flushing its buffered data to disk.", 96L << 20, false), booleanProperty( @@ -211,6 +224,15 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) " string means no match for all.", "", true), + stringProperty( + NATIVE_DEBUG_MEMORY_POOL_WARN_THRESHOLD_BYTES, + "Warning threshold in bytes for debug memory pools. When set to a " + + "non-zero value, a warning will be logged once per memory pool when " + + "allocations cause the pool to exceed this threshold. This is useful for " + + "identifying memory usage patterns during debugging. A value of " + + "0 means no warning threshold is enforced.", + "0B", + true), booleanProperty( NATIVE_SELECTIVE_NIMBLE_READER_ENABLED, "Temporary flag to control whether selective Nimble reader should be " + @@ -218,6 +240,29 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) "reader is fully rolled out.", false, !nativeExecution), + booleanProperty( + NATIVE_ROW_SIZE_TRACKING_ENABLED, + "Flag to control whether row size tracking should be enabled as a fallback " + + "for reader row size estimates.", + true, + !nativeExecution), + longProperty( + NATIVE_PREFERRED_OUTPUT_BATCH_BYTES, + "Prefered memory budget for operator output batches. " + + "Used in tandem with average row size estimates when available.", + 10L << 20, + !nativeExecution), + integerProperty( + NATIVE_PREFERRED_OUTPUT_BATCH_ROWS, + "Preferred row count per operator output batch. Used when average row size estimates are unknown.", + 1024, + !nativeExecution), + integerProperty( + NATIVE_MAX_OUTPUT_BATCH_ROWS, + "Upperbound for row count per output batch, used together with " + + "preferred_output_batch_bytes and average row size estimates.", + 10000, + !nativeExecution), longProperty( NATIVE_MAX_PARTIAL_AGGREGATION_MEMORY, "The max partial aggregation memory when data reduction is not optimal.", @@ -241,8 +286,8 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) "Base dir of a query to store tracing data.", "", !nativeExecution), - stringProperty(NATIVE_QUERY_TRACE_NODE_IDS, - "A comma-separated list of plan node ids whose input data will be traced. Empty string if only want to trace the query metadata.", + stringProperty(NATIVE_QUERY_TRACE_NODE_ID, + "The plan node id whose input data will be traced.", "", !nativeExecution), longProperty(NATIVE_QUERY_TRACE_MAX_BYTES, @@ -254,7 +299,7 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) "", !nativeExecution), stringProperty(NATIVE_QUERY_TRACE_FRAGMENT_ID, - "The fragment id of the traced task.", + "The fragment id of the traced task.", "", !nativeExecution), stringProperty(NATIVE_QUERY_TRACE_SHARD_ID, @@ -350,6 +395,61 @@ public NativeWorkerSessionPropertyProvider(FeaturesConfig featuresConfig) NATIVE_REQUEST_DATA_SIZES_MAX_WAIT_SEC, "Maximum wait time for exchange long poll requests in seconds.", 10, + !nativeExecution), + integerProperty( + NATIVE_QUERY_MEMORY_RECLAIMER_PRIORITY, + "Native Execution only. Priority of memory recliamer when deciding on memory pool to abort." + + "Lower value has higher priority and less likely to be choosen for memory pool abort", + 2147483647, + !nativeExecution), + integerProperty( + NATIVE_MAX_NUM_SPLITS_LISTENED_TO, + "Maximum number of splits to listen to per table scan node per worker.", + 0, + !nativeExecution), + + integerProperty( + NATIVE_INDEX_LOOKUP_JOIN_MAX_PREFETCH_BATCHES, + "Specifies the max number of input batches to prefetch to do index lookup ahead. " + + "If it is zero, then process one input batch at a time.", + 0, + !nativeExecution), + booleanProperty( + NATIVE_INDEX_LOOKUP_JOIN_SPLIT_OUTPUT, + "If this is true, then the index join operator might split output for each input " + + "batch based on the output batch size control. Otherwise, it tries to produce a " + + "single output for each input batch.", + true, + !nativeExecution), + booleanProperty( + NATIVE_UNNEST_SPLIT_OUTPUT, + "If this is true, then the unnest operator might split output for each input " + + "batch based on the output batch size control. Otherwise, it produces a single " + + "output for each input batch.", + true, + !nativeExecution), + booleanProperty( + NATIVE_USE_VELOX_GEOSPATIAL_JOIN, + "If this is true, then the protocol::SpatialJoinNode is converted to a " + + "velox::core::SpatialJoinNode. Otherwise, it is converted to a " + + "velox::core::NestedLoopJoinNode.", + true, + !nativeExecution), + longProperty( + NATIVE_AGGREGATION_COMPACTION_BYTES_THRESHOLD, + "Memory threshold in bytes for triggering string compaction during " + + "global aggregation. When total string storage exceeds this limit with " + + "high unused memory ratio, compaction is triggered to reclaim dead strings. " + + "Disabled by default (0). NOTE: Currently only applies to approx_most_frequent " + + "aggregate with StringView type during global aggregation.", + 0L, + !nativeExecution), + doubleProperty( + NATIVE_AGGREGATION_COMPACTION_UNUSED_MEMORY_RATIO, + "Ratio of unused (evicted) bytes to total bytes that triggers compaction. " + + "The value is in the range of [0, 1). NOTE: Currently only applies to approx_most_frequent " + + "aggregate with StringView type during global aggregation.", + 0.25, !nativeExecution)); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/AesSpillCipher.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/AesSpillCipher.java index 7386bf8a68ede..253c59d7d281f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/AesSpillCipher.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/AesSpillCipher.java @@ -33,8 +33,8 @@ final class AesSpillCipher implements SpillCipher { - // 256-bit AES CBC mode - private static final String CIPHER_NAME = "AES/CBC/PKCS5Padding"; + // 256-bit AES CTR mode + private static final String CIPHER_NAME = "AES/CTR/NoPadding"; private static final int KEY_BITS = 256; private SecretKey key; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileHolder.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileHolder.java index 35e320a3e0d86..92ba407300a6b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileHolder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileHolder.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.spiller; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.Closeable; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpiller.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpiller.java index f31eb09d1d89f..d833a75e37cd5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpiller.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpiller.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spiller; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.memory.context.LocalMemoryContext; import com.facebook.presto.operator.SpillContext; @@ -31,8 +32,6 @@ import io.airlift.slice.OutputStreamSliceOutput; import io.airlift.slice.SliceOutput; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.Closeable; import java.io.IOException; import java.io.InputStream; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java index 6d685b3cee0cf..dce184f249987 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/FileSingleStreamSpillerFactory.java @@ -28,9 +28,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.Inject; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; import java.io.IOException; import java.nio.file.DirectoryStream; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericPartitioningSpiller.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericPartitioningSpiller.java index 11f5a5d7a5ecb..c9388edae15c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericPartitioningSpiller.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericPartitioningSpiller.java @@ -24,10 +24,9 @@ import com.google.common.io.Closer; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.ThreadSafe; import it.unimi.dsi.fastutil.ints.IntArrayList; -import javax.annotation.concurrent.ThreadSafe; - import java.io.IOException; import java.util.HashSet; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericSpiller.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericSpiller.java index 4f1d8dec4b3c9..a33350e3a94a1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericSpiller.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/GenericSpiller.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spiller; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; import com.facebook.presto.memory.context.AggregatedMemoryContext; @@ -21,8 +22,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java index ab3d46528f9b5..80559fb73baff 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalSpillContext.java @@ -15,8 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.operator.SpillContext; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import static com.google.common.base.Preconditions.checkState; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java index 78db62b3c8f88..35f78c1146a08 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/LocalTempStorage.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.storage.TempStorageHandle; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.IOException; import java.io.InputStream; @@ -122,8 +121,19 @@ public void remove(TempDataOperationContext context, TempStorageHandle handle) Files.delete(((LocalTempStorageHandle) handle).getFilePath()); } + @Override + public TempStorageHandle getRootDirectoryHandle() + { + return new LocalTempStorageHandle(getNextSpillPath()); + } + @Override public byte[] serializeHandle(TempStorageHandle storageHandle) + { + return LocalTempStorage.serializeHandleStatic(storageHandle); + } + + public static byte[] serializeHandleStatic(TempStorageHandle storageHandle) { URI uri = ((LocalTempStorageHandle) storageHandle).getFilePath().toUri(); return uri.toString().getBytes(UTF_8); @@ -131,6 +141,11 @@ public byte[] serializeHandle(TempStorageHandle storageHandle) @Override public TempStorageHandle deserialize(byte[] serializedStorageHandle) + { + return LocalTempStorage.deserializeStatic(serializedStorageHandle); + } + + public static LocalTempStorageHandle deserializeStatic(byte[] serializedStorageHandle) { String uriString = new String(serializedStorageHandle, UTF_8); try { @@ -193,7 +208,7 @@ private boolean hasEnoughDiskSpace(Path path) } } - private static class LocalTempStorageHandle + public static class LocalTempStorageHandle implements TempStorageHandle { private final Path filePath; @@ -208,6 +223,12 @@ public Path getFilePath() return filePath; } + @Override + public String getPathAsString() + { + return filePath.toString(); + } + @Override public String toString() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/NodeSpillConfig.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/NodeSpillConfig.java index 4c7816a428b07..1e8746c2076d1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/NodeSpillConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/NodeSpillConfig.java @@ -14,10 +14,9 @@ package com.facebook.presto.spiller; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; -import io.airlift.units.DataSize; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class NodeSpillConfig { diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/SpillSpaceTracker.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/SpillSpaceTracker.java index 441f925075453..e2bea719ef512 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/SpillSpaceTracker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/SpillSpaceTracker.java @@ -13,19 +13,18 @@ */ package com.facebook.presto.spiller; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededSpillLimitException; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.weakref.jmx.Managed; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.ExceededSpillLimitException.exceededLocalLimit; import static com.facebook.presto.operator.Operator.NOT_BLOCKED; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.succinctBytes; import static java.util.Objects.requireNonNull; @ThreadSafe diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpiller.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpiller.java index 49634921a69fc..c26dc42454e84 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpiller.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpiller.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.spiller; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.io.DataOutput; @@ -35,8 +36,6 @@ import com.google.common.util.concurrent.ListeningExecutorService; import io.airlift.slice.InputStreamSliceInput; -import javax.annotation.concurrent.NotThreadSafe; - import java.io.Closeable; import java.io.IOException; import java.io.InputStream; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpillerFactory.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpillerFactory.java index c5eceeb68203e..7f4266d79f4c9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpillerFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageSingleStreamSpillerFactory.java @@ -26,8 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.Inject; - -import javax.annotation.PreDestroy; +import jakarta.annotation.PreDestroy; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageStandaloneSpillerFactory.java b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageStandaloneSpillerFactory.java index fc13d4ebf45c0..b52fe088f43f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageStandaloneSpillerFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/spiller/TempStorageStandaloneSpillerFactory.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.storage.TempStorage; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.storage.TempStorageManager; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java index 69cf25fe60273..f7f6afc493848 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/CloseableSplitSourceProvider.java @@ -15,11 +15,11 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.io.Closeable; import java.util.ArrayList; @@ -33,14 +33,14 @@ public class CloseableSplitSourceProvider { private static final Logger log = Logger.get(CloseableSplitSourceProvider.class); - private final SplitSourceProvider delegate; + private final SplitManager delegate; @GuardedBy("this") private List splitSources = new ArrayList<>(); @GuardedBy("this") private boolean closed; - public CloseableSplitSourceProvider(SplitSourceProvider delegate) + public CloseableSplitSourceProvider(SplitManager delegate) { this.delegate = requireNonNull(delegate, "delegate is null"); } @@ -54,6 +54,15 @@ public synchronized SplitSource getSplits(Session session, TableHandle tableHand return splitSource; } + @Override + public synchronized SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle) + { + checkState(!closed, "split source provider is closed"); + SplitSource splitSource = delegate.getSplitsForTableFunction(session, tableFunctionHandle); + splitSources.add(splitSource); + return splitSource; + } + @Override public synchronized void close() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java index aee08454aefc9..1f48983f916f7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkManager.java @@ -14,12 +14,17 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorMergeSink; import com.facebook.presto.spi.ConnectorPageSink; import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.PageSinkContext; +import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import java.util.concurrent.ConcurrentHashMap; @@ -46,20 +51,47 @@ public void removeConnectorPageSinkProvider(ConnectorId connectorId) pageSinkProviders.remove(connectorId); } - @Override - public ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle, PageSinkContext pageSinkContext) + public ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle, PageSinkContext pageSinkContext, RuntimeStats runtimeStats) { // assumes connectorId and catalog are the same - ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getConnectorId()); + ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getConnectorId(), runtimeStats); + return providerFor(tableHandle.getConnectorId()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle(), pageSinkContext); + } + + public ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle, PageSinkContext pageSinkContext, RuntimeStats runtimeStats) + { + // assumes connectorId and catalog are the same + ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getConnectorId(), runtimeStats); return providerFor(tableHandle.getConnectorId()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle(), pageSinkContext); } + @Override + public ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle, PageSinkContext pageSinkContext) + { + return createPageSink(session, tableHandle, pageSinkContext, null); + } + @Override public ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle, PageSinkContext pageSinkContext) + { + return createPageSink(session, tableHandle, pageSinkContext, null); + } + + @Override + public ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle) { // assumes connectorId and catalog are the same + TableHandle tableHandle = mergeHandle.getTableHandle(); ConnectorSession connectorSession = session.toConnectorSession(tableHandle.getConnectorId()); - return providerFor(tableHandle.getConnectorId()).createPageSink(tableHandle.getTransactionHandle(), connectorSession, tableHandle.getConnectorHandle(), pageSinkContext); + return providerFor(tableHandle.getConnectorId()).createMergeSink(tableHandle.getTransaction(), connectorSession, mergeHandle.getConnectorMergeTableHandle()); + } + + @Override + public ConnectorPageSink createPageSink(Session session, DistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext) + { + // assumes connectorId and catalog are the same + ConnectorSession connectorSession = session.toConnectorSession(procedureHandle.getConnectorId()); + return providerFor(procedureHandle.getConnectorId()).createPageSink(procedureHandle.getTransactionHandle(), connectorSession, procedureHandle.getConnectorHandle(), pageSinkContext); } private ConnectorPageSinkProvider providerFor(ConnectorId connectorId) diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java index 3e46127870194..e042d5c509207 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/PageSinkProvider.java @@ -14,9 +14,12 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.DistributedProcedureHandle; import com.facebook.presto.metadata.InsertTableHandle; import com.facebook.presto.metadata.OutputTableHandle; +import com.facebook.presto.spi.ConnectorMergeSink; import com.facebook.presto.spi.ConnectorPageSink; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.PageSinkContext; public interface PageSinkProvider @@ -24,4 +27,11 @@ public interface PageSinkProvider ConnectorPageSink createPageSink(Session session, OutputTableHandle tableHandle, PageSinkContext pageSinkContext); ConnectorPageSink createPageSink(Session session, InsertTableHandle tableHandle, PageSinkContext pageSinkContext); + + /* + * Used to write the result of SQL MERGE to an existing table + */ + ConnectorMergeSink createMergeSink(Session session, MergeHandle mergeHandle); + + ConnectorPageSink createPageSink(Session session, DistributedProcedureHandle procedureHandle, PageSinkContext pageSinkContext); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/RemoteSplit.java b/presto-main-base/src/main/java/com/facebook/presto/split/RemoteSplit.java index c91d7f1e38395..b87429a880615 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/RemoteSplit.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/RemoteSplit.java @@ -13,6 +13,9 @@ */ package com.facebook.presto.split; +import com.facebook.drift.annotations.ThriftConstructor; +import com.facebook.drift.annotations.ThriftField; +import com.facebook.drift.annotations.ThriftStruct; import com.facebook.presto.execution.Location; import com.facebook.presto.execution.TaskId; import com.facebook.presto.spi.ConnectorSplit; @@ -29,6 +32,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; +@ThriftStruct public class RemoteSplit implements ConnectorSplit { @@ -36,6 +40,7 @@ public class RemoteSplit private final TaskId remoteSourceTaskId; @JsonCreator + @ThriftConstructor public RemoteSplit(@JsonProperty("location") Location location, @JsonProperty("remoteSourceTaskId") TaskId remoteSourceTaskId) { this.location = requireNonNull(location, "location is null"); @@ -43,12 +48,14 @@ public RemoteSplit(@JsonProperty("location") Location location, @JsonProperty("r } @JsonProperty + @ThriftField(1) public Location getLocation() { return location; } @JsonProperty + @ThriftField(2) public TaskId getRemoteSourceTaskId() { return remoteSourceTaskId; diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SampledSplitSource.java b/presto-main-base/src/main/java/com/facebook/presto/split/SampledSplitSource.java index 02c65e7198e6b..d83e3f0e47451 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SampledSplitSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SampledSplitSource.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.concurrent.ThreadLocalRandom; diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java index adb189379ed36..d334e5a5e6a70 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitManager.java @@ -18,6 +18,7 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.metadata.TableLayoutResult; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorSession; @@ -29,8 +30,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingContext; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; @@ -105,4 +105,17 @@ private ConnectorSplitManager getConnectorSplitManager(ConnectorId connectorId) checkArgument(result != null, "No split manager for connector '%s'", connectorId); return result; } + + public SplitSource getSplitsForTableFunction(Session session, TableFunctionHandle function) + { + ConnectorId connectorId = function.getConnectorId(); + ConnectorSplitManager splitManager = splitManagers.get(connectorId); + + ConnectorSplitSource source = splitManager.getSplits( + function.getTransactionHandle(), + session.toConnectorSession(connectorId), + function.getFunctionHandle()); + + return new ConnectorAwareSplitSource(connectorId, function.getTransactionHandle(), source); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java index 617fba7093613..30b54174c27b6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/split/SplitSourceProvider.java @@ -14,6 +14,7 @@ package com.facebook.presto.split; import com.facebook.presto.Session; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy; @@ -21,4 +22,5 @@ public interface SplitSourceProvider { SplitSource getSplits(Session session, TableHandle tableHandle, SplitSchedulingStrategy splitSchedulingStrategy, WarningCollector warningCollector); + SplitSource getSplits(Session session, TableFunctionHandle tableFunctionHandle); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java index eea0746d1841d..dcc80d61ab1ad 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/MaterializedViewUtils.java @@ -26,22 +26,27 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.planner.ExpressionDomainTranslator; import com.facebook.presto.sql.planner.LiteralEncoder; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; import com.facebook.presto.sql.tree.BooleanLiteral; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.IsNullPredicate; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -51,15 +56,18 @@ import java.util.Set; import java.util.Stack; +import static com.facebook.presto.SystemSessionProperties.isLegacyMaterializedViews; import static com.facebook.presto.common.predicate.TupleDomain.extractFixedValues; import static com.facebook.presto.common.type.StandardTypes.HYPER_LOG_LOG; import static com.facebook.presto.common.type.StandardTypes.VARBINARY; +import static com.facebook.presto.sql.ExpressionUtils.combineDisjuncts; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Operator.DIVIDE; import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL; import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Operator.AND; import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Operator.OR; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -88,9 +96,17 @@ private MaterializedViewUtils() {} public static Session buildOwnerSession(Session session, Optional owner, SessionPropertyManager sessionPropertyManager, String catalog, String schema) { + // When legacy_materialized_views=false, owner must be present for REFRESH operations + if (!isLegacyMaterializedViews(session) && !owner.isPresent()) { + throw new IllegalStateException( + "Materialized view owner is required when legacy_materialized_views=false. " + + "This indicates a materialized view created before the security mode feature was added. " + + "Set session property legacy_materialized_views=true to refresh this view, or drop and recreate the view."); + } + Identity identity = getOwnerIdentity(owner, session); - return Session.builder(sessionPropertyManager) + Session.SessionBuilder builder = Session.builder(sessionPropertyManager) .setQueryId(session.getQueryId()) .setTransactionId(session.getTransactionId().orElse(null)) .setIdentity(identity) @@ -102,14 +118,19 @@ public static Session buildOwnerSession(Session session, Optional owner, .setRemoteUserAddress(session.getRemoteUserAddress().orElse(null)) .setUserAgent(session.getUserAgent().orElse(null)) .setClientInfo(session.getClientInfo().orElse(null)) - .setStartTime(session.getStartTime()) - .build(); + .setStartTime(session.getStartTime()); + + for (Map.Entry property : session.getSystemProperties().entrySet()) { + builder.setSystemProperty(property.getKey(), property.getValue()); + } + + return builder.build(); } public static Identity getOwnerIdentity(Optional owner, Session session) { if (owner.isPresent() && !owner.get().equals(session.getIdentity().getUser())) { - return new Identity(owner.get(), Optional.empty(), session.getIdentity().getExtraCredentials()); + return new Identity(owner.get(), Optional.empty(), ImmutableMap.of(), session.getIdentity().getExtraCredentials(), ImmutableMap.of(), Optional.empty(), session.getIdentity().getReasonForSelect(), ImmutableList.of()); } return session.getIdentity(); } @@ -314,4 +335,79 @@ public boolean validate(Identifier baseTableColumn, Map return baseToViewColumnMap.containsKey(new Cast(new FunctionCall(APPROX_SET, ImmutableList.of(baseTableColumn)), VARBINARY)); } } + + /** + * Generate WHERE predicates for missing partitions from MaterializedDataPredicates. + * Used for auto-refresh of materialized views without explicit WHERE clause. + */ + public static Map generatePredicatesForMissingPartitions( + Map missingPartitionsPerTable, + Metadata metadata) + { + Map predicates = new HashMap<>(); + + for (Map.Entry entry : + missingPartitionsPerTable.entrySet()) { + SchemaTableName tableName = entry.getKey(); + MaterializedViewStatus.MaterializedDataPredicates missingPartitions = entry.getValue(); + + Expression predicate = convertMaterializedDataPredicatesToExpression(missingPartitions, metadata); + + predicates.put(tableName, predicate); + } + + return predicates; + } + + /** + * Convert MaterializedDataPredicates to a SQL Expression tree. + * Builds an OR expression of partition predicates, where each partition is an AND expression of column filters. + */ + public static Expression convertMaterializedDataPredicatesToExpression( + MaterializedViewStatus.MaterializedDataPredicates predicates, + Metadata metadata) + { + List columnNames = predicates.getColumnNames(); + List> predicateDisjuncts = predicates.getPredicateDisjuncts(); + + ExpressionDomainTranslator translator = new ExpressionDomainTranslator( + new LiteralEncoder(metadata.getBlockEncodingSerde())); + + List disjuncts = new ArrayList<>(); + + for (TupleDomain tupleDomain : predicateDisjuncts) { + checkState(!tupleDomain.isAll(), "TupleDomain.isAll() should not appear in MaterializedDataPredicates"); + if (tupleDomain.isNone()) { + continue; + } + + Expression conjunction = translator.toPredicate(tupleDomain); + conjunction = convertSymbolReferencesToIdentifiers(conjunction); + + disjuncts.add(conjunction); + } + + if (disjuncts.isEmpty()) { + throw new IllegalStateException("No predicates generated for missing partitions"); + } + + if (disjuncts.size() == 1) { + return disjuncts.get(0); + } + else { + return combineDisjuncts(disjuncts); + } + } + + private static Expression convertSymbolReferencesToIdentifiers(Expression expression) + { + return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { + @Override + public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter treeRewriter) + { + return new Identifier(node.getName()); + } + }, expression); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/Optimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/Optimizer.java index 7356b7d464acf..7a0337b134c3a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/Optimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/Optimizer.java @@ -47,6 +47,8 @@ import static com.facebook.presto.SystemSessionProperties.isPrintStatsForNonJoinQuery; import static com.facebook.presto.SystemSessionProperties.isVerboseOptimizerInfoEnabled; import static com.facebook.presto.SystemSessionProperties.isVerboseOptimizerResults; +import static com.facebook.presto.common.RuntimeMetricName.VALIDATE_FINAL_PLAN_TIME_NANOS; +import static com.facebook.presto.common.RuntimeMetricName.VALIDATE_INTERMEDIATE_PLAN_TIME_NANOS; import static com.facebook.presto.common.RuntimeUnit.NANO; import static com.facebook.presto.spi.StandardErrorCode.QUERY_PLANNING_TIMEOUT; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; @@ -100,7 +102,7 @@ public Optimizer( public Plan validateAndOptimizePlan(PlanNode root, PlanStage stage) { - planChecker.validateIntermediatePlan(root, session, metadata, warningCollector); + validateIntermediatePlanWithRuntimeStats(root); boolean enableVerboseRuntimeStats = SystemSessionProperties.isVerboseRuntimeStatsEnabled(session); if (stage.ordinal() >= OPTIMIZED.ordinal()) { @@ -120,16 +122,27 @@ public Plan validateAndOptimizePlan(PlanNode root, PlanStage stage) root = optimizerResult.getPlanNode(); } } - if (stage.ordinal() >= OPTIMIZED_AND_VALIDATED.ordinal()) { // make sure we produce a valid plan after optimizations run. This is mainly to catch programming errors - planChecker.validateFinalPlan(root, session, metadata, warningCollector); + validateFinalPlanWithRuntimeStats(root); } TypeProvider types = TypeProvider.viewOf(variableAllocator.getVariables()); return new Plan(root, types, computeStats(root, types)); } + private void validateIntermediatePlanWithRuntimeStats(PlanNode root) + { + session.getRuntimeStats().recordWallAndCpuTime(VALIDATE_INTERMEDIATE_PLAN_TIME_NANOS, + () -> planChecker.validateIntermediatePlan(root, session, metadata, warningCollector)); + } + + private void validateFinalPlanWithRuntimeStats(PlanNode root) + { + session.getRuntimeStats().recordWallAndCpuTime(VALIDATE_FINAL_PLAN_TIME_NANOS, + () -> planChecker.validateFinalPlan(root, session, metadata, warningCollector)); + } + private StatsAndCosts computeStats(PlanNode root, TypeProvider types) { if (explain || isPrintStatsForNonJoinQuery(session) || diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/Serialization.java b/presto-main-base/src/main/java/com/facebook/presto/sql/Serialization.java index 85ef11e4cd0e0..0eac23cf5de94 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/Serialization.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/Serialization.java @@ -25,8 +25,7 @@ import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.KeyDeserializer; import com.fasterxml.jackson.databind.SerializerProvider; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.IOException; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/SqlEnvironmentConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/SqlEnvironmentConfig.java index bcb02475c52ba..04c2c1b4779f0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/SqlEnvironmentConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/SqlEnvironmentConfig.java @@ -16,9 +16,8 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.presto.common.type.TimeZoneKey; - -import javax.annotation.Nullable; -import javax.validation.constraints.NotNull; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java b/presto-main-base/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java index 7108270078d07..dd46653509bfa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/TemporaryTableUtil.java @@ -139,7 +139,11 @@ public static TableScanNode createTemporaryTableScan( cteId.map(CteMaterializationInfo::new)); } - public static Map assignTemporaryTableColumnNames(Collection outputVariables, + public static Map assignTemporaryTableColumnNames( + Metadata metadata, + Session session, + String catalogName, + Collection outputVariables, Collection constantPartitioningVariables) { ImmutableMap.Builder result = ImmutableMap.builder(); @@ -147,7 +151,7 @@ public static Map assignTemporaryTa for (VariableReferenceExpression outputVariable : concat(outputVariables, constantPartitioningVariables)) { String columnName = format("_c%d_%s", column, outputVariable.getName()); result.put(outputVariable, ColumnMetadata.builder() - .setName(columnName) + .setName(metadata.normalizeIdentifier(session, catalogName, columnName)) .setType(outputVariable.getType()) .build()); column++; @@ -155,9 +159,13 @@ public static Map assignTemporaryTa return result.build(); } - public static Map assignTemporaryTableColumnNames(Collection outputVariables) + public static Map assignTemporaryTableColumnNames( + Metadata metadata, + Session session, + String catalogName, + Collection outputVariables) { - return assignTemporaryTableColumnNames(outputVariables, Collections.emptyList()); + return assignTemporaryTableColumnNames(metadata, session, catalogName, outputVariables, Collections.emptyList()); } public static BasePlanFragmenter.PartitioningVariableAssignments assignPartitioningVariables(VariableAllocator variableAllocator, @@ -194,7 +202,7 @@ public static TableFinishNode createTemporaryTableWriteWithoutExchanges( Optional cteId) { SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); - TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); + TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName, Optional.empty()); List outputColumnNames = outputs.stream() .map(variableToColumnMap::get) .map(ColumnMetadata::getName) @@ -292,7 +300,7 @@ public static TableFinishNode createTemporaryTableWriteWithExchanges( .collect(Collectors.toSet()); SchemaTableName schemaTableName = metadata.getTableMetadata(session, tableHandle).getTable(); - TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName); + TableWriterNode.InsertReference insertReference = new TableWriterNode.InsertReference(tableHandle, schemaTableName, Optional.empty()); PartitioningScheme partitioningScheme = new PartitioningScheme( Partitioning.create(partitioningHandle, partitioningVariables), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java index aff01ca92ec18..edf2341a77817 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AggregationAnalyzer.java @@ -68,8 +68,7 @@ import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.collect.ImmutableList; import com.google.common.collect.Multimap; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Collection; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java index d4cb14ac2ac68..10d6c37bdec79 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/Analyzer.java @@ -14,8 +14,10 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.parser.SqlParser; @@ -36,6 +38,7 @@ import static com.facebook.presto.SystemSessionProperties.isCheckAccessControlOnUtilizedColumnsOnly; import static com.facebook.presto.SystemSessionProperties.isCheckAccessControlWithSubfields; +import static com.facebook.presto.SystemSessionProperties.isLegacyMaterializedViews; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractAggregateFunctions; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExpressions; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.extractExternalFunctions; @@ -43,7 +46,6 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CANNOT_HAVE_AGGREGATIONS_WINDOWS_OR_GROUPING; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.UtilizedColumnsAnalyzer.analyzeForUtilizedColumns; -import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static java.util.Objects.requireNonNull; public class Analyzer @@ -58,6 +60,7 @@ public class Analyzer private final WarningCollector warningCollector; private final MetadataExtractor metadataExtractor; private final String query; + private final ViewDefinitionReferences viewDefinitionReferences; public Analyzer( Session session, @@ -68,9 +71,10 @@ public Analyzer( List parameters, Map, Expression> parameterLookup, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { - this(session, metadata, sqlParser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, Optional.empty(), query); + this(session, metadata, sqlParser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, Optional.empty(), query, viewDefinitionReferences); } public Analyzer( @@ -83,7 +87,8 @@ public Analyzer( Map, Expression> parameterLookup, WarningCollector warningCollector, Optional metadataExtractorExecutor, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); @@ -96,34 +101,36 @@ public Analyzer( requireNonNull(metadataExtractorExecutor, "metadataExtractorExecutor is null"); this.metadataExtractor = new MetadataExtractor(session, metadata, metadataExtractorExecutor, sqlParser, warningCollector); this.query = requireNonNull(query, "query is null"); + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } - public Analysis analyze(Statement statement) + public Analysis analyzeSemantic(Statement statement, boolean isDescribe) { - return analyze(statement, false); + return analyzeSemantic(statement, Optional.empty(), isDescribe); } - // TODO: Remove this method once all calls are moved to analyzer interface, as this call is overloaded with analyze and columnCheckPermissions - public Analysis analyze(Statement statement, boolean isDescribe) + public Analysis analyzeSemantic( + Statement statement, + Optional procedureName, + boolean isDescribe) { - Analysis analysis = analyzeSemantic(statement, isDescribe); - checkAccessPermissions(analysis.getAccessControlReferences(), query); - return analysis; - } - - public Analysis analyzeSemantic(Statement statement, boolean isDescribe) - { - Statement rewrittenStatement = StatementRewrite.rewrite(session, metadata, sqlParser, queryExplainer, statement, parameters, parameterLookup, accessControl, warningCollector, query); - Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, isDescribe); + Statement rewrittenStatement = StatementRewrite.rewrite(session, metadata, sqlParser, queryExplainer, statement, parameters, parameterLookup, accessControl, warningCollector, query, viewDefinitionReferences); + Analysis analysis = new Analysis(rewrittenStatement, parameterLookup, isDescribe, viewDefinitionReferences); metadataExtractor.populateMetadataHandle(session, rewrittenStatement, analysis.getMetadataHandle()); + analysis.setProcedureName(procedureName); StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); analyzer.analyze(rewrittenStatement, Optional.empty()); analyzeForUtilizedColumns(analysis, analysis.getStatement(), warningCollector); - analysis.populateTableColumnAndSubfieldReferencesForAccessControl(isCheckAccessControlOnUtilizedColumnsOnly(session), isCheckAccessControlWithSubfields(session)); + analysis.populateTableColumnAndSubfieldReferencesForAccessControl(isCheckAccessControlOnUtilizedColumnsOnly(session), isCheckAccessControlWithSubfields(session), isLegacyMaterializedViews(session)); return analysis; } + public ViewDefinitionReferences getViewDefinitionReferences() + { + return viewDefinitionReferences; + } + static void verifyNoAggregateWindowOrGroupingFunctions( Map, FunctionHandle> functionHandles, FunctionAndTypeResolver functionAndTypeResolver, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AnalyzerProviderManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AnalyzerProviderManager.java index a7048b602b273..03dd11d355a0a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AnalyzerProviderManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/AnalyzerProviderManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.analyzer.AnalyzerProvider; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashMap; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java index 9308927b49c6f..aa2002e7cfd88 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryAnalyzer.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.QueryAnalysis; import com.facebook.presto.spi.analyzer.QueryAnalyzer; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.security.AccessControl; @@ -90,9 +91,13 @@ public QueryAnalysis analyze(AnalyzerContext analyzerContext, PreparedQuery prep parameterExtractor(builtInPreparedQuery.getStatement(), builtInPreparedQuery.getParameters()), session.getWarningCollector(), Optional.of(metadataExtractorExecutor), - analyzerContext.getQuery()); + analyzerContext.getQuery(), + new ViewDefinitionReferences()); - Analysis analysis = analyzer.analyzeSemantic(((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getStatement(), false); + Analysis analysis = analyzer.analyzeSemantic( + ((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getStatement(), + ((BuiltInQueryPreparer.BuiltInPreparedQuery) preparedQuery).getDistributedProcedureName(), + false); return new BuiltInQueryAnalysis(analysis); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparerProvider.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparerProvider.java index ab0abef0c9dc8..3e1ea38e8f632 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparerProvider.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/BuiltInQueryPreparerProvider.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.analyzer.QueryPreparer; import com.facebook.presto.spi.analyzer.QueryPreparerProvider; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java index d3ca87f4ac72e..a926c0243fdd2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionAnalyzer.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.Subfield; @@ -30,6 +29,7 @@ import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.common.type.TypeSignatureParameter; import com.facebook.presto.common.type.TypeUtils; import com.facebook.presto.common.type.TypeWithName; @@ -40,12 +40,14 @@ import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.DenyAllAccessControl; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.relational.FunctionResolution; @@ -119,8 +121,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import io.airlift.slice.SliceUtf8; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.LinkedHashMap; @@ -154,6 +155,7 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.spi.StandardErrorCode.OPERATOR_NOT_FOUND; +import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; import static com.facebook.presto.spi.StandardWarningCode.SEMANTIC_WARNING; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions; @@ -250,6 +252,8 @@ public class ExpressionAnalyzer // This contains types of variables referenced from outer scopes. private final Map, Type> outerScopeSymbolTypes; + private final List sourceFields = new ArrayList<>(); + private ExpressionAnalyzer( FunctionAndTypeResolver functionAndTypeResolver, Function statementAnalyzerFactory, @@ -381,6 +385,11 @@ public Multimap getTableColumnAndSubfieldReferenc return tableColumnAndSubfieldReferences; } + public List getSourceFields() + { + return sourceFields; + } + public Multimap getTableColumnAndSubfieldReferencesForAccessControl() { return tableColumnAndSubfieldReferencesForAccessControl; @@ -470,7 +479,7 @@ protected Type visitSymbolReference(SymbolReference node, StackableAstVisitorCon @Override protected Type visitIdentifier(Identifier node, StackableAstVisitorContext context) { - QualifiedName name = QualifiedName.of(node.getValue()); + QualifiedName name = QualifiedName.of(ImmutableList.of(new Identifier(node.getValue(), node.isDelimited()))); Optional resolvedField = context.getContext().getScope().tryResolveField(node, name); if (!resolvedField.isPresent() && outerScopeSymbolTypes.containsKey(NodeRef.of(node))) { return setExpressionType(node, outerScopeSymbolTypes.get(NodeRef.of(node))); @@ -496,6 +505,8 @@ private Type handleResolvedField(Expression node, FieldId fieldId, Field field, } } + sourceFields.add(field); + // If we found a direct column reference, and we will put it in tableColumnReferencesWithSubFields if (isTopMostReference(node, context)) { Optional tableName = field.getOriginTable(); @@ -1110,6 +1121,10 @@ else if (frame.getType() == GROUPS) { FunctionHandle function = resolveFunction(sessionFunctions, transactionId, node, argumentTypes, functionAndTypeResolver); FunctionMetadata functionMetadata = functionAndTypeResolver.getFunctionMetadata(function); + // Delegate function-specific validation to the FunctionNamespaceManager + // This allows function namespaces to perform custom validation + functionAndTypeResolver.validateFunctionCall(function, node.getArguments()); + if (node.getOrderBy().isPresent()) { for (SortItem sortItem : node.getOrderBy().get().getSortItems()) { Type sortKeyType = process(sortItem.getSortKey(), context); @@ -1119,6 +1134,25 @@ else if (frame.getType() == GROUPS) { } } + List arguments = functionMetadata.getArgumentTypes(); + String functionName = functionMetadata.getName().toString(); + + if (!argumentTypes.isEmpty() && "map".equals(arguments.get(0).getBase()) && + "map_filter".equalsIgnoreCase(functionMetadata.getName().getObjectName()) && + arguments.size() > 1 && node.getArguments().size() >= 2) { + Expression mapArg = node.getArguments().get(0); + Expression lambdaArg = node.getArguments().get(1); + + if (containsFeatures(mapArg) && lambdaArg instanceof LambdaExpression) { + LambdaExpression lambda = (LambdaExpression) lambdaArg; + if (lambda.getArguments().size() == 2 && isKeyOnlyMembershipFilter(lambda)) { + String warningMessage = createWarningMessage(node, + String.format("Function '%s' uses a lambda on large maps which is expensive. Consider using map_subset", functionName)); + warningCollector.add(new PrestoWarning(PERFORMANCE_WARNING, warningMessage)); + } + } + } + if (node.isIgnoreNulls() && node.getWindow().isPresent()) { if (!functionResolution.isWindowValueFunction(function)) { String warningMessage = createWarningMessage(node, "IGNORE NULLS is not used for aggregate and ranking window functions. This will cause queries to fail in future versions."); @@ -1178,6 +1212,100 @@ private String createWarningMessage(Node node, String message) } } + private boolean isKeyOnlyMembershipFilter(LambdaExpression lambda) + { + String valueArgName = lambda.getArguments().get(1).getName().getValue(); + Expression body = lambda.getBody(); + + if (expressionReferencesName(body, valueArgName)) { + return false; + } + + return isSimpleKeyEquality(body); + } + + private boolean expressionReferencesName(Expression expression, String name) + { + if (expression == null) { + return false; + } + if (expression instanceof Identifier) { + return ((Identifier) expression).getValue().equalsIgnoreCase(name); + } + if (expression instanceof ComparisonExpression) { + ComparisonExpression comp = (ComparisonExpression) expression; + return expressionReferencesName(comp.getLeft(), name) || expressionReferencesName(comp.getRight(), name); + } + if (expression instanceof LogicalBinaryExpression) { + LogicalBinaryExpression logical = (LogicalBinaryExpression) expression; + return expressionReferencesName(logical.getLeft(), name) || expressionReferencesName(logical.getRight(), name); + } + if (expression instanceof InPredicate) { + InPredicate inPred = (InPredicate) expression; + return expressionReferencesName(inPred.getValue(), name) || expressionReferencesName(inPred.getValueList(), name); + } + if (expression instanceof InListExpression) { + InListExpression inList = (InListExpression) expression; + for (Expression value : inList.getValues()) { + if (expressionReferencesName(value, name)) { + return true; + } + } + } + if (expression instanceof ArithmeticBinaryExpression) { + ArithmeticBinaryExpression arith = (ArithmeticBinaryExpression) expression; + return expressionReferencesName(arith.getLeft(), name) || expressionReferencesName(arith.getRight(), name); + } + if (expression instanceof FunctionCall) { + FunctionCall func = (FunctionCall) expression; + for (Expression arg : func.getArguments()) { + if (expressionReferencesName(arg, name)) { + return true; + } + } + } + // Literals don't reference any names + return false; + } + + private boolean containsFeatures(Expression expression) + { + if (expression instanceof Identifier) { + return ((Identifier) expression).getValue().toLowerCase().contains("features"); + } + if (expression instanceof SymbolReference) { + return ((SymbolReference) expression).getName().toLowerCase().contains("features"); + } + if (expression instanceof DereferenceExpression) { + DereferenceExpression deref = (DereferenceExpression) expression; + return containsFeatures(deref.getBase()) || deref.getField().getValue().toLowerCase().contains("features"); + } + return false; + } + + private boolean isSimpleKeyEquality(Expression expression) + { + if (expression instanceof ComparisonExpression) { + ComparisonExpression comparison = (ComparisonExpression) expression; + return comparison.getOperator() == ComparisonExpression.Operator.EQUAL; + } + if (expression instanceof InPredicate) { + return true; + } + if (expression instanceof LogicalBinaryExpression) { + LogicalBinaryExpression logical = (LogicalBinaryExpression) expression; + if (logical.getOperator() == LogicalBinaryExpression.Operator.OR) { + return isSimpleKeyEquality(logical.getLeft()) && isSimpleKeyEquality(logical.getRight()); + } + } + if (expression instanceof FunctionCall) { + FunctionCall func = (FunctionCall) expression; + String funcName = func.getName().toString(); + return funcName.equalsIgnoreCase("contains") || funcName.equalsIgnoreCase("presto.default.contains"); + } + return false; + } + private void analyzeFrameRangeOffset(Expression offsetValue, FrameBound.Type boundType, StackableAstVisitorContext context, Window window) { if (!window.getOrderBy().isPresent()) { @@ -1423,7 +1551,7 @@ else if (previousNode instanceof QuantifiedComparisonExpression) { else { scalarSubqueries.add(NodeRef.of(node)); } - + sourceFields.add(queryScope.getRelationType().getFieldByIndex(0)); Type type = getOnlyElement(queryScope.getRelationType().getVisibleFields()).getType(); return setExpressionType(node, type); } @@ -1513,7 +1641,8 @@ protected Type visitLambdaExpression(LambdaExpression node, StackableAstVisitorC fieldToLambdaArgumentDeclaration.putAll(context.getContext().getFieldToLambdaArgumentDeclaration()); } for (LambdaArgumentDeclaration lambdaArgument : lambdaArguments) { - ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, QualifiedName.of(lambdaArgument.getName().getValue())); + QualifiedName name = QualifiedName.of(ImmutableList.of(new Identifier(lambdaArgument.getName().getValue(), lambdaArgument.getName().isDelimited()))); + ResolvedField resolvedField = lambdaScope.resolveField(lambdaArgument, name); fieldToLambdaArgumentDeclaration.put(FieldId.from(resolvedField), lambdaArgument); } @@ -1893,7 +2022,7 @@ public static ExpressionAnalysis analyzeExpressions( { // expressions at this point can not have sub queries so deny all access checks // in the future, we will need a full access controller here to verify access to functions - Analysis analysis = new Analysis(null, parameters, isDescribe); + Analysis analysis = new Analysis(null, parameters, isDescribe, new ViewDefinitionReferences()); ExpressionAnalyzer analyzer = create(analysis, session, metadata, sqlParser, new DenyAllAccessControl(), types, warningCollector); for (Expression expression : expressions) { analyzer.analyze(expression, Scope.builder().withRelationType(RelationId.anonymous(), new RelationType()).build()); @@ -1962,6 +2091,8 @@ public static ExpressionAnalysis analyzeExpression( analyzer.getTableColumnAndSubfieldReferences(), analyzer.getTableColumnAndSubfieldReferencesForAccessControl()); + analysis.addExpressionFields(expression, analyzer.getSourceFields()); + return new ExpressionAnalysis( expressionTypes, expressionCoercions, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java index 7ad6060e5a90d..048e57ee7138d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.sql.analyzer; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.type.EnumType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeWithName; import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.ComparisonExpression; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 496e9b31ebaf9..34d047d0b7b84 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -17,38 +17,44 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; import com.facebook.airlift.configuration.LegacyConfig; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MaxDataSize; import com.facebook.presto.CompressionCodec; import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.resourceGroups.QueryType; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.FunctionMetadata; -import com.facebook.presto.sql.tree.CreateView; +import com.facebook.presto.spi.security.ViewSecurity; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import io.airlift.units.MaxDataSize; - -import javax.validation.constraints.AssertTrue; -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; +import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; import static com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy.LEGACY; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy.NONE; import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.ORDER_BY_CREATE_TIME; import static com.facebook.presto.sql.expressions.ExpressionOptimizerManager.DEFAULT_EXPRESSION_OPTIMIZER_NAME; -import static com.facebook.presto.sql.tree.CreateView.Security.DEFINER; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.stream.Collectors.joining; @DefunctConfig({ "resource-group-manager", @@ -97,6 +103,7 @@ public class FeaturesConfig private boolean cteFilterAndProjectionPushdownEnabled = true; private int cteHeuristicReplicationThreshold = 4; private int maxReorderedJoins = 9; + private int maxPrefixesCount = 100; private boolean useHistoryBasedPlanStatistics; private boolean trackHistoryBasedPlanStatistics; private boolean trackHistoryStatsFromFailedQuery = true; @@ -106,6 +113,8 @@ public class FeaturesConfig private String historyBasedOptimizerPlanCanonicalizationStrategies = "IGNORE_SAFE_CONSTANTS"; private boolean logPlansUsedInHistoryBasedOptimizer; private boolean enforceTimeoutForHBOQueryRegistration; + private boolean historyBasedOptimizerEstimateSizeUsingVariables; + private List queryTypesEnabledForHbo = ImmutableList.of(QueryType.SELECT, QueryType.INSERT); private boolean redistributeWrites; private boolean scaleWriters = true; private DataSize writerMinSize = new DataSize(32, MEGABYTE); @@ -185,6 +194,7 @@ public class FeaturesConfig private boolean listBuiltInFunctionsOnly = true; private boolean experimentalFunctionsEnabled; + private boolean useConnectorProvidedSerializationCodecs; private boolean optimizeCommonSubExpressions = true; private boolean preferDistributedUnion = true; private boolean optimizeNullsInJoin; @@ -194,7 +204,7 @@ public class FeaturesConfig private boolean treatLowConfidenceZeroEstimationAsUnknownEnabled; private boolean pushdownDereferenceEnabled; private boolean inlineSqlFunctions = true; - private boolean checkAccessControlOnUtilizedColumnsOnly; + private boolean checkAccessControlOnUtilizedColumnsOnly = true; private boolean checkAccessControlWithSubfields; private boolean skipRedundantSort = true; private boolean isAllowWindowOrderByLiterals = true; @@ -218,6 +228,10 @@ public class FeaturesConfig private boolean materializedViewDataConsistencyEnabled = true; private boolean materializedViewPartitionFilteringEnabled = true; private boolean queryOptimizationWithMaterializedViewEnabled; + private boolean legacyMaterializedViewRefresh = true; + private boolean allowLegacyMaterializedViewsToggle; + private boolean materializedViewAllowFullRefreshEnabled; + private MaterializedViewStaleReadBehavior materializedViewStaleReadBehavior = MaterializedViewStaleReadBehavior.USE_VIEW_QUERY; private AggregationIfToFilterRewriteStrategy aggregationIfToFilterRewriteStrategy = AggregationIfToFilterRewriteStrategy.DISABLED; private String analyzerType = "BUILTIN"; @@ -226,6 +240,8 @@ public class FeaturesConfig private boolean streamingForPartialAggregationEnabled; private boolean preferMergeJoinForSortedInputs; + private boolean preferSortMergeJoin; + private boolean isSortedExchangeEnabled; private boolean segmentedAggregationEnabled; private int maxStageCountForEagerScheduling = 25; @@ -235,6 +251,7 @@ public class FeaturesConfig private boolean pushRemoteExchangeThroughGroupId; private boolean isOptimizeMultipleApproxPercentileOnSameFieldEnabled = true; + private boolean isOptimizeMultipleApproxDistinctOnSameTypeEnabled; private boolean nativeExecutionEnabled; private boolean disableTimeStampWithTimeZoneForNative; private boolean disableIPAddressForNative; @@ -244,6 +261,7 @@ public class FeaturesConfig private boolean nativeEnforceJoinBuildInputPartition = true; private boolean randomizeOuterJoinNullKey; private RandomizeOuterJoinNullKeyStrategy randomizeOuterJoinNullKeyStrategy = RandomizeOuterJoinNullKeyStrategy.DISABLED; + private RandomizeNullSourceKeyInSemiJoinStrategy randomizeNullSourceKeyInSemiJoinStrategy = RandomizeNullSourceKeyInSemiJoinStrategy.DISABLED; private ShardedJoinStrategy shardedJoinStrategy = ShardedJoinStrategy.DISABLED; private int joinShardCount = 100; private boolean isOptimizeConditionalAggregationEnabled; @@ -269,6 +287,7 @@ public class FeaturesConfig private boolean pullUpExpressionFromLambda; private boolean rewriteConstantArrayContainsToIn; private boolean rewriteExpressionWithConstantVariable = true; + private boolean optimizeConditionalApproxDistinct = true; private boolean preProcessMetadataCalls; private boolean handleComplexEquiJoins; @@ -283,7 +302,7 @@ public class FeaturesConfig private boolean generateDomainFilters; private boolean printEstimatedStatsFromCache; private boolean removeCrossJoinWithSingleConstantRow = true; - private CreateView.Security defaultViewSecurityMode = DEFINER; + private ViewSecurity defaultViewSecurityMode = DEFINER; private boolean useHistograms; private boolean isInlineProjectionsOnValuesEnabled; @@ -295,13 +314,27 @@ public class FeaturesConfig private int eagerPlanValidationThreadPoolSize = 20; private boolean innerJoinPushdownEnabled; private boolean inEqualityJoinPushdownEnabled; + private boolean rewriteMinMaxByToTopNEnabled; + private boolean broadcastSemiJoinForDelete = true; private boolean prestoSparkExecutionEnvironment; private boolean singleNodeExecutionEnabled; private boolean nativeExecutionScaleWritersThreadsEnabled; - private boolean nativeExecutionTypeRewriteEnabled; private String expressionOptimizerName = DEFAULT_EXPRESSION_OPTIMIZER_NAME; private boolean addExchangeBelowPartialAggregationOverGroupId; + private boolean addDistinctBelowSemiJoinBuild; + private boolean pushdownSubfieldForMapFunctions = true; + private boolean pushdownSubfieldForCardinality; + private long maxSerializableObjectSize = 1000; + private boolean utilizeUniquePropertyInQueryPlanning = true; + private String expressionOptimizerUsedInRowExpressionRewrite = ""; + private double tableScanShuffleParallelismThreshold = 0.1; + private ShuffleForTableScanStrategy tableScanShuffleStrategy = ShuffleForTableScanStrategy.DISABLED; + private boolean skipPushdownThroughExchangeForRemoteProjection; + private String remoteFunctionNamesForFixedParallelism = ""; + private int remoteFunctionFixedParallelismTaskCount = 10; + + private boolean builtInSidecarFunctionsEnabled; public enum PartitioningPrecisionStrategy { @@ -404,6 +437,12 @@ public enum RandomizeOuterJoinNullKeyStrategy ALWAYS } + public enum RandomizeNullSourceKeyInSemiJoinStrategy + { + DISABLED, + ALWAYS + } + public enum ShardedJoinStrategy { DISABLED, @@ -452,6 +491,27 @@ public enum LeftJoinArrayContainsToInnerJoinStrategy ALWAYS_ENABLED } + public enum ShuffleForTableScanStrategy + { + DISABLED, + ALWAYS_ENABLED, + COST_BASED + } + + @Min(1) + @Config("max-prefixes-count") + @ConfigDescription("Maximum number of prefixes (catalog/schema/table scopes used to narrow metadata lookups) that Presto generates when querying information_schema.") + public FeaturesConfig setMaxPrefixesCount(Integer maxPrefixesCount) + { + this.maxPrefixesCount = maxPrefixesCount; + return this; + } + + public int getMaxPrefixesCount() + { + return maxPrefixesCount; + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -861,6 +921,33 @@ public FeaturesConfig setHistoryBasedOptimizerPlanCanonicalizationStrategies(Str return this; } + @NotNull + public List getQueryTypesEnabledForHbo() + { + return queryTypesEnabledForHbo; + } + + @Config("optimizer.query-types-enabled-for-hbo") + public FeaturesConfig setQueryTypesEnabledForHbo(String queryTypesEnabledForHbo) + { + this.queryTypesEnabledForHbo = parseQueryTypesFromString(queryTypesEnabledForHbo); + return this; + } + + public static List parseQueryTypesFromString(String queryTypes) + { + try { + return Splitter.on(",").trimResults().splitToList(queryTypes).stream() + .map(QueryType::valueOf).collect(toImmutableList()); + } + catch (Exception e) { + throw new PrestoException(INVALID_SESSION_PROPERTY, format("Allowed options for query_types_enabled_for_history_based_optimization are: %s", + Stream.of(QueryType.values()) + .map(QueryType::name) + .collect(joining(",")))); + } + } + public boolean isLogPlansUsedInHistoryBasedOptimizer() { return logPlansUsedInHistoryBasedOptimizer; @@ -885,6 +972,18 @@ public FeaturesConfig setEnforceTimeoutForHBOQueryRegistration(boolean enforceTi return this; } + public boolean isHistoryBasedOptimizerEstimateSizeUsingVariables() + { + return historyBasedOptimizerEstimateSizeUsingVariables; + } + + @Config("optimizer.history-based-optimizer-estimate-size-using-variables") + public FeaturesConfig setHistoryBasedOptimizerEstimateSizeUsingVariables(boolean historyBasedOptimizerEstimateSizeUsingVariables) + { + this.historyBasedOptimizerEstimateSizeUsingVariables = historyBasedOptimizerEstimateSizeUsingVariables; + return this; + } + public AggregationPartitioningMergingStrategy getAggregationPartitioningMergingStrategy() { return aggregationPartitioningMergingStrategy; @@ -1781,6 +1880,19 @@ public FeaturesConfig setExperimentalFunctionsEnabled(boolean experimentalFuncti return this; } + public boolean isUseConnectorProvidedSerializationCodecs() + { + return useConnectorProvidedSerializationCodecs; + } + + @Config("use-connector-provided-serialization-codecs") + @ConfigDescription("Enable use of custom connector-provided serialization codecs for handles") + public FeaturesConfig setUseConnectorProvidedSerializationCodecs(boolean useConnectorProvidedSerializationCodecs) + { + this.useConnectorProvidedSerializationCodecs = useConnectorProvidedSerializationCodecs; + return this; + } + public boolean isOptimizeCommonSubExpressions() { return optimizeCommonSubExpressions; @@ -2091,6 +2203,59 @@ public FeaturesConfig setQueryOptimizationWithMaterializedViewEnabled(boolean va return this; } + public boolean isLegacyMaterializedViews() + { + return legacyMaterializedViewRefresh; + } + + @Config("experimental.legacy-materialized-views") + @ConfigDescription("Experimental: Use legacy materialized views. This feature is under active development and may change" + + "or be removed at any time. Do not disable in production environments.") + public FeaturesConfig setLegacyMaterializedViews(boolean value) + { + this.legacyMaterializedViewRefresh = value; + return this; + } + + public boolean isAllowLegacyMaterializedViewsToggle() + { + return allowLegacyMaterializedViewsToggle; + } + + @Config("experimental.allow-legacy-materialized-views-toggle") + @ConfigDescription("Allow toggling legacy materialized views via session property. This should only be enabled in non-production environments.") + public FeaturesConfig setAllowLegacyMaterializedViewsToggle(boolean value) + { + this.allowLegacyMaterializedViewsToggle = value; + return this; + } + + public boolean isMaterializedViewAllowFullRefreshEnabled() + { + return materializedViewAllowFullRefreshEnabled; + } + + @Config("materialized-view-allow-full-refresh-enabled") + @ConfigDescription("Allow full refresh of MV when it's empty - potentially high cost.") + public FeaturesConfig setMaterializedViewAllowFullRefreshEnabled(boolean value) + { + this.materializedViewAllowFullRefreshEnabled = value; + return this; + } + + public MaterializedViewStaleReadBehavior getMaterializedViewStaleReadBehavior() + { + return materializedViewStaleReadBehavior; + } + + @Config("materialized-view-stale-read-behavior") + @ConfigDescription("Default behavior when reading from a stale materialized view (FAIL or USE_VIEW_QUERY)") + public FeaturesConfig setMaterializedViewStaleReadBehavior(MaterializedViewStaleReadBehavior value) + { + this.materializedViewStaleReadBehavior = value; + return this; + } + public boolean isVerboseRuntimeStatsEnabled() { return verboseRuntimeStatsEnabled; @@ -2214,6 +2379,32 @@ public FeaturesConfig setPreferMergeJoinForSortedInputs(boolean preferMergeJoinF return this; } + public boolean isPreferSortMergeJoin() + { + return preferSortMergeJoin; + } + + @Config("experimental.optimizer.prefer-sort-merge-join") + @ConfigDescription("Prefer sort merge join for all joins. A SortNode is added if input is not already sorted.") + public FeaturesConfig setPreferSortMergeJoin(boolean preferSortMergeJoin) + { + this.preferSortMergeJoin = preferSortMergeJoin; + return this; + } + + public boolean isSortedExchangeEnabled() + { + return isSortedExchangeEnabled; + } + + @Config("experimental.optimizer.sorted-exchange-enabled") + @ConfigDescription("(Experimental) Enable pushing sort operations down to exchange nodes for distributed queries") + public FeaturesConfig setSortedExchangeEnabled(boolean isSortedExchangeEnabled) + { + this.isSortedExchangeEnabled = isSortedExchangeEnabled; + return this; + } + public boolean isSegmentedAggregationEnabled() { return segmentedAggregationEnabled; @@ -2264,6 +2455,19 @@ public FeaturesConfig setOptimizeMultipleApproxPercentileOnSameFieldEnabled(bool return this; } + public boolean isOptimizeMultipleApproxDistinctOnSameTypeEnabled() + { + return isOptimizeMultipleApproxDistinctOnSameTypeEnabled; + } + + @Config("optimizer.optimize-multiple-approx-distinct-on-same-type") + @ConfigDescription("Enable combining individual approx_distinct calls on expressions of the same type using set_agg") + public FeaturesConfig setOptimizeMultipleApproxDistinctOnSameTypeEnabled(boolean isOptimizeMultipleApproxDistinctOnSameTypeEnabled) + { + this.isOptimizeMultipleApproxDistinctOnSameTypeEnabled = isOptimizeMultipleApproxDistinctOnSameTypeEnabled; + return this; + } + @Config("native-execution-enabled") @ConfigDescription("Enable execution on native engine") public FeaturesConfig setNativeExecutionEnabled(boolean nativeExecutionEnabled) @@ -2381,6 +2585,19 @@ public FeaturesConfig setRandomizeOuterJoinNullKeyStrategy(RandomizeOuterJoinNul return this; } + public RandomizeNullSourceKeyInSemiJoinStrategy getRandomizeNullSourceKeyInSemiJoinStrategy() + { + return randomizeNullSourceKeyInSemiJoinStrategy; + } + + @Config("optimizer.randomize-null-source-key-in-semi-join-strategy") + @ConfigDescription("When to apply randomization to null source keys in semi join") + public FeaturesConfig setRandomizeNullSourceKeyInSemiJoinStrategy(RandomizeNullSourceKeyInSemiJoinStrategy randomizeNullSourceKeyInSemiJoinStrategy) + { + this.randomizeNullSourceKeyInSemiJoinStrategy = randomizeNullSourceKeyInSemiJoinStrategy; + return this; + } + public ShardedJoinStrategy getShardedJoinStrategy() { return shardedJoinStrategy; @@ -2784,14 +3001,27 @@ public FeaturesConfig setRewriteExpressionWithConstantVariable(boolean rewriteEx return this; } - public CreateView.Security getDefaultViewSecurityMode() + public boolean isOptimizeConditionalApproxDistinct() + { + return this.optimizeConditionalApproxDistinct; + } + + @Config("optimizer.optimize-constant-approx-distinct") + @ConfigDescription("Optimize out APPROX_DISTINCT over conditional constant expressions") + public FeaturesConfig setOptimizeConditionalApproxDistinct(boolean optimizeConditionalApproxDistinct) + { + this.optimizeConditionalApproxDistinct = optimizeConditionalApproxDistinct; + return this; + } + + public ViewSecurity getDefaultViewSecurityMode() { return this.defaultViewSecurityMode; } @Config("default-view-security-mode") @ConfigDescription("Sets the default security mode for view creation. The options are definer/invoker.") - public FeaturesConfig setDefaultViewSecurityMode(CreateView.Security securityMode) + public FeaturesConfig setDefaultViewSecurityMode(ViewSecurity securityMode) { this.defaultViewSecurityMode = securityMode; return this; @@ -2909,10 +3139,37 @@ public FeaturesConfig setInEqualityJoinPushdownEnabled(boolean inEqualityJoinPus return this; } + public boolean isRewriteMinMaxByToTopNEnabled() + { + return rewriteMinMaxByToTopNEnabled; + } + + @Config("optimizer.rewrite-minBy-maxBy-to-topN-enabled") + @ConfigDescription("Rewrite min_by and max_by to topN") + public FeaturesConfig setRewriteMinMaxByToTopNEnabled(boolean rewriteMinMaxByToTopNEnabled) + { + this.rewriteMinMaxByToTopNEnabled = rewriteMinMaxByToTopNEnabled; + return this; + } + + public boolean isBroadcastSemiJoinForDelete() + { + return broadcastSemiJoinForDelete; + } + + @Config("optimizer.broadcast-semi-join-for-delete") + @ConfigDescription("Enforce broadcast semi join in delete queries") + public FeaturesConfig setBroadcastSemiJoinForDelete(boolean broadcastSemiJoinForDelete) + { + this.broadcastSemiJoinForDelete = broadcastSemiJoinForDelete; + return this; + } + public boolean isInEqualityJoinPushdownEnabled() { return inEqualityJoinPushdownEnabled; } + public boolean isPrestoSparkExecutionEnvironment() { return prestoSparkExecutionEnvironment; @@ -2950,18 +3207,6 @@ public FeaturesConfig setNativeExecutionScaleWritersThreadsEnabled(boolean nativ return this; } - public boolean isNativeExecutionTypeRewriteEnabled() - { - return nativeExecutionTypeRewriteEnabled; - } - - @Config("native-execution-type-rewrite-enabled") - public FeaturesConfig setNativeExecutionTypeRewriteEnabled(boolean nativeExecutionTypeRewriteEnabled) - { - this.nativeExecutionTypeRewriteEnabled = nativeExecutionTypeRewriteEnabled; - return this; - } - public String getExpressionOptimizerName() { return expressionOptimizerName; @@ -3000,4 +3245,161 @@ public boolean getAddExchangeBelowPartialAggregationOverGroupId() { return addExchangeBelowPartialAggregationOverGroupId; } + + @Config("optimizer.add-distinct-below-semi-join-build") + @ConfigDescription("Add a distinct aggregation below build side of semi join") + public FeaturesConfig setAddDistinctBelowSemiJoinBuild(boolean addDistinctBelowSemiJoinBuild) + { + this.addDistinctBelowSemiJoinBuild = addDistinctBelowSemiJoinBuild; + return this; + } + + public boolean isAddDistinctBelowSemiJoinBuild() + { + return addDistinctBelowSemiJoinBuild; + } + + @Config("optimizer.pushdown-subfield-for-map-functions") + @ConfigDescription("Enable subfield pruning for map functions, currently include map_subset and map_filter") + public FeaturesConfig setPushdownSubfieldForMapFunctions(boolean pushdownSubfieldForMapFunctions) + { + this.pushdownSubfieldForMapFunctions = pushdownSubfieldForMapFunctions; + return this; + } + + public boolean isPushdownSubfieldForMapFunctions() + { + return pushdownSubfieldForMapFunctions; + } + + @Config("optimizer.pushdown-subfield-for-cardinality") + @ConfigDescription("Enable subfield pruning for cardinality() function to skip reading keys and values") + public FeaturesConfig setPushdownSubfieldForCardinality(boolean pushdownSubfieldForCardinality) + { + this.pushdownSubfieldForCardinality = pushdownSubfieldForCardinality; + return this; + } + + public boolean isPushdownSubfieldForCardinality() + { + return pushdownSubfieldForCardinality; + } + + @Config("optimizer.utilize-unique-property-in-query-planning") + @ConfigDescription("Utilize the unique property of input columns in query planning") + public FeaturesConfig setUtilizeUniquePropertyInQueryPlanning(boolean utilizeUniquePropertyInQueryPlanning) + { + this.utilizeUniquePropertyInQueryPlanning = utilizeUniquePropertyInQueryPlanning; + return this; + } + + public boolean isUtilizeUniquePropertyInQueryPlanning() + { + return utilizeUniquePropertyInQueryPlanning; + } + + public String getExpressionOptimizerUsedInRowExpressionRewrite() + { + return expressionOptimizerUsedInRowExpressionRewrite; + } + + @Config("optimizer.expression-optimizer-used-in-expression-rewrite") + @ConfigDescription("The name of expression optimizer to be used in row expression rewrite") + public FeaturesConfig setExpressionOptimizerUsedInRowExpressionRewrite(String expressionOptimizerUsedInRowExpressionRewrite) + { + this.expressionOptimizerUsedInRowExpressionRewrite = expressionOptimizerUsedInRowExpressionRewrite; + return this; + } + + @Config("max_serializable_object_size") + @ConfigDescription("Configure the maximum byte size of a serializable object in expression interpreters") + public FeaturesConfig setMaxSerializableObjectSize(long maxSerializableObjectSize) + { + this.maxSerializableObjectSize = maxSerializableObjectSize; + return this; + } + + public long getMaxSerializableObjectSize() + { + return maxSerializableObjectSize; + } + + public double getTableScanShuffleParallelismThreshold() + { + return tableScanShuffleParallelismThreshold; + } + + @Config("optimizer.table-scan-shuffle-parallelism-threshold") + @ConfigDescription("Parallelism threshold for adding a shuffle above table scan. When the table's parallelism factor is below this threshold (0.0-1.0) and TABLE_SCAN_SHUFFLE_STRATEGY is COST_BASED, a round-robin shuffle exchange is added above the table scan to redistribute data.") + public FeaturesConfig setTableScanShuffleParallelismThreshold(double tableScanShuffleParallelismThreshold) + { + this.tableScanShuffleParallelismThreshold = tableScanShuffleParallelismThreshold; + return this; + } + + public ShuffleForTableScanStrategy getTableScanShuffleStrategy() + { + return tableScanShuffleStrategy; + } + + @Config("optimizer.table-scan-shuffle-strategy") + @ConfigDescription("Strategy for adding shuffle above table scan to redistribute data. Options are DISABLED, ALWAYS_ENABLED, COST_BASED") + public FeaturesConfig setTableScanShuffleStrategy(ShuffleForTableScanStrategy tableScanShuffleStrategy) + { + this.tableScanShuffleStrategy = tableScanShuffleStrategy; + return this; + } + + public boolean isSkipPushdownThroughExchangeForRemoteProjection() + { + return skipPushdownThroughExchangeForRemoteProjection; + } + + @Config("optimizer.skip-pushdown-through-exchange-for-remote-projection") + @ConfigDescription("Skip pushing down remote projection through exchange") + public FeaturesConfig setSkipPushdownThroughExchangeForRemoteProjection(boolean skipPushdownThroughExchangeForRemoteProjection) + { + this.skipPushdownThroughExchangeForRemoteProjection = skipPushdownThroughExchangeForRemoteProjection; + return this; + } + + @Config("built-in-sidecar-functions-enabled") + @ConfigDescription("Enable using CPP functions from sidecar over coordinator SQL implementations.") + public FeaturesConfig setBuiltInSidecarFunctionsEnabled(boolean builtInSidecarFunctionsEnabled) + { + this.builtInSidecarFunctionsEnabled = builtInSidecarFunctionsEnabled; + return this; + } + + public boolean isBuiltInSidecarFunctionsEnabled() + { + return this.builtInSidecarFunctionsEnabled; + } + + public String getRemoteFunctionNamesForFixedParallelism() + { + return remoteFunctionNamesForFixedParallelism; + } + + @Config("optimizer.remote-function-names-for-fixed-parallelism") + @ConfigDescription("Regex pattern to match remote function names that should use fixed parallelism") + public FeaturesConfig setRemoteFunctionNamesForFixedParallelism(String remoteFunctionNamesForFixedParallelism) + { + this.remoteFunctionNamesForFixedParallelism = remoteFunctionNamesForFixedParallelism; + return this; + } + + @Min(1) + public int getRemoteFunctionFixedParallelismTaskCount() + { + return remoteFunctionFixedParallelismTaskCount; + } + + @Config("optimizer.remote-function-fixed-parallelism-task-count") + @ConfigDescription("Number of tasks to use for remote functions matching the fixed parallelism pattern. If not set (0), the default hash partition count will be used.") + public FeaturesConfig setRemoteFunctionFixedParallelismTaskCount(int remoteFunctionFixedParallelismTaskCount) + { + this.remoteFunctionFixedParallelismTaskCount = remoteFunctionFixedParallelismTaskCount; + return this; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ForMetadataExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ForMetadataExtractor.java index 6a7811d1abf65..ec27269e42e4f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ForMetadataExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/ForMetadataExtractor.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.sql.analyzer; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FunctionsConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FunctionsConfig.java index 4e64a20d467dc..fbb20beccd208 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FunctionsConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FunctionsConfig.java @@ -19,8 +19,7 @@ import com.facebook.presto.operator.aggregation.histogram.HistogramGroupImplementation; import com.facebook.presto.operator.aggregation.multimapagg.MultimapAggGroupImplementation; import com.facebook.presto.spi.function.Description; - -import javax.validation.constraints.Min; +import jakarta.validation.constraints.Min; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/JavaFeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/JavaFeaturesConfig.java index ae5e783775650..51d238a46e610 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/JavaFeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/JavaFeaturesConfig.java @@ -15,9 +15,9 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class JavaFeaturesConfig { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewColumnMappingExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewColumnMappingExtractor.java index f9cefc11117ea..80a406edf6bf4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewColumnMappingExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewColumnMappingExtractor.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.CreateMaterializedView; @@ -48,6 +49,7 @@ public class MaterializedViewColumnMappingExtractor { private final Analysis analysis; private final Session session; + private final Metadata metadata; /** * We create a undirected graph where each node corresponds to a base table column. @@ -79,10 +81,11 @@ public class MaterializedViewColumnMappingExtractor */ private List baseTablesOnOuterJoinSide; - public MaterializedViewColumnMappingExtractor(Analysis analysis, Session session) + public MaterializedViewColumnMappingExtractor(Analysis analysis, Session session, Metadata metadata) { this.analysis = requireNonNull(analysis, "analysis is null"); this.session = requireNonNull(session, "session is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); this.mappedBaseColumns = new HashMap<>(); this.directMappedBaseColumns = new HashMap<>(); this.baseTablesOnOuterJoinSide = new ArrayList<>(); @@ -166,7 +169,7 @@ protected Void visitTable(Table node, MaterializedViewPlanValidatorContext conte super.visitTable(node, context); if (context.isWithinOuterJoin()) { - baseTablesOnOuterJoinSide.add(toSchemaTableName(createQualifiedObjectName(session, node, node.getName()))); + baseTablesOnOuterJoinSide.add(toSchemaTableName(createQualifiedObjectName(session, node, node.getName(), metadata))); } return null; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java index edeb6898a8838..3b6355bdfd317 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MaterializedViewQueryOptimizer.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.MetadataResolver; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.security.AccessControl; @@ -344,7 +345,7 @@ private QuerySpecification rewriteQuerySpecificationIfCompatible(QuerySpecificat List referencedMaterializedViews = metadata.getReferencedMaterializedViews( session, - createQualifiedObjectName(session, baseTable, baseTable.getName())); + createQualifiedObjectName(session, baseTable, baseTable.getName(), metadata)); // TODO: Select the most compatible and efficient materialized view for query rewrite optimization https://github.com/prestodb/presto/issues/16431 // TODO: Refactor query optimization code https://github.com/prestodb/presto/issues/16759 @@ -771,7 +772,7 @@ ExpressionAnalysis getExpressionAnalysis(Expression expression, Scope scope) accessControl, sqlParser, scope, - new Analysis(null, ImmutableMap.of(), false), + new Analysis(null, ImmutableMap.of(), false, new ViewDefinitionReferences()), expression, WarningCollector.NOOP); } @@ -815,7 +816,7 @@ private Expression coerceIfNecessary(Expression original, Expression rewritten) private Scope extractScope(Table table, QuerySpecification node, Expression whereClause) { - QualifiedObjectName baseTableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName baseTableName = createQualifiedObjectName(session, table, table.getName(), metadata); Optional tableHandle = metadata.getMetadataResolver(session).getTableHandle(baseTableName); if (!tableHandle.isPresent()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractor.java index 3eda7c8204334..f5a4b8afcf7b5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractor.java @@ -189,7 +189,7 @@ public Visitor(Session session) @Override protected Void visitTable(Table table, MetadataExtractorContext context) { - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); if (tableName.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, table, "Table name is empty"); } @@ -207,7 +207,7 @@ protected Void visitTable(Table table, MetadataExtractorContext context) @Override protected Void visitInsert(Insert insert, MetadataExtractorContext context) { - QualifiedObjectName tableName = createQualifiedObjectName(session, insert, insert.getTarget()); + QualifiedObjectName tableName = createQualifiedObjectName(session, insert, insert.getTarget(), metadata); if (tableName.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, insert, "Table name is empty"); } @@ -224,7 +224,7 @@ protected Void visitInsert(Insert insert, MetadataExtractorContext context) protected Void visitDelete(Delete node, MetadataExtractorContext context) { Table table = node.getTable(); - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); if (tableName.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, node, "Table name is empty"); } @@ -239,7 +239,7 @@ protected Void visitDelete(Delete node, MetadataExtractorContext context) @Override protected Void visitAnalyze(Analyze node, MetadataExtractorContext context) { - QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName(), metadata); if (tableName.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, node, "Table name is empty"); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractorMBean.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractorMBean.java index d668f854adf82..a6046231d2469 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractorMBean.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/MetadataExtractorMBean.java @@ -14,11 +14,10 @@ package com.facebook.presto.sql.analyzer; import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.concurrent.ExecutorService; import java.util.concurrent.ThreadPoolExecutor; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/PredicateStitcher.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/PredicateStitcher.java index a39c808d91238..47fd5046b57ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/PredicateStitcher.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/PredicateStitcher.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.analyzer; import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.AllColumns; @@ -60,11 +61,13 @@ public class PredicateStitcher { private final Map predicates; private final Session session; + private final Metadata metadata; - public PredicateStitcher(Session session, Map predicates) + public PredicateStitcher(Session session, Map predicates, Metadata metadata) { this.session = requireNonNull(session, "session is null"); this.predicates = requireNonNull(predicates, "predicates is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); } @Override @@ -185,7 +188,7 @@ protected Node visitAliasedRelation(AliasedRelation node, PredicateStitcherConte @Override protected Node visitTable(Table table, PredicateStitcherContext context) { - SchemaTableName schemaTableName = toSchemaTableName(createQualifiedObjectName(session, table, table.getName())); + SchemaTableName schemaTableName = toSchemaTableName(createQualifiedObjectName(session, table, table.getName(), metadata)); if (!predicates.containsKey(schemaTableName)) { return table; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java index 8b3b8dd0740db..7bb90d8619f4d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryExplainer.java @@ -21,6 +21,8 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.security.AccessControl; @@ -39,8 +41,7 @@ import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; @@ -59,6 +60,7 @@ import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.graphvizLogicalPlan; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonDistributedPlan; import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.jsonLogicalPlan; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissionsForTablesAndColumns; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -120,13 +122,17 @@ public QueryExplainer( this.planChecker = requireNonNull(planChecker, "planChecker is null"); } - public Analysis analyze(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query) + public Analysis analyze(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { - Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.of(this), parameters, parameterExtractor(statement, parameters), warningCollector, query); - return analyzer.analyze(statement); + Analyzer analyzer = new Analyzer(session, metadata, sqlParser, accessControl, Optional.of(this), parameters, parameterExtractor(statement, parameters), warningCollector, query, viewDefinitionReferences); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissionsForTablesAndColumns(accessControlReferences); + + return analysis; } - public String getPlan(Session session, Statement statement, Type planType, List parameters, boolean verbose, WarningCollector warningCollector, String query) + public String getPlan(Session session, Statement statement, Type planType, List parameters, boolean verbose, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { DataDefinitionTask task = dataDefinitionTask.get(statement.getClass()); if (task != null) { @@ -135,13 +141,13 @@ public String getPlan(Session session, Statement statement, Type planType, List< switch (planType) { case LOGICAL: - Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, query); + Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), plan.getStatsAndCosts(), metadata.getFunctionAndTypeManager(), session, 0, verbose, isVerboseOptimizerInfoEnabled(session)); case DISTRIBUTED: - SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query); + SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return PlanPrinter.textDistributedPlan(subPlan, metadata.getFunctionAndTypeManager(), session, verbose); case IO: - return IOPlanPrinter.textIOPlan(getLogicalPlan(session, statement, parameters, warningCollector, query).getRoot(), metadata, session); + return IOPlanPrinter.textIOPlan(getLogicalPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences).getRoot(), metadata, session); } throw new IllegalArgumentException("Unhandled plan type: " + planType); } @@ -151,7 +157,7 @@ private static String explainTask(Statement statement, Dat return task.explain((T) statement, parameters); } - public String getGraphvizPlan(Session session, Statement statement, Type planType, List parameters, WarningCollector warningCollector, String query) + public String getGraphvizPlan(Session session, Statement statement, Type planType, List parameters, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { DataDefinitionTask task = dataDefinitionTask.get(statement.getClass()); if (task != null) { @@ -161,16 +167,16 @@ public String getGraphvizPlan(Session session, Statement statement, Type planTyp switch (planType) { case LOGICAL: - Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, query); + Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return graphvizLogicalPlan(plan.getRoot(), plan.getTypes(), plan.getStatsAndCosts(), metadata.getFunctionAndTypeManager(), session); case DISTRIBUTED: - SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query); + SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return graphvizDistributedPlan(subPlan, metadata.getFunctionAndTypeManager(), session); } throw new IllegalArgumentException("Unhandled plan type: " + planType); } - public String getJsonPlan(Session session, Statement statement, Type planType, List parameters, WarningCollector warningCollector, String query) + public String getJsonPlan(Session session, Statement statement, Type planType, List parameters, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { DataDefinitionTask task = dataDefinitionTask.get(statement.getClass()); if (task != null) { @@ -181,29 +187,29 @@ public String getJsonPlan(Session session, Statement statement, Type planType, L Plan plan; switch (planType) { case IO: - plan = getLogicalPlan(session, statement, parameters, warningCollector, query); + plan = getLogicalPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return textIOPlan(plan.getRoot(), metadata, session); case LOGICAL: - plan = getLogicalPlan(session, statement, parameters, warningCollector, query); + plan = getLogicalPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return jsonLogicalPlan(plan.getRoot(), plan.getTypes(), metadata.getFunctionAndTypeManager(), plan.getStatsAndCosts(), session); case DISTRIBUTED: - SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query); + SubPlan subPlan = getDistributedPlan(session, statement, parameters, warningCollector, query, viewDefinitionReferences); return jsonDistributedPlan(subPlan, metadata.getFunctionAndTypeManager(), session); default: throw new PrestoException(NOT_SUPPORTED, format("Unsupported explain plan type %s for JSON format", planType)); } } - public Plan getLogicalPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query) + public Plan getLogicalPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { - return getLogicalPlan(session, statement, parameters, warningCollector, new PlanNodeIdAllocator(), query); + return getLogicalPlan(session, statement, parameters, warningCollector, new PlanNodeIdAllocator(), query, viewDefinitionReferences); } - public Plan getLogicalPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, PlanNodeIdAllocator idAllocator, String query) + public Plan getLogicalPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, PlanNodeIdAllocator idAllocator, String query, ViewDefinitionReferences viewDefinitionReferences) { // analyze statement Analysis analysis = session.getRuntimeStats() - .recordWallAndCpuTime(ANALYZE_TIME_NANOS, () -> analyze(session, statement, parameters, warningCollector, query)); + .recordWallAndCpuTime(ANALYZE_TIME_NANOS, () -> analyze(session, statement, parameters, warningCollector, query, viewDefinitionReferences)); final VariableAllocator planVariableAllocator = new VariableAllocator(); LogicalPlanner logicalPlanner = new LogicalPlanner( @@ -234,10 +240,10 @@ public Plan getLogicalPlan(Session session, Statement statement, List optimizer.validateAndOptimizePlan(planNode, OPTIMIZED_AND_VALIDATED)); } - public SubPlan getDistributedPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query) + public SubPlan getDistributedPlan(Session session, Statement statement, List parameters, WarningCollector warningCollector, String query, ViewDefinitionReferences viewDefinitionReferences) { PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); - Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, idAllocator, query); + Plan plan = getLogicalPlan(session, statement, parameters, warningCollector, idAllocator, query, viewDefinitionReferences); return session.getRuntimeStats() .recordWallAndCpuTime(FRAGMENT_PLAN_TIME_NANOS, () -> planFragmenter.createSubPlans(session, plan, false, idAllocator, warningCollector)); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryPreparerProviderManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryPreparerProviderManager.java index 6cd80e0586539..3580cb0a27001 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryPreparerProviderManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/QueryPreparerProviderManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.analyzer.QueryPreparerProvider; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.HashMap; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/RefreshMaterializedViewPredicateAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/RefreshMaterializedViewPredicateAnalyzer.java index 13986ea999b84..9890015db50a3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/RefreshMaterializedViewPredicateAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/RefreshMaterializedViewPredicateAnalyzer.java @@ -24,6 +24,8 @@ import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.InListExpression; +import com.facebook.presto.sql.tree.InPredicate; import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.Node; @@ -31,12 +33,10 @@ import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Map; import java.util.Optional; -import java.util.function.Supplier; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; @@ -46,7 +46,7 @@ /** * Map predicates on view columns in the RefreshMaterializedView where clause to predicates on base table columns, - * which could be used for predicate push-down afterwards. Mapped predicates are connected by AND. + * which could be used for predicate push-down afterwards. Mapped predicates are connected by AND or OR. * For view columns that do not have a direct mapping to a base table column, keep the predicate with the view. */ public class RefreshMaterializedViewPredicateAnalyzer @@ -105,7 +105,7 @@ public Map getTablePredicates() @Override public Void process(Node node, @Nullable Void context) { - if (!(node instanceof ComparisonExpression || node instanceof LogicalBinaryExpression)) { + if (!(node instanceof ComparisonExpression || node instanceof LogicalBinaryExpression || node instanceof InPredicate)) { throw new SemanticException(NOT_SUPPORTED, node, "Only column specifications connected by logical AND are supported in WHERE clause."); } @@ -121,37 +121,117 @@ protected Void visitExpression(Expression node, Void context) @Override protected Void visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) { - if (!LogicalBinaryExpression.Operator.AND.equals(node.getOperator())) { - throw new SemanticException(NOT_SUPPORTED, node, "Only logical AND is supported in WHERE clause."); + if (LogicalBinaryExpression.Operator.OR.equals(node.getOperator())) { + SchemaTableName viewName = new SchemaTableName(viewDefinition.getSchema(), viewDefinition.getTable()); + tablePredicatesBuilder.put(viewName, node); + return null; } - if (!(node.getLeft() instanceof ComparisonExpression || node.getLeft() instanceof LogicalBinaryExpression)) { + + if (!(node.getLeft() instanceof ComparisonExpression || node.getLeft() instanceof LogicalBinaryExpression || node.getLeft() instanceof InPredicate)) { throw new SemanticException(NOT_SUPPORTED, node.getLeft(), "Only column specifications connected by logical AND are supported in WHERE clause."); } - if (!(node.getRight() instanceof ComparisonExpression || node.getRight() instanceof LogicalBinaryExpression)) { + if (!(node.getRight() instanceof ComparisonExpression || node.getRight() instanceof LogicalBinaryExpression || node.getRight() instanceof InPredicate)) { throw new SemanticException(NOT_SUPPORTED, node.getRight(), "Only column specifications connected by logical AND are supported in WHERE clause."); } return super.visitLogicalBinaryExpression(node, null); } + @Override + protected Void visitInPredicate(InPredicate node, Void context) + { + Expression value = node.getValue(); + Expression valueList = node.getValueList(); + + if (!(value instanceof Identifier || value instanceof DereferenceExpression)) { + throw new SemanticException(NOT_SUPPORTED, value, "Only column references are supported on the left side of IN predicates in WHERE clause."); + } + if (!(valueList instanceof InListExpression)) { + throw new SemanticException(NOT_SUPPORTED, valueList, "Only IN list expressions are supported in WHERE clause's IN predicates."); + } + + InListExpression inListExpression = (InListExpression) valueList; + for (Expression inValue : inListExpression.getValues()) { + if (!(inValue instanceof Literal)) { + throw new SemanticException(NOT_SUPPORTED, inValue, "Only literal values are supported in WHERE clause's IN lists."); + } + } + + QualifiedName qualifiedName = value instanceof DereferenceExpression + ? DereferenceExpression.getQualifiedName((DereferenceExpression) value) + : QualifiedName.of(((Identifier) value).getValue()); + + ResolvedField resolvedField = viewScope.tryResolveField(value).orElseThrow(() -> missingAttributeException(value, qualifiedName)); + String column = resolvedField.getField().getOriginColumnName().orElseThrow(() -> missingAttributeException(value, qualifiedName)); + + if (!viewDefinition.getValidRefreshColumns().orElse(emptyList()).contains(column)) { + throw new SemanticException(NOT_SUPPORTED, value, "Refresh materialized view by column %s is not supported.", value.toString()); + } + + Map baseTableColumns = viewDefinition.getColumnMappingsAsMap().get(column); + + // Convert IN predicate to OR'd equality comparisons + boolean mappedToSingleBaseTablePartition = true; + for (Expression inValue : inListExpression.getValues()) { + if (baseTableColumns != null && inValue instanceof NullLiteral) { + if (viewDefinition.getBaseTablesOnOuterJoinSide().stream().anyMatch(t -> baseTableColumns.containsKey(t))) { + mappedToSingleBaseTablePartition = false; + break; + } + } + } + + if (mappedToSingleBaseTablePartition && baseTableColumns != null) { + for (SchemaTableName baseTable : baseTableColumns.keySet()) { + Expression orExpression = null; + for (Expression inValue : inListExpression.getValues()) { + ComparisonExpression comparison = new ComparisonExpression( + ComparisonExpression.Operator.EQUAL, + new Identifier(baseTableColumns.get(baseTable)), + inValue); + if (orExpression == null) { + orExpression = comparison; + } + else { + orExpression = new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, orExpression, comparison); + } + } + tablePredicatesBuilder.put(baseTable, orExpression); + } + } + else { + SchemaTableName viewName = new SchemaTableName(viewDefinition.getSchema(), viewDefinition.getTable()); + Expression orExpression = null; + for (Expression inValue : inListExpression.getValues()) { + ComparisonExpression comparison = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, value, inValue); + if (orExpression == null) { + orExpression = comparison; + } + else { + orExpression = new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, orExpression, comparison); + } + } + tablePredicatesBuilder.put(viewName, orExpression); + } + + return null; + } + @Override protected Void visitComparisonExpression(ComparisonExpression node, Void context) { if (!(node.getLeft() instanceof Identifier || node.getLeft() instanceof DereferenceExpression)) { - throw new SemanticException(NOT_SUPPORTED, node.getLeft(), "Only columns specified on literals are supported in WHERE clause."); + throw new SemanticException(NOT_SUPPORTED, node.getLeft(), "Only column references are supported on the left side of comparison expressions in WHERE clause."); } if (!(node.getRight() instanceof Literal)) { - throw new SemanticException(NOT_SUPPORTED, node.getRight(), "Only columns specified on literals are supported in WHERE clause."); + throw new SemanticException(NOT_SUPPORTED, node.getRight(), "Only literal values are supported on the right side of comparison expressions in WHERE clause."); } - Supplier qualifiedName = () -> { - if (node.getLeft() instanceof DereferenceExpression) { - return DereferenceExpression.getQualifiedName((DereferenceExpression) node.getLeft()); - } - return QualifiedName.of(((Identifier) node.getLeft()).getValue()); - }; + QualifiedName qualifiedName = node.getLeft() instanceof DereferenceExpression + ? DereferenceExpression.getQualifiedName((DereferenceExpression) node.getLeft()) + : QualifiedName.of(((Identifier) node.getLeft()).getValue()); - ResolvedField resolvedField = viewScope.tryResolveField(node.getLeft()).orElseThrow(() -> missingAttributeException(node.getLeft(), qualifiedName.get())); - String column = resolvedField.getField().getOriginColumnName().orElseThrow(() -> missingAttributeException(node.getLeft(), qualifiedName.get())); + ResolvedField resolvedField = viewScope.tryResolveField(node.getLeft()).orElseThrow(() -> missingAttributeException(node.getLeft(), qualifiedName)); + String column = resolvedField.getField().getOriginColumnName().orElseThrow(() -> missingAttributeException(node.getLeft(), qualifiedName)); if (!viewDefinition.getValidRefreshColumns().orElse(emptyList()).contains(column)) { throw new SemanticException(NOT_SUPPORTED, node.getLeft(), "Refresh materialized view by column %s is not supported.", node.getLeft().toString()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index c92edc1cdc57d..3b8bad43246d0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -17,6 +17,7 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.SourceColumn; import com.facebook.presto.common.Subfield; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.predicate.Domain; @@ -31,25 +32,46 @@ import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.CatalogMetadata; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.OperatorNotFoundException; +import com.facebook.presto.metadata.TableFunctionMetadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.analyzer.AccessControlInfo; import com.facebook.presto.spi.analyzer.AccessControlInfoForTable; import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.eventlistener.Column; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ArgumentSpecification; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; @@ -57,9 +79,13 @@ import com.facebook.presto.spi.security.Identity; import com.facebook.presto.spi.security.ViewAccessControl; import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.spi.security.ViewSecurity; +import com.facebook.presto.spi.type.UnknownTypeException; import com.facebook.presto.sql.ExpressionUtils; import com.facebook.presto.sql.MaterializedViewUtils; -import com.facebook.presto.sql.SqlFormatterUtil; +import com.facebook.presto.sql.analyzer.Analysis.MergeAnalysis; +import com.facebook.presto.sql.analyzer.Analysis.TableArgumentAnalysis; +import com.facebook.presto.sql.analyzer.Analysis.TableFunctionInvocationAnalysis; import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.ExpressionInterpreter; @@ -87,13 +113,16 @@ import com.facebook.presto.sql.tree.DefaultTraversalVisitor; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.DropBranch; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropConstraint; import com.facebook.presto.sql.tree.DropFunction; import com.facebook.presto.sql.tree.DropMaterializedView; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; +import com.facebook.presto.sql.tree.DropTag; import com.facebook.presto.sql.tree.DropView; +import com.facebook.presto.sql.tree.EmptyTableTreatment; import com.facebook.presto.sql.tree.Except; import com.facebook.presto.sql.tree.Execute; import com.facebook.presto.sql.tree.Explain; @@ -121,12 +150,17 @@ import com.facebook.presto.sql.tree.Literal; import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.Merge; +import com.facebook.presto.sql.tree.MergeCase; +import com.facebook.presto.sql.tree.MergeInsert; +import com.facebook.presto.sql.tree.MergeUpdate; import com.facebook.presto.sql.tree.NaturalJoin; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; +import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Prepare; import com.facebook.presto.sql.tree.Property; import com.facebook.presto.sql.tree.QualifiedName; @@ -158,6 +192,10 @@ import com.facebook.presto.sql.tree.StartTransaction; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionArgument; +import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; +import com.facebook.presto.sql.tree.TableFunctionInvocation; +import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.TruncateTable; import com.facebook.presto.sql.tree.Union; @@ -171,6 +209,7 @@ import com.facebook.presto.sql.tree.With; import com.facebook.presto.sql.tree.WithQuery; import com.facebook.presto.sql.util.AstUtils; +import com.facebook.presto.transaction.TransactionManager; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -178,6 +217,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; +import com.google.common.collect.Streams; import java.util.ArrayList; import java.util.Arrays; @@ -187,13 +227,16 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import static com.facebook.presto.SystemSessionProperties.getMaxGroupingSets; import static com.facebook.presto.SystemSessionProperties.isAllowWindowOrderByLiterals; +import static com.facebook.presto.SystemSessionProperties.isLegacyMaterializedViews; import static com.facebook.presto.SystemSessionProperties.isMaterializedViewDataConsistencyEnabled; import static com.facebook.presto.SystemSessionProperties.isMaterializedViewPartitionFilteringEnabled; import static com.facebook.presto.common.RuntimeMetricName.SKIP_READING_FROM_MATERIALIZED_VIEW_COUNT; @@ -204,16 +247,17 @@ import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.execution.CallTask.extractParameterValuesInOrder; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.StandardErrorCode.DATATYPE_MISMATCH; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_MASK; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.INVALID_ROW_FILTER; +import static com.facebook.presto.spi.StandardErrorCode.MV_MISSING_TOO_MUCH_DATA; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; import static com.facebook.presto.spi.StandardWarningCode.REDUNDANT_ORDER_BY; +import static com.facebook.presto.spi.StandardWarningCode.SEMANTIC_WARNING; import static com.facebook.presto.spi.analyzer.AccessControlRole.TABLE_CREATE; import static com.facebook.presto.spi.analyzer.AccessControlRole.TABLE_DELETE; import static com.facebook.presto.spi.analyzer.AccessControlRole.TABLE_INSERT; @@ -221,6 +265,10 @@ import static com.facebook.presto.spi.connector.ConnectorTableVersion.VersionType; import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.function.FunctionKind.WINDOW; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; import static com.facebook.presto.sql.MaterializedViewUtils.buildOwnerSession; import static com.facebook.presto.sql.MaterializedViewUtils.generateBaseTablePredicates; import static com.facebook.presto.sql.MaterializedViewUtils.generateFalsePredicates; @@ -229,6 +277,8 @@ import static com.facebook.presto.sql.NodeUtils.mapFromProperties; import static com.facebook.presto.sql.QueryUtil.selectList; import static com.facebook.presto.sql.QueryUtil.simpleQuery; +import static com.facebook.presto.sql.SqlFormatter.formatSql; +import static com.facebook.presto.sql.SqlFormatterUtil.getFormattedSql; import static com.facebook.presto.sql.analyzer.AggregationAnalyzer.verifyOrderByAggregations; import static com.facebook.presto.sql.analyzer.AggregationAnalyzer.verifySourceAggregations; import static com.facebook.presto.sql.analyzer.Analysis.MaterializedViewAnalysisState; @@ -249,6 +299,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_PARAMETER_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_ORDINAL; @@ -259,6 +310,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_COLUMN_ALIASES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISMATCHED_SET_COLUMN_TYPES; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_ATTRIBUTE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_COLUMN; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_MATERIALIZED_VIEW; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; @@ -269,7 +321,19 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NON_NUMERIC_SAMPLE_PERCENTAGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_SELECT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.PROCEDURE_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_ALREADY_EXISTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_COLUMN_NOT_FOUND; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_COLUMN_REFERENCE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_COPARTITIONING; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_MISSING_ARGUMENT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_MISSING_RETURN_TYPE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TOO_MANY_GROUPING_SETS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.VIEW_ANALYSIS_ERROR; @@ -282,6 +346,7 @@ import static com.facebook.presto.sql.planner.ExpressionDeterminismEvaluator.isDeterministic; import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; +import static com.facebook.presto.sql.tree.DereferenceExpression.getQualifiedName; import static com.facebook.presto.sql.tree.ExplainFormat.Type.JSON; import static com.facebook.presto.sql.tree.ExplainFormat.Type.TEXT; import static com.facebook.presto.sql.tree.ExplainType.Type.DISTRIBUTED; @@ -301,10 +366,11 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.collect.Iterables.getLast; +import static com.google.common.collect.Iterables.getOnlyElement; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Collections.emptyList; @@ -345,8 +411,6 @@ public StatementAnalyzer( this.metadataResolver = requireNonNull(metadata.getMetadataResolver(session), "metadataResolver is null"); requireNonNull(metadata.getFunctionAndTypeManager(), "functionAndTypeManager is null"); this.functionAndTypeResolver = requireNonNull(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(), "functionAndTypeResolver is null"); - - analysis.addQueryAccessControlInfo(new AccessControlInfo(accessControl, session.getIdentity(), session.getTransactionId(), session.getAccessControlContext())); } public Scope analyze(Node node, Scope outerQueryScope) @@ -356,7 +420,7 @@ public Scope analyze(Node node, Scope outerQueryScope) public Scope analyze(Node node, Optional outerQueryScope) { - return new Visitor(outerQueryScope, warningCollector).process(node, Optional.empty()); + return new Visitor(metadata, session, outerQueryScope, warningCollector).process(node, Optional.empty()); } /** @@ -367,11 +431,19 @@ public Scope analyze(Node node, Optional outerQueryScope) private class Visitor extends DefaultTraversalVisitor> { + private final Metadata metadata; + private final Session session; private final Optional outerQueryScope; private final WarningCollector warningCollector; - private Visitor(Optional outerQueryScope, WarningCollector warningCollector) + private Visitor( + Metadata metadata, + Session session, + Optional outerQueryScope, + WarningCollector warningCollector) { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); this.outerQueryScope = requireNonNull(outerQueryScope, "outerQueryScope is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); } @@ -398,7 +470,7 @@ protected Scope visitUse(Use node, Optional scope) @Override protected Scope visitInsert(Insert insert, Optional scope) { - QualifiedObjectName targetTable = createQualifiedObjectName(session, insert, insert.getTarget()); + QualifiedObjectName targetTable = createQualifiedObjectName(session, insert, insert.getTarget(), metadata); MetadataHandle metadataHandle = analysis.getMetadataHandle(); if (getViewDefinition(session, metadataResolver, metadataHandle, targetTable).isPresent()) { @@ -412,7 +484,7 @@ protected Scope visitInsert(Insert insert, Optional scope) // analyze the query that creates the data Scope queryScope = process(insert.getQuery(), scope); - analysis.setUpdateType("INSERT"); + analysis.setUpdateInfo(insert.getUpdateInfo()); TableColumnMetadata tableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, metadataHandle, targetTable); // verify the insert destination columns match the query @@ -438,7 +510,7 @@ protected Scope visitInsert(Insert insert, Optional scope) if (insert.getColumns().isPresent()) { insertColumns = insert.getColumns().get().stream() .map(Identifier::getValue) - .map(column -> column.toLowerCase(ENGLISH)) + .map(column -> metadata.normalizeIdentifier(session, targetTable.getCatalogName(), column)) .collect(toImmutableList()); Set columnNames = new HashSet<>(); @@ -466,6 +538,20 @@ protected Scope visitInsert(Insert insert, Optional scope) tableColumnsMetadata.getTableHandle().get(), insertColumns.stream().map(columnHandles::get).collect(toImmutableList()))); + List types = queryScope.getRelationType().getVisibleFields().stream() + .map(Field::getType) + .collect(toImmutableList()); + + Stream columnStream = Streams.zip( + insertColumns.stream(), + types.stream() + .map(Type::toString), + Column::new); + + analysis.setUpdatedSourceColumns(Optional.of(Streams.zip( + columnStream, queryScope.getRelationType().getVisibleFields().stream(), (column, field) -> new OutputColumnMetadata(column.getName(), column.getType(), analysis.getSourceColumns(field))) + .collect(toImmutableList()))); + return createAndAssignScope(insert, scope, Field.newUnqualified(insert.getLocation(), "rows", BIGINT)); } @@ -588,7 +674,7 @@ private void checkTypesMatchForNestedStructs( protected Scope visitDelete(Delete node, Optional scope) { Table table = node.getTable(); - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); MetadataHandle metadataHandle = analysis.getMetadataHandle(); if (getViewDefinition(session, metadataResolver, metadataHandle, tableName).isPresent()) { @@ -612,7 +698,7 @@ protected Scope visitDelete(Delete node, Optional scope) Scope tableScope = analyzer.analyze(table, scope); node.getWhere().ifPresent(where -> analyzeWhere(node, tableScope, where)); - analysis.setUpdateType("DELETE"); + analysis.setUpdateInfo(node.getUpdateInfo()); analysis.addAccessControlCheckForTable(TABLE_DELETE, new AccessControlInfoForTable(accessControl, session.getIdentity(), session.getTransactionId(), session.getAccessControlContext(), tableName)); @@ -632,8 +718,8 @@ protected Scope visitDelete(Delete node, Optional scope) @Override protected Scope visitAnalyze(Analyze node, Optional scope) { - analysis.setUpdateType("ANALYZE"); - QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, node, node.getTableName(), metadata); + analysis.setUpdateInfo(node.getUpdateInfo()); MetadataHandle metadataHandle = analysis.getMetadataHandle(); // verify the target table exists, and it's not a view @@ -667,10 +753,10 @@ protected Scope visitAnalyze(Analyze node, Optional scope) @Override protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional scope) { - analysis.setUpdateType("CREATE TABLE"); + analysis.setUpdateInfo(node.getUpdateInfo()); // turn this into a query that has a new table writer node on top. - QualifiedObjectName targetTable = createQualifiedObjectName(session, node, node.getName()); + QualifiedObjectName targetTable = createQualifiedObjectName(session, node, node.getName(), metadata); analysis.setCreateTableDestination(targetTable); if (metadataResolver.tableExists(targetTable)) { @@ -694,29 +780,41 @@ protected Scope visitCreateTableAsSelect(CreateTableAsSelect node, Optional outputColumns = ImmutableList.builder(); + if (node.getColumnAliases().isPresent()) { validateColumnAliases(node.getColumnAliases().get(), queryScope.getRelationType().getVisibleFieldCount()); - + int aliasPosition = 0; // analyze only column types in subquery if column alias exists for (Field field : queryScope.getRelationType().getVisibleFields()) { if (field.getType().equals(UNKNOWN)) { throw new SemanticException(COLUMN_TYPE_UNKNOWN, node, "Column type is unknown at position %s", queryScope.getRelationType().indexOf(field) + 1); } + String columnName = node.getColumnAliases().get().get(aliasPosition).getValue(); + outputColumns.add(new OutputColumnMetadata(columnName, field.getType().toString(), analysis.getSourceColumns(field))); + aliasPosition++; } } else { validateColumns(node, queryScope.getRelationType()); + queryScope.getRelationType().getVisibleFields().stream() + .map(this::createOutputColumn) + .forEach(outputColumns::add); } - + analysis.setUpdatedSourceColumns(Optional.of(outputColumns.build())); return createAndAssignScope(node, scope, Field.newUnqualified(node.getLocation(), "rows", BIGINT)); } + private OutputColumnMetadata createOutputColumn(Field field) + { + return new OutputColumnMetadata(field.getName().get(), field.getType().toString(), analysis.getSourceColumns(field)); + } + @Override protected Scope visitCreateView(CreateView node, Optional scope) { - analysis.setUpdateType("CREATE VIEW"); - - QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName()); + QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName(), metadata); + analysis.setUpdateInfo(node.getUpdateInfo()); // analyze the query that creates the view StatementAnalyzer analyzer = new StatementAnalyzer(analysis, metadata, sqlParser, accessControl, session, warningCollector); @@ -734,9 +832,8 @@ protected Scope visitCreateView(CreateView node, Optional scope) @Override protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optional scope) { - analysis.setUpdateType("CREATE MATERIALIZED VIEW"); - - QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName()); + QualifiedObjectName viewName = createQualifiedObjectName(session, node, node.getName(), metadata); + analysis.setUpdateInfo(node.getUpdateInfo()); analysis.setCreateTableDestination(viewName); if (metadataResolver.tableExists(viewName)) { @@ -766,18 +863,29 @@ protected Scope visitCreateMaterializedView(CreateMaterializedView node, Optiona @Override protected Scope visitRefreshMaterializedView(RefreshMaterializedView node, Optional scope) { - analysis.setUpdateType("INSERT"); - - QualifiedObjectName viewName = createQualifiedObjectName(session, node.getTarget(), node.getTarget().getName()); + QualifiedObjectName viewName = createQualifiedObjectName(session, node.getTarget(), node.getTarget().getName(), metadata); + analysis.setUpdateInfo(node.getUpdateInfo()); MaterializedViewDefinition view = getMaterializedViewDefinition(session, metadataResolver, analysis.getMetadataHandle(), viewName) .orElseThrow(() -> new SemanticException(MISSING_MATERIALIZED_VIEW, node, "Materialized view '%s' does not exist", viewName)); // the original refresh statement will always be one line analysis.setExpandedQuery(format("-- Expanded Query: %s%nINSERT INTO %s %s", - SqlFormatterUtil.getFormattedSql(node, sqlParser, Optional.empty()), + getFormattedSql(node, sqlParser, Optional.empty()), viewName.getObjectName(), view.getOriginalSql())); + + if (!isLegacyMaterializedViews(session)) { + analysis.addAccessControlCheckForTable( + TABLE_DELETE, + new AccessControlInfoForTable( + accessControl, + getOwnerIdentity(view.getOwner(), session), + session.getTransactionId(), + session.getAccessControlContext(), + viewName)); + } + analysis.addAccessControlCheckForTable( TABLE_INSERT, new AccessControlInfoForTable( @@ -790,11 +898,13 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView node, Optio // Use AllowAllAccessControl; otherwise Analyzer will check SELECT permission on the materialized view, which is not necessary. StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), session, warningCollector); Scope viewScope = viewAnalyzer.analyze(node.getTarget(), scope); - Map tablePredicates = extractTablePredicates(viewName, node.getWhere(), viewScope, metadata, session); + + Map tablePredicates = getTablePredicatesForMaterializedViewRefresh( + session, node, viewName, viewScope, metadata); Query viewQuery = parseView(view.getOriginalSql(), viewName, node); Query refreshQuery = tablePredicates.containsKey(toSchemaTableName(viewName)) ? - buildQueryWithPredicate(viewQuery, tablePredicates.get(toSchemaTableName(viewName))) + buildSubqueryWithPredicate(viewQuery, tablePredicates.get(toSchemaTableName(viewName))) : viewQuery; // Check if the owner has SELECT permission on the base tables StatementAnalyzer queryAnalyzer = new StatementAnalyzer( @@ -823,21 +933,46 @@ protected Scope visitRefreshMaterializedView(RefreshMaterializedView node, Optio return createAndAssignScope(node, scope, Field.newUnqualified(node.getLocation(), "rows", BIGINT)); } + private Map analyzeAutoRefreshMaterializedView( + RefreshMaterializedView node, + QualifiedObjectName viewName) + { + MaterializedViewStatus viewStatus = metadataResolver.getMaterializedViewStatus(viewName, TupleDomain.all()); + Map missingPartitionsPerTable = + viewStatus.getPartitionsFromBaseTables(); + + if (viewStatus.isFullyMaterialized() || missingPartitionsPerTable.isEmpty()) { + warningCollector.add(new PrestoWarning(SEMANTIC_WARNING, + format("Materialized view %s is already fully refreshed", viewName))); + return ImmutableMap.of(); + } + if ((viewStatus.isNotMaterialized() || viewStatus.isTooManyPartitionsMissing()) && + !SystemSessionProperties.isMaterializedViewAllowFullRefreshEnabled(session)) { + throw new PrestoException(MV_MISSING_TOO_MUCH_DATA, + format("%s misses too many partitions or is never refreshed and may incur high cost. " + + "Consider refreshing with predicates first.", viewName.toString())); + } + + return MaterializedViewUtils.generatePredicatesForMissingPartitions(missingPartitionsPerTable, metadata); + } + private Optional analyzeBaseTableForRefreshMaterializedView(Table baseTable, Optional scope) { checkState(analysis.getStatement() instanceof RefreshMaterializedView, "Not analyzing RefreshMaterializedView statement"); RefreshMaterializedView refreshMaterializedView = (RefreshMaterializedView) analysis.getStatement(); - QualifiedObjectName viewName = createQualifiedObjectName(session, refreshMaterializedView.getTarget(), refreshMaterializedView.getTarget().getName()); + QualifiedObjectName viewName = createQualifiedObjectName(session, refreshMaterializedView.getTarget(), refreshMaterializedView.getTarget().getName(), metadata); // Use AllowAllAccessControl; otherwise Analyzer will check SELECT permission on the materialized view, which is not necessary. StatementAnalyzer viewAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, new AllowAllAccessControl(), session, warningCollector); Scope viewScope = viewAnalyzer.analyze(refreshMaterializedView.getTarget(), scope); - Map tablePredicates = extractTablePredicates(viewName, refreshMaterializedView.getWhere(), viewScope, metadata, session); - SchemaTableName baseTableName = toSchemaTableName(createQualifiedObjectName(session, baseTable, baseTable.getName())); + Map tablePredicates = getTablePredicatesForMaterializedViewRefresh( + session, refreshMaterializedView, viewName, viewScope, metadata); + + SchemaTableName baseTableName = toSchemaTableName(createQualifiedObjectName(session, baseTable, baseTable.getName(), metadata)); if (tablePredicates.containsKey(baseTableName)) { - Query tableSubquery = buildQueryWithPredicate(baseTable, tablePredicates.get(baseTableName)); + Query tableSubquery = buildTableQueryWithPredicate(baseTable, tablePredicates.get(baseTableName)); analysis.registerNamedQuery(baseTable, tableSubquery, true); Scope subqueryScope = process(tableSubquery, scope); @@ -848,23 +983,51 @@ private Optional analyzeBaseTableForRefreshMaterializedView(Table return Optional.empty(); } - private Query buildQueryWithPredicate(Table table, Expression predicate) + private Map getTablePredicatesForMaterializedViewRefresh( + Session session, + RefreshMaterializedView node, + QualifiedObjectName viewName, + Scope viewScope, + Metadata metadata) + { + if (isLegacyMaterializedViews(session)) { + // There are some duplicated logic for where.isPresent condition across existing/rfc approaches, but let's keep them separate + // for now so it's cleaner for the two paths + if (!node.getWhere().isPresent()) { + return analyzeAutoRefreshMaterializedView(node, viewName); + } + else { + return extractTablePredicates(viewName, node.getWhere().get(), viewScope, metadata, session); + } + } + else { + if (node.getWhere().isPresent()) { + throw new SemanticException(NOT_SUPPORTED, node, "WHERE clause in REFRESH MATERIALIZED VIEW is not supported. " + + "Connectors automatically determine which data needs refreshing based on staleness detection."); + } + return ImmutableMap.of(); + } + } + + private Query buildTableQueryWithPredicate(Table table, Expression predicate) { Query query = simpleQuery(selectList(new AllColumns()), table, predicate); - return (Query) sqlParser.createStatement( - SqlFormatterUtil.getFormattedSql(query, sqlParser, Optional.empty()), - createParsingOptions(session, warningCollector)); + String formattedSql = formatSql(query, Optional.empty()); + return (Query) sqlParser.createStatement(formattedSql, createParsingOptions(session, warningCollector)); } - private Query buildQueryWithPredicate(Query originalQuery, Expression predicate) + private Query buildSubqueryWithPredicate(Query originalQuery, Expression predicate) { - return simpleQuery(selectList(new AllColumns()), new TableSubquery(originalQuery), predicate); + Query query = simpleQuery(selectList(new AllColumns()), new TableSubquery(originalQuery), predicate); + return (Query) sqlParser.createStatement( + getFormattedSql(query, sqlParser, Optional.empty()), + createParsingOptions(session, warningCollector)); } @Override protected Scope visitCreateFunction(CreateFunction node, Optional scope) { - analysis.setUpdateType("CREATE FUNCTION"); + analysis.setUpdateInfo(node.getUpdateInfo()); // Check function name checkFunctionName(node, node.getFunctionName(), node.isTemporary()); @@ -941,12 +1104,14 @@ protected Scope visitResetSession(ResetSession node, Optional scope) @Override protected Scope visitAddColumn(AddColumn node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitCreateSchema(CreateSchema node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); validateProperties(node.getProperties(), scope); return createAndAssignScope(node, scope); } @@ -954,18 +1119,21 @@ protected Scope visitCreateSchema(CreateSchema node, Optional scope) @Override protected Scope visitDropSchema(DropSchema node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitRenameSchema(RenameSchema node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitCreateTable(CreateTable node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); validateProperties(node.getProperties(), scope); return createAndAssignScope(node, scope); } @@ -982,18 +1150,21 @@ protected Scope visitProperty(Property node, Optional scope) @Override protected Scope visitTruncateTable(TruncateTable node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitDropTable(DropTable node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitRenameTable(RenameTable node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @@ -1006,11 +1177,25 @@ protected Scope visitSetProperties(SetProperties node, Optional scope) @Override protected Scope visitRenameColumn(RenameColumn node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitDropColumn(DropColumn node, Optional scope) + { + analysis.setUpdateInfo(node.getUpdateInfo()); + return createAndAssignScope(node, scope); + } + + @Override + protected Scope visitDropBranch(DropBranch node, Optional scope) + { + return createAndAssignScope(node, scope); + } + + @Override + protected Scope visitDropTag(DropTag node, Optional scope) { return createAndAssignScope(node, scope); } @@ -1018,36 +1203,42 @@ protected Scope visitDropColumn(DropColumn node, Optional scope) @Override protected Scope visitDropConstraint(DropConstraint node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitAddConstraint(AddConstraint node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitAlterColumnNotNull(AlterColumnNotNull node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitRenameView(RenameView node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitDropView(DropView node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @Override protected Scope visitDropMaterializedView(DropMaterializedView node, Optional scope) { + analysis.setUpdateInfo(node.getUpdateInfo()); return createAndAssignScope(node, scope); } @@ -1100,9 +1291,78 @@ protected Scope visitRevoke(Revoke node, Optional scope) } @Override - protected Scope visitCall(Call node, Optional scope) + protected Scope visitCall(Call call, Optional scope) { - return createAndAssignScope(node, scope); + if (analysis.isDescribe()) { + return createAndAssignScope(call, scope); + } + Optional procedureNameOptional = analysis.getProcedureName(); + QualifiedObjectName procedureName; + if (!procedureNameOptional.isPresent()) { + procedureName = createQualifiedObjectName(session, call, call.getName(), metadata); + analysis.setProcedureName(Optional.of(procedureName)); + } + else { + procedureName = procedureNameOptional.get(); + } + ConnectorId connectorId = metadata.getCatalogHandle(session, procedureName.getCatalogName()) + .orElseThrow(() -> new SemanticException(MISSING_CATALOG, call, "Catalog %s does not exist", procedureName.getCatalogName())); + + if (!metadata.getProcedureRegistry().isDistributedProcedure(connectorId, toSchemaTableName(procedureName))) { + throw new SemanticException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + procedureName); + } + DistributedProcedure procedure = metadata.getProcedureRegistry().resolveDistributed(connectorId, toSchemaTableName(procedureName)); + Object[] values = extractParameterValuesInOrder(call, procedure, metadata, session, analysis.getParameters()); + accessControl.checkCanCallProcedure(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), procedureName); + + analysis.setUpdateInfo(call.getUpdateInfo()); + analysis.setDistributedProcedureType(Optional.of(procedure.getType())); + analysis.setProcedureArguments(Optional.of(values)); + switch (procedure.getType()) { + case TABLE_DATA_REWRITE: + TableDataRewriteDistributedProcedure tableDataRewriteDistributedProcedure = (TableDataRewriteDistributedProcedure) procedure; + QualifiedName qualifiedName = QualifiedName.of(tableDataRewriteDistributedProcedure.getSchema(values), tableDataRewriteDistributedProcedure.getTableName(values)); + QualifiedObjectName tableName = createQualifiedObjectName(session, call, qualifiedName, metadata); + + analysis.addAccessControlCheckForTable( + TABLE_INSERT, + new AccessControlInfoForTable( + accessControl, + session.getIdentity(), + session.getTransactionId(), + session.getAccessControlContext(), + tableName)); + analysis.addAccessControlCheckForTable( + TABLE_DELETE, + new AccessControlInfoForTable( + accessControl, + session.getIdentity(), + session.getTransactionId(), + session.getAccessControlContext(), + tableName)); + + String filter = tableDataRewriteDistributedProcedure.getFilter(values); + Expression filterExpression = sqlParser.createExpression(filter); + QuerySpecification querySpecification = new QuerySpecification( + selectList(new AllColumns()), + Optional.of(new Table(qualifiedName)), + Optional.of(filterExpression), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + analyze(querySpecification, scope); + analysis.setTargetQuery(querySpecification); + + TableHandle tableHandle = metadata.getHandleVersion(session, tableName, Optional.empty()) + .orElseThrow(() -> (new SemanticException(MISSING_TABLE, call, "Table '%s' does not exist", tableName))); + analysis.setCallTarget(tableHandle); + break; + default: + throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Unsupported distributed procedure type: " + procedure.getType()); + } + return createAndAssignScope(call, scope, Field.newUnqualified(Optional.empty(), "rows", BIGINT)); } private void validateProperties(List properties, Optional scope) @@ -1159,7 +1419,7 @@ private void validateColumnAliases(List columnAliases, int sourceCol private void validateBaseTables(List

baseTables, Node node) { for (Table baseTable : baseTables) { - QualifiedObjectName baseName = createQualifiedObjectName(session, baseTable, baseTable.getName()); + QualifiedObjectName baseName = createQualifiedObjectName(session, baseTable, baseTable.getName(), metadata); Optional optionalMaterializedView = getMaterializedViewDefinition(session, metadataResolver, analysis.getMetadataHandle(), baseName); if (optionalMaterializedView.isPresent()) { @@ -1181,7 +1441,7 @@ protected Scope visitExplain(Explain node, Optional scope) .filter(option -> option instanceof ExplainFormat) .map(ExplainFormat.class::cast) .map(ExplainFormat::getType) - .collect(Collectors.toList()); + .collect(toImmutableList()); checkState(formats.size() <= 1, "only a single format option is supported in EXPLAIN ANALYZE"); formats.stream().findFirst().ifPresent(format -> checkState(format.equals(TEXT) || format.equals(JSON), "only TEXT and JSON formats are supported in EXPLAIN ANALYZE")); @@ -1192,7 +1452,7 @@ protected Scope visitExplain(Explain node, Optional scope) .map(ExplainType::getType) .orElse(DISTRIBUTED).equals(DISTRIBUTED), "only DISTRIBUTED type is supported in EXPLAIN ANALYZE"); process(node.getStatement(), scope); - analysis.setUpdateType(null); + analysis.setUpdateInfo(null); return createAndAssignScope(node, scope, Field.newUnqualified(node.getLocation(), "Query Plan", VARCHAR)); } @@ -1256,7 +1516,7 @@ else if (expressionType instanceof MapType) { outputFields.add(Field.newUnqualified(expression.getLocation(), Optional.empty(), ((MapType) expressionType).getValueType())); } else { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Cannot unnest type: " + expressionType); + throw new PrestoException(StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Cannot unnest type: " + expressionType); } } if (node.isWithOrdinality()) { @@ -1273,6 +1533,587 @@ protected Scope visitLateral(Lateral node, Optional scope) return createAndAssignScope(node, scope, queryScope.getRelationType()); } + @Override + protected Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional scope) + { + TableFunctionMetadata tableFunctionMetadata = metadata.getFunctionAndTypeManager() + .getTableFunctionRegistry() + .resolve(session, node.getName()) + .orElseThrow(() -> new SemanticException( + FUNCTION_NOT_FOUND, + node, + "Table function %s not registered", + node.getName())); + + ConnectorTableFunction function = tableFunctionMetadata.getFunction(); + ConnectorId connectorId = tableFunctionMetadata.getConnectorId(); + + ArgumentsAnalysis argumentsAnalysis = analyzeArguments(node, function.getArguments(), scope); + + TransactionManager transactionManager = metadata.getFunctionAndTypeManager().getTransactionManager(); + CatalogMetadata registrationCatalogMetadata = transactionManager.getOptionalCatalogMetadata(session.getRequiredTransactionId(), connectorId.getCatalogName()).orElseThrow(() -> new IllegalStateException("Missing catalog metadata")); + // a call to getRequiredCatalogHandle() is necessary so that the catalog is recorded by the TransactionManager + ConnectorTransactionHandle transactionHandle = transactionManager.getConnectorTransaction( + session.getRequiredTransactionId(), registrationCatalogMetadata.getConnectorId()); + + TableFunctionAnalysis functionAnalysis = function.analyze(session.toConnectorSession(connectorId), transactionHandle, argumentsAnalysis.getPassedArguments()); + List> copartitioningLists = analyzeCopartitioning(node.getCopartitioning(), argumentsAnalysis.getTableArgumentAnalyses()); + + // determine the result relation type per SQL standard ISO/IEC 9075-2, 4.33 SQL-invoked routines, p. 123, 413, 414 + ReturnTypeSpecification returnTypeSpecification = function.getReturnTypeSpecification(); + if (returnTypeSpecification == GENERIC_TABLE || !argumentsAnalysis.getTableArgumentAnalyses().isEmpty()) { + analysis.addPolymorphicTableFunction(node); + } + Optional analyzedProperColumnsDescriptor = functionAnalysis.getReturnedType(); + Descriptor properColumnsDescriptor = verifyProperColumnsDescriptor(node, function, returnTypeSpecification, analyzedProperColumnsDescriptor); + + Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() + .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); + verifyRequiredColumns(node, functionAnalysis.getRequiredColumns(), tableArgumentsByName); + + // The result relation type of a table function consists of: + // 1. columns created by the table function, called the proper columns. + // 2. passed columns from input tables: + // - for tables with the "pass through columns" option, these are all columns of the table, + // - for tables without the "pass through columns" option, these are the partitioning columns of the table, if any. + ImmutableList.Builder fields = ImmutableList.builder(); + + // proper columns first + if (properColumnsDescriptor != null) { + properColumnsDescriptor.getFields().stream() + // per spec, field names are mandatory. We support anonymous fields. + .map(field -> Field.newUnqualified(Optional.empty(), field.getName(), field.getType().orElseThrow(() -> new IllegalStateException("missing returned type for proper field")))) + .forEach(fields::add); + } + + // next, columns derived from table arguments, in order of argument declarations + List tableArgumentNames = function.getArguments().stream() + .filter(argumentSpecification -> argumentSpecification instanceof TableArgumentSpecification) + .map(ArgumentSpecification::getName) + .collect(toImmutableList()); + + // table arguments in order of argument declarations + ImmutableList.Builder orderedTableArguments = ImmutableList.builder(); + + for (String name : tableArgumentNames) { + TableArgumentAnalysis argument = tableArgumentsByName.get(name); + orderedTableArguments.add(argument); + Scope argumentScope = analysis.getScope(argument.getRelation()); + if (argument.isPassThroughColumns()) { + argumentScope.getRelationType().getAllFields().stream() + .forEach(fields::add); + } + else if (argument.getPartitionBy().isPresent()) { + argument.getPartitionBy().get().stream() + .map(expression -> validateAndGetInputField(expression, argumentScope)) + .forEach(fields::add); + } + } + + analysis.setTableFunctionAnalysis(node, new TableFunctionInvocationAnalysis( + connectorId, + function.getName(), + argumentsAnalysis.getPassedArguments(), + orderedTableArguments.build(), + functionAnalysis.getRequiredColumns(), + copartitioningLists, + properColumnsDescriptor == null ? 0 : properColumnsDescriptor.getFields().size(), + functionAnalysis.getHandle(), + transactionHandle)); + + return createAndAssignScope(node, scope, fields.build()); + } + + private void verifyRequiredColumns(TableFunctionInvocation node, Map> requiredColumns, Map tableArgumentsByName) + { + Set allInputs = ImmutableSet.copyOf(tableArgumentsByName.keySet()); + requiredColumns.forEach((name, columns) -> { + if (!allInputs.contains(name)) { + throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Table function %s specifies required columns from table argument %s which cannot be found", node.getName(), name); + } + if (columns.isEmpty()) { + throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Table function %s specifies empty list of required columns from table argument %s", node.getName(), name); + } + // the scope is recorded, because table arguments are already analyzed + Scope inputScope = analysis.getScope(tableArgumentsByName.get(name).getRelation()); + columns.stream() + .filter(column -> column < 0 || column >= inputScope.getRelationType().getVisibleFieldCount()) + .findFirst() + .ifPresent(column -> { + throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); + }); + }); + Set requiredInputs = ImmutableSet.copyOf(requiredColumns.keySet()); + allInputs.stream() + .filter(input -> !requiredInputs.contains(input)) + .findFirst() + .ifPresent(input -> { + throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Table function %s does not specify required input columns from table argument %s", node.getName(), input); + }); + } + + private Descriptor verifyProperColumnsDescriptor(TableFunctionInvocation node, ConnectorTableFunction function, ReturnTypeSpecification returnTypeSpecification, Optional analyzedProperColumnsDescriptor) + { + switch (returnTypeSpecification.getReturnType()) { + case ReturnTypeSpecification.OnlyPassThrough.returnType: + if (analysis.isAliased(node)) { + // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, + // table alias is prohibited for a table function with ONLY PASS THROUGH returned type. + throw new SemanticException(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, node, "Alias specified for table function with ONLY PASS THROUGH return type"); + } + if (analyzedProperColumnsDescriptor.isPresent()) { + // If a table function has ONLY PASS THROUGH returned type, it does not produce any proper columns, + // so the function's analyze() method should not return the proper columns descriptor. + throw new SemanticException(TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); + } + if (function.getArguments().stream() + .filter(TableArgumentSpecification.class::isInstance) + .map(TableArgumentSpecification.class::cast) + .noneMatch(TableArgumentSpecification::isPassThroughColumns)) { + // According to SQL standard ISO/IEC 9075-2, 10.4 , p. 764, + // if there is no generic table parameter that specifies PASS THROUGH, then number of proper columns shall be positive. + // For GENERIC_TABLE and DescribedTable returned types, this is enforced by the Descriptor constructor, which requires positive number of fields. + // Here we enforce it for the remaining returned type specification: ONLY_PASS_THROUGH. + throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "A table function with ONLY_PASS_THROUGH return type must have a table argument with pass-through columns."); + } + return null; + case ReturnTypeSpecification.GenericTable.returnType: + // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, + // table alias is mandatory for a polymorphic table function invocation which produces proper columns. + // We don't enforce this requirement. + return analyzedProperColumnsDescriptor + .orElseThrow(() -> new SemanticException(TABLE_FUNCTION_MISSING_RETURN_TYPE, node, "Cannot determine returned relation type for table function " + node.getName())); + default: + // returned type is statically declared at function declaration and cannot be overridden + // According to SQL standard ISO/IEC 9075-2, 7.6
, p. 409, + // table alias is mandatory for a polymorphic table function invocation which produces proper columns. + // We don't enforce this requirement. + if (analyzedProperColumnsDescriptor.isPresent()) { + // If a table function has statically declared returned type, it is returned in TableFunctionMetadata + // so the function's analyze() method should not return the proper columns descriptor. + throw new SemanticException(TABLE_FUNCTION_AMBIGUOUS_RETURN_TYPE, node, "Returned relation type for table function %s is ambiguous", node.getName()); + } + return ((ReturnTypeSpecification.DescribedTable) returnTypeSpecification).getDescriptor(); + } + } + + private ArgumentsAnalysis analyzeArguments(TableFunctionInvocation node, List argumentSpecifications, Optional scope) + { + List arguments = node.getArguments(); + Node errorLocation = node; + if (!arguments.isEmpty()) { + errorLocation = arguments.get(0); + } + if (argumentSpecifications.size() < arguments.size()) { + throw new SemanticException(TABLE_FUNCTION_INVALID_ARGUMENTS, errorLocation, "Too many arguments. Expected at most %s arguments, got %s arguments", argumentSpecifications.size(), arguments.size()); + } + + if (argumentSpecifications.isEmpty()) { + return new ArgumentsAnalysis(ImmutableMap.of(), ImmutableList.of()); + } + + boolean argumentsPassedByName = !arguments.isEmpty() && arguments.stream().allMatch(argument -> argument.getName().isPresent()); + boolean argumentsPassedByPosition = arguments.stream().allMatch(argument -> !argument.getName().isPresent()); + if (!argumentsPassedByName && !argumentsPassedByPosition) { + throw new SemanticException(TABLE_FUNCTION_INVALID_ARGUMENTS, errorLocation, "All arguments must be passed by name or all must be passed positionally"); + } + + if (argumentsPassedByName) { + return mapTableFunctionsArgsByName(argumentSpecifications, arguments, errorLocation, scope); + } + else { + return mapTableFunctionArgsByPosition(argumentSpecifications, arguments, errorLocation, scope); + } + } + + private ArgumentsAnalysis mapTableFunctionsArgsByName(List argumentSpecifications, List arguments, Node errorLocation, Optional scope) + { + ImmutableMap.Builder passedArguments = ImmutableMap.builder(); + ImmutableList.Builder tableArgumentAnalyses = ImmutableList.builder(); + Map argumentSpecificationsByName = new HashMap<>(); + for (ArgumentSpecification argumentSpecification : argumentSpecifications) { + if (argumentSpecificationsByName.put(argumentSpecification.getName(), argumentSpecification) != null) { + // this should never happen, because the argument names are validated at function registration time + throw new IllegalStateException("Duplicate argument specification for name: " + argumentSpecification.getName()); + } + } + Set uniqueArgumentNames = new HashSet<>(); + for (TableFunctionArgument argument : arguments) { + String argumentName = argument.getName().orElseThrow(() -> new IllegalStateException("Missing table function argument name")).getCanonicalValue(); + if (!uniqueArgumentNames.add(argumentName)) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Duplicate argument name: %s", argumentName); + } + ArgumentSpecification argumentSpecification = argumentSpecificationsByName.remove(argumentName); + if (argumentSpecification == null) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Unexpected argument name: %s", argumentName); + } + ArgumentAnalysis argumentAnalysis = analyzeArgument(argumentSpecification, argument, scope); + passedArguments.put(argumentSpecification.getName(), argumentAnalysis.getArgument()); + argumentAnalysis.getTableArgumentAnalysis().ifPresent(tableArgumentAnalyses::add); + } + // apply defaults for not specified arguments + for (Map.Entry entry : argumentSpecificationsByName.entrySet()) { + ArgumentSpecification argumentSpecification = entry.getValue(); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + return new ArgumentsAnalysis(passedArguments.buildOrThrow(), tableArgumentAnalyses.build()); + } + + private ArgumentsAnalysis mapTableFunctionArgsByPosition(List argumentSpecifications, List arguments, Node errorLocation, Optional scope) + { + ImmutableMap.Builder passedArguments = ImmutableMap.builder(); + ImmutableList.Builder tableArgumentAnalyses = ImmutableList.builder(); + for (int i = 0; i < arguments.size(); i++) { + TableFunctionArgument argument = arguments.get(i); + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); // TODO args passed positionally - can one only pass some prefix of args? + ArgumentAnalysis argumentAnalysis = analyzeArgument(argumentSpecification, argument, scope); + passedArguments.put(argumentSpecification.getName(), argumentAnalysis.getArgument()); + argumentAnalysis.getTableArgumentAnalysis().ifPresent(tableArgumentAnalyses::add); + } + // apply defaults for not specified arguments + for (int i = arguments.size(); i < argumentSpecifications.size(); i++) { + ArgumentSpecification argumentSpecification = argumentSpecifications.get(i); + passedArguments.put(argumentSpecification.getName(), analyzeDefault(argumentSpecification, errorLocation)); + } + return new ArgumentsAnalysis(passedArguments.buildOrThrow(), tableArgumentAnalyses.build()); + } + + private ArgumentAnalysis analyzeArgument(ArgumentSpecification argumentSpecification, TableFunctionArgument argument, Optional scope) + { + String actualType = getArgumentTypeString(argument); + switch (argumentSpecification.getArgumentType()) { + case TableArgumentSpecification.argumentType: + return analyzeTableArgument(argument, (TableArgumentSpecification) argumentSpecification, scope, actualType); + case DescriptorArgumentSpecification.argumentType: + return analyzeDescriptorArgument(argument, (DescriptorArgumentSpecification) argumentSpecification, actualType); + case ScalarArgumentSpecification.argumentType: + return analyzeScalarArgument(argument, argumentSpecification, actualType); + default: + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + } + + private Argument analyzeDefault(ArgumentSpecification argumentSpecification, Node errorLocation) + { + if (argumentSpecification.isRequired()) { + throw new SemanticException(TABLE_FUNCTION_MISSING_ARGUMENT, errorLocation, "Missing argument: " + argumentSpecification.getName()); + } + + checkArgument(!(argumentSpecification instanceof TableArgumentSpecification), "invalid table argument specification: default set"); + + if (argumentSpecification instanceof DescriptorArgumentSpecification) { + return DescriptorArgument.builder() + .descriptor((Descriptor) argumentSpecification.getDefaultValue()) + .build(); + } + if (argumentSpecification instanceof ScalarArgumentSpecification) { + return ScalarArgument.builder() + .type(((ScalarArgumentSpecification) argumentSpecification).getType()) + .value(argumentSpecification.getDefaultValue()) + .build(); + } + + throw new IllegalStateException("Unexpected argument specification: " + argumentSpecification.getClass().getSimpleName()); + } + + private String getArgumentTypeString(TableFunctionArgument argument) + { + try { + return argument.getValue().getArgumentTypeString(); + } + catch (IllegalArgumentException e) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Unexpected table function argument type: ", argument.getClass().getSimpleName()); + } + } + + private ArgumentAnalysis analyzeScalarArgument(TableFunctionArgument argument, ArgumentSpecification argumentSpecification, String actualType) + { + Type type = ((ScalarArgumentSpecification) argumentSpecification).getType(); + if (!(argument.getValue() instanceof Expression)) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected expression, got %s", argumentSpecification.getName(), actualType); + } + Expression expression = (Expression) argument.getValue(); + // 'descriptor' as a function name is not allowed in this context + if (expression instanceof FunctionCall && ((FunctionCall) expression).getName().hasSuffix(QualifiedName.of("descriptor"))) { // function name is always compared case-insensitive + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "'descriptor' function is not allowed as a table function argument"); + } + // inline parameters + Expression inlined = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { + @Override + public Expression rewriteParameter(Parameter node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (analysis.isDescribe()) { + // We cannot handle DESCRIBE when a table function argument involves a parameter. + // In DESCRIBE, the parameter values are not known. We cannot pass a dummy value for a parameter. + // The value of a table function argument can affect the returned relation type. The returned + // relation type can affect the assumed types for other parameters in the query. + throw new SemanticException(NOT_SUPPORTED, node, "DESCRIBE is not supported if a table function uses parameters"); + } + return analysis.getParameters().get(NodeRef.of(node)); + } + }, expression); + // currently, only constant arguments are supported + Object constantValue = ExpressionInterpreter.evaluateConstantExpression(inlined, type, metadata, session, analysis.getParameters()); + return new ArgumentAnalysis( + ScalarArgument.builder() + .type(type) + .value(constantValue) + .build(), + Optional.empty()); + } + + private ArgumentAnalysis analyzeTableArgument(TableFunctionArgument argument, TableArgumentSpecification argumentSpecification, Optional scope, String actualType) + { + if (!(argument.getValue() instanceof TableFunctionTableArgument)) { + if (argument.getValue() instanceof FunctionCall) { + // probably an attempt to pass a table function call, which is not supported, and was parsed as a function call + throw new SemanticException(NOT_SUPPORTED, argument, "Invalid table argument %s. Table functions are not allowed as table function arguments", argumentSpecification.getName()); + } + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected table, got %s", argumentSpecification.getName(), actualType); + } + TableFunctionTableArgument tableArgument = (TableFunctionTableArgument) argument.getValue(); + + TableArgument.Builder argumentBuilder = TableArgument.builder(); + TableArgumentAnalysis.Builder analysisBuilder = TableArgumentAnalysis.builder(); + analysisBuilder.withArgumentName(argumentSpecification.getName()); + + // process the relation + Relation relation = tableArgument.getTable(); + analysisBuilder.withRelation(relation); + Scope argumentScope = process(relation, scope); + QualifiedName relationName = analysis.getRelationName(relation); + if (relationName != null) { + analysisBuilder.withName(relationName); + } + + argumentBuilder.rowType(RowType.from(argumentScope.getRelationType().getVisibleFields().stream() + .map(field -> new RowType.Field(field.getName(), field.getType())) + .collect(toImmutableList()))); + + // analyze PARTITION BY + if (tableArgument.getPartitionBy().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Partitioning specified for table argument with row semantics", argumentSpecification.getName()); + } + List partitionBy = tableArgument.getPartitionBy().get(); + analysisBuilder.withPartitionBy(partitionBy); + partitionBy.stream() + .forEach(partitioningColumn -> { + validateAndGetInputField(partitioningColumn, argumentScope); + Type type = analyzeExpression(partitioningColumn, argumentScope).getType(partitioningColumn); + if (!type.isComparable()) { + throw new SemanticException(TYPE_MISMATCH, partitioningColumn, "%s is not comparable, and therefore cannot be used in PARTITION BY", type); + } + }); + argumentBuilder.partitionBy(partitionBy.stream() + // each expression is either an Identifier or a DereferenceExpression + .map(Expression::toString) + .collect(toImmutableList())); + } + + // analyze ORDER BY + if (tableArgument.getOrderBy().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Ordering specified for table argument with row semantics", argumentSpecification.getName()); + } + OrderBy orderBy = tableArgument.getOrderBy().get(); + analysisBuilder.withOrderBy(orderBy); + orderBy.getSortItems().stream() + .map(SortItem::getSortKey) + .forEach(orderingColumn -> { + validateAndGetInputField(orderingColumn, argumentScope); + Type type = analyzeExpression(orderingColumn, argumentScope).getType(orderingColumn); + if (!type.isOrderable()) { + throw new SemanticException(TYPE_MISMATCH, orderingColumn, "%s is not orderable, and therefore cannot be used in ORDER BY", type); + } + }); + argumentBuilder.orderBy(orderBy.getSortItems().stream() + // each sort key is either an Identifier or a DereferenceExpression + .map(sortItem -> sortItem.getSortKey().toString()) + .collect(toImmutableList())); + } + + // analyze the PRUNE/KEEP WHEN EMPTY property + boolean pruneWhenEmpty = argumentSpecification.isPruneWhenEmpty(); + if (tableArgument.getEmptyTableTreatment().isPresent()) { + if (argumentSpecification.isRowSemantics()) { + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, tableArgument.getEmptyTableTreatment().get(), "Invalid argument %s. Empty behavior specified for table argument with row semantics", argumentSpecification.getName()); + } + pruneWhenEmpty = tableArgument.getEmptyTableTreatment().get().getTreatment() == EmptyTableTreatment.Treatment.PRUNE; + } + analysisBuilder.withPruneWhenEmpty(pruneWhenEmpty); + + // record remaining properties + analysisBuilder.withRowSemantics(argumentSpecification.isRowSemantics()); + analysisBuilder.withPassThroughColumns(argumentSpecification.isPassThroughColumns()); + + return new ArgumentAnalysis(argumentBuilder.build(), Optional.of(analysisBuilder.build())); + } + + private ArgumentAnalysis analyzeDescriptorArgument(TableFunctionArgument argument, DescriptorArgumentSpecification argumentSpecification, String actualType) + { + if (!(argument.getValue() instanceof TableFunctionDescriptorArgument)) { + if (argument.getValue() instanceof FunctionCall && ((FunctionCall) argument.getValue()).getName().hasSuffix(QualifiedName.of("descriptor"))) { // function name is always compared case-insensitive + // malformed descriptor which parsed as a function call + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid descriptor argument %s. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'", (Object) argumentSpecification.getName()); + } + throw new SemanticException(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, argument, "Invalid argument %s. Expected descriptor, got %s", argumentSpecification.getName(), actualType); + } + return new ArgumentAnalysis( + ((TableFunctionDescriptorArgument) argument.getValue()).getDescriptor() + .map(descriptor -> DescriptorArgument.builder() + .descriptor(new Descriptor(descriptor.getFields().stream() + .map(field -> new Descriptor.Field( + field.getName().getCanonicalValue(), + field.getType().map(type -> { + try { + return functionAndTypeResolver.getType(parseTypeSignature(type)); + } + catch (IllegalArgumentException | UnknownTypeException e) { + throw new SemanticException(TYPE_MISMATCH, field, "Unknown type: %s", type); + } + }))) + .collect(toImmutableList()))) + .build()) + .orElse(NULL_DESCRIPTOR), + Optional.empty()); + } + + private Field validateAndGetInputField(Expression expression, Scope inputScope) + { + QualifiedName qualifiedName; + if (expression instanceof Identifier) { + qualifiedName = QualifiedName.of(ImmutableList.of(((Identifier) expression))); + } + else if (expression instanceof DereferenceExpression) { + qualifiedName = getQualifiedName((DereferenceExpression) expression); + } + else { + throw new SemanticException(TABLE_FUNCTION_INVALID_COLUMN_REFERENCE, expression, "Expected column reference. Actual: %s", expression); + } + Optional field = inputScope.tryResolveField(expression, qualifiedName); + if (!field.isPresent() || !field.get().isLocal()) { + throw new SemanticException(TABLE_FUNCTION_COLUMN_NOT_FOUND, expression, "Column %s is not present in the input relation", expression); + } + + return field.get().getField(); + } + + private List> analyzeCopartitioning(List> copartitioning, List tableArgumentAnalyses) + { + // map table arguments by relation names. usa a multimap, because multiple arguments can have the same value, e.g. input_1 => tpch.tiny.orders, input_2 => tpch.tiny.orders + ImmutableMultimap.Builder unqualifiedInputsBuilder = ImmutableMultimap.builder(); + ImmutableMultimap.Builder qualifiedInputsBuilder = ImmutableMultimap.builder(); + tableArgumentAnalyses.stream() + .filter(argument -> argument.getName().isPresent()) + .forEach(argument -> { + QualifiedName name = argument.getName().get(); + if (name.getParts().size() == 1) { + unqualifiedInputsBuilder.put(name, argument); + } + else if (name.getParts().size() == 3) { + qualifiedInputsBuilder.put(name, argument); + } + else { + throw new IllegalStateException("relation name should be unqualified or fully qualified"); + } + }); + Multimap unqualifiedInputs = unqualifiedInputsBuilder.build(); + Multimap qualifiedInputs = qualifiedInputsBuilder.build(); + + ImmutableList.Builder> copartitionBuilder = ImmutableList.builder(); + Set referencedArguments = new HashSet<>(); + for (List nameList : copartitioning) { + ImmutableList.Builder copartitionListBuilder = ImmutableList.builder(); + + // resolve copartition tables as references to table arguments + for (QualifiedName name : nameList) { + Collection candidates = emptyList(); + if (name.getParts().size() == 1) { + // try to match unqualified name. it might be a reference to a CTE or an aliased relation + candidates = unqualifiedInputs.get(name); + } + if (candidates.isEmpty()) { + // qualify the name using current schema and catalog + // Since we lost the Identifier context, create a new one here + QualifiedObjectName fullyQualifiedName = createQualifiedObjectName(session, new Identifier(name.getOriginalParts().get(0).getValue()), name, metadata); + candidates = qualifiedInputs.get(QualifiedName.of(fullyQualifiedName.getCatalogName(), fullyQualifiedName.getSchemaName(), fullyQualifiedName.getObjectName())); + } + if (candidates.isEmpty()) { + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, name.getOriginalParts().get(0), "No table argument found for name: " + name); + } + if (candidates.size() > 1) { + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, name.getOriginalParts().get(0), "Ambiguous reference: multiple table arguments found for name: " + name); + } + TableArgumentAnalysis argument = getOnlyElement(candidates); + if (!referencedArguments.add(argument.getArgumentName())) { + // multiple references to argument in COPARTITION clause are implicitly prohibited by + // ISO/IEC TR REPORT 19075-7, p.33, Feature B203, “More than one copartition specification” + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, name.getOriginalParts().get(0), "Multiple references to table argument: %s in COPARTITION clause", name); + } + copartitionListBuilder.add(argument); + } + List copartitionList = copartitionListBuilder.build(); + + // analyze partitioning columns + copartitionList.stream() + .filter(argument -> !argument.getPartitionBy().isPresent()) + .findFirst().ifPresent(unpartitioned -> { + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, unpartitioned.getRelation(), "Table %s referenced in COPARTITION clause is not partitioned", unpartitioned.getName().orElseThrow(() -> new IllegalStateException("Missing unpartitioned TableArgumentAnalysis name"))); + }); + // TODO(#26147): make sure that copartitioned tables cannot have empty partitioning lists. + // ISO/IEC TR REPORT 19075-7, 4.5 Partitioning and ordering, p.25 is not clear: "With copartitioning, the copartitioned table arguments must have the same number of partitioning columns, + // and corresponding partitioning columns must be comparable. The DBMS effectively performs a full outer equijoin on the copartitioning columns" + copartitionList.stream() + .filter(argument -> argument.getPartitionBy().orElseThrow(() -> new IllegalStateException("PartitionBy not present in copartitionList")).isEmpty()) + .findFirst().ifPresent(partitionedOnEmpty -> { + // table is partitioned but no partitioning columns are specified (single partition) + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, partitionedOnEmpty.getRelation(), "No partitioning columns specified for table %s referenced in COPARTITION clause", partitionedOnEmpty.getName().orElseThrow(() -> new IllegalStateException("Missing partitionedOnEmpty TableArgumentAnalysis name"))); + }); + List> partitioningColumns = copartitionList.stream() + .map(TableArgumentAnalysis::getPartitionBy) + .map(opt -> opt.orElseThrow(() -> new IllegalStateException("PartitionBy not present in partitioningColumns"))) + .collect(toImmutableList()); + if (partitioningColumns.stream() + .map(List::size) + .distinct() + .count() > 1) { + throw new SemanticException(TABLE_FUNCTION_INVALID_COPARTITIONING, nameList.get(0).getOriginalParts().get(0), "Numbers of partitioning columns in copartitioned tables do not match"); + } + + // coerce corresponding copartition columns to common supertype + for (int index = 0; index < partitioningColumns.get(0).size(); index++) { + Type commonSuperType = analysis.getType(partitioningColumns.get(0).get(index)); + // find common supertype + for (List columnList : partitioningColumns) { + Optional superType = functionAndTypeResolver.getCommonSuperType(commonSuperType, analysis.getType(columnList.get(index))); + if (!superType.isPresent()) { + throw new SemanticException(TYPE_MISMATCH, nameList.get(0).getOriginalParts().get(0), "Partitioning columns in copartitioned tables have incompatible types"); + } + commonSuperType = superType.get(); + } + for (List columnList : partitioningColumns) { + Expression column = columnList.get(index); + Type type = analysis.getType(column); + if (!type.equals(commonSuperType)) { + if (!functionAndTypeResolver.canCoerce(type, commonSuperType)) { + throw new SemanticException(TYPE_MISMATCH, column, "Cannot coerce column of type %s to common supertype: %s", type.getDisplayName(), commonSuperType.getDisplayName()); + } + analysis.addCoercion(column, commonSuperType, functionAndTypeResolver.isTypeOnlyCoercion(type, commonSuperType)); + } + } + } + + // record the resolved copartition arguments by argument names + copartitionBuilder.add(copartitionList.stream() + .map(TableArgumentAnalysis::getArgumentName) + .collect(toImmutableList())); + } + + return copartitionBuilder.build(); + } + @Override protected Scope visitTable(Table table, Optional scope) { @@ -1284,6 +2125,7 @@ protected Scope visitTable(Table table, Optional scope) if (withQuery.isPresent()) { Query query = withQuery.get().getQuery(); analysis.registerNamedQuery(table, query, false); + analysis.setRelationName(table, table.getName()); // re-alias the fields with the name assigned to the query in the WITH declaration RelationType queryDescriptor = analysis.getOutputDescriptor(query); @@ -1297,7 +2139,7 @@ protected Scope visitTable(Table table, Optional scope) Iterator visibleFieldsIterator = queryDescriptor.getVisibleFields().iterator(); for (Identifier columnName : columnNames.get()) { Field inputField = visibleFieldsIterator.next(); - fieldBuilder.add(Field.newQualified( + Field field = Field.newQualified( columnName.getLocation(), QualifiedName.of(name), Optional.of(columnName.getValue()), @@ -1305,30 +2147,37 @@ protected Scope visitTable(Table table, Optional scope) false, inputField.getOriginTable(), inputField.getOriginColumnName(), - inputField.isAliased())); + inputField.isAliased()); + fieldBuilder.add(field); + analysis.addSourceColumns(field, analysis.getSourceColumns(inputField)); } fields = fieldBuilder.build(); } else { - fields = queryDescriptor.getAllFields().stream() - .map(field -> Field.newQualified( - field.getNodeLocation(), - QualifiedName.of(name), - field.getName(), - field.getType(), - field.isHidden(), - field.getOriginTable(), - field.getOriginColumnName(), - field.isAliased())) - .collect(toImmutableList()); + ImmutableList.Builder fieldBuilder = ImmutableList.builder(); + for (Field inputField : queryDescriptor.getAllFields()) { + Field field = Field.newQualified( + inputField.getNodeLocation(), + QualifiedName.of(name), + inputField.getName(), + inputField.getType(), + inputField.isHidden(), + inputField.getOriginTable(), + inputField.getOriginColumnName(), + inputField.isAliased()); + fieldBuilder.add(field); + analysis.addSourceColumns(field, analysis.getSourceColumns(inputField)); + } + fields = fieldBuilder.build(); } return createAndAssignScope(table, scope, fields); } } - QualifiedObjectName name = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName name = createQualifiedObjectName(session, table, table.getName(), metadata); + analysis.setRelationName(table, QualifiedName.of(name.getCatalogName(), name.getSchemaName(), name.getObjectName())); if (name.getObjectName().isEmpty()) { throw new SemanticException(MISSING_TABLE, table, "Table name is empty"); } @@ -1354,15 +2203,30 @@ protected Scope visitTable(Table table, Optional scope) optionalMaterializedView.get().getTable()); } Statement statement = analysis.getStatement(); - if (isMaterializedViewDataConsistencyEnabled(session) && optionalMaterializedView.isPresent() && statement instanceof Query) { - // When the materialized view has already been expanded, do not process it. Just use it as a table. - MaterializedViewAnalysisState materializedViewAnalysisState = analysis.getMaterializedViewAnalysisState(table); + if (optionalMaterializedView.isPresent() && statement instanceof Query) { + if (isMaterializedViewDataConsistencyEnabled(session) || !isLegacyMaterializedViews(session)) { + // When the materialized view has already been expanded, do not process it. Just use it as a table. + MaterializedViewAnalysisState materializedViewAnalysisState = analysis.getMaterializedViewAnalysisState(table); - if (materializedViewAnalysisState.isNotVisited()) { - return processMaterializedView(table, name, scope, optionalMaterializedView.get()); + if (materializedViewAnalysisState.isNotVisited()) { + return processMaterializedView(table, name, scope, optionalMaterializedView.get()); + } + if (materializedViewAnalysisState.isVisited()) { + throw new SemanticException(MATERIALIZED_VIEW_IS_RECURSIVE, table, "Materialized view is recursive"); + } } - if (materializedViewAnalysisState.isVisited()) { - throw new SemanticException(MATERIALIZED_VIEW_IS_RECURSIVE, table, "Materialized view is recursive"); + else { + // when stitching is not enabled, still check permission of each base table + MaterializedViewDefinition materializedViewDefinition = optionalMaterializedView.get(); + analysis.getViewDefinitionReferences().addMaterializedViewDefinitionReference(name, materializedViewDefinition); + + Query viewQuery = (Query) sqlParser.createStatement( + materializedViewDefinition.getOriginalSql(), + createParsingOptions(session, warningCollector)); + + analysis.registerMaterializedViewForAnalysis(name, table, materializedViewDefinition.getOriginalSql()); + process(viewQuery, scope); + analysis.unregisterMaterializedViewForAnalysis(table); } } @@ -1389,6 +2253,17 @@ protected Scope visitTable(Table table, Optional scope) ColumnHandle columnHandle = columnHandles.get(column.getName()); checkArgument(columnHandle != null, "Unknown field %s", field); analysis.setColumn(field, columnHandle); + analysis.addSourceColumns(field, ImmutableSet.of(new SourceColumn(name, column.getName()))); + } + + boolean isMergeIntoStatement = statement instanceof Merge && ((Merge) statement).getTargetTable().equals(table); + if (isMergeIntoStatement) { + // Add the target table row id field used to process the MERGE command. + ColumnHandle targetTableRowIdColumnHandle = metadata.getMergeTargetTableRowIdColumnHandle(session, tableHandle.get()); + Type targetTableRowIdType = metadata.getColumnMetadata(session, tableHandle.get(), targetTableRowIdColumnHandle).getType(); + Field targetTableRowIdField = Field.newUnqualified(table.getLocation(), "$target_table_row_id", targetTableRowIdType); + fields.add(targetTableRowIdField); + analysis.setColumn(targetTableRowIdField, targetTableRowIdColumnHandle); } analysis.registerTable(table, tableHandle.get()); @@ -1415,7 +2290,16 @@ protected Scope visitTable(Table table, Optional scope) } } - return createAndAssignScope(table, scope, outputFields); + Scope tableScope = createAndAssignScope(table, scope, outputFields); + + if (isMergeIntoStatement) { + // Set the target table row id field reference used to process the MERGE command. + FieldReference targetTableRowIdFieldReference = new FieldReference(outputFields.size() - 1); + analyzeExpression(targetTableRowIdFieldReference, tableScope); + analysis.setRowIdField(table, targetTableRowIdFieldReference); + } + + return tableScope; } private Optional getTableHandle(TableColumnMetadata tableColumnsMetadata, Table table, QualifiedObjectName name, Optional scope) @@ -1450,6 +2334,7 @@ private VersionType toVersionType(TableVersionType type) } throw new SemanticException(NOT_SUPPORTED, "Table version type %s not supported." + type); } + private Optional processTableVersion(Table table, QualifiedObjectName name, Optional scope) { Expression stateExpr = table.getTableVersionExpression().get().getStateExpression(); @@ -1459,7 +2344,7 @@ private Optional processTableVersion(Table table, QualifiedObjectNa analysis.recordSubqueries(table, expressionAnalysis); Type stateExprType = expressionAnalysis.getType(stateExpr); if (stateExprType == UNKNOWN) { - throw new PrestoException(INVALID_ARGUMENTS, format("Table version AS OF/BEFORE expression cannot be NULL for %s", name.toString())); + throw new PrestoException(StandardErrorCode.INVALID_ARGUMENTS, format("Table version AS OF/BEFORE expression cannot be NULL for %s", name.toString())); } Object evalStateExpr = evaluateConstantExpression(stateExpr, stateExprType, metadata, session, analysis.getParameters()); if (tableVersionType == TIMESTAMP) { @@ -1483,7 +2368,7 @@ private Optional processTableVersion(Table table, QualifiedObjectNa private Scope getScopeFromTable(Table table, Optional scope) { - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); TableColumnMetadata tableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, analysis.getMetadataHandle(), tableName); // TODO: discover columns lazily based on where they are needed (to support connectors that can't enumerate all tables) @@ -1512,7 +2397,7 @@ private Scope processView(Table table, Optional scope, QualifiedObjectNam Statement statement = analysis.getStatement(); if (statement instanceof CreateView) { CreateView viewStatement = (CreateView) statement; - QualifiedObjectName viewNameFromStatement = createQualifiedObjectName(session, viewStatement, viewStatement.getName()); + QualifiedObjectName viewNameFromStatement = createQualifiedObjectName(session, viewStatement, viewStatement.getName(), metadata); if (viewStatement.isReplace() && viewNameFromStatement.equals(name)) { throw new SemanticException(VIEW_IS_RECURSIVE, table, "Statement would create a recursive view"); } @@ -1522,7 +2407,11 @@ private Scope processView(Table table, Optional scope, QualifiedObjectNam } ViewDefinition view = optionalView.get(); - analysis.getAccessControlReferences().addViewDefinitionReference(name, view); + analysis.getViewDefinitionReferences().addViewDefinitionReference(name, view); + + Optional savedViewAccessorWhereClause = analysis.getCurrentQuerySpecification() + .flatMap(QuerySpecification::getWhere); + savedViewAccessorWhereClause.ifPresent(analysis::setViewAccessorWhereClause); Query query = parseView(view.getOriginalSql(), name, table); @@ -1530,6 +2419,11 @@ private Scope processView(Table table, Optional scope, QualifiedObjectNam analysis.registerTableForView(table); RelationType descriptor = analyzeView(query, name, view.getCatalog(), view.getSchema(), view.getOwner(), table); analysis.unregisterTableForView(); + + if (savedViewAccessorWhereClause.isPresent()) { + analysis.clearViewAccessorWhereClause(); + } + if (isViewStale(view.getColumns(), descriptor.getVisibleFields())) { throw new SemanticException(VIEW_IS_STALE, table, "View '%s' is stale; it must be re-created", name); } @@ -1567,24 +2461,97 @@ private Scope processMaterializedView( { MaterializedViewPlanValidator.validate((Query) sqlParser.createStatement(materializedViewDefinition.getOriginalSql(), createParsingOptions(session, warningCollector))); - analysis.getAccessControlReferences().addMaterializedViewDefinitionReference(materializedViewName, materializedViewDefinition); + analysis.getViewDefinitionReferences().addMaterializedViewDefinitionReference(materializedViewName, materializedViewDefinition); analysis.registerMaterializedViewForAnalysis(materializedViewName, materializedView, materializedViewDefinition.getOriginalSql()); - String newSql = getMaterializedViewSQL(materializedView, materializedViewName, materializedViewDefinition, scope); - Query query = (Query) sqlParser.createStatement(newSql, createParsingOptions(session, warningCollector)); - analysis.registerNamedQuery(materializedView, query, true); + if (isLegacyMaterializedViews(session)) { + // Legacy SQL stitching approach: create UNION query with base tables + String newSql = getMaterializedViewSQL(materializedView, materializedViewName, materializedViewDefinition, scope); - Scope queryScope = process(query, scope); - RelationType relationType = queryScope.getRelationType().withAlias(materializedViewName.getObjectName(), null); - analysis.unregisterMaterializedViewForAnalysis(materializedView); + Query query = (Query) sqlParser.createStatement(newSql, createParsingOptions(session, warningCollector)); + analysis.registerNamedQuery(materializedView, query, true); - Scope accessControlScope = Scope.builder() - .withRelationType(RelationId.anonymous(), relationType) - .build(); - analyzeFiltersAndMasks(materializedView, materializedViewName, accessControlScope, relationType.getAllFields()); + Scope queryScope = process(query, scope); + RelationType relationType = queryScope.getRelationType().withAlias(materializedViewName.getObjectName(), null); + analysis.unregisterMaterializedViewForAnalysis(materializedView); + + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), relationType) + .build(); + analyzeFiltersAndMasks(materializedView, materializedViewName, accessControlScope, relationType.getAllFields()); + + return createAndAssignScope(materializedView, scope, relationType); + } + else { + Query viewQuery = (Query) sqlParser.createStatement( + materializedViewDefinition.getOriginalSql(), + createParsingOptions(session, warningCollector)); + + QualifiedName dataTableName = QualifiedName.of( + materializedViewName.getCatalogName(), + materializedViewDefinition.getSchema(), + materializedViewDefinition.getTable()); + Table dataTable = new Table(dataTableName); + + Analysis.MaterializedViewInfo mvInfo = new Analysis.MaterializedViewInfo( + materializedViewName, + dataTable, + viewQuery, + materializedViewDefinition); + analysis.setMaterializedViewInfo(materializedView, mvInfo); + + // Legacy materialized views are treated as INVOKER rights + ViewSecurity securityMode = materializedViewDefinition.getSecurityMode().orElse(INVOKER); + + Identity queryIdentity; + AccessControl queryAccessControl; + if (securityMode == DEFINER) { + Optional owner = materializedViewDefinition.getOwner(); + if (!owner.isPresent()) { + throw new SemanticException(NOT_SUPPORTED, "Owner must be present for DEFINER security mode"); + } + queryIdentity = new Identity(owner.get(), Optional.empty(), emptyMap(), session.getIdentity().getExtraCredentials(), emptyMap(), Optional.empty(), session.getIdentity().getReasonForSelect(), emptyList()); + // Use ViewAccessControl when the session user is not the owner, matching regular view behavior. + // This checks CREATE_VIEW_WITH_SELECT_COLUMNS permissions to prevent privilege escalation + // where a user with only SELECT could grant access to others via a DEFINER MV. + if (!owner.get().equals(session.getIdentity().getUser())) { + queryAccessControl = new ViewAccessControl(accessControl); + } + else { + queryAccessControl = accessControl; + } + } + else { + queryIdentity = session.getIdentity(); + queryAccessControl = accessControl; + } + + Session materializedViewSession = createViewSession( + Optional.of(materializedViewName.getCatalogName()), + Optional.of(materializedViewDefinition.getSchema()), + queryIdentity); + + StatementAnalyzer materializedViewAnalyzer = new StatementAnalyzer( + analysis, + metadata, + sqlParser, + queryAccessControl, + materializedViewSession, + warningCollector); + materializedViewAnalyzer.analyze(viewQuery, scope); - return createAndAssignScope(materializedView, scope, relationType); + Scope queryScope = process(dataTable, scope); + RelationType relationType = queryScope.getRelationType().withOnlyVisibleFields().withAlias(materializedViewName.getObjectName(), null); + analysis.unregisterMaterializedViewForAnalysis(materializedView); + + Scope accessControlScope = Scope.builder() + .withRelationType(RelationId.anonymous(), relationType) + .build(); + analyzeFiltersAndMasks(materializedView, materializedViewName, accessControlScope, relationType.getAllFields()); + + return createAndAssignScope(materializedView, scope, relationType); + } } private String getMaterializedViewSQL( @@ -1614,7 +2581,7 @@ else if (materializedViewStatus.isPartiallyMaterialized()) { baseTablePredicates = generateBaseTablePredicates(materializedViewStatus.getPartitionsFromBaseTables(), metadata); } - Query predicateStitchedQuery = (Query) new PredicateStitcher(session, baseTablePredicates).process(createSqlStatement, new PredicateStitcherContext()); + Query predicateStitchedQuery = (Query) new PredicateStitcher(session, baseTablePredicates, metadata).process(createSqlStatement, new PredicateStitcherContext()); // TODO: consider materialized view predicates https://github.com/prestodb/presto/issues/16034 QuerySpecification materializedViewQuerySpecification = new QuerySpecification( @@ -1633,7 +2600,7 @@ else if (materializedViewStatus.isPartiallyMaterialized()) { Query unionQuery = new Query(predicateStitchedQuery.getWith(), union, predicateStitchedQuery.getOrderBy(), predicateStitchedQuery.getOffset(), predicateStitchedQuery.getLimit()); // can we return the above query object, instead of building a query string? // in case of returning the query object, make sure to clone the original query object. - return SqlFormatterUtil.getFormattedSql(unionQuery, sqlParser, Optional.empty()); + return getFormattedSql(unionQuery, sqlParser, Optional.empty()); } /** @@ -1651,7 +2618,12 @@ private MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName mat checkArgument(analysis.getCurrentQuerySpecification().isPresent(), "Current subquery should be set when processing materialized view"); QuerySpecification currentSubquery = analysis.getCurrentQuerySpecification().get(); - if (currentSubquery.getWhere().isPresent() && isMaterializedViewPartitionFilteringEnabled(session)) { + // Collect where clause from both current subquery and possible logical view + List wherePredicates = new ArrayList<>(); + currentSubquery.getWhere().ifPresent(wherePredicates::add); + analysis.getViewAccessorWhereClause().ifPresent(wherePredicates::add); + + if (!wherePredicates.isEmpty() && isMaterializedViewPartitionFilteringEnabled(session)) { Optional materializedViewDefinition = getMaterializedViewDefinition(session, metadataResolver, analysis.getMetadataHandle(), materializedViewName); if (!materializedViewDefinition.isPresent()) { log.warn("Materialized view definition not present as expected when fetching materialized view status"); @@ -1659,39 +2631,69 @@ private MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName mat } Scope sourceScope = getScopeFromTable(table, scope); - Expression viewQueryWhereClause = currentSubquery.getWhere().get(); + Expression combinedWhereClause = ExpressionUtils.combineConjuncts(wherePredicates); + + // Extract column names from materialized view scope + Set materializedViewColumns = sourceScope.getRelationType().getAllFields().stream() + .map(field -> field.getName()) + .filter(Optional::isPresent) + .map(Optional::get) + .map(QualifiedName::of) + .collect(Collectors.toSet()); - analyzeWhere(currentSubquery, sourceScope, viewQueryWhereClause); + // Only proceed with partition filtering if there are conjuncts that reference MV columns + List conjuncts = ExpressionUtils.extractConjuncts(combinedWhereClause); + List mvConjuncts = conjuncts.stream() + .filter(conjunct -> { + Set referencedColumns = VariablesExtractor.extractNames(conjunct, analysis.getColumnReferences()); + return !referencedColumns.isEmpty() && referencedColumns.stream().allMatch(materializedViewColumns::contains); + }) + .collect(Collectors.toList()); - DomainTranslator domainTranslator = new RowExpressionDomainTranslator(metadata); - RowExpression rowExpression = SqlToRowExpressionTranslator.translate( - viewQueryWhereClause, - analysis.getTypes(), - ImmutableMap.of(), - metadata.getFunctionAndTypeManager(), - session); + if (!mvConjuncts.isEmpty()) { + Expression filteredWhereClause = ExpressionUtils.combineConjuncts(mvConjuncts); - TupleDomain viewQueryDomain = MaterializedViewUtils.getDomainFromFilter(session, domainTranslator, rowExpression); + // Analyze the filtered WHERE clause only for type inference, don't record it in analysis + // to avoid preventing the full WHERE clause from being analyzed later + ExpressionAnalysis expressionAnalysis = analyzeExpression(filteredWhereClause, sourceScope); - Map> directColumnMappings = materializedViewDefinition.get().getDirectColumnMappingsAsMap(); + DomainTranslator domainTranslator = new RowExpressionDomainTranslator(metadata); + RowExpression rowExpression = SqlToRowExpressionTranslator.translate( + filteredWhereClause, + analysis.getTypes(), + ImmutableMap.of(), + metadata.getFunctionAndTypeManager(), + session); - // Get base query domain we have mapped from view query- if there are not direct mappings, don't filter partition count for predicate - boolean mappedToOneTable = true; - Map rewrittenDomain = new HashMap<>(); + TupleDomain viewQueryDomain = MaterializedViewUtils.getDomainFromFilter(session, domainTranslator, rowExpression); - for (Map.Entry entry : viewQueryDomain.getDomains().orElse(ImmutableMap.of()).entrySet()) { - Map baseTableMapping = directColumnMappings.get(entry.getKey()); - if (baseTableMapping == null || baseTableMapping.size() != 1) { - mappedToOneTable = false; - break; - } + Map> directColumnMappings = materializedViewDefinition.get().getDirectColumnMappingsAsMap(); - String baseColumnName = baseTableMapping.entrySet().stream().findAny().get().getValue(); - rewrittenDomain.put(baseColumnName, entry.getValue()); - } + // Get base query domain we have mapped from view query- if there are not direct mappings, don't filter partition count for predicate + boolean mappedToOneTable = true; + Map rewrittenDomain = new HashMap<>(); + + for (Map.Entry entry : viewQueryDomain.getDomains().orElse(ImmutableMap.of()).entrySet()) { + Map baseTableMapping = null; + for (String columnName : directColumnMappings.keySet()) { + if (columnName.equalsIgnoreCase(entry.getKey())) { + baseTableMapping = directColumnMappings.get(columnName); + break; + } + } + + if (baseTableMapping == null || baseTableMapping.size() != 1) { + mappedToOneTable = false; + break; + } + + String baseColumnName = baseTableMapping.entrySet().stream().findAny().get().getValue(); + rewrittenDomain.put(baseColumnName, entry.getValue()); + } - if (mappedToOneTable) { - baseQueryDomain = TupleDomain.withColumnDomains(rewrittenDomain); + if (mappedToOneTable) { + baseQueryDomain = TupleDomain.withColumnDomains(rewrittenDomain); + } } } @@ -1701,10 +2703,19 @@ private MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName mat @Override protected Scope visitAliasedRelation(AliasedRelation relation, Optional scope) { + analysis.setRelationName(relation, QualifiedName.of(relation.getAlias().getValue())); + analysis.addAliased(relation.getRelation()); Scope relationScope = process(relation.getRelation(), scope); - // todo this check should be inside of TupleDescriptor.withAlias, but the exception needs the node object RelationType relationType = relationScope.getRelationType(); + + // special-handle table function invocation + if (relation.getRelation() instanceof TableFunctionInvocation) { + return createAndAssignScope(relation, scope, + aliasTableFunctionInvocation(relation, relationType, (TableFunctionInvocation) relation.getRelation())); + } + + // todo this check should be inside of TupleDescriptor.withAlias, but the exception needs the node object if (relation.getColumnNames() != null) { int totalColumns = relationType.getVisibleFieldCount(); if (totalColumns != relation.getColumnNames().size()) { @@ -1713,17 +2724,107 @@ protected Scope visitAliasedRelation(AliasedRelation relation, Optional s } List aliases = null; + Collection inputFields = relationType.getAllFields(); if (relation.getColumnNames() != null) { aliases = relation.getColumnNames().stream() .map(Identifier::getValue) - .collect(Collectors.toList()); + .collect(toImmutableList()); + inputFields = relationType.getVisibleFields(); } RelationType descriptor = relationType.withAlias(relation.getAlias().getValue(), aliases); + checkArgument(inputFields.size() == descriptor.getAllFieldCount(), + "Expected %s fields, got %s", + descriptor.getAllFieldCount(), + inputFields.size()); + + Streams.forEachPair( + descriptor.getAllFields().stream(), + inputFields.stream(), + (newField, field) -> analysis.addSourceColumns(newField, analysis.getSourceColumns(field))); return createAndAssignScope(relation, scope, descriptor); } + // As described by the SQL standard ISO/IEC 9075-2, 7.6
, p. 409 + private RelationType aliasTableFunctionInvocation(AliasedRelation relation, RelationType relationType, TableFunctionInvocation function) + { + TableFunctionInvocationAnalysis tableFunctionAnalysis = analysis.getTableFunctionAnalysis(function); + int properColumnsCount = tableFunctionAnalysis.getProperColumnsCount(); + + // check that relation alias is different from range variables of all table arguments + tableFunctionAnalysis.getTableArgumentAnalyses().stream() + .map(TableArgumentAnalysis::getName) + .filter(Optional::isPresent) + .map(Optional::get) + .filter(name -> name.hasSuffix(QualifiedName.of(ImmutableList.of(relation.getAlias())))) + .findFirst() + .ifPresent(name -> { + throw new SemanticException(TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE, relation.getAlias(), "Relation alias: %s is a duplicate of input table name: %s", relation.getAlias(), name); + }); + + // build the new relation type. the alias must be applied to the proper columns only, + // and it must not shadow the range variables exposed by the table arguments + ImmutableList.Builder fieldsBuilder = ImmutableList.builder(); + // first, put the table function's proper columns with alias + if (relation.getColumnNames() != null) { + // check that number of column aliases matches number of table function's proper columns + if (properColumnsCount != relation.getColumnNames().size()) { + throw new SemanticException(MISMATCHED_COLUMN_ALIASES, relation, "Column alias list has %s entries but table function has %s proper columns", relation.getColumnNames().size(), properColumnsCount); + } + for (int i = 0; i < properColumnsCount; i++) { + // proper columns are not hidden, so we don't need to skip hidden fields + Field field = relationType.getFieldByIndex(i); + fieldsBuilder.add(Field.newQualified( + field.getNodeLocation(), + QualifiedName.of(ImmutableList.of(relation.getAlias())), + Optional.of(relation.getColumnNames().get(i).getCanonicalValue()), // although the canonical name is recorded, fields are resolved case-insensitive + field.getType(), + field.isHidden(), + field.getOriginTable(), + field.getOriginColumnName(), + field.isAliased())); + } + } + else { + for (int i = 0; i < properColumnsCount; i++) { + Field field = relationType.getFieldByIndex(i); + fieldsBuilder.add(Field.newQualified( + field.getNodeLocation(), + QualifiedName.of(ImmutableList.of(relation.getAlias())), + field.getName(), + field.getType(), + field.isHidden(), + field.getOriginTable(), + field.getOriginColumnName(), + field.isAliased())); + } + } + + // append remaining fields. They are not being aliased, so hidden fields are included + for (int i = properColumnsCount; i < relationType.getAllFieldCount(); i++) { + fieldsBuilder.add(relationType.getFieldByIndex(i)); + } + + List fields = fieldsBuilder.build(); + + // check that there are no duplicate names within the table function's proper columns + Set names = new HashSet<>(); + fields.subList(0, properColumnsCount).stream() + .map(Field::getName) + .filter(Optional::isPresent) + .map(Optional::get) + // field names are resolved case-insensitive + .map(name -> name.toLowerCase(ENGLISH)) + .forEach(name -> { + if (!names.add(name)) { + throw new SemanticException(DUPLICATE_COLUMN_NAME, relation.getRelation(), "Duplicate name of table function proper column: " + name); + } + }); + + return new RelationType(fields); + } + @Override protected Scope visitSampledRelation(SampledRelation relation, Optional scope) { @@ -1767,9 +2868,33 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional s analysis.setSampleRatio(relation, samplePercentageValue / 100); Scope relationScope = process(relation.getRelation(), scope); + + // TABLESAMPLE cannot be applied to a polymorphic table function (SQL standard ISO/IEC 9075-2, 7.6
, p. 409) + // Note: the below method finds a table function immediately nested in SampledRelation, or aliased. + // Potentially, a table function could be also nested with intervening PatternRecognitionRelation. + // Such case is handled in visitPatternRecognitionRelation(). + validateNoNestedTableFunction(relation.getRelation(), "sample"); + return createAndAssignScope(relation, scope, relationScope.getRelationType()); } + // this method should run after the `base` relation is processed, so that it is + // determined whether the table function is polymorphic + private void validateNoNestedTableFunction(Relation base, String context) + { + TableFunctionInvocation tableFunctionInvocation = null; + if (base instanceof TableFunctionInvocation) { + tableFunctionInvocation = (TableFunctionInvocation) base; + } + else if (base instanceof AliasedRelation && + ((AliasedRelation) base).getRelation() instanceof TableFunctionInvocation) { + tableFunctionInvocation = (TableFunctionInvocation) ((AliasedRelation) base).getRelation(); + } + if (tableFunctionInvocation != null && analysis.isPolymorphicTableFunction(tableFunctionInvocation)) { + throw new SemanticException(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, base, "Cannot apply %s to polymorphic table function invocation", context); + } + } + @Override protected Scope visitTableSubquery(TableSubquery node, Optional scope) { @@ -1919,6 +3044,14 @@ protected Scope visitSetOperation(SetOperation node, Optional scope) oldField.getOriginTable(), oldField.getOriginColumnName(), oldField.isAliased()); + + int index = i; + analysis.addSourceColumns( + outputDescriptorFields[index], + relationScopes.stream() + .map(relationType -> relationType.getRelationType().getFieldByIndex(index)) + .flatMap(field -> analysis.getSourceColumns(field).stream()) + .collect(toImmutableSet())); } for (int i = 0; i < node.getRelations().size(); i++) { @@ -2093,7 +3226,7 @@ private String createWarningMessage(Node node, String description) protected Scope visitUpdate(Update update, Optional scope) { Table table = update.getTable(); - QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName()); + QualifiedObjectName tableName = createQualifiedObjectName(session, table, table.getName(), metadata); MetadataHandle metadataHandle = analysis.getMetadataHandle(); if (getViewDefinition(session, metadataResolver, metadataHandle, tableName).isPresent()) { @@ -2112,21 +3245,21 @@ protected Scope visitUpdate(Update update, Optional scope) .collect(toImmutableMap(ColumnMetadata::getName, Function.identity())); for (UpdateAssignment assignment : update.getAssignments()) { - String columnName = assignment.getName().getValue(); + String columnName = metadata.normalizeIdentifier(session, tableName.getCatalogName(), assignment.getName().getValue()); if (!columns.containsKey(columnName)) { throw new SemanticException(MISSING_COLUMN, assignment.getName(), "The UPDATE SET target column %s doesn't exist", columnName); } } Set assignmentTargets = update.getAssignments().stream() - .map(assignment -> assignment.getName().getValue()) + .map(assignment -> metadata.normalizeIdentifier(session, tableName.getCatalogName(), assignment.getName().getValue())) .collect(toImmutableSet()); accessControl.checkCanUpdateTableColumns(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), tableName, assignmentTargets); List updatedColumns = allColumns.stream() .filter(column -> assignmentTargets.contains(column.getName())) .collect(toImmutableList()); - analysis.setUpdateType("UPDATE"); + analysis.setUpdateInfo(update.getUpdateInfo()); analysis.setUpdatedColumns(updatedColumns); // Analyzer checks for select permissions but UPDATE has a separate permission, so disable access checks @@ -2153,7 +3286,7 @@ protected Scope visitUpdate(Update update, Optional scope) List expressionTypes = expressionTypesBuilder.build(); List tableTypes = update.getAssignments().stream() - .map(assignment -> requireNonNull(columns.get(assignment.getName().getValue()))) + .map(assignment -> requireNonNull(columns.get(metadata.normalizeIdentifier(session, tableName.getCatalogName(), assignment.getName().getValue())))) .map(ColumnMetadata::getType) .collect(toImmutableList()); @@ -2170,6 +3303,248 @@ protected Scope visitUpdate(Update update, Optional scope) return createAndAssignScope(update, scope, Field.newUnqualified(update.getLocation(), "rows", BIGINT)); } + @Override + protected Scope visitMerge(Merge merge, Optional scope) + { + Relation targetRelation = merge.getTarget(); + Table targetTable = getMergeTargetTable(targetRelation); + QualifiedObjectName targetTableQualifiedName = createQualifiedObjectName(session, targetTable, targetTable.getName(), metadata); + MetadataHandle metadataHandle = analysis.getMetadataHandle(); + + if (getViewDefinition(session, metadataResolver, metadataHandle, targetTableQualifiedName).isPresent()) { + throw new SemanticException(NOT_SUPPORTED, merge, "Merging into views is not supported"); + } + + if (getMaterializedViewDefinition(session, metadataResolver, metadataHandle, targetTableQualifiedName).isPresent()) { + throw new SemanticException(NOT_SUPPORTED, merge, "Merging into materialized views is not supported"); + } + + TableColumnMetadata targetTableColumnsMetadata = getTableColumnsMetadata(session, metadataResolver, metadataHandle, targetTableQualifiedName); + + TableHandle targetTableHandle = targetTableColumnsMetadata.getTableHandle() + .orElseThrow(() -> new SemanticException(MISSING_TABLE, targetTable, "Table '%s' does not exist", targetTableQualifiedName)); + + // The analyzer checks for select permissions, but the MERGE INTO statement has different permissions, so disable access checks. + StatementAnalyzer statementAnalyzer = new StatementAnalyzer(analysis, metadata, sqlParser, + new AllowAllAccessControl(), session, warningCollector); + + Scope targetTableScope = statementAnalyzer.analyze(targetRelation, scope); + Scope sourceTableScope = process(merge.getSource(), scope); + Scope joinScope = createAndAssignScope(merge, scope, targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType())); + + List targetColumnsMetadata = targetTableColumnsMetadata.getColumnsMetadata().stream() + .filter(column -> !column.isHidden()) + .collect(toImmutableList()); + + Map targetAllColumnHandles = metadata.getColumnHandles(session, targetTableHandle); + ImmutableList.Builder targetColumnHandlesBuilder = ImmutableList.builder(); + ImmutableSet.Builder targetColumnNamesBuilder = ImmutableSet.builder(); + for (ColumnMetadata columnMetadata : targetColumnsMetadata) { + String targetColumnName = columnMetadata.getName(); + ColumnHandle targetColumnHandle = targetAllColumnHandles.get(targetColumnName); + targetColumnHandlesBuilder.add(targetColumnHandle); + targetColumnNamesBuilder.add(targetColumnName); + } + List targetColumnHandles = targetColumnHandlesBuilder.build(); + Set targetColumnNames = targetColumnNamesBuilder.build(); + + Map targetColumnTypes = targetColumnsMetadata.stream().collect(toImmutableMap(ColumnMetadata::getName, ColumnMetadata::getType)); + + // Analyze all expressions in the Merge node + + Expression mergePredicate = merge.getPredicate(); + ExpressionAnalysis mergePredicateAnalysis = analyzeExpression(mergePredicate, joinScope); + Type mergePredicateType = mergePredicateAnalysis.getType(mergePredicate); + if (!mergePredicateType.equals(BOOLEAN)) { + if (!mergePredicateType.equals(UNKNOWN)) { + throw new SemanticException(TYPE_MISMATCH, mergePredicate, "The MERGE predicate must evaluate to a boolean: actual type %s", mergePredicateType); + } + // coerce null to boolean + analysis.addCoercion(mergePredicate, BOOLEAN, false); + } + analysis.recordSubqueries(merge, mergePredicateAnalysis); + + Set allUpdateColumnNames = new HashSet<>(); + + for (int caseCounter = 0; caseCounter < merge.getMergeCases().size(); caseCounter++) { + MergeCase mergeCase = merge.getMergeCases().get(caseCounter); + List setColumnNames = lowercaseIdentifierList(mergeCase.getSetColumns()); + if (mergeCase instanceof MergeUpdate) { + allUpdateColumnNames.addAll(setColumnNames); + } + else if (mergeCase instanceof MergeInsert && setColumnNames.isEmpty()) { + setColumnNames = targetColumnsMetadata.stream().map(ColumnMetadata::getName).collect(toImmutableList()); + } + int mergeCaseSetColumnCount = setColumnNames.size(); + List mergeCaseSetExpressions = mergeCase.getSetExpressions(); + checkArgument( + mergeCaseSetColumnCount == mergeCaseSetExpressions.size(), + "Number of merge columns (%s) isn't equal to number of expressions (%s)", + mergeCaseSetColumnCount, mergeCaseSetExpressions.size()); + Set mergeCaseColumnNameSet = new HashSet<>(mergeCaseSetColumnCount); + // Look for missing or duplicate column names. + setColumnNames.forEach(mergeCaseColumnName -> { + if (!targetColumnNames.contains(mergeCaseColumnName)) { + throw new SemanticException(MISSING_COLUMN, merge, "Merge column name does not exist in target table: %s", mergeCaseColumnName); + } + if (!mergeCaseColumnNameSet.add(mergeCaseColumnName)) { + throw new SemanticException(DUPLICATE_COLUMN_NAME, merge, "Merge column name is specified more than once: %s", mergeCaseColumnName); + } + }); + + // Collects types for columns and expressions in this MergeCase. + ImmutableList.Builder setColumnTypesBuilder = ImmutableList.builder(); + ImmutableList.Builder setExpressionTypesBuilder = ImmutableList.builder(); + for (int index = 0; index < setColumnNames.size(); index++) { + String columnName = setColumnNames.get(index); + Expression setExpression = mergeCaseSetExpressions.get(index); + ExpressionAnalysis setExpressionAnalysis = analyzeExpression(setExpression, joinScope); + analysis.recordSubqueries(merge, setExpressionAnalysis); + Type setColumnType = requireNonNull(targetColumnTypes.get(columnName)); + setColumnTypesBuilder.add(setColumnType); + setExpressionTypesBuilder.add(setExpressionAnalysis.getType(setExpression)); + } + List setColumnTypes = setColumnTypesBuilder.build(); + List setExpressionTypes = setExpressionTypesBuilder.build(); + + // Check if the types of the columns and expressions match for the MERGE SET clause. + if (!checkTypesMatchForMergeSet(setColumnTypes, setExpressionTypes)) { + throw new SemanticException(TYPE_MISMATCH, + mergeCase, + "MERGE table column types don't match for MERGE case %s, SET expressions: Table: [%s], Expressions: [%s]", + caseCounter, + Joiner.on(", ").join(setColumnTypes), + Joiner.on(", ").join(setExpressionTypes)); + } + + // Add coercion if the target column type and set expression type do not match. + for (int index = 0; index < setColumnNames.size(); index++) { + Expression setExpression = mergeCase.getSetExpressions().get(index); + Type targetColumnType = targetColumnTypes.get(setColumnNames.get(index)); + Type setExpressionType = setExpressionTypes.get(index); + if (!targetColumnType.equals(setExpressionType)) { + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + analysis.addCoercion(setExpression, targetColumnType, functionAndTypeManager.isTypeOnlyCoercion(setExpressionType, targetColumnType)); + } + } + } + + // Check if the user has permission to insert into the target table + merge.getMergeCases().stream() + .filter(mergeCase -> mergeCase instanceof MergeInsert) + .findFirst() + .ifPresent(mergeCase -> accessControl.checkCanInsertIntoTable(session.getRequiredTransactionId(), + session.getIdentity(), session.getAccessControlContext(), targetTableQualifiedName)); + + // If there are any columns to update then verify the user has permission to update these columns. + if (!allUpdateColumnNames.isEmpty()) { + accessControl.checkCanUpdateTableColumns(session.getRequiredTransactionId(), session.getIdentity(), + session.getAccessControlContext(), targetTableQualifiedName, allUpdateColumnNames); + } + + analysis.setUpdateInfo(merge.getUpdateInfo()); + + List> mergeCaseColumnHandles = buildMergeCaseColumnLists(merge, targetColumnsMetadata, targetAllColumnHandles); + + ImmutableMap.Builder columnHandleFieldNumbersBuilder = ImmutableMap.builder(); + Map fieldIndexes = new HashMap<>(); + RelationType targetRelationType = targetTableScope.getRelationType(); + for (Field targetField : targetRelationType.getAllFields()) { + targetField.getName() + .filter(targetFieldName -> !"$target_table_row_id".equals(targetFieldName)) // Skip "$target_table_row_id" column. + .ifPresent(targetFieldName -> { + int targetFieldIndex = targetRelationType.indexOf(targetField); + ColumnHandle targetColumnHandle = targetAllColumnHandles.get(targetFieldName); + verify(targetColumnHandle != null, "targetAllColumnHandles does not contain the named handle: %s", targetFieldName); + columnHandleFieldNumbersBuilder.put(targetColumnHandle, targetFieldIndex); + fieldIndexes.put(targetFieldName, targetFieldIndex); + }); + } + Map columnHandleFieldNumbers = columnHandleFieldNumbersBuilder.buildOrThrow(); + + Set nonNullableColumnHandles = metadata.getTableMetadata(session, targetTableHandle).getColumns().stream() + .filter(column -> !column.isNullable()) + .map(ColumnMetadata::getName) + .map(targetAllColumnHandles::get) + .collect(toImmutableSet()); + + analysis.setMergeAnalysis(new MergeAnalysis( + targetTable, + targetColumnsMetadata, + targetColumnHandles, + mergeCaseColumnHandles, + nonNullableColumnHandles, + columnHandleFieldNumbers, + targetTableScope, + joinScope)); + + return createAndAssignScope(merge, Optional.empty(), Field.newUnqualified(merge.getLocation(), "rows", BIGINT)); + } + + private boolean checkTypesMatchForMergeSet(Iterable tableTypes, Iterable queryTypes) + { + if (Iterables.size(tableTypes) != Iterables.size(queryTypes)) { + return false; + } + + Iterator tableTypesIterator = tableTypes.iterator(); + Iterator queryTypesIterator = queryTypes.iterator(); + while (tableTypesIterator.hasNext()) { + Type tableType = tableTypesIterator.next(); + Type queryType = queryTypesIterator.next(); + + if (!metadata.getFunctionAndTypeManager().canCoerce(queryType, tableType)) { + return false; + } + } + + return true; + } + + private Table getMergeTargetTable(Relation relation) + { + if (relation instanceof Table) { + return (Table) relation; + } + checkArgument(relation instanceof AliasedRelation, "relation is neither a Table nor an AliasedRelation"); + return (Table) ((AliasedRelation) relation).getRelation(); + } + + /** + * Builds a list of column handles for each merge case in the given merge statement. + * + * @param merge the merge statement + * @param columnSchemas the list of column metadata for the target table. + * @param allColumnHandles a map of column names to column handles for the target table. + * @return a list of lists of column handles, where each inner list corresponds to a merge case. + */ + private List> buildMergeCaseColumnLists(Merge merge, List columnSchemas, Map allColumnHandles) + { + ImmutableList.Builder> mergeCaseColumnsListsBuilder = ImmutableList.builder(); + for (int caseCounter = 0; caseCounter < merge.getMergeCases().size(); caseCounter++) { + MergeCase mergeCase = merge.getMergeCases().get(caseCounter); + List mergeColumnNames; + if (mergeCase instanceof MergeInsert && mergeCase.getSetColumns().isEmpty()) { + mergeColumnNames = columnSchemas.stream().map(ColumnMetadata::getName).collect(toImmutableList()); + } + else { + mergeColumnNames = lowercaseIdentifierList(mergeCase.getSetColumns()); + } + mergeCaseColumnsListsBuilder.add( + mergeColumnNames.stream() + .map(name -> requireNonNull(allColumnHandles.get(name), "No column found for name")) + .collect(toImmutableList())); + } + return mergeCaseColumnsListsBuilder.build(); + } + + private List lowercaseIdentifierList(Collection identifiers) + { + return identifiers.stream() + .map(identifier -> identifier.getValue().toLowerCase(ENGLISH)) + .collect(toImmutableList()); + } + private Scope analyzeJoinUsing(Join node, List columns, Optional scope, Scope left, Scope right) { List joinFields = new ArrayList<>(); @@ -2479,7 +3854,7 @@ private void checkFunctionName(Statement node, QualifiedName functionName, boole .map(SqlFunction::getSignature) .map(Signature::getName) .map(QualifiedObjectName::getObjectName) - .collect(Collectors.toList()); + .collect(toImmutableList()); if (builtInFunctionNames.contains(functionName.toString())) { throw new SemanticException(INVALID_FUNCTION_NAME, node, format("Function %s is already registered as a built-in function.", functionName)); } @@ -2515,7 +3890,7 @@ public Expression rewriteIdentifier(Identifier reference, Void context, Expressi } if (expressions.size() == 1) { - return Iterables.getOnlyElement(expressions); + return getOnlyElement(expressions); } // otherwise, couldn't resolve name against output aliases, so fall through... @@ -2669,7 +4044,9 @@ private Scope computeAndAssignOutputScope(QuerySpecification node, Optional starPrefix = ((AllColumns) item).getPrefix(); for (Field field : sourceScope.getRelationType().resolveFieldsWithPrefix(starPrefix)) { - outputFields.add(Field.newUnqualified(node.getSelect().getLocation(), field.getName(), field.getType(), field.getOriginTable(), field.getOriginColumnName(), false)); + Field newField = Field.newUnqualified(node.getSelect().getLocation(), field.getName(), field.getType(), field.getOriginTable(), field.getOriginColumnName(), false); + analysis.addSourceColumns(newField, analysis.getSourceColumns(field)); + outputFields.add(newField); } } else if (item instanceof SingleColumn) { @@ -2686,7 +4063,7 @@ else if (item instanceof SingleColumn) { name = QualifiedName.of(((Identifier) expression).getValue()); } else if (expression instanceof DereferenceExpression) { - name = DereferenceExpression.getQualifiedName((DereferenceExpression) expression); + name = getQualifiedName((DereferenceExpression) expression); } if (name != null) { @@ -2699,11 +4076,19 @@ else if (expression instanceof DereferenceExpression) { if (!field.isPresent()) { if (name != null) { - field = Optional.of(new Identifier(getLast(name.getOriginalParts()))); + field = Optional.of(name.getOriginalSuffix()); } } - - outputFields.add(Field.newUnqualified(expression.getLocation(), field.map(Identifier::getValue), analysis.getType(expression), originTable, originColumn, column.getAlias().isPresent())); // TODO don't use analysis as a side-channel. Use outputExpressions to look up the type + Field newField = Field.newUnqualified(expression.getLocation(), field.map(Identifier::getValue), analysis.getType(expression), originTable, originColumn, column.getAlias().isPresent()); + if (originTable.isPresent()) { + analysis.addSourceColumns(newField, ImmutableSet.of( + new SourceColumn(originTable.get(), originColumn.orElseThrow( + () -> new NoSuchElementException("originColumn not found"))))); + } + else { + analysis.addSourceColumns(newField, analysis.getExpressionSourceColumns(expression)); + } + outputFields.add(newField); } else { throw new IllegalArgumentException("Unsupported SelectItem type: " + item.getClass().getName()); @@ -2916,7 +4301,7 @@ private RelationType analyzeView(Query query, QualifiedObjectName name, Optional AccessControl viewAccessControl; if (owner.isPresent() && !owner.get().equals(session.getIdentity().getUser())) { // definer mode - identity = new Identity(owner.get(), Optional.empty(), session.getIdentity().getExtraCredentials()); + identity = new Identity(owner.get(), Optional.empty(), emptyMap(), session.getIdentity().getExtraCredentials(), emptyMap(), Optional.empty(), session.getIdentity().getReasonForSelect(), emptyList()); viewAccessControl = new ViewAccessControl(accessControl); } else { @@ -3303,4 +4688,48 @@ private static boolean hasScopeAsLocalParent(Scope root, Scope parent) return false; } + + private static final class ArgumentAnalysis + { + private final Argument argument; + private final Optional tableArgumentAnalysis; + + public ArgumentAnalysis(Argument argument, Optional tableArgumentAnalysis) + { + this.argument = requireNonNull(argument, "argument is null"); + this.tableArgumentAnalysis = requireNonNull(tableArgumentAnalysis, "tableArgumentAnalysis is null"); + } + + public Argument getArgument() + { + return argument; + } + + public Optional getTableArgumentAnalysis() + { + return tableArgumentAnalysis; + } + } + + private static final class ArgumentsAnalysis + { + private final Map passedArguments; + private final List tableArgumentAnalyses; + + public ArgumentsAnalysis(Map passedArguments, List tableArgumentAnalyses) + { + this.passedArguments = ImmutableMap.copyOf(requireNonNull(passedArguments, "passedArguments is null")); + this.tableArgumentAnalyses = ImmutableList.copyOf(requireNonNull(tableArgumentAnalyses, "tableArgumentAnalyses is null")); + } + + public Map getPassedArguments() + { + return passedArguments; + } + + public List getTableArgumentAnalyses() + { + return tableArgumentAnalyses; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java index 0d200e5555a92..21216d4f04aee 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/ExpressionOptimizerManager.java @@ -13,12 +13,14 @@ */ package com.facebook.presto.sql.expressions; +import com.facebook.airlift.log.Logger; import com.facebook.presto.FullConnectorSession; import com.facebook.presto.Session; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; @@ -26,8 +28,7 @@ import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionOptimizer; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.io.IOException; @@ -47,12 +48,14 @@ public class ExpressionOptimizerManager implements ExpressionOptimizerProvider { + private static final Logger log = Logger.get(ExpressionOptimizerManager.class); public static final String DEFAULT_EXPRESSION_OPTIMIZER_NAME = "default"; private static final File EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY = new File("etc/expression-manager/"); private static final String EXPRESSION_MANAGER_FACTORY_NAME = "expression-manager-factory.name"; private final NodeManager nodeManager; private final FunctionAndTypeManager functionAndTypeManager; + private final RowExpressionSerde rowExpressionSerde; private final FunctionResolution functionResolution; private final File configurationDirectory; @@ -60,16 +63,17 @@ public class ExpressionOptimizerManager private final Map expressionOptimizers = new ConcurrentHashMap<>(); @Inject - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde) { - this(nodeManager, functionAndTypeManager, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); + this(nodeManager, functionAndTypeManager, rowExpressionSerde, EXPRESSION_MANAGER_CONFIGURATION_DIRECTORY); } - public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, File configurationDirectory) + public ExpressionOptimizerManager(PluginNodeManager nodeManager, FunctionAndTypeManager functionAndTypeManager, RowExpressionSerde rowExpressionSerde, File configurationDirectory) { requireNonNull(nodeManager, "nodeManager is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.rowExpressionSerde = requireNonNull(rowExpressionSerde, "rowExpressionSerde is null"); this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); this.configurationDirectory = requireNonNull(configurationDirectory, "configurationDirectory is null"); expressionOptimizers.put(DEFAULT_EXPRESSION_OPTIMIZER_NAME, new RowExpressionOptimizer(functionAndTypeManager)); @@ -89,23 +93,32 @@ public void loadExpressionOptimizerFactories() } } - private void loadExpressionOptimizerFactory(File configurationFile) + public void loadExpressionOptimizerFactory(File configurationFile) throws IOException { - String name = getNameWithoutExtension(configurationFile.getName()); - checkArgument(!isNullOrEmpty(name), "File name is empty, full path: %s", configurationFile.getAbsolutePath()); - checkArgument(!name.equals(DEFAULT_EXPRESSION_OPTIMIZER_NAME), "Cannot name an expression optimizer instance %s", DEFAULT_EXPRESSION_OPTIMIZER_NAME); + String optimizerName = getNameWithoutExtension(configurationFile.getName()); + checkArgument(!isNullOrEmpty(optimizerName), "File name is empty, full path: %s", configurationFile.getAbsolutePath()); + checkArgument(!optimizerName.equals(DEFAULT_EXPRESSION_OPTIMIZER_NAME), "Cannot name an expression optimizer instance %s", DEFAULT_EXPRESSION_OPTIMIZER_NAME); Map properties = new HashMap<>(loadProperties(configurationFile)); String factoryName = properties.remove(EXPRESSION_MANAGER_FACTORY_NAME); checkArgument(!isNullOrEmpty(factoryName), "%s does not contain %s", configurationFile, EXPRESSION_MANAGER_FACTORY_NAME); + + loadExpressionOptimizerFactory(factoryName, optimizerName, properties); + } + + public void loadExpressionOptimizerFactory(String factoryName, String optimizerName, Map properties) + { + requireNonNull(factoryName, "factoryName is null"); checkArgument(expressionOptimizerFactories.containsKey(factoryName), "ExpressionOptimizerFactory %s is not registered, registered factories: ", factoryName, expressionOptimizerFactories.keySet()); + log.info("-- Loading expression optimizer [%s] --", optimizerName); ExpressionOptimizer optimizer = expressionOptimizerFactories.get(factoryName).createOptimizer( properties, - new ExpressionOptimizerContext(nodeManager, functionAndTypeManager, functionResolution)); - expressionOptimizers.put(name, optimizer); + new ExpressionOptimizerContext(nodeManager, rowExpressionSerde, functionAndTypeManager, functionResolution)); + expressionOptimizers.put(optimizerName, optimizer); + log.info("-- Added expression optimizer [%s] --", optimizerName); } public void addExpressionOptimizerFactory(ExpressionOptimizerFactory expressionOptimizerFactory) @@ -123,8 +136,14 @@ public ExpressionOptimizer getExpressionOptimizer(ConnectorSession connectorSess checkArgument(connectorSession instanceof FullConnectorSession, "connectorSession is not an instance of FullConnectorSession"); Session session = ((FullConnectorSession) connectorSession).getSession(); String expressionOptimizerName = getExpressionOptimizerName(session); - checkArgument(expressionOptimizers.containsKey(expressionOptimizerName), "ExpressionOptimizer '%s' is not registered", expressionOptimizerName); - return expressionOptimizers.get(expressionOptimizerName); + return getExpressionOptimizer(expressionOptimizerName); + } + + public ExpressionOptimizer getExpressionOptimizer(String optimizerName) + { + requireNonNull(optimizerName, "optimizerName is null"); + checkArgument(expressionOptimizers.containsKey(optimizerName), "ExpressionOptimizer '%s' is not registered", optimizerName); + return expressionOptimizers.get(optimizerName); } private static List listFiles(File directory) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java new file mode 100644 index 0000000000000..cf276d6de3223 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/expressions/JsonCodecRowExpressionSerde.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.spi.RowExpressionSerde; +import com.facebook.presto.spi.relation.RowExpression; + +import javax.inject.Inject; + +import static java.util.Objects.requireNonNull; + +public class JsonCodecRowExpressionSerde + implements RowExpressionSerde +{ + private final JsonCodec codec; + + @Inject + public JsonCodecRowExpressionSerde(JsonCodec codec) + { + this.codec = requireNonNull(codec, "codec is null"); + } + + @Override + public String serialize(RowExpression expression) + { + return codec.toJson(expression); + } + + @Override + public RowExpression deserialize(String data) + { + return codec.fromJson(data); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CompilerOperations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CompilerOperations.java index 7005702dd9dfb..c46aa596b68ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CompilerOperations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CompilerOperations.java @@ -14,8 +14,7 @@ package com.facebook.presto.sql.gen; import com.facebook.presto.common.block.Block; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java index 5e3b87c38c8ee..5c6d5fd0bd138 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/CursorProcessorCompiler.java @@ -78,6 +78,8 @@ public class CursorProcessorCompiler implements BodyCompiler { + private static final int PROJECT_LIST_BATCH_SIZE = 1000; + public static final int HIGH_PROJECTION_WARNING_THRESHOLD = 2000; private static Logger log = Logger.get(CursorProcessorCompiler.class); private final Metadata metadata; @@ -94,6 +96,12 @@ public CursorProcessorCompiler(Metadata metadata, boolean isOptimizeCommonSubExp @Override public void generateMethods(SqlFunctionProperties sqlFunctionProperties, ClassDefinition classDefinition, CallSiteBinder callSiteBinder, RowExpression filter, List projections) { + if (projections.size() >= HIGH_PROJECTION_WARNING_THRESHOLD) { + log.warn("Query contains %d projections, which exceeds the recommended threshold of %d. " + + "Queries with very high projection counts may encounter JVM constant pool limits " + + "or performance issues. Consider reducing the number of projected columns if possible.", + projections.size(), HIGH_PROJECTION_WARNING_THRESHOLD); + } CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder); List rowExpressions = ImmutableList.builder() @@ -219,9 +227,9 @@ private static void generateProcessMethod(ClassDefinition classDefinition, int p cseFields.values().forEach(field -> whileFunctionBlock.append(scope.getThis().setField(field.getEvaluatedField(), constantBoolean(false)))); whileFunctionBlock.comment("do the projection") - .append(createProjectIfStatement(classDefinition, method, properties, cursor, pageBuilder, projections)) - .comment("completedPositions++;") - .incrementVariable(completedPositionsVariable, (byte) 1); + .append(createProjectIfStatement(classDefinition, method, properties, cursor, pageBuilder, projections)) + .comment("completedPositions++;") + .incrementVariable(completedPositionsVariable, (byte) 1); WhileLoop whileLoop = new WhileLoop() .condition(constantTrue()) @@ -255,31 +263,74 @@ private static IfStatement createProjectIfStatement( .getVariable(pageBuilder) .invokeVirtual(PageBuilder.class, "declarePosition", void.class); - // this.project_43(properties, cursor, pageBuilder.getBlockBuilder(42))); - for (int projectionIndex = 0; projectionIndex < projections; projectionIndex++) { + // Call batch methods instead of inlining all projections. The is to prevent against MethodTooLargeException + // Define all the batch methods + int batchCount = (projections + PROJECT_LIST_BATCH_SIZE - 1) / PROJECT_LIST_BATCH_SIZE; + generateProjectBatchMethods(classDefinition, projections, batchCount); + for (int batchNumber = 0; batchNumber < batchCount; batchNumber++) { ifStatement.ifTrue() .append(method.getThis()) .getVariable(properties) - .getVariable(cursor); - - // pageBuilder.getBlockBuilder(0) - ifStatement.ifTrue() + .getVariable(cursor) .getVariable(pageBuilder) - .push(projectionIndex) - .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class); - - // project(block..., blockBuilder)gen - ifStatement.ifTrue() + .push(batchNumber) .invokeVirtual(classDefinition.getType(), - "project_" + projectionIndex, + "processBatch_" + batchNumber, type(void.class), type(SqlFunctionProperties.class), type(RecordCursor.class), - type(BlockBuilder.class)); + type(PageBuilder.class), + type(int.class)); } + return ifStatement; } + private static void generateProjectBatchMethods( + ClassDefinition classDefinition, + int projections, + int batchCount) + { + for (int batchNumber = 0; batchNumber < batchCount; batchNumber++) { + Parameter properties = arg("properties", SqlFunctionProperties.class); + Parameter cursor = arg("cursor", RecordCursor.class); + Parameter pageBuilder = arg("pageBuilder", PageBuilder.class); + Parameter batchIndex = arg("batchIndex", int.class); + + MethodDefinition batchMethod = classDefinition.declareMethod( + a(PRIVATE), + "processBatch_" + batchNumber, + type(void.class), + properties, cursor, pageBuilder, batchIndex); + + BytecodeBlock body = batchMethod.getBody(); + + int startProjection = batchNumber * PROJECT_LIST_BATCH_SIZE; + int endProjection = Math.min(projections, (batchNumber + 1) * PROJECT_LIST_BATCH_SIZE); + + for (int projectionIndex = startProjection; projectionIndex < endProjection; projectionIndex++) { + body.append(batchMethod.getThis()) + .getVariable(properties) + .getVariable(cursor) + + // pageBuilder.getBlockBuilder(projectionIndex) + .getVariable(pageBuilder) + .push(projectionIndex) + .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class) + + // project_X(properties, cursor, blockBuilder) + .invokeVirtual(classDefinition.getType(), + "project_" + projectionIndex, + type(void.class), + type(SqlFunctionProperties.class), + type(RecordCursor.class), + type(BlockBuilder.class)); + } + + body.ret(); + } + } + private void generateFilterMethod( ClassDefinition classDefinition, RowExpressionCompiler compiler, @@ -298,16 +349,16 @@ private void generateFilterMethod( LabelNode end = new LabelNode("end"); body.comment("boolean wasNull = false;") - .putVariable(wasNullVariable, false) - .comment("evaluate filter: " + filter) - .append(compiler.compile(filter, scope, Optional.empty())) - .comment("if (wasNull) return false;") - .getVariable(wasNullVariable) - .ifFalseGoto(end) - .pop(boolean.class) - .push(false) - .visitLabel(end) - .retBoolean(); + .putVariable(wasNullVariable, false) + .comment("evaluate filter: " + filter) + .append(compiler.compile(filter, scope, Optional.empty())) + .comment("if (wasNull) return false;") + .getVariable(wasNullVariable) + .ifFalseGoto(end) + .pop(boolean.class) + .push(false) + .visitLabel(end) + .retBoolean(); } private void generateProjectMethod( diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java index 4721e326daf8b..6cadc749f18b4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionCompiler.java @@ -32,11 +32,10 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.UncheckedExecutionException; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.util.List; import java.util.Map; import java.util.Objects; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionProfiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionProfiler.java index b8719343d28ba..7caf6a102a5bd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionProfiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/ExpressionProfiler.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.sql.gen; +import com.facebook.airlift.units.Duration; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; -import io.airlift.units.Duration; import static com.facebook.presto.execution.executor.PrioritizedSplitRunner.SPLIT_RUN_QUANTA; import static com.google.common.base.Ticker.systemTicker; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java index 207a7fa9640fe..b75aafcfe2401 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinCompiler.java @@ -52,12 +52,11 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; +import jakarta.inject.Inject; import org.openjdk.jol.info.ClassLayout; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java index 5a22fac20bfd4..7ddc1949a4977 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/JoinFilterFunctionCompiler.java @@ -42,11 +42,10 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.inject.Inject; - import java.lang.reflect.Constructor; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java index 3449873800675..ceb69b5dd388f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/PageFunctionCompiler.java @@ -65,12 +65,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Primitives; import com.google.common.util.concurrent.UncheckedExecutionException; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.Nullable; -import javax.inject.Inject; - import java.util.Collection; import java.util.HashSet; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/RowExpressionPredicateCompiler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/RowExpressionPredicateCompiler.java index 08b29f520d636..d3099fee6a0e2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/gen/RowExpressionPredicateCompiler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/gen/RowExpressionPredicateCompiler.java @@ -40,8 +40,7 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index 348e031539142..2326c62b56dfd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -25,6 +25,8 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.MetadataDeleteNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -40,13 +42,17 @@ import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -158,6 +164,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan properties.getPartitioningHandle(), schedulingOrder, properties.getPartitioningScheme(), + properties.getOutputOrderingScheme(), StageExecutionDescriptor.ungroupedExecution(), outputTableWriterFragment, Optional.of(statsAndCosts.getForSubplan(root)), @@ -263,6 +270,27 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) + { + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitMergeProcessor(MergeProcessorNode node, RewriteContext context) + { + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext context) + { + if (node.getPartitioningScheme().isPresent()) { + context.get().setDistribution(node.getPartitioningScheme().get().getPartitioning().getHandle(), metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitValues(ValuesNode node, RewriteContext context) { @@ -270,6 +298,22 @@ public PlanNode visitValues(ValuesNode node, RewriteContext return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + // context is mutable. The leaf node should set the PartitioningHandle. + context.get().addSourceDistribution(node.getId(), SOURCE_DISTRIBUTION, metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { @@ -283,7 +327,7 @@ public PlanNode visitExchange(ExchangeNode exchange, RewriteContext builder = ImmutableList.builder(); for (int sourceIndex = 0; sourceIndex < exchange.getSources().size(); sourceIndex++) { - FragmentProperties childProperties = new FragmentProperties(translateOutputLayout(partitioningScheme, exchange.getInputs().get(sourceIndex))); + PartitioningScheme childPartitioningScheme = translateOutputLayout(partitioningScheme, exchange.getInputs().get(sourceIndex)); + FragmentProperties childProperties = new FragmentProperties(childPartitioningScheme); + + // If the exchange has ordering requirements, translate them for the child fragment + Optional childOutputOrderingScheme = Optional.empty(); + if (exchange.getOrderingScheme().isPresent()) { + childOutputOrderingScheme = exchange.getOrderingScheme(); + } + + // Set the output ordering scheme for the child fragment + childProperties.setOutputOrderingScheme(childOutputOrderingScheme); builder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); } @@ -334,7 +388,7 @@ else if (exchangeType == ExchangeNode.Type.REPARTITION) { } } - private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, RewriteContext context) + private PlanNode createRemoteMaterializedExchange(Metadata metadata, Session session, ExchangeNode exchange, RewriteContext context) { checkArgument(exchange.getType() == REPARTITION, "Unexpected exchange type: %s", exchange.getType()); checkArgument(exchange.getScope() == REMOTE_MATERIALIZED, "Unexpected exchange scope: %s", exchange.getScope()); @@ -352,7 +406,8 @@ private PlanNode createRemoteMaterializedExchange(ExchangeNode exchange, Rewrite Partitioning partitioning = partitioningScheme.getPartitioning(); PartitioningVariableAssignments partitioningVariableAssignments = assignPartitioningVariables(variableAllocator, partitioning); - Map variableToColumnMap = assignTemporaryTableColumnNames(exchange.getOutputVariables(), partitioningVariableAssignments.getConstants().keySet()); + Map variableToColumnMap = assignTemporaryTableColumnNames(metadata, + session, connectorId.getCatalogName(), exchange.getOutputVariables(), partitioningVariableAssignments.getConstants().keySet()); List partitioningVariables = partitioningVariableAssignments.getVariables(); List partitionColumns = partitioningVariables.stream() .map(variable -> variableToColumnMap.get(variable).getName()) @@ -434,11 +489,24 @@ public static class FragmentProperties private Optional partitioningHandle = Optional.empty(); private final Set partitionedSources = new HashSet<>(); + // Output ordering scheme for the fragment - this gets transferred to the PlanFragment + private Optional outputOrderingScheme = Optional.empty(); + public FragmentProperties(PartitioningScheme partitioningScheme) { this.partitioningScheme = partitioningScheme; } + public void setOutputOrderingScheme(Optional outputOrderingScheme) + { + this.outputOrderingScheme = requireNonNull(outputOrderingScheme, "outputOrderingScheme is null"); + } + + public Optional getOutputOrderingScheme() + { + return outputOrderingScheme; + } + public List getChildren() { return children; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 068b4459ea02d..6f5274eb36925 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.GroupingSetDescriptor; @@ -50,6 +51,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; @@ -63,7 +65,6 @@ import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.JsonProcessingException; @@ -1115,6 +1116,12 @@ public SchemaTableName getSchemaTableName() return new SchemaTableName("schema", "table"); } + @Override + public Optional> getOutputColumns() + { + return Optional.empty(); + } + @Override public String toString() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CompilerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CompilerConfig.java index 77348c13508f1..6269e8b9aa644 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CompilerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CompilerConfig.java @@ -17,8 +17,7 @@ import com.facebook.airlift.configuration.ConfigDescription; import com.facebook.airlift.configuration.DefunctConfig; import com.facebook.presto.spi.function.Description; - -import javax.validation.constraints.Min; +import jakarta.validation.constraints.Min; @DefunctConfig("compiler.interpreter-enabled") public class CompilerConfig diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ConnectorPlanOptimizerManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ConnectorPlanOptimizerManager.java index 7dfb48014de1c..4c31753341b6f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ConnectorPlanOptimizerManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ConnectorPlanOptimizerManager.java @@ -18,8 +18,7 @@ import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.connector.ConnectorPlanOptimizerProvider; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Map; import java.util.Set; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionDomainTranslator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionDomainTranslator.java index 6e7f05434eda5..f8ed668cad493 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionDomainTranslator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionDomainTranslator.java @@ -52,8 +52,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java index 603f218f23344..de44fcd5dcb92 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/ExpressionInterpreter.java @@ -116,6 +116,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.getMaxSerializableObjectSize; import static com.facebook.presto.SystemSessionProperties.isLegacyRowFieldOrdinalAccessEnabled; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; @@ -160,7 +161,6 @@ @Deprecated public class ExpressionInterpreter { - private static final long MAX_SERIALIZABLE_OBJECT_SIZE = 1000; private final Expression expression; private final Metadata metadata; private final LiteralEncoder literalEncoder; @@ -1318,7 +1318,7 @@ private boolean isSerializable(Object value, Type type) { requireNonNull(type, "type is null"); // If value is already Expression, literal values contained inside should already have been made serializable. Otherwise, we make sure the object is small and serializable. - return value instanceof Expression || (isSupportedLiteralType(type) && estimatedSizeInBytes(value) <= MAX_SERIALIZABLE_OBJECT_SIZE); + return value instanceof Expression || (isSupportedLiteralType(type) && estimatedSizeInBytes(value) <= getMaxSerializableObjectSize(session)); } private List toExpressions(List values, List types) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java index d4a6d6981f612..19ac5a575ded5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java @@ -27,7 +27,9 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; @@ -39,7 +41,7 @@ import java.util.OptionalInt; import static com.facebook.presto.SystemSessionProperties.GROUPED_EXECUTION; -import static com.facebook.presto.SystemSessionProperties.isGroupedExecutionEnabled; +import static com.facebook.presto.SystemSessionProperties.preferSortMergeJoin; import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_PAGE_SINK_COMMIT; import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_REWINDABLE_SPLIT_SOURCE; @@ -57,13 +59,15 @@ class GroupedExecutionTagger private final Metadata metadata; private final NodePartitioningManager nodePartitioningManager; private final boolean groupedExecutionEnabled; + private final boolean isPrestoOnSpark; - public GroupedExecutionTagger(Session session, Metadata metadata, NodePartitioningManager nodePartitioningManager) + public GroupedExecutionTagger(Session session, Metadata metadata, NodePartitioningManager nodePartitioningManager, boolean groupedExecutionEnabled, boolean isPrestoOnSpark) { this.session = requireNonNull(session, "session is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); - this.groupedExecutionEnabled = isGroupedExecutionEnabled(session); + this.groupedExecutionEnabled = groupedExecutionEnabled; + this.isPrestoOnSpark = isPrestoOnSpark; } @Override @@ -161,6 +165,19 @@ public GroupedExecutionTagger.GroupedExecutionProperties visitMergeJoin(MergeJoi left.totalLifespans, left.recoveryEligible && right.recoveryEligible); } + if (preferSortMergeJoin(session)) { + // TODO: This will break the other use case for merge join operating on sorted tables, which requires grouped execution for correctness. + return GroupedExecutionTagger.GroupedExecutionProperties.notCapable(); + } + + if (isPrestoOnSpark) { + GroupedExecutionTagger.GroupedExecutionProperties mergeJoinLeft = node.getLeft().accept(new GroupedExecutionTagger(session, metadata, nodePartitioningManager, true, true), null); + GroupedExecutionTagger.GroupedExecutionProperties mergeJoinRight = node.getRight().accept(new GroupedExecutionTagger(session, metadata, nodePartitioningManager, true, true), null); + if (mergeJoinLeft.currentNodeCapable || mergeJoinRight.currentNodeCapable) { + return GroupedExecutionTagger.GroupedExecutionProperties.notCapable(); + } + } + throw new PrestoException( INVALID_PLAN_ERROR, format("When grouped execution can't be enabled, merge join plan is not valid." + @@ -226,6 +243,22 @@ public GroupedExecutionTagger.GroupedExecutionProperties visitMarkDistinct(MarkD return GroupedExecutionTagger.GroupedExecutionProperties.notCapable(); } + @Override + public GroupedExecutionTagger.GroupedExecutionProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + GroupedExecutionTagger.GroupedExecutionProperties properties = node.getSource().accept(this, null); + boolean recoveryEligible = properties.isRecoveryEligible(); + CallDistributedProcedureTarget target = node.getTarget().orElseThrow(() -> new VerifyException("target is absent")); + recoveryEligible &= metadata.getConnectorCapabilities(session, target.getConnectorId()).contains(SUPPORTS_PAGE_SINK_COMMIT); + + return new GroupedExecutionTagger.GroupedExecutionProperties( + properties.isCurrentNodeCapable(), + properties.isSubTreeUseful(), + properties.getCapableTableScanNodes(), + properties.getTotalLifespans(), + recoveryEligible); + } + @Override public GroupedExecutionTagger.GroupedExecutionProperties visitTableWriter(TableWriterNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/InputExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/InputExtractor.java index 5fa416acc165e..05f1fe1ea17ce 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/InputExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/InputExtractor.java @@ -69,7 +69,7 @@ private Input createInput(TableMetadata table, TableHandle tableHandle, Set inputMetadata = metadata.getInfo(session, tableHandle); - return new Input(table.getConnectorId(), schemaTable.getSchemaName(), schemaTable.getTableName(), inputMetadata, ImmutableList.copyOf(columns), statistics, ""); + return new Input(table.getConnectorId(), schemaTable.getSchemaName(), schemaTable.getTableName(), inputMetadata, ImmutableList.copyOf(columns), statistics, Optional.empty()); } private class Visitor diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LazySplitSource.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LazySplitSource.java index 3200bfefbeb07..239b2e4a08953 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LazySplitSource.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LazySplitSource.java @@ -19,8 +19,7 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.split.SplitSource; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.function.Supplier; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalDynamicFiltersCollector.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalDynamicFiltersCollector.java index 8059775953aec..a3faae96afaef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalDynamicFiltersCollector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalDynamicFiltersCollector.java @@ -15,9 +15,8 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.spi.relation.VariableReferenceExpression; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; @ThreadSafe public class LocalDynamicFiltersCollector diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java index 6c46f02d91576..13cc38caef3fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LocalExecutionPlanner.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.Page; @@ -30,13 +31,14 @@ import com.facebook.presto.execution.FragmentResultCacheContext; import com.facebook.presto.execution.StageExecutionId; import com.facebook.presto.execution.TaskManagerConfig; -import com.facebook.presto.execution.TaskMetadataContext; import com.facebook.presto.execution.buffer.OutputBuffer; import com.facebook.presto.execution.buffer.PagesSerdeFactory; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.DeleteHandle; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.InsertHandle; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.MergeHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.RefreshMaterializedViewHandle; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.UpdateHandle; import com.facebook.presto.execution.scheduler.TableWriteInfo; @@ -48,7 +50,6 @@ import com.facebook.presto.memory.MemoryManagerConfig; import com.facebook.presto.metadata.AnalyzeTableHandle; import com.facebook.presto.metadata.BuiltInFunctionHandle; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; @@ -68,12 +69,15 @@ import com.facebook.presto.operator.JoinBridgeManager; import com.facebook.presto.operator.JoinOperatorFactory; import com.facebook.presto.operator.JoinOperatorFactory.OuterOperatorFactoryResult; +import com.facebook.presto.operator.LeafTableFunctionOperator; import com.facebook.presto.operator.LimitOperator.LimitOperatorFactory; import com.facebook.presto.operator.LocalPlannerAware; import com.facebook.presto.operator.LookupJoinOperators; import com.facebook.presto.operator.LookupOuterOperator.LookupOuterOperatorFactory; import com.facebook.presto.operator.LookupSourceFactory; import com.facebook.presto.operator.MarkDistinctOperator.MarkDistinctOperatorFactory; +import com.facebook.presto.operator.MergeProcessorOperator; +import com.facebook.presto.operator.MergeWriterOperator; import com.facebook.presto.operator.MetadataDeleteOperator.MetadataDeleteOperatorFactory; import com.facebook.presto.operator.NestedLoopJoinBridge; import com.facebook.presto.operator.NestedLoopJoinPagesSupplier; @@ -86,6 +90,7 @@ import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PartitionedLookupSourceFactory; import com.facebook.presto.operator.PipelineExecutionStrategy; +import com.facebook.presto.operator.RegularTableFunctionPartition; import com.facebook.presto.operator.RemoteProjectOperator.RemoteProjectOperatorFactory; import com.facebook.presto.operator.RowNumberOperator; import com.facebook.presto.operator.ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory; @@ -99,6 +104,7 @@ import com.facebook.presto.operator.StreamingAggregationOperator.StreamingAggregationOperatorFactory; import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TableFinishOperator.PageSinkCommitter; +import com.facebook.presto.operator.TableFunctionOperator; import com.facebook.presto.operator.TableScanOperator.TableScanOperatorFactory; import com.facebook.presto.operator.TableWriterMergeOperator.TableWriterMergeOperatorFactory; import com.facebook.presto.operator.TaskContext; @@ -132,6 +138,7 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorIndex; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.RecordSet; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; @@ -141,20 +148,24 @@ import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.aggregation.LambdaProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinDistributionType; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -172,6 +183,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.plan.WindowNode.Frame; @@ -198,20 +210,22 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.VariableToChannelTranslator; @@ -233,9 +247,7 @@ import com.google.common.collect.Multimap; import com.google.common.collect.SetMultimap; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Arrays; @@ -244,6 +256,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; @@ -255,6 +268,7 @@ import java.util.stream.IntStream; import static com.facebook.airlift.concurrent.MoreFutures.addSuccessCallback; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.SystemSessionProperties.getAdaptivePartialAggregationRowsReductionRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverRowCount; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverSize; @@ -313,6 +327,7 @@ import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.INTERMEDIATE; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; @@ -325,6 +340,7 @@ import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.gen.CursorProcessorCompiler.HIGH_PROJECTION_WARNING_THRESHOLD; import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.compileLambdaProvider; import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; @@ -349,10 +365,11 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.DiscreteDomain.integers; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.Range.closedOpen; -import static io.airlift.units.DataSize.Unit.BYTE; +import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.IntStream.range; @@ -366,7 +383,6 @@ public class LocalExecutionPlanner private final PartitioningProviderManager partitioningProviderManager; private final NodePartitioningManager nodePartitioningManager; private final PageSinkManager pageSinkManager; - private final ConnectorMetadataUpdaterManager metadataUpdaterManager; private final ExpressionCompiler expressionCompiler; private final PageFunctionCompiler pageFunctionCompiler; private final JoinFilterFunctionCompiler joinFilterFunctionCompiler; @@ -402,7 +418,6 @@ public LocalExecutionPlanner( PartitioningProviderManager partitioningProviderManager, NodePartitioningManager nodePartitioningManager, PageSinkManager pageSinkManager, - ConnectorMetadataUpdaterManager metadataUpdaterManager, ExpressionCompiler expressionCompiler, PageFunctionCompiler pageFunctionCompiler, JoinFilterFunctionCompiler joinFilterFunctionCompiler, @@ -431,7 +446,6 @@ public LocalExecutionPlanner( this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); - this.metadataUpdaterManager = requireNonNull(metadataUpdaterManager, "metadataUpdaterManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "compiler is null"); this.pageFunctionCompiler = requireNonNull(pageFunctionCompiler, "pageFunctionCompiler is null"); this.joinFilterFunctionCompiler = requireNonNull(joinFilterFunctionCompiler, "compiler is null"); @@ -793,11 +807,6 @@ private void setInputDriver(boolean inputDriver) this.inputDriver = inputDriver; } - public TaskMetadataContext getTaskMetadataContext() - { - return taskContext.getTaskMetadataContext(); - } - public TableWriteInfo getTableWriteInfo() { return tableWriteInfo; @@ -1212,6 +1221,98 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext return new PhysicalOperation(operatorFactory, outputMappings.build(), context, source); } + @Override + public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecutionPlanContext context) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + TableFunctionProcessorProvider processorProvider = metadata.getFunctionAndTypeManager().getTableFunctionProcessorProvider(node.getHandle()); + + if (!node.getSource().isPresent()) { + OperatorFactory operatorFactory = new LeafTableFunctionOperator.LeafTableFunctionOperatorFactory(context.getNextOperatorId(), node.getId(), processorProvider, node.getHandle().getFunctionHandle()); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, Optional.empty(), UNGROUPED_EXECUTION); + } + + PhysicalOperation source = node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + + int properChannelsCount = node.getProperOutputs().size(); + + long passThroughSourcesCount = node.getPassThroughSpecifications().stream() + .filter(TableFunctionNode.PassThroughSpecification::isDeclaredAsPassThrough) + .count(); + + List> requiredChannels = node.getRequiredVariables().stream() + .map(list -> getChannelsForVariables(list, source.getLayout())) + .collect(toImmutableList()); + + Optional> markerChannels = node.getMarkerVariables() + .map(map -> map.entrySet().stream() + .collect(toImmutableMap(entry -> source.getLayout().get(entry.getKey()), entry -> source.getLayout().get(entry.getValue())))); + + int channel = properChannelsCount; + ImmutableList.Builder passThroughColumnSpecifications = ImmutableList.builder(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + // the table function produces one index channel for each source declared as pass-through. They are laid out after the proper channels. + int indexChannel = specification.isDeclaredAsPassThrough() ? channel++ : -1; + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + passThroughColumnSpecifications.add(new RegularTableFunctionPartition.PassThroughColumnSpecification(column.isPartitioningColumn(), source.getLayout().get(column.getOutputVariables()), indexChannel)); + } + } + + List partitionChannels = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .map(list -> getChannelsForVariables(list, source.getLayout())) + .orElse(ImmutableList.of()); + + List sortChannels = ImmutableList.of(); + List sortOrders = ImmutableList.of(); + if (node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).isPresent()) { + OrderingScheme orderingScheme = node.getSpecification().flatMap(DataOrganizationSpecification::getOrderingScheme).orElseThrow(NoSuchElementException::new); + sortChannels = getChannelsForVariables(orderingScheme.getOrderByVariables(), source.getLayout()); + sortOrders = orderingScheme.getOrderingsMap().values().stream().collect(toImmutableList()); + } + + OperatorFactory operator = new TableFunctionOperator.TableFunctionOperatorFactory( + context.getNextOperatorId(), + node.getId(), + processorProvider, + node.getHandle().getFunctionHandle(), + properChannelsCount, + toIntExact(passThroughSourcesCount), + requiredChannels, + markerChannels, + passThroughColumnSpecifications.build(), + node.isPruneWhenEmpty(), + partitionChannels, + getChannelsForVariables(ImmutableList.copyOf(node.getPrePartitioned()), source.getLayout()), + sortChannels, + sortOrders, + node.getPreSorted(), + source.getTypes(), + 10_000, + pagesIndexFactory); + + ImmutableMap.Builder outputMappings = ImmutableMap.builder(); + for (int i = 0; i < node.getProperOutputs().size(); i++) { + outputMappings.put(node.getProperOutputs().get(i), i); + } + List passThroughVariables = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableList()); + int outputChannel = properChannelsCount; + for (VariableReferenceExpression passThroughVariable : passThroughVariables) { + outputMappings.put(passThroughVariable, outputChannel++); + } + + return new PhysicalOperation(operator, outputMappings.buildOrThrow(), context, source); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { @@ -1513,6 +1614,13 @@ private PhysicalOperation visitScanFilterAndProject( .collect(toImmutableList()); try { + if (projections.size() >= HIGH_PROJECTION_WARNING_THRESHOLD) { + session.getWarningCollector().add(new PrestoWarning( + PERFORMANCE_WARNING.toWarningCode(), + String.format("Query contains %d projections, which exceeds the recommended threshold of %d. " + + "Queries with very high projection counts may encounter JVM constant pool limits or performance issues.", + projections.size(), HIGH_PROJECTION_WARNING_THRESHOLD))); + } if (columns != null) { Supplier cursorProcessor = expressionCompiler.compileCursorProcessor( session.getSqlFunctionProperties(), @@ -2666,6 +2774,46 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont return new PhysicalOperation(operator, outputMappings, context, probeSource); } + @Override + public PhysicalOperation visitCallDistributedProcedure(CallDistributedProcedureNode node, LocalExecutionPlanContext context) + { + // Set table writer count + if (node.getPartitioningScheme().isPresent()) { + context.setDriverInstanceCount(getTaskPartitionedWriterCount(session)); + } + else { + context.setDriverInstanceCount(getTaskWriterCount(session)); + } + + PhysicalOperation source = node.getSource().accept(this, context); + + ImmutableMap.Builder outputMapping = ImmutableMap.builder(); + outputMapping.put(node.getRowCountVariable(), ROW_COUNT_CHANNEL); + outputMapping.put(node.getFragmentVariable(), FRAGMENT_CHANNEL); + outputMapping.put(node.getTableCommitContextVariable(), CONTEXT_CHANNEL); + + List inputChannels = node.getColumns().stream() + .map(source::variableToChannel) + .collect(toImmutableList()); + List notNullChannelColumnNames = node.getColumns().stream() + .map(variable -> node.getNotNullColumnVariables().contains(variable) ? node.getColumnNames().get(source.variableToChannel(variable)) : null) + .collect(Collectors.toList()); + + OperatorFactory operatorFactory = new TableWriterOperatorFactory( + context.getNextOperatorId(), + node.getId(), + pageSinkManager, + context.getTableWriteInfo().getWriterTarget().orElseThrow(() -> new VerifyException("writerTarget is absent")), + inputChannels, + notNullChannelColumnNames, + session, + new DevNullOperatorFactory(context.getNextOperatorId(), node.getId()), // statistics are not calculated + getVariableTypes(node.getOutputVariables()), + tableCommitContextCodec, + getPageSinkCommitStrategy()); + return new PhysicalOperation(operatorFactory, outputMapping.build(), context, source); + } + @Override public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context) { @@ -2739,8 +2887,6 @@ public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPl context.getNextOperatorId(), node.getId(), pageSinkManager, - metadataUpdaterManager, - context.getTaskMetadataContext(), context.getTableWriteInfo().getWriterTarget().orElseThrow(() -> new VerifyException("writerTarget is absent")), inputChannels, notNullChannelColumnNames, @@ -2764,6 +2910,43 @@ private PageSinkCommitStrategy getPageSinkCommitStrategy() return NO_COMMIT; } + @Override + public PhysicalOperation visitMergeWriter(MergeWriterNode node, LocalExecutionPlanContext context) + { + context.setDriverInstanceCount(getTaskWriterCount(session)); + + PhysicalOperation source = node.getSource().accept(this, context); + OperatorFactory operatorFactory = new MergeWriterOperator.MergeWriterOperatorFactory( + context.getNextOperatorId(), node.getId(), pageSinkManager, node.getTarget(), session, + tableCommitContextCodec); + return new PhysicalOperation(operatorFactory, makeLayout(node), context, source); + } + + @Override + public PhysicalOperation visitMergeProcessor(MergeProcessorNode node, LocalExecutionPlanContext context) + { + PhysicalOperation source = node.getSource().accept(this, context); + + Map nodeLayout = makeLayout(node); + Map sourceLayout = makeLayout(node.getSource()); + int rowIdChannel = sourceLayout.get(node.getTargetTableRowIdColumnVariable()); + int mergeRowChannel = sourceLayout.get(node.getMergeRowVariable()); + + List targetColumnChannels = node.getTargetColumnVariables().stream() + .map(nodeLayout::get) + .collect(toImmutableList()); + + OperatorFactory operatorFactory = new MergeProcessorOperator.MergeProcessorOperatorFactory( + context.getNextOperatorId(), + node.getId(), + node.getTarget().getMergeParadigmAndTypes(), + rowIdChannel, + mergeRowChannel, + targetColumnChannels); + + return new PhysicalOperation(operatorFactory, nodeLayout, context, source); + } + @Override public PhysicalOperation visitStatisticsWriterNode(StatisticsWriterNode node, LocalExecutionPlanContext context) { @@ -2854,7 +3037,7 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl Map aggregationMap = aggregation.getAggregations().entrySet() .stream().collect( - ImmutableMap.toImmutableMap( + toImmutableMap( Map.Entry::getKey, entry -> createAggregation(entry.getValue()))); if (groupingVariables.isEmpty()) { @@ -2919,8 +3102,10 @@ public PhysicalOperation visitTableFinish(TableFinishNode node, LocalExecutionPl public PhysicalOperation visitDelete(DeleteNode node, LocalExecutionPlanContext context) { PhysicalOperation source = node.getSource().accept(this, context); - - OperatorFactory operatorFactory = new DeleteOperatorFactory(context.getNextOperatorId(), node.getId(), source.getLayout().get(node.getRowId()), tableCommitContextCodec); + if (!node.getRowId().isPresent()) { + throw new PrestoException(NOT_SUPPORTED, "DELETE is not supported by this connector"); + } + OperatorFactory operatorFactory = new DeleteOperatorFactory(context.getNextOperatorId(), node.getId(), source.getLayout().get(node.getRowId().get()), tableCommitContextCodec); Map layout = ImmutableMap.builder() .put(node.getOutputVariables().get(0), 0) @@ -2996,16 +3181,15 @@ private List createColumnValueAndRowIdChannels(List= 0) { - columnValueAndRowIdChannels[index] = symbolCounter; - } + for (VariableReferenceExpression columnValueAndRowIdSymbol : columnValueAndRowIdSymbols) { + int index = variableReferenceExpressions.indexOf(columnValueAndRowIdSymbol); + + verify(index >= 0, "Could not find columnValueAndRowIdSymbol %s in the variableReferenceExpressions %s", columnValueAndRowIdSymbol, variableReferenceExpressions); + columnValueAndRowIdChannels[symbolCounter] = index; + symbolCounter++; } - checkArgument(symbolCounter == columnValueAndRowIdSymbols.size(), "symbolCounter %s should be columnValueAndRowIdChannels.size() %s", symbolCounter); + return Arrays.asList(columnValueAndRowIdChannels); } @@ -3480,8 +3664,7 @@ else if (target instanceof InsertHandle) { return metadata.finishInsert(session, ((InsertHandle) target).getHandle(), fragments, statistics); } else if (target instanceof DeleteHandle) { - metadata.finishDelete(session, ((DeleteHandle) target).getHandle(), fragments); - return Optional.empty(); + return metadata.finishDeleteWithOutput(session, ((DeleteHandle) target).getHandle(), fragments); } else if (target instanceof RefreshMaterializedViewHandle) { return metadata.finishRefreshMaterializedView(session, ((RefreshMaterializedViewHandle) target).getHandle(), fragments, statistics); @@ -3490,6 +3673,14 @@ else if (target instanceof UpdateHandle) { metadata.finishUpdate(session, ((UpdateHandle) target).getHandle(), fragments); return Optional.empty(); } + else if (target instanceof MergeHandle) { + metadata.finishMerge(session, ((MergeHandle) target).getHandle(), fragments, statistics); + return Optional.empty(); + } + else if (target instanceof ExecuteProcedureHandle) { + metadata.finishCallDistributedProcedure(session, ((ExecuteProcedureHandle) target).getHandle(), ((ExecuteProcedureHandle) target).getProcedureName(), fragments); + return Optional.empty(); + } else { throw new AssertionError("Unhandled target type: " + target.getClass().getName()); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 4f2f6b826b221..fe73971051712 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableLayout; +import com.facebook.presto.metadata.TableLayout.TablePartitioning; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; @@ -33,6 +35,7 @@ import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; @@ -41,6 +44,7 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TableWriterNode.DeleteHandle; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; @@ -53,10 +57,13 @@ import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.StatisticsAggregationPlanner.TableStatisticAggregation; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.tree.Analyze; +import com.facebook.presto.sql.tree.Call; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CreateTableAsSelect; import com.facebook.presto.sql.tree.Delete; @@ -66,10 +73,12 @@ import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Insert; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; +import com.facebook.presto.sql.tree.Merge; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.Query; +import com.facebook.presto.sql.tree.QuerySpecification; import com.facebook.presto.sql.tree.RefreshMaterializedView; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.Update; @@ -85,12 +94,14 @@ import java.util.Map.Entry; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.MetadataUtil.toSchemaTableName; import static com.facebook.presto.spi.PartitionedTableWritePolicy.MULTIPLE_WRITERS_PER_PARTITION_ALLOWED; +import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL; @@ -103,6 +114,7 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.tree.ExplainFormat.Type.TEXT; @@ -171,6 +183,15 @@ private RelationPlan planStatementWithoutOutput(Analysis analysis, Statement sta else if (statement instanceof Analyze) { return createAnalyzePlan(analysis, (Analyze) statement); } + else if (statement instanceof Call) { + checkState(analysis.getDistributedProcedureType().isPresent(), "Call distributed procedure analysis is missing"); + switch (analysis.getDistributedProcedureType().get()) { + case TABLE_DATA_REWRITE: + return createCallDistributedProcedurePlanForTableDataRewrite(analysis, (Call) statement); + default: + throw new PrestoException(NOT_SUPPORTED, "Unsupported distributed procedure type: " + analysis.getDistributedProcedureType().get()); + } + } else if (statement instanceof Insert) { checkState(analysis.getInsert().isPresent(), "Insert handle is missing"); return createInsertPlan(analysis, (Insert) statement); @@ -181,6 +202,9 @@ else if (statement instanceof Delete) { if (statement instanceof Update) { return createUpdatePlan(analysis, (Update) statement); } + if (statement instanceof Merge) { + return createMergePlan(analysis, (Merge) statement); + } else if (statement instanceof Query) { return createRelationPlan(analysis, (Query) statement, new SqlPlannerContext(0)); } @@ -211,6 +235,83 @@ private RelationPlan createExplainAnalyzePlan(Analysis analysis, Explain stateme return new RelationPlan(root, scope, ImmutableList.of(outputVariable)); } + private RelationPlan createCallDistributedProcedurePlanForTableDataRewrite(Analysis analysis, Call statement) + { + TableHandle targetTable = analysis.getCallTarget() + .orElseThrow(() -> new PrestoException(NOT_FOUND, "Target table does not exist")); + Optional procedureName = analysis.getProcedureName(); + Optional procedureArguments = analysis.getProcedureArguments(); + + QuerySpecification querySpecification = analysis.getTargetQuery() + .orElseThrow(() -> new PrestoException(NOT_FOUND, "The query for target table does not exist")); + RelationPlan plan = createRelationPlan(analysis, querySpecification, new SqlPlannerContext(0)); + + ConnectorTableMetadata tableMetadata = metadata.getTableMetadata(session, targetTable).getMetadata(); + List columnNames = tableMetadata.getColumns().stream() + .filter(column -> !column.isHidden()) + .map(ColumnMetadata::getName) + .collect(toImmutableList()); + + Map columnHandleMap = metadata.getColumnHandles(session, targetTable); + TableLayout tableLayout = metadata.getLayout(session, targetTable); + List columnHandles = columnNames.stream().map(columnHandleMap::get).collect(Collectors.toList()); + List outputLayout = plan.getRoot().getOutputVariables(); + + Optional partitioningScheme = Optional.empty(); + Optional partitioningHandle = tableLayout.getTablePartitioning().map(TablePartitioning::getPartitioningHandle); + if (partitioningHandle.isPresent()) { + List partitionFunctionArguments = new ArrayList<>(); + tableLayout.getTablePartitioning().get().getPartitioningColumns().stream() + .mapToInt(columnHandles::indexOf) + .mapToObj(outputLayout::get) + .forEach(partitionFunctionArguments::add); + partitioningScheme = Optional.of(new PartitioningScheme( + Partitioning.create(partitioningHandle.get(), partitionFunctionArguments), + outputLayout)); + } + + verify(columnNames.size() == outputLayout.size(), "columnNames.size() != outputLayout.size(): %s and %s", columnNames, outputLayout); + List variables = plan.getFieldMappings(); + verify(columnNames.size() == variables.size(), "columnNames.size() != variables.size(): %s and %s", columnNames, variables); + Map columnToVariableMap = zip(columnNames.stream(), plan.getFieldMappings().stream(), SimpleImmutableEntry::new) + .collect(toImmutableMap(Entry::getKey, Entry::getValue)); + + Set notNullColumnVariables = tableMetadata.getColumns().stream() + .filter(column -> !column.isNullable()) + .map(ColumnMetadata::getName) + .map(columnToVariableMap::get) + .collect(toImmutableSet()); + + CallDistributedProcedureTarget callDistributedProcedureTarget = new CallDistributedProcedureTarget( + procedureName.get(), + procedureArguments.get(), + Optional.of(targetTable), + tableMetadata.getTable(), + false); + TableFinishNode commitNode = new TableFinishNode( + Optional.empty(), + idAllocator.getNextId(), + new CallDistributedProcedureNode( + Optional.empty(), + idAllocator.getNextId(), + Optional.empty(), + plan.getRoot(), + Optional.of(callDistributedProcedureTarget), + variableAllocator.newVariable("rows", BIGINT), + variableAllocator.newVariable("fragment", VARBINARY), + variableAllocator.newVariable("commitcontext", VARBINARY), + plan.getRoot().getOutputVariables(), + columnNames, + notNullColumnVariables, + partitioningScheme), + Optional.of(callDistributedProcedureTarget), + variableAllocator.newVariable("rows", BIGINT), + Optional.empty(), + Optional.empty(), + Optional.empty()); + return new RelationPlan(commitNode, analysis.getScope(statement), commitNode.getOutputVariables()); + } + private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStatement) { TableHandle targetTable = analysis.getAnalyzeTarget().get(); @@ -271,7 +372,7 @@ private RelationPlan createTableCreationPlan(Analysis analysis, Query query) ConnectorTableMetadata tableMetadata = createTableMetadata( destination, - getOutputTableColumns(plan, analysis.getColumnAliases()), + getOutputTableColumns(metadata, session, destination.getCatalogName(), plan, analysis.getColumnAliases()), analysis.getCreateTableProperties(), analysis.getParameters(), analysis.getCreateTableComment()); @@ -286,7 +387,7 @@ private RelationPlan createTableCreationPlan(Analysis analysis, Query query) return createTableWriterPlan( analysis, plan, - new CreateName(new ConnectorId(destination.getCatalogName()), tableMetadata, newTableLayout), + new CreateName(new ConnectorId(destination.getCatalogName()), tableMetadata, newTableLayout, analysis.getUpdatedSourceColumns()), columnNames, tableMetadata.getColumns(), newTableLayout, @@ -310,7 +411,7 @@ private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement) TableHandle tableHandle = insertAnalysis.getTarget(); List columnHandles = insertAnalysis.getColumns(); - TableWriterNode.WriterTarget target = new InsertReference(tableHandle, metadata.getTableMetadata(session, tableHandle).getTable()); + TableWriterNode.WriterTarget target = new InsertReference(tableHandle, metadata.getTableMetadata(session, tableHandle).getTable(), analysis.getUpdatedSourceColumns()); return buildInternalInsertPlan(tableHandle, columnHandles, insertStatement.getQuery(), analysis, target); } @@ -511,7 +612,7 @@ private RelationPlan createUpdatePlan(Analysis analysis, Update node) .filter(column -> !column.isHidden()) .collect(toImmutableList()); List targetColumnNames = node.getAssignments().stream() - .map(assignment -> assignment.getName().getValue()) + .map(assignment -> metadata.normalizeIdentifier(session, handle.getConnectorId().getCatalogName(), assignment.getName().getValue())) .collect(toImmutableList()); for (ColumnMetadata columnMetadata : dataColumns) { @@ -537,6 +638,26 @@ private RelationPlan createUpdatePlan(Analysis analysis, Update node) return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputVariables()); } + private RelationPlan createMergePlan(Analysis analysis, Merge node) + { + SqlPlannerContext context = new SqlPlannerContext(0); + MergeWriterNode mergeNode = new QueryPlanner(analysis, variableAllocator, idAllocator, + buildLambdaDeclarationToVariableMap(analysis, variableAllocator), metadata, session, context, sqlParser) + .plan(node); + + TableFinishNode commitNode = new TableFinishNode( + mergeNode.getSourceLocation(), + idAllocator.getNextId(), + mergeNode, + Optional.of(mergeNode.getTarget()), + variableAllocator.newVariable("rows", BIGINT), + Optional.empty(), + Optional.empty(), + Optional.empty()); + + return new RelationPlan(commitNode, analysis.getScope(node), commitNode.getOutputVariables()); + } + private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) { ImmutableList.Builder outputs = ImmutableList.builder(); @@ -564,6 +685,12 @@ private RelationPlan createRelationPlan(Analysis analysis, Query query, SqlPlann .process(query, context); } + private RelationPlan createRelationPlan(Analysis analysis, QuerySpecification query, SqlPlannerContext context) + { + return new RelationPlanner(analysis, variableAllocator, idAllocator, buildLambdaDeclarationToVariableMap(analysis, variableAllocator), metadata, session, sqlParser) + .process(query, context); + } + private ConnectorTableMetadata createTableMetadata(QualifiedObjectName table, List columns, Map propertyExpressions, Map, Expression> parameters, Optional comment) { Map properties = metadata.getTablePropertyManager().getProperties( @@ -589,14 +716,14 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - private static List getOutputTableColumns(RelationPlan plan, Optional> columnAliases) + private static List getOutputTableColumns(Metadata metadata, Session session, String catalogName, RelationPlan plan, Optional> columnAliases) { ImmutableList.Builder columns = ImmutableList.builder(); int aliasPosition = 0; for (Field field : plan.getDescriptor().getVisibleFields()) { String columnName = columnAliases.isPresent() ? columnAliases.get().get(aliasPosition).getValue() : field.getName().get(); columns.add(ColumnMetadata.builder() - .setName(columnName) + .setName(metadata.normalizeIdentifier(session, catalogName, columnName)) .setType(field.getType()) .build()); aliasPosition++; @@ -640,7 +767,7 @@ private static Optional getPartitioningSchemeForTableWrite(O List outputLayout = new ArrayList<>(variables); partitioningScheme = Optional.of(new PartitioningScheme( - Partitioning.create(tableLayout.get().getPartitioning(), partitionFunctionArguments), + Partitioning.create(tableLayout.get().getPartitioning().orElse(FIXED_HASH_DISTRIBUTION), partitionFunctionArguments), outputLayout, tableLayout.get().getWriterPolicy() == MULTIPLE_WRITERS_PER_PARTITION_ALLOWED)); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/NodePartitioningManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/NodePartitioningManager.java index 24e635bb7fe39..3de6783a91a41 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/NodePartitioningManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/NodePartitioningManager.java @@ -39,8 +39,7 @@ import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collections; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java index bd4418d337e1f..6467b03df04d3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/OutputExtractor.java @@ -13,23 +13,21 @@ */ package com.facebook.presto.sql.planner; -import com.facebook.presto.execution.Column; import com.facebook.presto.execution.Output; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.google.common.base.VerifyException; -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.List; import java.util.Optional; -import static com.facebook.presto.spi.connector.ConnectorCommitHandle.EMPTY_COMMIT_OUTPUT; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; public class OutputExtractor @@ -46,8 +44,8 @@ public Optional extractOutput(PlanNode root) visitor.getConnectorId(), visitor.getSchemaTableName().getSchemaName(), visitor.getSchemaTableName().getTableName(), - EMPTY_COMMIT_OUTPUT, - Optional.of(ImmutableList.copyOf(visitor.getColumns())))); + visitor.getOutputColumns(), + Optional.empty())); } private class Visitor @@ -55,7 +53,7 @@ private class Visitor { private ConnectorId connectorId; private SchemaTableName schemaTableName; - private List columns = new ArrayList<>(); + private Optional> outputColumns = Optional.empty(); @Override public Void visitTableWriter(TableWriterNode node, Void context) @@ -65,11 +63,33 @@ public Void visitTableWriter(TableWriterNode node, Void context) checkState(schemaTableName == null || schemaTableName.equals(writerTarget.getSchemaTableName()), "cannot have more than a single create, insert or delete in a query"); schemaTableName = writerTarget.getSchemaTableName(); + outputColumns = writerTarget.getOutputColumns(); + return null; + } - checkArgument(node.getColumnNames().size() == node.getColumns().size(), "Column names and columns sizes must be equal"); - for (int i = 0; i < node.getColumnNames().size(); i++) { - columns.add(new Column(node.getColumnNames().get(i), node.getColumns().get(i).getType().toString())); + @Override + public Void visitTableFinish(TableFinishNode node, Void context) + { + if (node.getTarget().isPresent() && node.getTarget().get() instanceof TableWriterNode.DeleteHandle) { + TableWriterNode.DeleteHandle deleteHandle = (TableWriterNode.DeleteHandle) node.getTarget().get(); + connectorId = deleteHandle.getConnectorId(); + checkState(schemaTableName == null || schemaTableName.equals(deleteHandle.getSchemaTableName()), + "cannot have more than a single create, insert or delete in a query"); + schemaTableName = deleteHandle.getSchemaTableName(); + outputColumns = deleteHandle.getOutputColumns(); + return null; } + return super.visitTableFinish(node, context); + } + + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + TableWriterNode.WriterTarget writerTarget = node.getTarget().orElseThrow(() -> new VerifyException("target is absent")); + connectorId = writerTarget.getConnectorId(); + checkState(schemaTableName == null || schemaTableName.equals(writerTarget.getSchemaTableName()), + "cannot have more than a single create, insert or delete in a query"); + schemaTableName = writerTarget.getSchemaTableName(); return null; } @@ -100,9 +120,9 @@ public SchemaTableName getSchemaTableName() return schemaTableName; } - public List getColumns() + public Optional> getOutputColumns() { - return columns; + return outputColumns; } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java index 8729e4831c9e2..25cd1d1d8b57d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanBuilder.java @@ -25,6 +25,8 @@ import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; +import com.facebook.presto.sql.tree.NodeRef; import com.google.common.collect.ImmutableMap; import java.util.Map; @@ -47,6 +49,16 @@ public PlanBuilder(TranslationMap translations, PlanNode root) this.root = root; } + public static PlanBuilder newPlanBuilder( + RelationPlan plan, + Analysis analysis, + Map, VariableReferenceExpression> lambdaArguments) + { + return new PlanBuilder( + new TranslationMap(plan, analysis, lambdaArguments), + plan.getRoot()); + } + public TranslationMap copyTranslations() { TranslationMap translations = new TranslationMap(getRelationPlan(), getAnalysis(), getTranslations().getLambdaDeclarationToVariableMap()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java index 5e88d518a0cd0..1fdf87a5865b7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragment.java @@ -17,6 +17,7 @@ import com.facebook.airlift.json.Codec; import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.StatsAndCosts; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningHandle; import com.facebook.presto.spi.plan.PartitioningScheme; import com.facebook.presto.spi.plan.PlanFragmentId; @@ -30,8 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; import java.util.Optional; @@ -55,6 +55,10 @@ public class PlanFragment private final PartitioningScheme partitioningScheme; private final StageExecutionDescriptor stageExecutionDescriptor; + // Describes the ordering of the fragment's output data + // This is separate from partitioningScheme as ordering is orthogonal to partitioning + private final Optional outputOrderingScheme; + // Only true for output table writer and false for temporary table writers private final boolean outputTableWriterFragment; private final Optional statsAndCosts; @@ -74,6 +78,7 @@ public PlanFragment( @JsonProperty("partitioning") PartitioningHandle partitioning, @JsonProperty("tableScanSchedulingOrder") List tableScanSchedulingOrder, @JsonProperty("partitioningScheme") PartitioningScheme partitioningScheme, + @JsonProperty("outputOrderingScheme") Optional outputOrderingScheme, @JsonProperty("stageExecutionDescriptor") StageExecutionDescriptor stageExecutionDescriptor, @JsonProperty("outputTableWriterFragment") boolean outputTableWriterFragment, @JsonProperty("statsAndCosts") Optional statsAndCosts, @@ -85,6 +90,7 @@ public PlanFragment( this.partitioning = requireNonNull(partitioning, "partitioning is null"); this.tableScanSchedulingOrder = ImmutableList.copyOf(requireNonNull(tableScanSchedulingOrder, "tableScanSchedulingOrder is null")); this.stageExecutionDescriptor = requireNonNull(stageExecutionDescriptor, "stageExecutionDescriptor is null"); + this.outputOrderingScheme = requireNonNull(outputOrderingScheme, "outputOrderingScheme is null"); this.outputTableWriterFragment = outputTableWriterFragment; this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.jsonRepresentation = requireNonNull(jsonRepresentation, "jsonRepresentation is null"); @@ -157,6 +163,12 @@ public Optional getStatsAndCosts() return statsAndCosts; } + @JsonProperty + public Optional getOutputOrderingScheme() + { + return outputOrderingScheme; + } + @JsonProperty public Optional getJsonRepresentation() { @@ -188,6 +200,7 @@ private PlanFragment forTaskSerialization() id, root, variables, partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, Optional.empty(), @@ -247,6 +260,7 @@ public PlanFragment withBucketToPartition(Optional bucketToPartition) partitioning, tableScanSchedulingOrder, partitioningScheme.withBucketToPartition(bucketToPartition), + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, statsAndCosts, @@ -262,6 +276,7 @@ public PlanFragment withFixedLifespanScheduleGroupedExecution(List c partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.fixedLifespanScheduleGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -277,6 +292,7 @@ public PlanFragment withDynamicLifespanScheduleGroupedExecution(List partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.dynamicLifespanScheduleGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -292,6 +308,7 @@ public PlanFragment withRecoverableGroupedExecution(List capableTabl partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, StageExecutionDescriptor.recoverableGroupedExecution(capableTableScanNodes, totalLifespans), outputTableWriterFragment, statsAndCosts, @@ -307,6 +324,7 @@ public PlanFragment withSubPlan(PlanNode subPlan) partitioning, tableScanSchedulingOrder, partitioningScheme, + outputOrderingScheme, stageExecutionDescriptor, outputTableWriterFragment, statsAndCosts, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java index 584a72a4a077f..e526fdf4ed573 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenter.java @@ -31,8 +31,7 @@ import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Set; @@ -53,6 +52,7 @@ public class PlanFragmenter private final QueryManagerConfig config; private final PlanChecker distributedPlanChecker; private final PlanChecker singleNodePlanChecker; + private final boolean isPrestoOnSpark; @Inject public PlanFragmenter(Metadata metadata, NodePartitioningManager nodePartitioningManager, QueryManagerConfig queryManagerConfig, FeaturesConfig featuresConfig, PlanCheckerProviderManager planCheckerProviderManager) @@ -62,6 +62,7 @@ public PlanFragmenter(Metadata metadata, NodePartitioningManager nodePartitionin this.config = requireNonNull(queryManagerConfig, "queryManagerConfig is null"); this.distributedPlanChecker = new PlanChecker(requireNonNull(featuresConfig, "featuresConfig is null"), false, planCheckerProviderManager); this.singleNodePlanChecker = new PlanChecker(requireNonNull(featuresConfig, "featuresConfig is null"), true, planCheckerProviderManager); + this.isPrestoOnSpark = featuresConfig.isPrestoSparkExecutionEnvironment(); } public SubPlan createSubPlans(Session session, Plan plan, boolean noExchange, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) @@ -91,7 +92,7 @@ public SubPlan createSubPlans(Session session, Plan plan, boolean noExchange, Pl PlanNode root = SimplePlanRewriter.rewriteWith(fragmenter, plan.getRoot(), properties); SubPlan subPlan = fragmenter.buildRootFragment(root, properties); - return finalizeSubPlan(subPlan, config, metadata, nodePartitioningManager, session, noExchange, warningCollector, subPlan.getFragment().getPartitioning()); + return finalizeSubPlan(subPlan, config, metadata, nodePartitioningManager, session, noExchange, warningCollector, subPlan.getFragment().getPartitioning(), isPrestoOnSpark); } private static class Fragmenter diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java index 5f267f4c94e22..d58617ae19543 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanFragmenterUtils.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -32,7 +33,6 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; @@ -46,6 +46,7 @@ import static com.facebook.presto.SystemSessionProperties.getExchangeMaterializationStrategy; import static com.facebook.presto.SystemSessionProperties.getQueryMaxStageCount; import static com.facebook.presto.SystemSessionProperties.isForceSingleNodeOutput; +import static com.facebook.presto.SystemSessionProperties.isGroupedExecutionEnabled; import static com.facebook.presto.SystemSessionProperties.isRecoverableGroupedExecutionEnabled; import static com.facebook.presto.SystemSessionProperties.isSingleNodeExecutionEnabled; import static com.facebook.presto.spi.StandardErrorCode.QUERY_HAS_TOO_MANY_STAGES; @@ -92,12 +93,13 @@ public static SubPlan finalizeSubPlan( Session session, boolean noExchange, WarningCollector warningCollector, - PartitioningHandle partitioningHandle) + PartitioningHandle partitioningHandle, + boolean isPrestoOnSpark) { subPlan = reassignPartitioningHandleIfNecessary(metadata, session, subPlan, partitioningHandle); if (!noExchange && !isSingleNodeExecutionEnabled(session)) { // grouped execution is not supported for SINGLE_DISTRIBUTION or SINGLE_NODE_EXECUTION_ENABLED - subPlan = analyzeGroupedExecution(session, subPlan, false, metadata, nodePartitioningManager); + subPlan = analyzeGroupedExecution(session, subPlan, false, metadata, nodePartitioningManager, isPrestoOnSpark); } checkState(subPlan.getFragment().getId().getId() != ROOT_FRAGMENT_ID || !isForceSingleNodeOutput(session) || subPlan.getFragment().getPartitioning().isSingleNode(), "Root of PlanFragment is not single node"); @@ -148,10 +150,10 @@ private static void sanityCheckFragmentedPlan( * TODO: We should introduce "query section" and make recoverability analysis done at query section level. */ - private static SubPlan analyzeGroupedExecution(Session session, SubPlan subPlan, boolean parentContainsTableFinish, Metadata metadata, NodePartitioningManager nodePartitioningManager) + private static SubPlan analyzeGroupedExecution(Session session, SubPlan subPlan, boolean parentContainsTableFinish, Metadata metadata, NodePartitioningManager nodePartitioningManager, boolean isPrestoOnSpark) { PlanFragment fragment = subPlan.getFragment(); - GroupedExecutionTagger.GroupedExecutionProperties properties = fragment.getRoot().accept(new GroupedExecutionTagger(session, metadata, nodePartitioningManager), null); + GroupedExecutionTagger.GroupedExecutionProperties properties = fragment.getRoot().accept(new GroupedExecutionTagger(session, metadata, nodePartitioningManager, isGroupedExecutionEnabled(session), isPrestoOnSpark), null); if (properties.isSubTreeUseful()) { boolean preferDynamic = fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE) && new HashSet<>(properties.getCapableTableScanNodes()).containsAll(fragment.getTableScanSchedulingOrder()); @@ -185,7 +187,7 @@ private static SubPlan analyzeGroupedExecution(Session session, SubPlan subPlan, ImmutableList.Builder result = ImmutableList.builder(); boolean containsTableFinishNode = containsTableFinishNode(fragment); for (SubPlan child : subPlan.getChildren()) { - result.add(analyzeGroupedExecution(session, child, containsTableFinishNode, metadata, nodePartitioningManager)); + result.add(analyzeGroupedExecution(session, child, containsTableFinishNode, metadata, nodePartitioningManager, isPrestoOnSpark)); } return new SubPlan(fragment, result.build()); } @@ -231,6 +233,7 @@ private static SubPlan reassignPartitioningHandleIfNecessaryHelper(Metadata meta outputPartitioningScheme.isScaleWriters(), outputPartitioningScheme.getEncoding(), outputPartitioningScheme.getBucketToPartition()), + fragment.getOutputOrderingScheme(), fragment.getStageExecutionDescriptor(), fragment.isOutputTableWriterFragment(), fragment.getStatsAndCosts(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 1da8a7acfe6dd..8da6e1866cad2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -20,6 +20,7 @@ import com.facebook.presto.cost.TaskCountEstimator; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -28,9 +29,11 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.properties.LogicalPropertiesProviderImpl; +import com.facebook.presto.sql.planner.iterative.rule.AddDistinctForSemiJoinBuild; import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; import com.facebook.presto.sql.planner.iterative.rule.AddNotNullFiltersToJoinNode; +import com.facebook.presto.sql.planner.iterative.rule.CombineApproxDistinctFunctions; import com.facebook.presto.sql.planner.iterative.rule.CombineApproxPercentileFunctions; import com.facebook.presto.sql.planner.iterative.rule.CreatePartialTopN; import com.facebook.presto.sql.planner.iterative.rule.CrossJoinWithArrayContainsToInnerJoin; @@ -44,6 +47,7 @@ import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit; import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroSample; import com.facebook.presto.sql.planner.iterative.rule.ExtractSpatialJoins; +import com.facebook.presto.sql.planner.iterative.rule.ExtractSystemTableFilterRuleSet; import com.facebook.presto.sql.planner.iterative.rule.GatherAndMergeWindows; import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; @@ -53,12 +57,14 @@ import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions; import com.facebook.presto.sql.planner.iterative.rule.LeftJoinNullFilterToSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.LeftJoinWithArrayContainsToEquiJoinCondition; +import com.facebook.presto.sql.planner.iterative.rule.MaterializedViewRewrite; import com.facebook.presto.sql.planner.iterative.rule.MergeDuplicateAggregation; import com.facebook.presto.sql.planner.iterative.rule.MergeFilters; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithDistinct; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithTopN; import com.facebook.presto.sql.planner.iterative.rule.MergeLimits; +import com.facebook.presto.sql.planner.iterative.rule.MinMaxByToWindowFunction; import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout; import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections; @@ -72,12 +78,15 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneLimitColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneMarkDistinctColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneMergeSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneOrderByInAggregation; import com.facebook.presto.sql.planner.iterative.rule.PruneOutputColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneProjectColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneRedundantProjectionAssignments; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneUpdateSourceColumns; @@ -103,6 +112,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughGroupId; import com.facebook.presto.sql.planner.iterative.rule.PushTableWriteThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushTopNThroughUnion; +import com.facebook.presto.sql.planner.iterative.rule.RandomizeSourceKeyInSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.RemoveCrossJoinWithConstantInput; import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; @@ -116,6 +126,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantLimit; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSort; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSortColumns; +import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopN; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.RemoveTrivialFilters; @@ -123,11 +134,14 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveUnreferencedScalarLateralNodes; import com.facebook.presto.sql.planner.iterative.rule.RemoveUnsupportedDynamicFilters; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; +import com.facebook.presto.sql.planner.iterative.rule.ReplaceConditionalApproxDistinct; import com.facebook.presto.sql.planner.iterative.rule.RewriteAggregationIfToFilter; import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseExpressionPredicate; import com.facebook.presto.sql.planner.iterative.rule.RewriteCaseToMap; import com.facebook.presto.sql.planner.iterative.rule.RewriteConstantArrayContainsToInExpression; +import com.facebook.presto.sql.planner.iterative.rule.RewriteExcludeColumnsFunctionToProjection; import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject; +import com.facebook.presto.sql.planner.iterative.rule.RewriteRowExpressions; import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation; import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides; import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule; @@ -145,6 +159,8 @@ import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToLeftEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformDistinctInnerJoinToRightEarlyOutJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformExistsApplyToLateralNode; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionProcessorToTableScan; +import com.facebook.presto.sql.planner.iterative.rule.TransformTableFunctionToTableFunctionProcessor; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToDistinctInnerJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin; import com.facebook.presto.sql.planner.iterative.rule.TransformUncorrelatedLateralToJoin; @@ -180,9 +196,12 @@ import com.facebook.presto.sql.planner.optimizations.ReplaceConstantVariableReferencesWithConstants; import com.facebook.presto.sql.planner.optimizations.ReplicateSemiJoinInDelete; import com.facebook.presto.sql.planner.optimizations.RewriteIfOverAggregation; +import com.facebook.presto.sql.planner.optimizations.RewriteWriterTarget; import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.ShardJoins; import com.facebook.presto.sql.planner.optimizations.SimplifyPlanWithEmptyInput; +import com.facebook.presto.sql.planner.optimizations.SortMergeJoinOptimizer; +import com.facebook.presto.sql.planner.optimizations.SortedExchangeRule; import com.facebook.presto.sql.planner.optimizations.StatsRecordingPlanOptimizer; import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences; @@ -190,12 +209,11 @@ import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.MBeanExporter; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.List; import java.util.Optional; import java.util.Set; @@ -227,7 +245,8 @@ public PlanOptimizers( PartitioningProviderManager partitioningProviderManager, FeaturesConfig featuresConfig, ExpressionOptimizerManager expressionOptimizerManager, - TaskManagerConfig taskManagerConfig) + TaskManagerConfig taskManagerConfig, + AccessControl accessControl) { this(metadata, sqlParser, @@ -244,7 +263,8 @@ public PlanOptimizers( partitioningProviderManager, featuresConfig, expressionOptimizerManager, - taskManagerConfig); + taskManagerConfig, + accessControl); } @PostConstruct @@ -277,7 +297,8 @@ public PlanOptimizers( PartitioningProviderManager partitioningProviderManager, FeaturesConfig featuresConfig, ExpressionOptimizerManager expressionOptimizerManager, - TaskManagerConfig taskManagerConfig) + TaskManagerConfig taskManagerConfig, + AccessControl accessControl) { this.exporter = exporter; ImmutableList.Builder builder = ImmutableList.builder(); @@ -295,6 +316,7 @@ public PlanOptimizers( new PruneJoinChildrenColumns(), new PruneJoinColumns(), new PruneUpdateSourceColumns(), + new PruneMergeSourceColumns(), new PruneMarkDistinctColumns(), new PruneOutputColumns(), new PruneProjectColumns(), @@ -304,10 +326,19 @@ public PlanOptimizers( new PruneValuesColumns(), new PruneWindowColumns(), new PruneLimitColumns(), + new PruneTableFunctionProcessorColumns(), + new PruneTableFunctionProcessorSourceColumns(), new PruneTableScanColumns()); builder.add(new LogicalCteOptimizer(metadata)); + builder.add(new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new MaterializedViewRewrite(metadata, accessControl)))); + IterativeOptimizer inlineProjections = new IterativeOptimizer( metadata, ruleStats, @@ -402,6 +433,7 @@ public PlanOptimizers( .addAll(predicatePushDownRules) .addAll(columnPruningRules) .addAll(ImmutableSet.of( + new TransformTableFunctionToTableFunctionProcessor(metadata), new MergeDuplicateAggregation(metadata.getFunctionAndTypeManager()), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), @@ -424,6 +456,9 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata.getFunctionAndTypeManager()), new PruneOrderByInAggregation(metadata.getFunctionAndTypeManager()), + new RemoveRedundantTableFunctionProcessor(), // must run after TransformTableFunctionToTableFunctionProcessor + new RewriteExcludeColumnsFunctionToProjection(), // must run after TransformTableFunctionToTableFunctionProcessor + new TransformTableFunctionProcessorToTableScan(metadata), // must run after TransformTableFunctionToTableFunctionProcessor new RewriteSpatialPartitioningAggregation(metadata))) .build()), new IterativeOptimizer( @@ -453,6 +488,12 @@ public PlanOptimizers( new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()), simplifyRowExpressionOptimizer, new ReplaceConstantVariableReferencesWithConstants(metadata.getFunctionAndTypeManager()), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new ReplaceConditionalApproxDistinct(metadata.getFunctionAndTypeManager()))), new IterativeOptimizer( metadata, ruleStats, @@ -530,6 +571,9 @@ public PlanOptimizers( ImmutableSet.>builder().add(new RemoveRedundantCastToVarcharInJoinClause(metadata.getFunctionAndTypeManager())) .addAll(new RemoveMapCastRule(metadata.getFunctionAndTypeManager()).rules()).build())); + builder.add(new IterativeOptimizer(metadata, ruleStats, statsCalculator, estimatedExchangesCostCalculator, + new RewriteRowExpressions(expressionOptimizerManager).rules())); + builder.add(new IterativeOptimizer( metadata, ruleStats, @@ -544,6 +588,13 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.of(new CombineApproxPercentileFunctions(metadata.getFunctionAndTypeManager())))); + builder.add(new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new CombineApproxDistinctFunctions(metadata.getFunctionAndTypeManager())))); + // In RewriteIfOverAggregation, we can only optimize when the aggregation output is used in only one IF expression, and not used in any other expressions (excluding // identity assignments). Hence we need to simplify projection assignments to combine/inline expressions in assignments so as to identify the candidate IF expressions. builder.add( @@ -592,6 +643,12 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new LeftJoinNullFilterToSemiJoin(metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new AddDistinctForSemiJoinBuild())), new KeyBasedSampler(metadata), new IterativeOptimizer( metadata, @@ -653,6 +710,12 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new MinMaxByToWindowFunction(metadata.getFunctionAndTypeManager()))), new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits prefilterForLimitingAggregation, @@ -726,6 +789,14 @@ public PlanOptimizers( // Pass a supplier so that we pickup connector optimizers that are installed later builder.add( new ApplyConnectorOptimization(() -> planOptimizerManager.getOptimizers(LOGICAL)), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of( + new RewriteFilterWithExternalFunctionToProject(metadata.getFunctionAndTypeManager()), + new PlanRemoteProjections(metadata.getFunctionAndTypeManager()))), projectionPushDown, new PruneUnreferencedOutputs()); @@ -738,13 +809,17 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantIdentityProjections(), new PruneRedundantProjectionAssignments())), + ImmutableSet.of( + new RemoveRedundantIdentityProjections(), + new PruneRedundantProjectionAssignments(), + // Re-run RemoveRedundantTableFunctionProcessor after SimplifyPlanWithEmptyInput to optimize empty input tables to empty ValueNode + new RemoveRedundantTableFunctionProcessor())), new PushdownSubfields(metadata, expressionOptimizerManager)); builder.add(predicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate builder.add(simplifyRowExpressionOptimizer); // Should be always run after PredicatePushDown - builder.add(new MetadataQueryOptimizer(metadata)); + builder.add(new MetadataQueryOptimizer(metadata, expressionOptimizerManager)); // This can pull up Filter and Project nodes from between Joins, so we need to push them down again builder.add( @@ -847,7 +922,15 @@ public PlanOptimizers( // to avoid temporarily having an invalid plan new DetermineSemiJoinDistributionType(costComparator, taskCountEstimator)))); - builder.add(new RandomizeNullKeyInOuterJoin(metadata.getFunctionAndTypeManager(), statsCalculator), + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of( + new RandomizeSourceKeyInSemiJoin(metadata.getFunctionAndTypeManager()))), + new RandomizeNullKeyInOuterJoin(metadata.getFunctionAndTypeManager(), statsCalculator), new PruneUnreferencedOutputs(), new IterativeOptimizer( metadata, @@ -907,7 +990,15 @@ public PlanOptimizers( // MergeJoinForSortedInputOptimizer can avoid the local exchange for a join operation // Should be placed after AddExchanges, but before AddLocalExchange // To replace the JoinNode to MergeJoin ahead of AddLocalExchange to avoid adding extra local exchange - builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); + builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled(), featuresConfig.isPrestoSparkExecutionEnvironment()), + new SortMergeJoinOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); + // SortedExchangeRule pushes sorts down to exchange nodes for distributed queries + // The rule is added unconditionally but only applies when: + // 1. Native execution is enabled + // 2. Running in Presto Spark execution environment + // 3. Session property sorted_exchange_enabled is true + builder.add(new SortedExchangeRule( + featuresConfig.isNativeExecutionEnabled() && featuresConfig.isPrestoSparkExecutionEnvironment())); // Optimizers above this don't understand local exchanges, so be careful moving this. builder.add(new AddLocalExchanges(metadata, featuresConfig.isNativeExecutionEnabled())); @@ -915,13 +1006,13 @@ public PlanOptimizers( // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers builder.add(new IterativeOptimizer( - metadata, - ruleStats, - statsCalculator, - costCalculator, - ImmutableSet.of( - new PushPartialAggregationThroughJoin(), - new PushPartialAggregationThroughExchange(metadata.getFunctionAndTypeManager(), featuresConfig.isNativeExecutionEnabled()))), + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of( + new PushPartialAggregationThroughJoin(), + new PushPartialAggregationThroughExchange(metadata.getFunctionAndTypeManager(), featuresConfig.isNativeExecutionEnabled()))), // MergePartialAggregationsWithFilter should immediately follow PushPartialAggregationThroughExchange new MergePartialAggregationsWithFilter(metadata.getFunctionAndTypeManager()), new IterativeOptimizer( @@ -960,6 +1051,14 @@ public PlanOptimizers( // Pass after connector optimizer, as it relies on connector optimizer to identify empty input tables and convert them to empty ValuesNode builder.add(new SimplifyPlanWithEmptyInput()); + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + new ExtractSystemTableFilterRuleSet(metadata.getFunctionAndTypeManager()).rules())); + // DO NOT add optimizers that change the plan shape (computations) after this point // Precomputed hashes - this assumes that partitioning will not change @@ -974,6 +1073,8 @@ public PlanOptimizers( featuresConfig.isPrestoSparkExecutionEnvironment())))); builder.add(new MetadataDeleteOptimizer(metadata)); + builder.add(new RewriteWriterTarget(metadata, accessControl)); + // TODO: consider adding a formal final plan sanitization optimizer that prepares the plan for transmission/execution/logging // TODO: figure out how to improve the set flattening optimizer so that it can run at any point this.planningTimeOptimizers = builder.build(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 74d182887b03d..4285cd778fa01 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -58,6 +58,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slices; import java.lang.invoke.MethodHandle; import java.util.HashSet; @@ -66,20 +67,24 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.ConnectorId.isInternalSystemConnector; import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.callOperator; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.type.TypeUtils.NULL_HASH_CODE; @@ -201,6 +206,9 @@ public static PlanNode addOverrideProjection(PlanNode source, PlanNodeIdAllocato || source.getOutputVariables().stream().distinct().count() != source.getOutputVariables().size()) { return source; } + if (source instanceof ProjectNode && ((ProjectNode) source).getAssignments().getMap().equals(variableMap)) { + return source; + } Assignments.Builder assignmentsBuilder = Assignments.builder(); assignmentsBuilder.putAll(source.getOutputVariables().stream().collect(toImmutableMap(identity(), x -> variableMap.containsKey(x) ? variableMap.get(x) : x))); return new ProjectNode(source.getSourceLocation(), planNodeIdAllocator.getNextId(), source, assignmentsBuilder.build(), LOCAL); @@ -537,4 +545,52 @@ public static boolean isConstant(RowExpression expression, Type type, Object val ((ConstantExpression) expression).getType() == type && ((ConstantExpression) expression).getValue() == value; } + + /** + * Creates a randomized join key expression to mitigate data skew caused by null values. + * The function uses COALESCE to replace null values with a randomized string containing + * the provided prefix and a random number within the partition count range. + * + * @param session the session containing configuration like hash partition count + * @param functionAndTypeManager function and type manager for creating expressions + * @param keyExpression the original join key expression + * @param prefix prefix to use for the randomized value (e.g., "l" for left, "r" for right) + * @return a RowExpression that coalesces the original key with a randomized replacement + */ + public static RowExpression randomizeJoinKey(Session session, FunctionAndTypeManager functionAndTypeManager, RowExpression keyExpression, String prefix) + { + int partitionCount = getHashPartitionCount(session); + RowExpression randomNumber = call( + functionAndTypeManager, + "random", + BIGINT, + constant((long) partitionCount, BIGINT)); + RowExpression randomNumberVarchar = call("CAST", functionAndTypeManager.lookupCast(CAST, randomNumber.getType(), VARCHAR), VARCHAR, randomNumber); + RowExpression concatExpression = call(functionAndTypeManager, + "concat", + VARCHAR, + ImmutableList.of(constant(Slices.utf8Slice(prefix), VARCHAR), randomNumberVarchar)); + + RowExpression castToVarchar = keyExpression; + // Only do cast if keyExpression is not VARCHAR type + if (!(keyExpression.getType() instanceof VarcharType)) { + castToVarchar = call("CAST", functionAndTypeManager.lookupCast(CAST, keyExpression.getType(), VARCHAR), VARCHAR, keyExpression); + } + return new SpecialFormExpression(COALESCE, VARCHAR, ImmutableList.of(castToVarchar, concatExpression)); + } + + public static RowExpression getVariableHash(List inputVariables, FunctionAndTypeManager functionAndTypeManager) + { + checkArgument(!inputVariables.isEmpty()); + List hashExpressionList = inputVariables.stream().map(keyVariable -> + callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList()); + RowExpression hashExpression = hashExpressionList.get(0); + if (hashExpressionList.size() > 1) { + hashExpression = orNullHashCode(hashExpression); + for (int i = 1; i < hashExpressionList.size(); ++i) { + hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i))); + } + } + return hashExpression; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index 71149793bd3d8..c7deb6a31cdec 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -14,16 +14,21 @@ package com.facebook.presto.sql.planner; import com.facebook.presto.Session; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.TableMetadata; import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; import com.facebook.presto.spi.plan.AggregationNode; @@ -33,6 +38,7 @@ import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.LimitNode; +import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; @@ -41,6 +47,7 @@ import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; @@ -54,34 +61,54 @@ import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Delete; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FieldReference; import com.facebook.presto.sql.tree.FrameBound; import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.GenericLiteral; import com.facebook.presto.sql.tree.GroupingOperation; import com.facebook.presto.sql.tree.IfExpression; import com.facebook.presto.sql.tree.IntervalLiteral; +import com.facebook.presto.sql.tree.IsNotNullPredicate; +import com.facebook.presto.sql.tree.IsNullPredicate; +import com.facebook.presto.sql.tree.Join; import com.facebook.presto.sql.tree.LambdaArgumentDeclaration; import com.facebook.presto.sql.tree.LambdaExpression; +import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.sql.tree.Merge; +import com.facebook.presto.sql.tree.MergeCase; +import com.facebook.presto.sql.tree.MergeInsert; +import com.facebook.presto.sql.tree.MergeUpdate; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.NotExpression; +import com.facebook.presto.sql.tree.NullLiteral; import com.facebook.presto.sql.tree.Offset; import com.facebook.presto.sql.tree.OrderBy; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.sql.tree.Query; import com.facebook.presto.sql.tree.QuerySpecification; +import com.facebook.presto.sql.tree.Row; +import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.SymbolReference; +import com.facebook.presto.sql.tree.Table; import com.facebook.presto.sql.tree.Update; +import com.facebook.presto.sql.tree.WhenClause; import com.facebook.presto.sql.tree.Window; import com.facebook.presto.sql.tree.WindowFrame; import com.google.common.base.VerifyException; @@ -104,8 +131,12 @@ import static com.facebook.presto.SystemSessionProperties.isSkipRedundantSort; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; import static com.facebook.presto.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; @@ -113,8 +144,9 @@ import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.isNumericType; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation; -import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.PlanBuilder.newPlanBuilder; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; import static com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme; import static com.facebook.presto.sql.planner.PlannerUtils.toSortOrder; @@ -143,7 +175,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -253,8 +285,6 @@ public DeleteNode plan(Delete node) { RelationType descriptor = analysis.getOutputDescriptor(node.getTable()); TableHandle handle = analysis.getTableHandle(node.getTable()); - ColumnHandle rowIdHandle = metadata.getDeleteRowIdColumnHandle(session, handle); - Type rowIdType = metadata.getColumnMetadata(session, handle, rowIdHandle).getType(); // add table columns ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); @@ -268,11 +298,16 @@ public DeleteNode plan(Delete node) } // add rowId column - Field rowIdField = Field.newUnqualified(node.getLocation(), Optional.empty(), rowIdType); - VariableReferenceExpression rowIdVariable = variableAllocator.newVariable(getSourceLocation(node), "$rowId", rowIdField.getType()); - outputVariablesBuilder.add(rowIdVariable); - columns.put(rowIdVariable, rowIdHandle); - fields.add(rowIdField); + Optional rowIdHandle = metadata.getDeleteRowIdColumn(session, handle); + Optional rowIdField = Optional.empty(); + if (rowIdHandle.isPresent()) { + Type rowIdType = metadata.getColumnMetadata(session, handle, rowIdHandle.get()).getType(); + rowIdField = Optional.of(Field.newUnqualified(node.getLocation(), Optional.empty(), rowIdType)); + VariableReferenceExpression rowIdVariable = variableAllocator.newVariable(getSourceLocation(node), "$rowId", rowIdType); + outputVariablesBuilder.add(rowIdVariable); + columns.put(rowIdVariable, rowIdHandle.get()); + fields.add(rowIdField.get()); + } // create table scan List outputVariables = outputVariablesBuilder.build(); @@ -290,12 +325,14 @@ public DeleteNode plan(Delete node) } // create delete node - VariableReferenceExpression rowId = new VariableReferenceExpression(Optional.empty(), builder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(rowIdField))).getName(), rowIdField.getType()); + PlanBuilder finalBuilder = builder; + Optional rowId = rowIdField.map(f -> + new VariableReferenceExpression(Optional.empty(), finalBuilder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(f))).getName(), f.getType())); List deleteNodeOutputVariables = ImmutableList.of( variableAllocator.newVariable("partialrows", BIGINT), variableAllocator.newVariable("fragment", VARBINARY)); - return new DeleteNode(getSourceLocation(node), idAllocator.getNextId(), builder.getRoot(), rowId, deleteNodeOutputVariables, Optional.empty()); + return new DeleteNode(getSourceLocation(node), idAllocator.getNextId(), finalBuilder.getRoot(), rowId, deleteNodeOutputVariables, Optional.empty()); } public UpdateNode plan(Update node) @@ -312,11 +349,10 @@ public UpdateNode plan(Update node) .map(Map.Entry::getValue) .collect(toImmutableList()); handle = metadata.beginUpdate(session, handle, updatedColumns); - ColumnHandle rowIdHandle = metadata.getUpdateRowIdColumnHandle(session, handle, updatedColumns); - Type rowIdType = metadata.getColumnMetadata(session, handle, rowIdHandle).getType(); + String catalogName = handle.getConnectorId().getCatalogName(); List targetColumnNames = node.getAssignments().stream() - .map(assignment -> assignment.getName().getValue()) + .map(assignment -> metadata.normalizeIdentifier(session, catalogName, assignment.getName().getValue())) .collect(toImmutableList()); // Create lists of columnnames and SET expressions, in table column order @@ -338,11 +374,16 @@ public UpdateNode plan(Update node) List orderedColumnValues = orderedColumnValuesBuilder.build(); // add rowId column - Field rowIdField = Field.newUnqualified(node.getLocation(), Optional.empty(), rowIdType); - VariableReferenceExpression rowIdVariable = variableAllocator.newVariable(getSourceLocation(node), "$rowId", rowIdField.getType()); - outputVariablesBuilder.add(rowIdVariable); - columns.put(rowIdVariable, rowIdHandle); - fields.add(rowIdField); + Optional rowIdHandle = metadata.getUpdateRowIdColumn(session, handle, updatedColumns); + Optional rowIdField = Optional.empty(); + if (rowIdHandle.isPresent()) { + Type rowIdType = metadata.getColumnMetadata(session, handle, rowIdHandle.get()).getType(); + rowIdField = Optional.of(Field.newUnqualified(node.getLocation(), Optional.empty(), rowIdType)); + VariableReferenceExpression rowIdVariable = variableAllocator.newVariable(getSourceLocation(node), "$rowId", rowIdType); + outputVariablesBuilder.add(rowIdVariable); + columns.put(rowIdVariable, rowIdHandle.get()); + fields.add(rowIdField.get()); + } // create table scan List outputVariables = outputVariablesBuilder.build(); @@ -365,8 +406,10 @@ public UpdateNode plan(Update node) ImmutableList.Builder updatedColumnValuesBuilder = ImmutableList.builder(); orderedColumnValues.forEach(columnValue -> updatedColumnValuesBuilder.add(planAndMappings.get(columnValue))); - VariableReferenceExpression rowId = new VariableReferenceExpression(Optional.empty(), builder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(rowIdField))).getName(), rowIdField.getType()); - updatedColumnValuesBuilder.add(rowId); + PlanBuilder finalBuilder = builder; + Optional rowId = rowIdField.map(f -> + new VariableReferenceExpression(Optional.empty(), finalBuilder.translate(new FieldReference(relationPlan.getDescriptor().indexOf(f))).getName(), f.getType())); + rowId.ifPresent(r -> updatedColumnValuesBuilder.add(r)); List outputs = ImmutableList.of( variableAllocator.newVariable("partialrows", BIGINT), @@ -379,12 +422,347 @@ public UpdateNode plan(Update node) return new UpdateNode( getSourceLocation(node), idAllocator.getNextId(), - builder.getRoot(), + finalBuilder.getRoot(), rowId, updatedColumnValuesBuilder.build(), outputs); } + /** + * Plan a MERGE statement. The MERGE statement is processed by creating a RIGHT JOIN between the target table and the source. + * Example of converting a MERGE statement into a SELECT statement with a RIGHT JOIN: + * Merge statement: + * MERGE INTO t USING s + * ON (t. = s.) + * WHEN NOT MATCHED THEN + * INSERT (column1, column2, column3) + * VALUES (s.column1, s.column2, s.column3); + * WHEN MATCHED THEN + * UPDATE SET = s. + t., + * = s. + * + * SELECT statement with a RIGHT JOIN created to process the previous MERGE statement: + * SELECT + * CASE + * WHEN NOT MATCHED THEN + * -- Insert column values: operation INSERT=1, case_number=0 + * row(s.column1, s.column2, s.column3, 1, 0) + * WHEN MATCHED THEN + * -- Update column values: operation UPDATE=3, case_number=1 + * row(t.column1, s.column1 + t.column1, s.column2, 3, 1) + * ELSE + * -- Null values for no case matched: operation=-1, case_number=-1 + * row(null, null, null, -1, -1) + * END + * FROM + * t RIGHT JOIN s + * ON t. = s.; + * + * @param mergeStmt the MERGE statement to plan into a MergeWriterNode. + * @return a MergeWriterNode that represents the plan for the MERGE statement. + */ + public MergeWriterNode plan(Merge mergeStmt) + { + // The goal of this method is to build the following MERGE INTO execution plan: + // + // MergeWriterNode : Write the merge results into the target table. + // | + // MergeProcessorNode : Processes the result of the RIGHT JOIN to identify which rows need to be inserted and which need to be updated. + // | + // FilterDuplicateMatchingRows : Look for marked rows in the previous step. If it finds one, then it stops MERGE execution and returns an error. + // | + // MarkDistinctNode : Look for target rows that matched more than one source row and mark them. + // | + // RightEquiJoin : Run a RIGHT JOIN as a first step to process the MERGE INTO command. + // / \ + // / \ + // AssignUniqueID \ : Assign a unique ID to each row in the target table. + // | \ + // TableScan TableScan : Read data from the target and source tables. + // | | + // (target table) (source table) + + Analysis.MergeAnalysis mergeAnalysis = analysis.getMergeAnalysis().orElseThrow(() -> new IllegalArgumentException("analysis.getMergeAnalysis() isn't present")); + + // Make the plan for the merge target table scan + RelationPlan targetTableRelationPlan = new RelationPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, sqlParser) + .process(mergeStmt.getTarget(), sqlPlannerContext); + + // Assign a unique id to every target table row + VariableReferenceExpression targetUniqueIdColumnVariable = variableAllocator.newVariable("unique_id", BIGINT); + AssignUniqueId assignUniqueRowIdToTargetTable = new AssignUniqueId(getSourceLocation(mergeStmt), idAllocator.getNextId(), targetTableRelationPlan.getRoot(), targetUniqueIdColumnVariable); + RelationPlan relationPlanWithUniqueRowIds = new RelationPlan( + assignUniqueRowIdToTargetTable, + mergeAnalysis.getTargetTableScope(), + targetTableRelationPlan.getFieldMappings()); + + RelationPlan sourceRelationPlan = new RelationPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, sqlParser) + .process(mergeStmt.getSource(), sqlPlannerContext); + + RelationPlan joinRelationPlan = new RelationPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, sqlParser) + .planJoin( + coerceIfNecessary(analysis, mergeStmt.getPredicate(), mergeStmt.getPredicate()), + Join.Type.RIGHT, mergeAnalysis.getJoinScope(), + relationPlanWithUniqueRowIds, sourceRelationPlan, + mergeStmt, sqlPlannerContext); + + // Build the SearchedCaseExpression that creates the project "merge_row" + PlanBuilder joinSubPlan = newPlanBuilder(joinRelationPlan, analysis, lambdaDeclarationToVariableMap); + + // CASE + // WHEN (unique_id IS NULL) THEN row(column1, column2, ..., present=false, operation INSERT=1, case_number=0) + // WHEN (unique_id IS NOT NULL) THEN row(column1, column2, ..., present=true, operation UPDATE=3, case_number=1) + // ELSE row(null, null, ..., false, -1, -1) + // END + ImmutableList.Builder whenClauses = ImmutableList.builder(); + for (int caseNumber = 0; caseNumber < mergeStmt.getMergeCases().size(); caseNumber++) { + MergeCase mergeCase = mergeStmt.getMergeCases().get(caseNumber); + + ImmutableList.Builder joinResultBuilder = ImmutableList.builder(); + List> mergeCaseColumnsHandles = mergeAnalysis.getMergeCaseColumnHandles(); + List mergeCaseSetColumns = mergeCaseColumnsHandles.get(caseNumber); + for (ColumnHandle targetColumnHandle : mergeAnalysis.getTargetColumnHandles()) { + int index = mergeCaseSetColumns.indexOf(targetColumnHandle); + int fieldNumber = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(targetColumnHandle), "Field number for ColumnHandle is null"); + Expression expression; + if (index >= 0) { // Update column value + Expression setExpression = mergeCase.getSetExpressions().get(index); + joinSubPlan = subqueryPlanner.handleSubqueries(joinSubPlan, setExpression, mergeStmt, sqlPlannerContext); + expression = joinSubPlan.rewrite(setExpression); + expression = coerceIfNecessary(analysis, setExpression, expression); + expression = checkNotNullColumns(targetColumnHandle, expression, fieldNumber, mergeAnalysis); + } + else { // Insert column value + expression = createSymbolReference(relationPlanWithUniqueRowIds.getFieldMappings().get(fieldNumber)); + + if (mergeCase instanceof MergeInsert) { + expression = checkNotNullColumns(targetColumnHandle, expression, fieldNumber, mergeAnalysis); + } + } + joinResultBuilder.add(expression); + } + + // Add the operation number + joinResultBuilder.add(new GenericLiteral("TINYINT", String.valueOf(getMergeCaseOperationNumber(mergeCase)))); + + // Add the mergeStmt case number, needed by MarkDistinct + joinResultBuilder.add(new GenericLiteral("INTEGER", String.valueOf(caseNumber))); + + // Build the match condition for the MERGE case + SymbolReference targetUniqueIdColumnSymbolReference = createSymbolReference(targetUniqueIdColumnVariable); + Expression mergeCondition = mergeCase instanceof MergeInsert ? + new IsNullPredicate(targetUniqueIdColumnSymbolReference) : new IsNotNullPredicate(targetUniqueIdColumnSymbolReference); + + whenClauses.add(new WhenClause(mergeCondition, new Row(joinResultBuilder.build()))); + } + + // Build the "else" clause for the SearchedCaseExpression + ImmutableList.Builder joinElseBuilder = ImmutableList.builder(); + mergeAnalysis.getTargetColumnsMetadata().forEach(columnMetadata -> + joinElseBuilder.add(new Cast(new NullLiteral(), columnMetadata.getType().getDisplayName()))); + + // The operation number column value: -1 + joinElseBuilder.add(new GenericLiteral("TINYINT", "-1")); + // The case number column value: -1 + joinElseBuilder.add(new GenericLiteral("INTEGER", "-1")); + + SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(joinElseBuilder.build()))); + + RowType mergeRowType = createMergeRowType(mergeAnalysis.getTargetColumnsMetadata()); + Table targetTable = mergeAnalysis.getTargetTable(); + FieldReference targetTableRowIdReference = analysis.getRowIdField(targetTable); + + VariableReferenceExpression targetTableRowIdColumnVariable = relationPlanWithUniqueRowIds.getFieldMappings().get(targetTableRowIdReference.getFieldIndex()); + VariableReferenceExpression mergeRowColumnVariable = variableAllocator.newVariable("merge_row", mergeRowType); + + // Project the partitioning variables, the merge_row, the rowId, and the unique_id variable. + Assignments.Builder projectionAssignmentsBuilder = Assignments.builder(); + projectionAssignmentsBuilder.put(targetUniqueIdColumnVariable, targetUniqueIdColumnVariable); + projectionAssignmentsBuilder.put(targetTableRowIdColumnVariable, targetTableRowIdColumnVariable); + projectionAssignmentsBuilder.put(mergeRowColumnVariable, rowExpression(caseExpression, sqlPlannerContext)); + + ProjectNode joinSubPlanProject = new ProjectNode( + idAllocator.getNextId(), + joinSubPlan.getRoot(), + projectionAssignmentsBuilder.build()); + + // Now add a column for the case_number, gotten from the merge_row + SubscriptExpression caseNumberExpression = new SubscriptExpression( + createSymbolReference(mergeRowColumnVariable), new LongLiteral(Long.toString(mergeRowType.getFields().size()))); + + VariableReferenceExpression caseNumberVariable = variableAllocator.newVariable("case_number", INTEGER); + + ProjectNode joinProjectNode = new ProjectNode( + joinSubPlanProject.getSourceLocation(), + idAllocator.getNextId(), + joinSubPlanProject, + Assignments.builder() + .putAll(joinSubPlanProject.getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))) + .put(caseNumberVariable, rowExpression(caseNumberExpression, sqlPlannerContext)) + .build(), + LOCAL); + + // Mark distinct combinations of the unique_id value and the case_number + VariableReferenceExpression isDistinctVariable = variableAllocator.newVariable("is_distinct", BOOLEAN); + MarkDistinctNode markDistinctNode = new MarkDistinctNode( + getSourceLocation(mergeStmt), idAllocator.getNextId(), joinProjectNode, isDistinctVariable, + ImmutableList.of(targetUniqueIdColumnVariable, caseNumberVariable), Optional.empty()); + + // Raise an error if unique_id variable is non-null and the unique_id/case_number combination was not distinct + Expression multipleMatchesExpression = new IfExpression( + LogicalBinaryExpression.and( + new NotExpression(createSymbolReference(isDistinctVariable)), + new IsNotNullPredicate(createSymbolReference(targetUniqueIdColumnVariable))), + new Cast( + new FunctionCall( + QualifiedName.of("presto", "default", "fail"), + ImmutableList.of(new Cast(new StringLiteral( + "MERGE INTO operation failed for target table '" + targetTable.getName() + "'. " + + "One or more rows in the target table matched multiple source rows. " + + "The MERGE INTO command requires each target row to match at most one source row. " + + "Please review the ON condition to ensure it produces a one-to-one or one-to-none match."), + VARCHAR.getTypeSignature().toString()))), + BOOLEAN.getTypeSignature().toString()), + TRUE_LITERAL); + + FilterNode filterMultipleMatches = new FilterNode(getSourceLocation(mergeStmt), idAllocator.getNextId(), + markDistinctNode, rowExpression(multipleMatchesExpression, sqlPlannerContext)); + + TableHandle targetTableHandle = analysis.getTableHandle(targetTable); + RowChangeParadigm rowChangeParadigm = metadata.getRowChangeParadigm(session, targetTableHandle); + Type targetTableRowIdColumnType = analysis.getType(analysis.getRowIdField(targetTable)); + TableMetadata targetTableMetadata = metadata.getTableMetadata(session, targetTableHandle); + + List targetColumnsDataTypes = targetTableMetadata.getMetadata().getColumns().stream() + .filter(column -> !column.isHidden()) + .map(ColumnMetadata::getType) + .collect(toImmutableList()); + + TableWriterNode.MergeParadigmAndTypes mergeParadigmAndTypes = + new TableWriterNode.MergeParadigmAndTypes(rowChangeParadigm, targetColumnsDataTypes, targetTableRowIdColumnType); + + Optional mergeHandle = Optional.of(metadata.beginMerge(session, targetTableHandle)); + TableWriterNode.MergeTarget mergeTarget = + new TableWriterNode.MergeTarget(targetTableHandle, mergeHandle, targetTableMetadata.getTable(), mergeParadigmAndTypes); + + ImmutableList.Builder mergeColumnVariablesBuilder = ImmutableList.builder(); + for (ColumnHandle columnHandle : mergeAnalysis.getTargetColumnHandles()) { + int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(columnHandle), "Could not find field number for column handle"); + mergeColumnVariablesBuilder.add(relationPlanWithUniqueRowIds.getFieldMappings().get(fieldIndex)); + } + List mergeColumnVariables = mergeColumnVariablesBuilder.build(); + + // Variable to specify whether the MERGE INTO statement should insert a new row or update an existing one. + // Operations defined in ConnectorMergeSink: INSERT_OPERATION_NUMBER and UPDATE_OPERATION_NUMBER. + VariableReferenceExpression mergeOperationVariable = variableAllocator.newVariable("operation", TINYINT); + + // Variable to indicate whether the row is an insert resulting from an UPDATE or an INSERT. + // The RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW will use this information. + // Values: + // 1: if this row is an insert derived from an UPDATE + // 0: if this row is an insert derived from an INSERT + VariableReferenceExpression insertFromUpdateVariable = variableAllocator.newVariable("insert_from_update", TINYINT); + + List mergeProjectedVariables = ImmutableList.builder() + .addAll(mergeColumnVariables) + .add(mergeOperationVariable) + .add(targetTableRowIdColumnVariable) + .add(insertFromUpdateVariable) + .build(); + + MergeProcessorNode mergeProcessorNode = new MergeProcessorNode( + getSourceLocation(mergeStmt), + idAllocator.getNextId(), + filterMultipleMatches, + mergeTarget, + targetTableRowIdColumnVariable, + mergeRowColumnVariable, + mergeColumnVariables, + mergeProjectedVariables); + + List mergeWriterOutputs = ImmutableList.of( + variableAllocator.newVariable("partialrows", BIGINT), + variableAllocator.newVariable("fragment", VARBINARY)); + + return new MergeWriterNode( + getSourceLocation(mergeStmt), + idAllocator.getNextId(), + mergeProcessorNode, + mergeTarget, + mergeProjectedVariables, + mergeWriterOutputs); + } + + /** + * Method to create a CoalesceExpression that triggers an error if the query attempts to insert a NULL value + * into a non-nullable column. The same applies if the query tries to update a non-nullable column with a NULL value. + * + * @return The same expression if the column allows null values; otherwise, a CoalesceExpression that + * prevents inserting NULL values into non-nullable columns. + */ + private static Expression checkNotNullColumns( + ColumnHandle targetColumnHandle, + Expression expression, + int fieldNumber, + Analysis.MergeAnalysis mergeAnalysis) + { + // If the current column allows NULL values, then the method returns the original expression. + if (!mergeAnalysis.getNonNullableColumnHandles().contains(targetColumnHandle)) { + return expression; + } + + ColumnMetadata columnMetadata = mergeAnalysis.getTargetColumnsMetadata().get(fieldNumber); + + // Build a coalesce expression that returns an error when the original expression value is NULL. + return new CoalesceExpression(expression, + new Cast( + new FunctionCall( + QualifiedName.of("presto", "default", "fail"), + ImmutableList.of(new Cast(new StringLiteral( + "NULL value not allowed for NOT NULL column. Table: " + mergeAnalysis.getTargetTable().getName() + + " Column: " + columnMetadata.getName()), + VARCHAR.getTypeSignature().toString()))), + columnMetadata.getType().getTypeSignature().toString())); + } + + private static int getMergeCaseOperationNumber(MergeCase mergeCase) + { + if (mergeCase instanceof MergeInsert) { + return INSERT_OPERATION_NUMBER; + } + if (mergeCase instanceof MergeUpdate) { + return UPDATE_OPERATION_NUMBER; + } + throw new IllegalArgumentException("Unrecognized MergeCase: " + mergeCase); + } + + private static RowType createMergeRowType(List allColumnsMetadata) + { + // Create the RowType that holds all column values + List fields = new ArrayList<>(); + for (ColumnMetadata columnMetadata : allColumnsMetadata) { + fields.add(new RowType.Field(Optional.empty(), columnMetadata.getType())); + } + + fields.add(new RowType.Field(Optional.empty(), TINYINT)); // operation_number + fields.add(new RowType.Field(Optional.empty(), INTEGER)); // case_number + return RowType.from(fields); + } + + public static Expression coerceIfNecessary(Analysis analysis, Expression original, Expression rewritten) + { + Type coercion = analysis.getCoercion(original); + if (coercion == null) { + return rewritten; + } + + return new Cast( + rewritten, + coercion.getDisplayName(), + false, + analysis.isTypeOnlyCoercion(original)); + } + private Optional getIdForLeftTableScan(PlanNode node) { if (node instanceof TableScanNode) { @@ -513,7 +891,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -1102,7 +1480,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp // First, append filter to validate offset values. They mustn't be negative or null. VariableReferenceExpression offsetSymbol = coercions.get(frameOffset.get()); Expression zeroOffset = zeroOfType(TypeProvider.viewOf(variableAllocator.getVariables()).get(offsetSymbol)); - FunctionHandle fail = metadata.getFunctionAndTypeManager().resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf("presto.default.fail"), fromTypes(VARCHAR)); + CatalogSchemaName defaultNamespace = metadata.getFunctionAndTypeManager().getDefaultNamespace(); Expression predicate = new IfExpression( new ComparisonExpression( GREATER_THAN_OR_EQUAL, @@ -1111,7 +1489,7 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp TRUE_LITERAL, new Cast( new FunctionCall( - QualifiedName.of("presto", "default", "fail"), + QualifiedName.of(defaultNamespace.getCatalogName(), defaultNamespace.getSchemaName(), "fail"), ImmutableList.of(new Cast(new StringLiteral("Window frame offset value must not be negative or null"), VARCHAR.getTypeSignature().toString()))), BOOLEAN.getTypeSignature().toString())); subPlan = subPlan.withNewRoot(new FilterNode( @@ -1335,15 +1713,18 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - private static List toSymbolReferences(List variables) + public static List toSymbolReferences(List variables) { return variables.stream() - .map(variable -> new SymbolReference( - variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), - variable.getName())) + .map(QueryPlanner::toSymbolReference) .collect(toImmutableList()); } + public static SymbolReference toSymbolReference(VariableReferenceExpression variable) + { + return new SymbolReference(variable.getSourceLocation().map(location -> new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index a32672c70674b..e07fadcbf4666 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -15,12 +15,15 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.TableHandle; @@ -29,16 +32,20 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.MaterializedViewScanNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -48,13 +55,17 @@ import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; +import com.facebook.presto.sql.analyzer.ResolvedField; import com.facebook.presto.sql.analyzer.Scope; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; @@ -82,8 +93,10 @@ import com.facebook.presto.sql.tree.Row; import com.facebook.presto.sql.tree.SampledRelation; import com.facebook.presto.sql.tree.SetOperation; +import com.facebook.presto.sql.tree.SortItem; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; +import com.facebook.presto.sql.tree.TableFunctionInvocation; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -91,11 +104,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.ListMultimap; import com.google.common.collect.UnmodifiableIterator; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.ArrayList; import java.util.HashMap; @@ -112,6 +125,7 @@ import static com.facebook.presto.SystemSessionProperties.getQueryAnalyzerTimeout; import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.StandardErrorCode.QUERY_PLANNING_TIMEOUT; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; @@ -176,6 +190,11 @@ public RelationPlan process(Node node, @Nullable SqlPlannerContext context) @Override protected RelationPlan visitTable(Table node, SqlPlannerContext context) { + Optional materializedViewInfo = analysis.getMaterializedViewInfo(node); + if (materializedViewInfo.isPresent()) { + return planMaterializedView(node, materializedViewInfo.get(), context); + } + NamedQuery namedQuery = analysis.getNamedQuery(node); Scope scope = analysis.getScope(node); @@ -183,7 +202,7 @@ protected RelationPlan visitTable(Table node, SqlPlannerContext context) if (namedQuery != null) { String cteName = node.getName().toString(); if (namedQuery.isFromView()) { - cteName = createQualifiedObjectName(session, node, node.getName()).toString(); + cteName = createQualifiedObjectName(session, node, node.getName(), metadata).toString(); } RelationPlan subPlan = process(namedQuery.getQuery(), context); if (getCteMaterializationStrategy(session).equals(NONE)) { @@ -298,6 +317,240 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan, SqlPlannerCo return new RelationPlan(planBuilder.getRoot(), plan.getScope(), newMappings.build()); } + private RelationPlan planMaterializedView(Table node, Analysis.MaterializedViewInfo materializedViewInfo, SqlPlannerContext context) + { + RelationPlan dataTablePlan = process(materializedViewInfo.getDataTable(), context); + RelationPlan viewQueryPlan = process(materializedViewInfo.getViewQuery(), context); + + Scope scope = analysis.getScope(node); + + QualifiedObjectName materializedViewName = materializedViewInfo.getMaterializedViewName(); + + RelationType dataTableDescriptor = dataTablePlan.getDescriptor(); + List dataTableVariables = dataTablePlan.getFieldMappings(); + List viewQueryVariables = viewQueryPlan.getFieldMappings(); + + checkArgument( + dataTableDescriptor.getVisibleFieldCount() == viewQueryVariables.size(), + "Materialized view %s has mismatched field counts: data table has %s visible fields but view query has %s fields", + materializedViewName, + dataTableDescriptor.getVisibleFieldCount(), + viewQueryVariables.size()); + + ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); + ImmutableMap.Builder dataTableMappingsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder viewQueryMappingsBuilder = ImmutableMap.builder(); + + for (Field field : dataTableDescriptor.getVisibleFields()) { + int fieldIndex = dataTableDescriptor.indexOf(field); + VariableReferenceExpression dataTableVar = dataTableVariables.get(fieldIndex); + VariableReferenceExpression viewQueryVar = viewQueryVariables.get(fieldIndex); + + VariableReferenceExpression outputVar = variableAllocator.newVariable(dataTableVar); + outputVariablesBuilder.add(outputVar); + + dataTableMappingsBuilder.put(outputVar, dataTableVar); + viewQueryMappingsBuilder.put(outputVar, viewQueryVar); + } + + List outputVariables = outputVariablesBuilder.build(); + Map dataTableMappings = dataTableMappingsBuilder.build(); + Map viewQueryMappings = viewQueryMappingsBuilder.build(); + + MaterializedViewScanNode mvScanNode = new MaterializedViewScanNode( + getSourceLocation(node.getLocation()), + idAllocator.getNextId(), + dataTablePlan.getRoot(), + viewQueryPlan.getRoot(), + materializedViewName, + dataTableMappings, + viewQueryMappings, + outputVariables); + + return new RelationPlan(mvScanNode, scope, outputVariables); + } + + /** + * Processes a {@code TableFunctionInvocation} node to construct and return a {@link RelationPlan}. + * This involves preparing the necessary plan nodes, variable mappings, and associated properties + * to represent the execution plan for the invoked table function. + * + * @param node The {@code TableFunctionInvocation} syntax tree node to be processed. + * @param context The SQL planner context used for planning and analysis tasks. + * @return A {@link RelationPlan} encapsulating the execution plan for the table function invocation. + */ + @Override + protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) + { + Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().orElse("field"), field.getType())) + .collect(toImmutableList()); + + outputVariables.addAll(properOutputs); + + processTableArguments(context, functionAnalysis, outputVariables, sources, sourceProperties); + + PlanNode root = new TableFunctionNode( + idAllocator.getNextId(), + functionAnalysis.getFunctionName(), + functionAnalysis.getArguments(), + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), + new TableFunctionHandle( + functionAnalysis.getConnectorId(), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); + } + + private void processTableArguments(SqlPlannerContext context, + Analysis.TableFunctionInvocationAnalysis functionAnalysis, + ImmutableList.Builder outputVariables, + ImmutableList.Builder sources, + ImmutableList.Builder sourceProperties) + { + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + int[] fieldIndexForVisibleColumn = getFieldIndexesForVisibleColumns(sourcePlan); + + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(column -> fieldIndexForVisibleColumn[column]) + .map(sourcePlan::getVariable) + .collect(toImmutableList()); + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + for (Expression partitionColumn : partitioningColumns) { + if (!sourcePlanBuilder.canTranslate(partitionColumn)) { + ResolvedField partition = sourcePlan.getScope().tryResolveField(partitionColumn).orElseThrow(() -> new PrestoException(INVALID_PLAN_ERROR, "Missing equivalent alias")); + sourcePlanBuilder.getTranslations().put(partitionColumn, sourcePlan.getVariable(partition.getRelationFieldIndex())); + } + } + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } + + // order by + Optional orderBy = getOrderingScheme(tableArgument, sourcePlanBuilder, sourcePlan); + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); + addPassthroughColumns(outputVariables, tableArgument, sourcePlan, specification, passThroughColumns, sourcePlanBuilder); + sources.add(sourcePlanBuilder.getRoot()); + + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); + } + } + + private static int[] getFieldIndexesForVisibleColumns(RelationPlan sourcePlan) + { + // required columns are a subset of visible columns of the source. remap required column indexes to field indexes in source relation type. + RelationType sourceRelationType = sourcePlan.getScope().getRelationType(); + int[] fieldIndexForVisibleColumn = new int[sourceRelationType.getVisibleFieldCount()]; + int visibleColumn = 0; + for (int i = 0; i < sourceRelationType.getAllFieldCount(); i++) { + if (!sourceRelationType.getFieldByIndex(i).isHidden()) { + fieldIndexForVisibleColumn[visibleColumn] = i; + visibleColumn++; + } + } + return fieldIndexForVisibleColumn; + } + + private static Optional getOrderingScheme(Analysis.TableArgumentAnalysis tableArgument, PlanBuilder sourcePlanBuilder, RelationPlan sourcePlan) + { + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + List sortItems = tableArgument.getOrderBy().get().getSortItems(); + + // Ensure all ORDER BY columns can be translated (populate missing translations if needed) + for (SortItem sortItem : sortItems) { + Expression sortKey = sortItem.getSortKey(); + if (!sourcePlanBuilder.canTranslate(sortKey)) { + Optional resolvedField = sourcePlan.getScope().tryResolveField(sortKey); + resolvedField.ifPresent(field -> sourcePlanBuilder.getTranslations().put( + sortKey, + sourcePlan.getVariable(field.getRelationFieldIndex()))); + } + } + + // The ordering symbols are coerced + List coerced = sortItems.stream() + .map(SortItem::getSortKey) + .map(sourcePlanBuilder::translate) + .collect(toImmutableList()); + + List sortOrders = sortItems.stream() + .map(PlannerUtils::toSortOrder) + .collect(toImmutableList()); + + orderBy = Optional.of(PlannerUtils.toOrderingScheme(coerced, sortOrders)); + } + return orderBy; + } + + private static void addPassthroughColumns(ImmutableList.Builder outputVariables, + Analysis.TableArgumentAnalysis tableArgument, RelationPlan sourcePlan, + Optional specification, + ImmutableList.Builder passThroughColumns, + PlanBuilder sourcePlanBuilder) + { + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(variable -> new PassThroughColumn(variable, partitionBy.contains(variable))) + .forEach(passThroughColumns::add); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + .map(sourcePlanBuilder::translate) + // the original symbols for partitioning columns, not coerced + .forEach(variable -> { + outputVariables.add(variable); + passThroughColumns.add(new PassThroughColumn(variable, true)); + }); + } + } + @Override protected RelationPlan visitAliasedRelation(AliasedRelation node, SqlPlannerContext context) { @@ -370,6 +623,11 @@ protected RelationPlan visitJoin(Join node, SqlPlannerContext context) return planJoinUsing(node, leftPlan, rightPlan, context); } + return planJoin(analysis.getJoinCriteria(node), node.getType(), analysis.getScope(node), leftPlan, rightPlan, node, context); + } + + public RelationPlan planJoin(Expression criteria, Join.Type type, Scope scope, RelationPlan leftPlan, RelationPlan rightPlan, Node node, SqlPlannerContext context) + { PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan); PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan); @@ -383,12 +641,10 @@ protected RelationPlan visitJoin(Join node, SqlPlannerContext context) List complexJoinExpressions = new ArrayList<>(); List postInnerJoinConditions = new ArrayList<>(); - if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) { - Expression criteria = analysis.getJoinCriteria(node); - - RelationType left = analysis.getOutputDescriptor(node.getLeft()); - RelationType right = analysis.getOutputDescriptor(node.getRight()); + RelationType left = leftPlan.getDescriptor(); + RelationType right = rightPlan.getDescriptor(); + if (type != Join.Type.CROSS && type != Join.Type.IMPLICIT) { List leftComparisonExpressions = new ArrayList<>(); List rightComparisonExpressions = new ArrayList<>(); List joinConditionComparisonOperators = new ArrayList<>(); @@ -396,7 +652,7 @@ protected RelationPlan visitJoin(Join node, SqlPlannerContext context) for (Expression conjunct : ExpressionUtils.extractConjuncts(criteria)) { conjunct = ExpressionUtils.normalize(conjunct); - if (!isEqualComparisonExpression(conjunct) && node.getType() != INNER) { + if (!isEqualComparisonExpression(conjunct) && type != INNER) { complexJoinExpressions.add(conjunct); continue; } @@ -460,7 +716,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende PlanNode root = new JoinNode( getSourceLocation(node), idAllocator.getNextId(), - JoinNodeUtils.typeConvert(node.getType()), + JoinNodeUtils.typeConvert(type), leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), @@ -474,7 +730,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende Optional.empty(), ImmutableMap.of()); - if (node.getType() != INNER) { + if (type != INNER) { for (Expression complexExpression : complexJoinExpressions) { Set inPredicates = subqueryPlanner.collectInPredicateSubqueries(complexExpression, node); if (!inPredicates.isEmpty()) { @@ -483,10 +739,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende } } - if (node.getType() == LEFT || node.getType() == RIGHT) { - RelationType left = analysis.getOutputDescriptor(node.getLeft()); - RelationType right = analysis.getOutputDescriptor(node.getRight()); - + if (type == LEFT || type == RIGHT) { for (Expression complexJoinExpression : complexJoinExpressions) { Set dependencies = VariablesExtractor.extractNames(complexJoinExpression, analysis.getColumnReferences()); // If there are no dependencies, no subqueries, or if the expression references both inputs, @@ -502,7 +755,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende // If the subquery references the right input, those variables will remain unresolved and caught in NoIdentifierLeftChecker leftPlanBuilder = subqueryPlanner.handleUncorrelatedSubqueries(leftPlanBuilder, ImmutableList.of(complexJoinExpression), node, context); } - else if (node.getType() == LEFT && !dependencies.stream().allMatch(left::canResolve)) { + else if (type == LEFT && !dependencies.stream().allMatch(left::canResolve)) { rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, complexJoinExpression, node, context); } else { @@ -517,19 +770,19 @@ else if (node.getType() == LEFT && !dependencies.stream().allMatch(left::canReso } } - RelationPlan intermediateRootRelationPlan = new RelationPlan(root, analysis.getScope(node), outputs); + RelationPlan intermediateRootRelationPlan = new RelationPlan(root, scope, outputs); TranslationMap translationMap = new TranslationMap(intermediateRootRelationPlan, analysis, lambdaDeclarationToVariableMap); translationMap.setFieldMappings(outputs); translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations()); translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations()); - if (node.getType() != INNER && !complexJoinExpressions.isEmpty()) { + if (type != INNER && !complexJoinExpressions.isEmpty()) { Expression joinedFilterCondition = ExpressionUtils.and(complexJoinExpressions); Expression rewrittenFilterCondition = translationMap.rewrite(joinedFilterCondition); root = new JoinNode( getSourceLocation(node), idAllocator.getNextId(), - JoinNodeUtils.typeConvert(node.getType()), + JoinNodeUtils.typeConvert(type), leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), @@ -544,7 +797,7 @@ else if (node.getType() == LEFT && !dependencies.stream().allMatch(left::canReso ImmutableMap.of()); } - if (node.getType() == INNER) { + if (type == INNER) { // rewrite all the other conditions using output variables from left + right plan node. PlanBuilder rootPlanBuilder = new PlanBuilder(translationMap, root); rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, node, context); @@ -561,7 +814,7 @@ else if (node.getType() == LEFT && !dependencies.stream().allMatch(left::canReso } } - return new RelationPlan(root, analysis.getScope(node), outputs); + return new RelationPlan(root, scope, outputs); } private RelationPlan planJoinUsing(Join node, RelationPlan left, RelationPlan right, SqlPlannerContext context) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java index a4498a5434f27..3b24606a808ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RowExpressionInterpreter.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner; import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.FullConnectorSession; import com.facebook.presto.client.FailureInfo; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.BlockBuilder; @@ -60,6 +61,7 @@ import java.util.Set; import java.util.stream.Stream; +import static com.facebook.presto.SystemSessionProperties.getMaxSerializableObjectSize; import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.IntegerType.INTEGER; @@ -111,7 +113,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.instanceOf; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.slice.Slices.utf8Slice; @@ -122,7 +123,6 @@ public class RowExpressionInterpreter { - private static final long MAX_SERIALIZABLE_OBJECT_SIZE = 1000; private final RowExpression expression; private final ConnectorSession session; private final Level optimizationLevel; @@ -132,15 +132,6 @@ public class RowExpressionInterpreter private final FunctionResolution resolution; private final Visitor visitor; - - public static Object evaluateConstantRowExpression(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session) - { - // evaluate the expression - Object result = new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED).evaluate(); - verify(!(result instanceof RowExpression), "RowExpression interpreter returned an unresolved expression"); - return result; - } - public static RowExpressionInterpreter rowExpressionInterpreter(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session) { return new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED); @@ -273,7 +264,7 @@ public Object visitCall(CallExpression node, Object context) (!functionMetadata.isDeterministic() || hasUnresolvedValue(argumentValues) || isDynamicFilter(node) || - resolution.isFailFunction(functionHandle))) { + resolution.isJavaBuiltInFailFunction(functionHandle))) { return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments())); } @@ -779,7 +770,7 @@ private RowExpression toRowExpression(Object value, RowExpression originalRowExp private boolean isSerializable(Object value, Type type) { // If value is already RowExpression, constant values contained inside should already have been made serializable. Otherwise, we make sure the object is small and serializable. - return value instanceof RowExpression || (isSupportedLiteralType(type) && estimatedSizeInBytes(value) <= MAX_SERIALIZABLE_OBJECT_SIZE); + return value instanceof RowExpression || (isSupportedLiteralType(type) && estimatedSizeInBytes(value) <= getMaxSerializableObjectSize(((FullConnectorSession) session).getSession())); } private SpecialCallResult tryHandleArrayConstructor(CallExpression callExpression, List argumentValues) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java index d34e01cda201a..471c797c426a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java @@ -14,17 +14,19 @@ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.NoSuchElementException; import java.util.function.Consumer; public class SchedulingOrderVisitor @@ -88,5 +90,17 @@ public Void visitTableScan(TableScanNode node, Consumer schedulingOr schedulingOrder.accept(node.getId()); return null; } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Consumer schedulingOrder) + { + if (!node.getSource().isPresent()) { + schedulingOrder.accept(node.getId()); + } + else { + node.getSource().orElseThrow(NoSuchElementException::new).accept(this, schedulingOrder); + } + return null; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java index b51bb17a8b56b..387271aa94e16 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SplitSourceFactory.java @@ -23,10 +23,12 @@ import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; @@ -40,32 +42,35 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.split.SampledSplitSource; import com.facebook.presto.split.SplitSource; import com.facebook.presto.split.SplitSourceProvider; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.function.Supplier; import static com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.GROUPED_SCHEDULING; @@ -280,6 +285,21 @@ public Map visitRowNumber(RowNumberNode node, Context c return node.getSource().accept(this, context); } + @Override + public Map visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + if (!node.getSource().isPresent()) { + // this is a source node, so produce splits + SplitSource splitSource = splitSourceProvider.getSplits( + session, + node.getHandle()); + splitSources.add(splitSource); + return ImmutableMap.of(node.getId(), splitSource); + } + + return node.getSource().orElseThrow(NoSuchElementException::new).accept(this, context); + } + @Override public Map visitTopNRowNumber(TopNRowNumberNode node, Context context) { @@ -346,6 +366,12 @@ public Map visitTableWriter(TableWriterNode node, Conte return node.getSource().accept(this, context); } + @Override + public Map visitCallDistributedProcedure(CallDistributedProcedureNode node, Context context) + { + return node.getSource().accept(this, context); + } + @Override public Map visitTableWriteMerge(TableWriterMergeNode node, Context context) { @@ -410,6 +436,18 @@ public Map visitPlan(PlanNode node, Context context) { throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } + + @Override + public Map visitMergeWriter(MergeWriterNode node, Context context) + { + return node.getSource().accept(this, context); + } + + @Override + public Map visitMergeProcessor(MergeProcessorNode node, Context context) + { + return node.getSource().accept(this, context); + } } private static class Context diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SubPlan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SubPlan.java index c7f30bbc2e549..775fa0cf225eb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SubPlan.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SubPlan.java @@ -17,8 +17,7 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.Multiset; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SystemPartitioningHandle.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SystemPartitioningHandle.java index 8a4174c9cefe7..47aed56286370 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SystemPartitioningHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SystemPartitioningHandle.java @@ -49,7 +49,7 @@ public final class SystemPartitioningHandle implements ConnectorPartitioningHandle { - private enum SystemPartitioning + public enum SystemPartitioning { SINGLE, FIXED, @@ -74,16 +74,31 @@ private static PartitioningHandle createSystemPartitioning(SystemPartitioning pa return new PartitioningHandle(Optional.empty(), Optional.empty(), new SystemPartitioningHandle(partitioning, function)); } + public static PartitioningHandle createSystemPartitioning(SystemPartitioning partitioning, SystemPartitionFunction function, int partitionCount) + { + return new PartitioningHandle(Optional.empty(), Optional.empty(), new SystemPartitioningHandle(partitioning, function, Optional.of(partitionCount))); + } + private final SystemPartitioning partitioning; private final SystemPartitionFunction function; + private final Optional partitionCount; @JsonCreator public SystemPartitioningHandle( @JsonProperty("partitioning") SystemPartitioning partitioning, - @JsonProperty("function") SystemPartitionFunction function) + @JsonProperty("function") SystemPartitionFunction function, + @JsonProperty("partitionCount") Optional partitionCount) { this.partitioning = requireNonNull(partitioning, "partitioning is null"); this.function = requireNonNull(function, "function is null"); + this.partitionCount = requireNonNull(partitionCount, "partitionCount is null"); + } + + public SystemPartitioningHandle( + SystemPartitioning partitioning, + SystemPartitionFunction function) + { + this(partitioning, function, Optional.empty()); } @JsonProperty @@ -98,6 +113,12 @@ public SystemPartitionFunction getFunction() return function; } + @JsonProperty + public Optional getPartitionCount() + { + return partitionCount; + } + @Override public boolean isSingleNode() { @@ -133,13 +154,14 @@ public boolean equals(Object o) } SystemPartitioningHandle that = (SystemPartitioningHandle) o; return partitioning == that.partitioning && - function == that.function; + function == that.function && + partitionCount.equals(that.partitionCount); } @Override public int hashCode() { - return Objects.hash(partitioning, function); + return Objects.hash(partitioning, function, partitionCount); } @Override @@ -162,7 +184,12 @@ else if (partitioning == SystemPartitioning.SINGLE) { nodes = nodeSelector.selectRandomNodes(1); } else if (partitioning == SystemPartitioning.FIXED) { - nodes = nodeSelector.selectRandomNodes(min(getHashPartitionCount(session), getMaxTasksPerStage(session))); + if (!partitionCount.isPresent()) { + nodes = nodeSelector.selectRandomNodes(min(getHashPartitionCount(session), getMaxTasksPerStage(session))); + } + else { + nodes = nodeSelector.selectRandomNodes(min(partitionCount.get(), min(getHashPartitionCount(session), getMaxTasksPerStage(session)))); + } } else { throw new IllegalArgumentException("Unsupported plan distribution " + partitioning); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 79768c665870a..7798d019181ee 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.cost.CachingCostProvider; @@ -39,7 +40,6 @@ import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.HashSet; import java.util.Iterator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 412e35666ddc8..42c860d36b2d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -26,7 +26,7 @@ public interface Lookup /** * Resolves a node by materializing GroupReference nodes * representing symbolic references to other nodes. This method - * is deprecated since is assumes group contains only one node. + * is deprecated since it assumes group contains only one node. *

* If the node is not a GroupReference, it returns the * argument as is. diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java index c9a105cc31202..e782c2847c731 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/Memo.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.HashMap; import java.util.HashSet; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/EquivalenceClassProperty.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/EquivalenceClassProperty.java index a0a369fcd16bd..f78e0c2f6d57e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/EquivalenceClassProperty.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/EquivalenceClassProperty.java @@ -196,7 +196,7 @@ public EquivalenceClassProperty addPredicate(RowExpression predicate, FunctionRe private static boolean isVariableEqualVariableOrConstant(FunctionResolution functionResolution, RowExpression expression) { if (expression instanceof CallExpression - && functionResolution.isEqualFunction(((CallExpression) expression).getFunctionHandle()) + && functionResolution.isEqualsFunction(((CallExpression) expression).getFunctionHandle()) && ((CallExpression) expression).getArguments().size() == 2) { RowExpression e1 = ((CallExpression) expression).getArguments().get(0); RowExpression e2 = ((CallExpression) expression).getArguments().get(1); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java new file mode 100644 index 0000000000000..2f30a29b5af69 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isAddDistinctBelowSemiJoinBuildEnabled; +import static com.facebook.presto.spi.plan.AggregationNode.isDistinct; +import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; + +/** + * Add a distinct aggregation under the build side of semi join, for example: + * Rewrite query from + *

+ *     - SemiJoin
+ *          l.col in r.col
+ *          - scan l
+ *              col
+ *          - scan r
+ *              col
+ * 
+ * into + *
+ *     - SemiJoin
+ *          l.col in r.col
+ *          - scan l
+ *              col
+ *          - Aggregate
+ *              group by r.col
+ *              - scan r
+ *                  col
+ * 
+ */ +public class AddDistinctForSemiJoinBuild + implements Rule +{ + @Override + public Pattern getPattern() + { + return semiJoin(); + } + + @Override + public boolean isEnabled(Session session) + { + return isAddDistinctBelowSemiJoinBuildEnabled(session); + } + + @Override + public Result apply(SemiJoinNode node, Captures captures, Context context) + { + PlanNode filterSource = context.getLookup().resolve(node.getFilteringSource()); + VariableReferenceExpression filteringSourceVariable = node.getFilteringSourceJoinVariable(); + if (isOutputDistinct(filterSource, filteringSourceVariable, context)) { + return Result.empty(); + } + AggregationNode.GroupingSetDescriptor groupingSetDescriptor = singleGroupingSet(ImmutableList.of(node.getFilteringSourceJoinVariable())); + AggregationNode distinctAggregation = new AggregationNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + filterSource, + ImmutableMap.of(), + groupingSetDescriptor, + ImmutableList.of(), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + return Result.ofPlanNode(node.replaceChildren(ImmutableList.of(node.getSource(), distinctAggregation))); + } + + boolean isOutputDistinct(PlanNode node, VariableReferenceExpression output, Context context) + { + if (node instanceof AggregationNode) { + AggregationNode aggregationNode = (AggregationNode) node; + return isDistinct(aggregationNode) && aggregationNode.getGroupingKeys().size() == 1 && aggregationNode.getGroupingKeys().contains(output); + } + else if (node instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) node; + RowExpression inputExpression = projectNode.getAssignments().get(output); + if (inputExpression instanceof VariableReferenceExpression) { + return isOutputDistinct(context.getLookup().resolve(projectNode.getSource()), (VariableReferenceExpression) inputExpression, context); + } + return false; + } + else if (node instanceof FilterNode) { + return isOutputDistinct(context.getLookup().resolve(((FilterNode) node).getSource()), output, context); + } + return false; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 8fc35402b5dda..45f075c88b1eb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.TaskCountEstimator; @@ -38,7 +39,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multiset; -import io.airlift.units.DataSize; import java.util.Collection; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxDistinctFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxDistinctFunctions.java new file mode 100644 index 0000000000000..6b45357ab0398 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CombineApproxDistinctFunctions.java @@ -0,0 +1,319 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.facebook.presto.SystemSessionProperties.isCombineApproxDistinctEnabled; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * For multiple approx_distinct() function calls on expressions of the same type, combine them using set_agg. + *

+ * From: + *

+ *   Aggregation (approx_distinct(e1), approx_distinct(e2), approx_distinct(e3))
+ * 
+ * To: + *
+ *   Project (coalesce(cardinality(array_distinct(remove_nulls(ads[1]))), 0),
+ *            coalesce(cardinality(array_distinct(remove_nulls(ads[2]))), 0),
+ *            coalesce(cardinality(array_distinct(remove_nulls(ads[3]))), 0))
+ *   - Project (ads <- transpose(ads_array))
+ *     - Aggregation (ads_array <- set_agg(array[e1, e2, e3]))
+ * 
+ *

+ */ +public class CombineApproxDistinctFunctions + implements Rule +{ + private static final String APPROX_DISTINCT = "approx_distinct"; + private static final String SET_AGG = "set_agg"; + private static final String ARRAY_CONSTRUCTOR = "array_constructor"; + private static final String ARRAY_TRANSPOSE = "array_transpose"; + private static final String ARRAY_DISTINCT = "array_distinct"; + private static final String REMOVE_NULLS = "remove_nulls"; + private static final String CARDINALITY = "cardinality"; + private static final String ELEMENT_AT = "element_at"; + + private final FunctionAndTypeManager functionAndTypeManager; + + public CombineApproxDistinctFunctions(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + private static final Pattern PATTERN = aggregation() + .matching(CombineApproxDistinctFunctions::hasMultipleApproxDistinct); + + private static boolean hasMultipleApproxDistinct(AggregationNode aggregation) + { + return aggregation.getAggregations().values().stream() + .filter(agg -> agg.getCall().getDisplayName().equals(APPROX_DISTINCT)).count() > 1; + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isCombineApproxDistinctEnabled(session); + } + + private static boolean aggregationCanMerge(AggregationNode.Aggregation aggregation1, AggregationNode.Aggregation aggregation2) + { + if (!aggregation1.getMask().equals(aggregation2.getMask()) + || !aggregation1.getOrderBy().equals(aggregation2.getOrderBy()) + || !aggregation1.getFilter().equals(aggregation2.getFilter()) + || aggregation1.isDistinct() != aggregation2.isDistinct()) { + return false; + } + CallExpression expression1 = aggregation1.getCall(); + CallExpression expression2 = aggregation2.getCall(); + if (expression1.getArguments().size() != expression2.getArguments().size()) { + return false; + } + boolean isSameType = expression1.getArguments().get(0).getType().equals(expression2.getArguments().get(0).getType()); + boolean isSameError = expression1.getArguments().size() == 1 || expression1.getArguments().get(1).equals(expression2.getArguments().get(1)); + return isSameType && isSameError; + } + + private static List> createMergeableAggregations(List candidateAggregations) + { + ImmutableList.Builder> result = ImmutableList.builder(); + Set mergedAggregation = new HashSet<>(); + for (int i = 0; i < candidateAggregations.size(); ++i) { + if (mergedAggregation.contains(candidateAggregations.get(i))) { + continue; + } + ImmutableList.Builder aggregationCanBeMerged = ImmutableList.builder(); + mergedAggregation.add(candidateAggregations.get(i)); + aggregationCanBeMerged.add(candidateAggregations.get(i)); + for (int j = i + 1; j < candidateAggregations.size(); ++j) { + if (mergedAggregation.contains(candidateAggregations.get(j))) { + continue; + } + if (aggregationCanMerge(candidateAggregations.get(i), candidateAggregations.get(j))) { + mergedAggregation.add(candidateAggregations.get(j)); + aggregationCanBeMerged.add(candidateAggregations.get(j)); + } + } + result.add(aggregationCanBeMerged.build()); + } + return result.build(); + } + + private CallExpression createArrayExpression(List aggregations) + { + List expressions = aggregations.stream() + .map(x -> x.getCall().getArguments().get(0)) + .collect(Collectors.toList()); + + return call( + functionAndTypeManager, + ARRAY_CONSTRUCTOR, + new ArrayType(expressions.get(0).getType()), + expressions); + } + + private AggregationNode.Aggregation createSetAggAggregation(List candidateList, VariableReferenceExpression arrayVariableReference) + { + AggregationNode.Aggregation aggregationBeforeMerge = candidateList.get(0); + Type elementType = aggregationBeforeMerge.getCall().getArguments().get(0).getType(); + Type arrayType = new ArrayType(elementType); + Type setType = new ArrayType(arrayType); + + CallExpression setAggCall = call( + functionAndTypeManager, + SET_AGG, + setType, + ImmutableList.of(arrayVariableReference)); + + return new AggregationNode.Aggregation( + setAggCall, + aggregationBeforeMerge.getFilter(), + aggregationBeforeMerge.getOrderBy(), + aggregationBeforeMerge.isDistinct(), + aggregationBeforeMerge.getMask()); + } + + @Override + public Result apply(AggregationNode aggregationNode, Captures captures, Context context) + { + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + + List approxDistinct = aggregationNode.getAggregations().values().stream().filter( + x -> x.getCall().getDisplayName().equals(APPROX_DISTINCT)).collect(Collectors.toList()); + + Map aggregationOccurrences = approxDistinct.stream().collect(Collectors.groupingBy(identity(), Collectors.counting())); + ImmutableList candidateApproxDistinct = approxDistinct.stream().filter(x -> aggregationOccurrences.get(x) == 1).collect(toImmutableList()); + + Map> sameTypeAggregations = + candidateApproxDistinct.stream().collect(Collectors.groupingBy( + x -> x.getCall().getArguments().get(0).getType(), LinkedHashMap::new, Collectors.toList())); + + ImmutableList.Builder> candidateLists = ImmutableList.builder(); + sameTypeAggregations.values().forEach(aggregationList -> { + candidateLists.addAll(createMergeableAggregations(aggregationList)); + }); + + List> candidateAggregationLists = + candidateLists.build().stream().filter(x -> x.size() > 1).collect(Collectors.toList()); + + if (candidateAggregationLists.isEmpty()) { + return Result.empty(); + } + + Set combinedAggregations = candidateAggregationLists.stream().flatMap(List::stream).collect(Collectors.toSet()); + Map aggregationVariableMap = new HashMap<>(); + Set combinedVariableReference = new HashSet<>(); + aggregationNode.getAggregations().forEach((variable, aggregation) -> { + if (combinedAggregations.contains(aggregation)) { + aggregationVariableMap.put(aggregation, variable); + combinedVariableReference.add(variable); + } + }); + + Assignments.Builder sourceProjectAssignments = Assignments.builder(); + Assignments.Builder intermediateProjectAssignments = Assignments.builder(); + Assignments.Builder outputProjectAssignments = Assignments.builder(); + + for (List candidateList : candidateAggregationLists) { + RowExpression arrayExpression = createArrayExpression(candidateList); + VariableReferenceExpression arrayVariableReference = context.getVariableAllocator().newVariable(arrayExpression); + sourceProjectAssignments.put(arrayVariableReference, arrayExpression); + + AggregationNode.Aggregation newAggregation = createSetAggAggregation(candidateList, arrayVariableReference); + VariableReferenceExpression setAggVariableReference = context.getVariableAllocator().newVariable(newAggregation.getCall()); + aggregations.put(setAggVariableReference, newAggregation); + + Type elementType = candidateList.get(0).getCall().getArguments().get(0).getType(); + Type arrayType = new ArrayType(elementType); + Type arrayArrayType = new ArrayType(arrayType); + + CallExpression transposeCall = call( + functionAndTypeManager, + ARRAY_TRANSPOSE, + arrayArrayType, + ImmutableList.of(setAggVariableReference)); + + VariableReferenceExpression transposeVariableReference = context.getVariableAllocator().newVariable(transposeCall); + intermediateProjectAssignments.put(transposeVariableReference, transposeCall); + + Map elementAtMap = + IntStream.range(0, candidateList.size()).boxed().collect(ImmutableMap.toImmutableMap( + x -> aggregationVariableMap.get(candidateList.get(x)), + x -> { + CallExpression elementAt = call( + functionAndTypeManager, + ELEMENT_AT, + arrayType, + ImmutableList.of(transposeVariableReference, constant((long) x + 1, BIGINT))); + CallExpression removeNullsCall = call( + functionAndTypeManager, + REMOVE_NULLS, + arrayType, + ImmutableList.of(elementAt)); + CallExpression arrayDistinctCall = call( + functionAndTypeManager, + ARRAY_DISTINCT, + arrayType, + ImmutableList.of(removeNullsCall)); + CallExpression cardinalityCall = call( + functionAndTypeManager, + CARDINALITY, + BIGINT, + ImmutableList.of(arrayDistinctCall)); + return new SpecialFormExpression( + COALESCE, + BIGINT, + ImmutableList.of(cardinalityCall, constant(0L, BIGINT))); + })); + outputProjectAssignments.putAll(elementAtMap); + } + + aggregationNode.getAggregations().forEach((key, value) -> { + if (!combinedVariableReference.contains(key)) { + aggregations.put(key, value); + } + }); + + aggregationNode.getOutputVariables().forEach(variable -> { + if (!combinedVariableReference.contains(variable)) { + outputProjectAssignments.put(variable, variable); + intermediateProjectAssignments.put(variable, variable); + } + }); + + aggregationNode.getSource().getOutputVariables().forEach(variable -> sourceProjectAssignments.put(variable, variable)); + + AggregationNode newAggregationNode = new AggregationNode( + aggregationNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + new ProjectNode(context.getIdAllocator().getNextId(), + aggregationNode.getSource(), sourceProjectAssignments.build()), + aggregations.build(), + aggregationNode.getGroupingSets(), + aggregationNode.getPreGroupedVariables(), + aggregationNode.getStep(), + aggregationNode.getHashVariable(), + aggregationNode.getGroupIdVariable(), + aggregationNode.getAggregationId()); + + ProjectNode intermediateProjectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + newAggregationNode, + intermediateProjectAssignments.build()); + + return Result.ofPlanNode( + new ProjectNode(context.getIdAllocator().getNextId(), + intermediateProjectNode, + outputProjectAssignments.build())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayContainsToInnerJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayContainsToInnerJoin.java index 5234be3cc6e40..004c6114f258d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayContainsToInnerJoin.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayContainsToInnerJoin.java @@ -24,12 +24,12 @@ import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayNotContainsToAntiJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayNotContainsToAntiJoin.java index f4060837c4c2d..61a16418b76cd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayNotContainsToAntiJoin.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithArrayNotContainsToAntiJoin.java @@ -27,13 +27,13 @@ import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithOrFilterToInnerJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithOrFilterToInnerJoin.java index 220347d826a27..b841b922c93a4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithOrFilterToInnerJoin.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CrossJoinWithOrFilterToInnerJoin.java @@ -30,13 +30,13 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java index 683755bb5da2e..fd70e91978fc0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineJoinDistributionType.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.LocalCostEstimate; @@ -36,7 +37,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; -import io.airlift.units.DataSize; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java index 34cf06d392872..dfe03747e5ba3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/DetermineSemiJoinDistributionType.java @@ -27,6 +27,7 @@ */ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.LocalCostEstimate; @@ -41,7 +42,6 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.iterative.Rule; import com.google.common.collect.Ordering; -import io.airlift.units.DataSize; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java index acf09957821b5..db28daede971c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -46,6 +46,7 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.SpatialJoinNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -56,7 +57,6 @@ import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.Rule.Context; import com.facebook.presto.sql.planner.iterative.Rule.Result; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.Expressions; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Splitter; @@ -67,7 +67,6 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -117,14 +116,14 @@ *

  • SELECT ... FROM a, b WHERE 15.5 > ST_Distance(b.geometry, a.geometry)
  • * *

    - * Joins expressed via ST_Contains and ST_Intersects functions must match all of + * Joins expressed via ST_Contains and ST_Intersects functions must match all * the following criteria: *

    * - arguments of the spatial function are non-scalar expressions; * - one of the arguments uses symbols from left side of the join, the other from right. *

    * Joins expressed via ST_Distance function must use less than or less than or equals operator - * to compare ST_Distance value with a radius and must match all of the following criteria: + * to compare ST_Distance value with a radius and must match all the following criteria: *

    * - arguments of the spatial function are non-scalar expressions; * - one of the arguments uses symbols from left side of the join, the other from right; @@ -161,6 +160,53 @@ public class ExtractSpatialJoins private final SplitManager splitManager; private final PageSourceManager pageSourceManager; + private enum VariableSide + { + Neither, + Left, + Right, + Both + } + + private static VariableSide inferVariableSide(RowExpression expression, JoinNode joinNode) + { + Set expressionVariables = extractUnique(expression); + + if (expressionVariables.isEmpty()) { + return VariableSide.Neither; + } + + List leftVariables = joinNode.getLeft().getOutputVariables(); + List rightVariables = joinNode.getRight().getOutputVariables(); + boolean leftContains = false; + boolean rightContains = false; + for (VariableReferenceExpression var : leftVariables) { + if (expressionVariables.contains(var)) { + leftContains = true; + break; + } + } + for (VariableReferenceExpression var : rightVariables) { + if (expressionVariables.contains(var)) { + rightContains = true; + break; + } + } + + if (leftContains && rightContains) { + return VariableSide.Both; + } + else if (leftContains) { + return VariableSide.Left; + } + else if (rightContains) { + return VariableSide.Right; + } + else { + return VariableSide.Neither; + } + } + public ExtractSpatialJoins(Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) { this.metadata = requireNonNull(metadata, "metadata is null"); @@ -305,18 +351,18 @@ private static Result tryCreateSpatialJoin( PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); - List leftVariables = leftNode.getOutputVariables(); - List rightVariables = rightNode.getOutputVariables(); - RowExpression radius; Optional newRadiusVariable; + VariableReferenceExpression radiusVariable; CallExpression newComparison; if (spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN || spatialComparisonMetadata.getOperatorType().get() == OperatorType.LESS_THAN_OR_EQUAL) { // ST_Distance(a, b) <= r radius = spatialComparison.getArguments().get(1); - Set radiusVariables = extractUnique(radius); - if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { - newRadiusVariable = newRadiusVariable(context, radius); + VariableSide radiusSide = inferVariableSide(radius, joinNode); + if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { + newRadiusVariable = newVariable(context, radius); + // If newRadiusVariable is empty, radius is VRE + radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); newComparison = new CallExpression( spatialComparison.getSourceLocation(), spatialComparison.getDisplayName(), @@ -331,9 +377,11 @@ private static Result tryCreateSpatialJoin( else { // r >= ST_Distance(a, b) radius = spatialComparison.getArguments().get(0); - Set radiusVariables = extractUnique(radius); - if (radiusVariables.isEmpty() || (rightVariables.containsAll(radiusVariables) && containsNone(leftVariables, radiusVariables))) { - newRadiusVariable = newRadiusVariable(context, radius); + VariableSide radiusSide = inferVariableSide(radius, joinNode); + if (radiusSide == VariableSide.Neither || radiusSide == VariableSide.Right) { + newRadiusVariable = newVariable(context, radius); + // If newRadiusVariable is empty, radius is VRE + radiusVariable = newRadiusVariable.orElseGet(() -> (VariableReferenceExpression) radius); OperatorType flippedOperatorType = flip(spatialComparisonMetadata.getOperatorType().get()); FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager()); newComparison = new CallExpression( @@ -365,7 +413,7 @@ private static Result tryCreateSpatialJoin( joinNode.getDistributionType(), joinNode.getDynamicFilters()); - return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager); + return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputVariables, (CallExpression) newComparison.getArguments().get(0), Optional.of(radiusVariable), metadata, splitManager, pageSourceManager); } private static Result tryCreateSpatialJoin( @@ -375,7 +423,7 @@ private static Result tryCreateSpatialJoin( PlanNodeId nodeId, List outputVariables, CallExpression spatialFunction, - Optional radius, + Optional radius, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) @@ -393,18 +441,27 @@ private static Result tryCreateSpatialJoin( return Result.empty(); } - Set firstVariables = extractUnique(firstArgument); - Set secondVariables = extractUnique(secondArgument); - - if (firstVariables.isEmpty() || secondVariables.isEmpty()) { + VariableSide firstSide = inferVariableSide(firstArgument, joinNode); + VariableSide secondSide = inferVariableSide(secondArgument, joinNode); + boolean firstArgumentOnLeft; + if (firstSide == VariableSide.Left && secondSide == VariableSide.Right) { + firstArgumentOnLeft = true; + } + else if (firstSide == VariableSide.Right && secondSide == VariableSide.Left) { + firstArgumentOnLeft = false; + } + else { + // Spatial joins require each argument comes from only one side, and they come from opposite sides return Result.empty(); } // If either firstArgument or secondArgument is not a // VariableReferenceExpression, will replace the left/right join node // with a projection that adds the argument as a variable. - Optional newFirstVariable = newGeometryVariable(context, firstArgument); - Optional newSecondVariable = newGeometryVariable(context, secondArgument); + Optional newFirstVariable = newVariable(context, firstArgument); + Optional newSecondVariable = newVariable(context, secondArgument); + VariableReferenceExpression leftGeometryVariable; + VariableReferenceExpression rightGeometryVariable; PlanNode leftNode = joinNode.getLeft(); PlanNode rightNode = joinNode.getRight(); @@ -412,17 +469,19 @@ private static Result tryCreateSpatialJoin( PlanNode newRightNode; // Check if the order of arguments of the spatial function matches the order of join sides - int alignment = checkAlignment(joinNode, firstVariables, secondVariables); - if (alignment > 0) { + if (firstArgumentOnLeft) { newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode); newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); + rightGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); } - else if (alignment < 0) { + else { newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode); newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode); - } - else { - return Result.empty(); + // If new variables are empty, argument is VariableReferenceExpression + leftGeometryVariable = newSecondVariable.orElseGet(() -> (VariableReferenceExpression) secondArgument); + rightGeometryVariable = newFirstVariable.orElseGet(() -> (VariableReferenceExpression) firstArgument); } RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument); @@ -441,7 +500,7 @@ else if (alignment < 0) { leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER)); rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER)); - if (alignment > 0) { + if (firstArgumentOnLeft) { newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty()); newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius); } @@ -457,10 +516,13 @@ else if (alignment < 0) { return Result.ofPlanNode(new SpatialJoinNode( joinNode.getSourceLocation(), nodeId, - SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), + SpatialJoinNode.SpatialJoinType.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputVariables, + leftGeometryVariable, + rightGeometryVariable, + radius, newFilter, leftPartitionVariable, rightPartitionVariable, @@ -469,6 +531,11 @@ else if (alignment < 0) { private static boolean isSphericalJoin(Metadata metadata, RowExpression firstArgument, RowExpression secondArgument) { + // In sidecar-enabled clusters, SphericalGeography isn't a supported type. + // If SphericalGeography is not supported, it can be assumed that this join isn't a spherical join, hence returning False. + if (!metadata.getFunctionAndTypeManager().hasType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE)) { + return false; + } Type sphericalGeographyType = metadata.getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE); return firstArgument.getType().equals(sphericalGeographyType) || secondArgument.getType().equals(sphericalGeographyType); } @@ -596,43 +663,12 @@ private static QualifiedObjectName toQualifiedObjectName(String name, String cat throw new PrestoException(INVALID_SPATIAL_PARTITIONING, format("Invalid name: %s", name)); } - private static int checkAlignment(JoinNode joinNode, Set maybeLeftVariables, Set maybeRightVariables) - { - List leftVariables = joinNode.getLeft().getOutputVariables(); - List rightVariables = joinNode.getRight().getOutputVariables(); - - if (leftVariables.containsAll(maybeLeftVariables) - && containsNone(leftVariables, maybeRightVariables) - && rightVariables.containsAll(maybeRightVariables) - && containsNone(rightVariables, maybeLeftVariables)) { - return 1; - } - - if (leftVariables.containsAll(maybeRightVariables) - && containsNone(leftVariables, maybeLeftVariables) - && rightVariables.containsAll(maybeLeftVariables) - && containsNone(rightVariables, maybeRightVariables)) { - return -1; - } - - return 0; - } - private static RowExpression mapToExpression(Optional optionalVariable, RowExpression defaultExpression) { return optionalVariable.map(RowExpression.class::cast).orElse(defaultExpression); } - private static Optional newGeometryVariable(Context context, RowExpression expression) - { - if (expression instanceof VariableReferenceExpression) { - return Optional.empty(); - } - - return Optional.of(context.getVariableAllocator().newVariable(expression)); - } - - private static Optional newRadiusVariable(Context context, RowExpression expression) + private static Optional newVariable(Context context, RowExpression expression) { if (expression instanceof VariableReferenceExpression) { return Optional.empty(); @@ -652,7 +688,7 @@ private static PlanNode addProjection(Context context, PlanNode node, VariableRe return new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node, projections.build(), LOCAL); } - private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius) + private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeManager functionAndTypeManager, PlanNode node, VariableReferenceExpression partitionVariable, KdbTree kdbTree, RowExpression geometry, Optional radius) { Assignments.Builder projections = Assignments.builder(); for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { @@ -685,9 +721,4 @@ private static PlanNode addPartitioningNodes(Context context, FunctionAndTypeMan ImmutableMap.of(partitionsVariable, ImmutableList.of(partitionVariable)), Optional.empty()); } - - private static boolean containsNone(Collection values, Collection testValues) - { - return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); - } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java new file mode 100644 index 0000000000000..05aa75acba01e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.PlannerUtils; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.planner.plan.Patterns.exchange; +import static com.facebook.presto.sql.planner.plan.Patterns.filter; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.planner.plan.Patterns.tableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; +import static java.util.Objects.requireNonNull; + +/** + * RuleSet for extracting system table filters when they contain non-coordinator-eligible functions (e.g., CPP functions). + * This ensures that system table scans happen on the coordinator while CPP functions execute on workers. + * + * Patterns handled: + * 1. Exchange -> Project -> Filter -> TableScan (system) => Project -> Filter -> Exchange -> TableScan + * 2. Exchange -> Project -> TableScan (system) => Project -> Exchange -> TableScan + * 3. Exchange -> Filter -> TableScan (system) => Filter -> Exchange -> TableScan + */ +public class ExtractSystemTableFilterRuleSet +{ + private final FunctionAndTypeManager functionAndTypeManager; + + public ExtractSystemTableFilterRuleSet(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new ProjectFilterScanRule(), + new ProjectScanRule(), + new FilterScanRule()); + } + + private abstract class SystemTableFilterRule + implements Rule + { + protected final Capture tableScanCapture = newCapture(); + + protected boolean containsFunctionsIneligibleOnCoordinator(Optional filterNode, Optional projectNode) + { + boolean hasIneligiblePredicates = filterNode + .map(filter -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, filter.getPredicate())) + .orElse(false); + + boolean hasIneligibleProjections = projectNode + .map(project -> project.getAssignments().getExpressions().stream() + .anyMatch(expression -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, expression))) + .orElse(false); + + return hasIneligiblePredicates || hasIneligibleProjections; + } + } + + private final class ProjectFilterScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture projectCapture = newCapture(); + private final Capture filterCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + project() + .capturedAs(projectCapture) + .with(source().matching( + filter() + .capturedAs(filterCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + ProjectNode projectNode = captures.get(projectCapture); + FilterNode filterNode = captures.get(filterCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.of(filterNode), Optional.of(projectNode))) { + return Result.empty(); + } + + // The exchange's output variables must match what the filter expects + // Since the filter was originally between project and table scan, it expects + // the table scan's output variables + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate filter with exchange as source + FilterNode newFilter = new FilterNode( + filterNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + filterNode.getPredicate()); + + // Recreate project with filter as source + ProjectNode newProject = new ProjectNode( + projectNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newFilter, + projectNode.getAssignments(), + projectNode.getLocality()); + + return Result.ofPlanNode(newProject); + } + } + + private final class ProjectScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture projectCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + project() + .capturedAs(projectCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + ProjectNode projectNode = captures.get(projectCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.empty(), Optional.of(projectNode))) { + return Result.empty(); + } + + // Update partitioning scheme to match table scan outputs + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate project with exchange as source + ProjectNode newProject = new ProjectNode( + projectNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + projectNode.getAssignments(), + projectNode.getLocality()); + + return Result.ofPlanNode(newProject); + } + } + + private final class FilterScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture filterCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + filter() + .capturedAs(filterCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + FilterNode filterNode = captures.get(filterCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.of(filterNode), Optional.empty())) { + return Result.empty(); + } + + // Update partitioning scheme to match table scan outputs + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate filter with exchange as source + FilterNode newFilter = new FilterNode( + filterNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + filterNode.getPredicate()); + + // Check if the original exchange's output variables match the filter's output + // If not, add a project node to align them + if (!exchangeNode.getOutputVariables().equals(newFilter.getOutputVariables())) { + Assignments.Builder assignments = Assignments.builder(); + for (VariableReferenceExpression variable : exchangeNode.getOutputVariables()) { + assignments.put(variable, variable); + } + + ProjectNode projectNode = new ProjectNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newFilter, + assignments.build(), + ProjectNode.Locality.LOCAL); + + return Result.ofPlanNode(projectNode); + } + + return Result.ofPlanNode(newFilter); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java index f548fc8ac5b2c..c5cb256552057 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; @@ -21,6 +22,7 @@ import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; @@ -28,10 +30,8 @@ import com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations; import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java index 9cc6a3674a771..bd1e575f8e78b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/LeftJoinWithArrayContainsToEquiJoinCondition.java @@ -22,13 +22,13 @@ import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.LeftJoinArrayContainsToInnerJoinStrategy; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.google.common.collect.ImmutableList; @@ -48,7 +48,7 @@ import static java.util.Objects.requireNonNull; /** - * When the join condition of a left join has pattern of contains(array, element) where array, we can rewrite it as a equi join condition. For example: + * When the join condition of a left join has pattern of contains(array, element) where array is from the right-side relation and element is from the left-side relation, we can rewrite it as an equi join condition. For example: *

      * - Left Join
      *      empty join clause
    diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MaterializedViewRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MaterializedViewRewrite.java
    new file mode 100644
    index 0000000000000..68f3d892a360a
    --- /dev/null
    +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MaterializedViewRewrite.java
    @@ -0,0 +1,218 @@
    +/*
    + * Licensed under the Apache License, Version 2.0 (the "License");
    + * you may not use this file except in compliance with the License.
    + * You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package com.facebook.presto.sql.planner.iterative.rule;
    +
    +import com.facebook.airlift.units.Duration;
    +import com.facebook.presto.Session;
    +import com.facebook.presto.common.QualifiedObjectName;
    +import com.facebook.presto.common.predicate.TupleDomain;
    +import com.facebook.presto.matching.Captures;
    +import com.facebook.presto.matching.Pattern;
    +import com.facebook.presto.metadata.Metadata;
    +import com.facebook.presto.spi.ColumnHandle;
    +import com.facebook.presto.spi.ColumnMetadata;
    +import com.facebook.presto.spi.MaterializedViewDefinition;
    +import com.facebook.presto.spi.MaterializedViewStalenessConfig;
    +import com.facebook.presto.spi.MaterializedViewStatus;
    +import com.facebook.presto.spi.PrestoException;
    +import com.facebook.presto.spi.SchemaTableName;
    +import com.facebook.presto.spi.TableHandle;
    +import com.facebook.presto.spi.analyzer.MetadataResolver;
    +import com.facebook.presto.spi.plan.Assignments;
    +import com.facebook.presto.spi.plan.MaterializedViewScanNode;
    +import com.facebook.presto.spi.plan.PlanNode;
    +import com.facebook.presto.spi.plan.ProjectNode;
    +import com.facebook.presto.spi.relation.VariableReferenceExpression;
    +import com.facebook.presto.spi.security.AccessControl;
    +import com.facebook.presto.spi.security.ViewExpression;
    +import com.facebook.presto.spi.security.ViewSecurity;
    +import com.facebook.presto.sql.planner.iterative.Rule;
    +
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Optional;
    +
    +import static com.facebook.presto.SystemSessionProperties.getMaterializedViewStaleReadBehavior;
    +import static com.facebook.presto.SystemSessionProperties.isLegacyMaterializedViews;
    +import static com.facebook.presto.spi.MaterializedViewStaleReadBehavior.USE_VIEW_QUERY;
    +import static com.facebook.presto.spi.StandardErrorCode.MATERIALIZED_VIEW_STALE;
    +import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
    +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER;
    +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER;
    +import static com.facebook.presto.sql.planner.plan.Patterns.materializedViewScan;
    +import static com.google.common.base.Preconditions.checkState;
    +import static com.google.common.collect.ImmutableList.toImmutableList;
    +import static java.lang.System.currentTimeMillis;
    +import static java.util.Objects.requireNonNull;
    +
    +public class MaterializedViewRewrite
    +        implements Rule
    +{
    +    private final Metadata metadata;
    +    private final AccessControl accessControl;
    +
    +    public MaterializedViewRewrite(Metadata metadata, AccessControl accessControl)
    +    {
    +        this.metadata = requireNonNull(metadata, "metadata is null");
    +        this.accessControl = requireNonNull(accessControl, "accessControl is null");
    +    }
    +
    +    @Override
    +    public Pattern getPattern()
    +    {
    +        return materializedViewScan();
    +    }
    +
    +    @Override
    +    public Result apply(MaterializedViewScanNode node, Captures captures, Context context)
    +    {
    +        Session session = context.getSession();
    +        checkState(!isLegacyMaterializedViews(session), "Materialized view rewrite rule should not fire when legacy materialized views are enabled");
    +
    +        MetadataResolver metadataResolver = metadata.getMetadataResolver(session);
    +
    +        boolean useDataTable = isUseDataTable(node, metadataResolver, session);
    +        PlanNode chosenPlan = useDataTable ? node.getDataTablePlan() : node.getViewQueryPlan();
    +        Map chosenMappings =
    +                useDataTable ? node.getDataTableMappings() : node.getViewQueryMappings();
    +
    +        Assignments.Builder assignments = Assignments.builder();
    +        for (VariableReferenceExpression outputVariable : node.getOutputVariables()) {
    +            VariableReferenceExpression sourceVariable = chosenMappings.get(outputVariable);
    +            requireNonNull(sourceVariable, "No mapping found for output variable: " + outputVariable);
    +            assignments.put(outputVariable, sourceVariable);
    +        }
    +
    +        return Result.ofPlanNode(new ProjectNode(
    +                node.getSourceLocation(),
    +                context.getIdAllocator().getNextId(),
    +                chosenPlan,
    +                assignments.build(),
    +                LOCAL));
    +    }
    +
    +    private boolean isUseDataTable(MaterializedViewScanNode node, MetadataResolver metadataResolver, Session session)
    +    {
    +        Optional materializedViewDefinition = metadataResolver.getMaterializedView(node.getMaterializedViewName());
    +        checkState(materializedViewDefinition.isPresent(), "Materialized view definition not found for: %s", node.getMaterializedViewName());
    +        MaterializedViewDefinition definition = materializedViewDefinition.get();
    +
    +        MaterializedViewStatus status = metadataResolver.getMaterializedViewStatus(node.getMaterializedViewName(), TupleDomain.all());
    +        if (status.isFullyMaterialized()) {
    +            return canUseDataTableWithSecurityChecks(node, metadataResolver, session, definition);
    +        }
    +
    +        Optional stalenessConfig = definition.getStalenessConfig();
    +        if (stalenessConfig.isPresent()) {
    +            MaterializedViewStalenessConfig config = stalenessConfig.get();
    +
    +            if (isStalenessBeyondTolerance(config, status)) {
    +                return applyStaleReadBehavior(config, node.getMaterializedViewName());
    +            }
    +            return canUseDataTableWithSecurityChecks(node, metadataResolver, session, definition);
    +        }
    +
    +        if (getMaterializedViewStaleReadBehavior(session) == USE_VIEW_QUERY) {
    +            return false;
    +        }
    +        throw new PrestoException(
    +                MATERIALIZED_VIEW_STALE,
    +                String.format("Materialized view '%s' is stale (base tables have changed since last refresh)", node.getMaterializedViewName()));
    +    }
    +
    +    private boolean isStalenessBeyondTolerance(
    +            MaterializedViewStalenessConfig config,
    +            MaterializedViewStatus status)
    +    {
    +        Duration stalenessWindow = config.getStalenessWindow();
    +
    +        Optional lastFreshTime = status.getLastFreshTime();
    +        return lastFreshTime
    +                .map(time -> (currentTimeMillis() - time) > stalenessWindow.toMillis())
    +                .orElse(true);
    +    }
    +
    +    private boolean applyStaleReadBehavior(MaterializedViewStalenessConfig config, QualifiedObjectName viewName)
    +    {
    +        switch (config.getStaleReadBehavior()) {
    +            case FAIL:
    +                throw new PrestoException(
    +                        MATERIALIZED_VIEW_STALE,
    +                        String.format("Materialized view '%s' is stale beyond the configured staleness window", viewName));
    +            case USE_VIEW_QUERY:
    +                return false;
    +            default:
    +                throw new IllegalStateException("Unexpected stale read behavior: " + config.getStaleReadBehavior());
    +        }
    +    }
    +
    +    private boolean canUseDataTableWithSecurityChecks(
    +            MaterializedViewScanNode node,
    +            MetadataResolver metadataResolver,
    +            Session session,
    +            MaterializedViewDefinition definition)
    +    {
    +        // Security mode defaults to INVOKER for legacy materialized views created without explicitly specifying it
    +        ViewSecurity securityMode = definition.getSecurityMode().orElse(INVOKER);
    +
    +        // In definer rights, there's only one user permissions (the definer), so row filters and column masks
    +        // do not depend on the invoker and can be safely ignored when deciding whether to use the data table
    +        if (securityMode == DEFINER) {
    +            return true;
    +        }
    +
    +        // Invoker rights: need to check for row filters and column masks on base tables because they may alter
    +        // the data returned by the materialized view depending on the invoker's permissions.
    +        String catalogName = node.getMaterializedViewName().getCatalogName();
    +        for (SchemaTableName schemaTableName : definition.getBaseTables()) {
    +            QualifiedObjectName baseTable = new QualifiedObjectName(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName());
    +
    +            // Check for row filters on this base table
    +            List rowFilters = accessControl.getRowFilters(
    +                    session.getTransactionId().get(),
    +                    session.getIdentity(),
    +                    session.getAccessControlContext(),
    +                    baseTable);
    +
    +            if (!rowFilters.isEmpty()) {
    +                return false;
    +            }
    +
    +            Optional tableHandle = metadataResolver.getTableHandle(baseTable);
    +            if (!tableHandle.isPresent()) {
    +                return false;
    +            }
    +
    +            // Check for column masks on this base table
    +            Map columnHandles = metadata.getColumnHandles(session, tableHandle.get());
    +            List columnsMetadata = columnHandles.values().stream()
    +                    .map(handle -> metadata.getColumnMetadata(session, tableHandle.get(), handle))
    +                    .collect(toImmutableList());
    +
    +            Map columnMasks = accessControl.getColumnMasks(
    +                    session.getTransactionId().get(),
    +                    session.getIdentity(),
    +                    session.getAccessControlContext(),
    +                    baseTable,
    +                    columnsMetadata);
    +
    +            if (!columnMasks.isEmpty()) {
    +                return false;
    +            }
    +        }
    +
    +        // No row filters or column masks found on base tables, safe to use data table
    +        return true;
    +    }
    +}
    diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java
    new file mode 100644
    index 0000000000000..93447d712844d
    --- /dev/null
    +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java
    @@ -0,0 +1,163 @@
    +/*
    + * Licensed under the Apache License, Version 2.0 (the "License");
    + * you may not use this file except in compliance with the License.
    + * You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package com.facebook.presto.sql.planner.iterative.rule;
    +
    +import com.facebook.presto.Session;
    +import com.facebook.presto.common.block.SortOrder;
    +import com.facebook.presto.common.type.ArrayType;
    +import com.facebook.presto.common.type.MapType;
    +import com.facebook.presto.matching.Captures;
    +import com.facebook.presto.matching.Pattern;
    +import com.facebook.presto.metadata.FunctionAndTypeManager;
    +import com.facebook.presto.spi.plan.AggregationNode;
    +import com.facebook.presto.spi.plan.Assignments;
    +import com.facebook.presto.spi.plan.DataOrganizationSpecification;
    +import com.facebook.presto.spi.plan.FilterNode;
    +import com.facebook.presto.spi.plan.Ordering;
    +import com.facebook.presto.spi.plan.OrderingScheme;
    +import com.facebook.presto.spi.plan.ProjectNode;
    +import com.facebook.presto.spi.relation.ConstantExpression;
    +import com.facebook.presto.spi.relation.RowExpression;
    +import com.facebook.presto.spi.relation.VariableReferenceExpression;
    +import com.facebook.presto.sql.planner.iterative.Rule;
    +import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
    +import com.facebook.presto.sql.relational.FunctionResolution;
    +import com.google.common.collect.ImmutableList;
    +import com.google.common.collect.ImmutableMap;
    +
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Optional;
    +
    +import static com.facebook.presto.SystemSessionProperties.isRewriteMinMaxByToTopNEnabled;
    +import static com.facebook.presto.common.function.OperatorType.EQUAL;
    +import static com.facebook.presto.common.type.BigintType.BIGINT;
    +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
    +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
    +import static com.facebook.presto.sql.relational.Expressions.comparisonExpression;
    +import static com.google.common.collect.ImmutableMap.toImmutableMap;
    +
    +/**
    + * For queries with min_by/max_by functions on map, rewrite it with top n window functions.
    + * For example, for query `select id, max(ds), max_by(feature, ds) from t group by id`,
    + * it will be rewritten from:
    + * 
    + * - Aggregation
    + *      ds_0 := max(ds)
    + *      feature_0 := max_by(feature, ds)
    + *      group by id
    + *      - scan t
    + *          ds
    + *          feature
    + *          id
    + * 
    + * into: + *
    + *     - Filter
    + *          row_num = 1
    + *          - TopNRow
    + *              partition by id
    + *              order by ds desc
    + *              maxRowCountPerPartition = 1
    + *              - scan t
    + *                  ds
    + *                  feature
    + *                  id
    + * 
    + */ +public class MinMaxByToWindowFunction + implements Rule +{ + private static final Pattern PATTERN = aggregation().matching(x -> !x.getHashVariable().isPresent() && !x.getGroupingKeys().isEmpty() && x.getGroupingSetCount() == 1 && x.getStep().equals(AggregationNode.Step.SINGLE)); + private final FunctionResolution functionResolution; + + public MinMaxByToWindowFunction(FunctionAndTypeManager functionAndTypeManager) + { + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public boolean isEnabled(Session session) + { + return isRewriteMinMaxByToTopNEnabled(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode node, Captures captures, Context context) + { + Map maxByAggregations = node.getAggregations().entrySet().stream() + .filter(x -> functionResolution.isMaxByFunction(x.getValue().getFunctionHandle())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + Map minByAggregations = node.getAggregations().entrySet().stream() + .filter(x -> functionResolution.isMinByFunction(x.getValue().getFunctionHandle())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + boolean isMaxByAggregation; + Map candidateAggregation; + if (maxByAggregations.isEmpty() && !minByAggregations.isEmpty()) { + isMaxByAggregation = false; + candidateAggregation = minByAggregations; + } + else if (!maxByAggregations.isEmpty() && minByAggregations.isEmpty()) { + isMaxByAggregation = true; + candidateAggregation = maxByAggregations; + } + else { + return Result.empty(); + } + if (candidateAggregation.values().stream().noneMatch(x -> x.getArguments().get(0).getType() instanceof MapType || x.getArguments().get(0).getType() instanceof ArrayType)) { + return Result.empty(); + } + boolean allMaxOrMinByWithSameField = candidateAggregation.values().stream().map(x -> x.getArguments().get(1)).distinct().count() == 1; + if (!allMaxOrMinByWithSameField) { + return Result.empty(); + } + VariableReferenceExpression orderByVariable = (VariableReferenceExpression) candidateAggregation.values().stream().findFirst().get().getArguments().get(1); + Map remainingAggregations = node.getAggregations().entrySet().stream().filter(x -> !candidateAggregation.containsKey(x.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + boolean remainingEmptyOrMinOrMaxOnOrderBy = remainingAggregations.isEmpty() || (remainingAggregations.size() == 1 + && remainingAggregations.values().stream().allMatch(x -> (isMaxByAggregation ? functionResolution.isMaxFunction(x.getFunctionHandle()) : functionResolution.isMinFunction(x.getFunctionHandle())) && x.getArguments().size() == 1 && x.getArguments().get(0).equals(orderByVariable))); + if (!remainingEmptyOrMinOrMaxOnOrderBy) { + return Result.empty(); + } + + List partitionKeys = node.getGroupingKeys(); + OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(new Ordering(orderByVariable, isMaxByAggregation ? SortOrder.DESC_NULLS_LAST : SortOrder.ASC_NULLS_LAST))); + DataOrganizationSpecification dataOrganizationSpecification = new DataOrganizationSpecification(partitionKeys, Optional.of(orderingScheme)); + VariableReferenceExpression rowNumberVariable = context.getVariableAllocator().newVariable("row_number", BIGINT); + TopNRowNumberNode topNRowNumberNode = + new TopNRowNumberNode(node.getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getStatsEquivalentPlanNode(), + node.getSource(), + dataOrganizationSpecification, + rowNumberVariable, + 1, + false, + Optional.empty()); + RowExpression equal = comparisonExpression(functionResolution, EQUAL, rowNumberVariable, new ConstantExpression(1L, BIGINT)); + FilterNode filterNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), topNRowNumberNode, equal); + Map assignments = ImmutableMap.builder() + .putAll(node.getAggregations().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, x -> x.getValue().getArguments().get(0)))).build(); + + ProjectNode projectNode = new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), filterNode, + Assignments.builder().putAll(assignments).putAll(identityAssignments(node.getGroupingKeys())).build(), ProjectNode.Locality.LOCAL); + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java index faf96d27adabc..ea79310608dfb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java @@ -59,10 +59,12 @@ import static com.facebook.presto.metadata.TableLayoutResult.computeEnforced; import static com.facebook.presto.spi.relation.DomainTranslator.BASIC_COLUMN_EXTRACTOR; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; import static com.facebook.presto.sql.planner.iterative.rule.PreconditionRules.checkRulesAreFiredBeforeAddExchangesRule; import static com.facebook.presto.sql.planner.plan.Patterns.filter; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.tableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.intersection; @@ -271,6 +273,16 @@ private static PlanNode pushPredicateIntoTableScan( new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()), metadata.getFunctionAndTypeManager()); RowExpression deterministicPredicate = logicalRowExpressions.filterDeterministicConjuncts(predicate); + // If the predicate contains non-Java expressions, we cannot prune partitions over system tables. + RowExpression ineligiblePredicate = TRUE_CONSTANT; + if (containsSystemTableScan(node)) { + ineligiblePredicate = logicalRowExpressions.filterConjuncts( + deterministicPredicate, + expression -> containsNonCoordinatorEligibleCallExpression(metadata.getFunctionAndTypeManager(), expression)); + deterministicPredicate = logicalRowExpressions.filterConjuncts( + deterministicPredicate, + expression -> !containsNonCoordinatorEligibleCallExpression(metadata.getFunctionAndTypeManager(), expression)); + } DomainTranslator.ExtractionResult decomposedPredicate = domainTranslator.fromPredicate( session.toConnectorSession(), deterministicPredicate, @@ -339,7 +351,8 @@ private static PlanNode pushPredicateIntoTableScan( RowExpression resultingPredicate = logicalRowExpressions.combineConjuncts( domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), logicalRowExpressions.filterNonDeterministicConjuncts(predicate), - decomposedPredicate.getRemainingExpression()); + decomposedPredicate.getRemainingExpression(), + ineligiblePredicate); if (!TRUE_CONSTANT.equals(resultingPredicate)) { return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), tableScan, resultingPredicate); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMergeSourceColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMergeSourceColumns.java new file mode 100644 index 0000000000000..9237e98e06833 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneMergeSourceColumns.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; +import com.google.common.collect.ImmutableSet; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictChildOutputs; +import static com.facebook.presto.sql.planner.plan.Patterns.mergeWriter; + +public class PruneMergeSourceColumns + implements Rule +{ + private static final Pattern PATTERN = mergeWriter(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(MergeWriterNode mergeNode, Captures captures, Context context) + { + return restrictChildOutputs(context.getIdAllocator(), mergeNode, ImmutableSet.copyOf(mergeNode.getMergeProcessorProjectedVariables())) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..c95212bf38c0d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** + * TableFunctionProcessorNode has two kinds of outputs: + * - proper outputs, which are the columns produced by the table function, + * - pass-through outputs, which are the columns copied from table arguments. + * This rule filters out unreferenced pass-through symbols. + * Unreferenced proper symbols are not pruned, because there is currently no way + * to communicate to the table function the request for not producing certain columns. + * // TODO prune table function's proper outputs + * Example: + *
    + * - Project
    + *   assignments={proper->proper1}
    + *  - TableFunctionProcessor
    + *    properOutputs=[proper1, proper2]
    + *    passThroughSymbols=[[passthrough1],[passthrough2]]
    + * 
    + * is transformed into + *
    + * - Project
    + *   assignments={proper->proper1}
    + *   - TableFunctionProcessor
    + *     properOutputs=[proper1, proper2]
    + *     passThroughSymbols=[]
    + * 
    + */ +public class PruneTableFunctionProcessorColumns + extends ProjectOffPushDownRule +{ + public PruneTableFunctionProcessorColumns() + { + super(tableFunctionProcessor()); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, TableFunctionProcessorNode node, Set referencedOutputs) + { + List prunedPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(sourceSpecification -> { + List prunedPassThroughColumns = sourceSpecification.getColumns().stream() + .filter(column -> referencedOutputs.contains(column.getOutputVariables())) + .collect(toImmutableList()); + return new TableFunctionNode.PassThroughSpecification(sourceSpecification.isDeclaredAsPassThrough(), prunedPassThroughColumns); + }) + .collect(toImmutableList()); + + int originalPassThroughCount = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + int prunedPassThroughCount = prunedPassThroughSpecifications.stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + if (originalPassThroughCount == prunedPassThroughCount) { + return Optional.empty(); + } + + return Optional.of(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + node.getSource(), + node.isPruneWhenEmpty(), + prunedPassThroughSpecifications, + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..d90f668d4c98f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.Maps.filterKeys; + +/** + * This rule prunes unreferenced outputs of TableFunctionProcessorNode. + * First, it extracts all symbols required for: + * - pass-through + * - table function computation + * - partitioning and ordering (including the hashSymbol) + * Next, a mapping of input symbols to marker symbols is updated + * so that it only contains mappings for the required symbols. + * Last, all the remaining marker symbols are added to the collection + * of required symbols. + * Any source output symbols not included in the required symbols + * can be pruned. + * Example: + *
    + * - TableFunctionProcessor
    + *   properOutputs=[proper]
    + *   passThroughSymbols=[[passthrough1],[passthrough2]]
    + *   requiredSymbols=[[require1], [require2]]
    + *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
    + *   hashSymbol=[hash]
    + *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker, unreferenced->marker2}
    + *   - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
    + * 
    + * is transformed into + *
    + * - TableFunctionProcessor
    + *   properOutputs=[proper]
    + *   passThroughSymbols=[[passthrough1],[passthrough2]]
    + *   requiredSymbols=[[require1], [require2]]
    + *   specification=[partition={[partition1]} orderby={[order1 ASC_NULLS_LAST]}]
    + *   hashSymbol=[hash]
    + *   markerVariables={passthrough1->marker1, require1->marker1, partition1->marker1, order1->marker1, passthrough2->marker2, require2->marker}
    + *   - Project
    + *     assignments=[passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash]
    + *     - Source (which produces passthrough1, require1, partition1, order1, passthrough2, require2, marker, hash, unreferenced)
    + * 
    + */ +public class PruneTableFunctionProcessorSourceColumns + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!node.getSource().isPresent()) { + return Result.empty(); + } + + ImmutableSet.Builder requiredInputs = ImmutableSet.builder(); + + node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(requiredInputs::add); + + node.getRequiredVariables() + .forEach(requiredInputs::addAll); + + node.getSpecification().ifPresent(specification -> { + requiredInputs.addAll(specification.getPartitionBy()); + specification.getOrderingScheme().ifPresent(orderingScheme -> requiredInputs.addAll(orderingScheme.getOrderByVariables())); + }); + + node.getHashSymbol().ifPresent(requiredInputs::add); + + Optional> updatedMarkerSymbols = node.getMarkerVariables() + .map(mapping -> filterKeys(mapping, requiredInputs.build()::contains)); + + updatedMarkerSymbols.ifPresent(mapping -> requiredInputs.addAll(mapping.values())); + + return restrictOutputs(context.getIdAllocator(), node.getSource().orElseThrow(NoSuchElementException::new), requiredInputs.build()) + .map(child -> Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + updatedMarkerSymbols, + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle()))) + .orElse(Result.empty()); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java index d948bad21bdbf..1d09e2699b238 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java @@ -31,6 +31,7 @@ import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.TopNNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; @@ -40,7 +41,6 @@ import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 10071de898254..967cc68238c79 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -33,6 +33,7 @@ import java.util.Map; import java.util.Set; +import static com.facebook.presto.SystemSessionProperties.isSkipPushdownThroughExchangeForRemoteProjection; import static com.facebook.presto.matching.Capture.newCapture; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.facebook.presto.sql.planner.plan.Patterns.exchange; @@ -80,6 +81,10 @@ public Pattern getPattern() public Result apply(ProjectNode project, Captures captures, Context context) { ExchangeNode exchange = captures.get(CHILD); + if (isSkipPushdownThroughExchangeForRemoteProjection(context.getSession()) && project.getLocality().equals(ProjectNode.Locality.REMOTE)) { + return Result.empty(); + } + Set partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getVariableReferences(); ImmutableList.Builder newSourceBuilder = ImmutableList.builder(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RandomizeSourceKeyInSemiJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RandomizeSourceKeyInSemiJoin.java new file mode 100644 index 0000000000000..d23fec7bc7f6d --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RandomizeSourceKeyInSemiJoin.java @@ -0,0 +1,195 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.expressions.LogicalRowExpressions; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.PlannerUtils; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; +import java.util.stream.Stream; + +import static com.facebook.presto.SystemSessionProperties.getRandomizeNullSourceKeyInSemiJoinStrategy; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; +import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constantNull; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * Randomizes semi-join source keys to improve hash distribution and avoid data skew on NULL values in semi-join. + * + *

    The rule transforms a semi-join with potentially skewed keys: + *

    + * - SemiJoin
    + *      output: semi_output
    + *      condition: source_key = filtering_key
    + *      - Source (table scan)
    + *          source_key (value are skewed on NULLs)
    + *      - FilteringSource (table scan)
    + *          filtering_key
    + * 
    + *

    + * into a semi-join with randomized keys and NULL-aware logic: + *

    + * - Project
    + *      semi_output := new_semi_output OR (source_key IS NULL ? NULL : FALSE)
    + *      - SemiJoin
    + *          output: new_semi_output
    + *          condition: randomized_source_key = cast_filtering_key
    + *          - Project
    + *              randomized_source_key := COALESCE(source_key, Randomize(source_key))
    + *              - Source (table scan)
    + *                  source_key
    + *          - Project
    + *              cast_filtering_key := CAST(filtering_key AS VARCHAR)
    + *              - FilteringSource (table scan)
    + *                  filtering_key
    + * 
    + * Since the randomization will turn the semi join output for NULL source key to be false, we add one more projection + * semi_output := new_semi_output OR (source_key IS NULL ? NULL : FALSE) to project the semi join output back to NULL. + */ +public class RandomizeSourceKeyInSemiJoin + implements Rule +{ + private static final String LEFT_PREFIX = "l"; + private final FunctionAndTypeManager functionAndTypeManager; + + public RandomizeSourceKeyInSemiJoin(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + private static boolean isSupportedType(VariableReferenceExpression variable) + { + return Stream.of(INTEGER, BIGINT, DATE).anyMatch(x -> x.equals(variable.getType())); + } + + @Override + public Pattern getPattern() + { + return semiJoin(); + } + + @Override + public boolean isEnabled(Session session) + { + return getRandomizeNullSourceKeyInSemiJoinStrategy(session).equals(FeaturesConfig.RandomizeNullSourceKeyInSemiJoinStrategy.ALWAYS); + } + + @Override + public Result apply(SemiJoinNode node, Captures captures, Context context) + { + if ((node.getDistributionType().isPresent() && node.getDistributionType().get().equals(SemiJoinNode.DistributionType.REPLICATED))) { + return Result.empty(); + } + + // Only process supported types + if (!isSupportedType(node.getSourceJoinVariable()) || !isSupportedType(node.getFilteringSourceJoinVariable())) { + return Result.empty(); + } + + VariableReferenceExpression randomizedSourceKey; + RowExpression sourceRandomExpression = PlannerUtils.randomizeJoinKey(context.getSession(), functionAndTypeManager, node.getSourceJoinVariable(), LEFT_PREFIX); + randomizedSourceKey = context.getVariableAllocator().newVariable( + sourceRandomExpression, + RandomizeSourceKeyInSemiJoin.class.getSimpleName()); + + // Create project nodes to add randomized keys + Assignments.Builder sourceAssignments = Assignments.builder(); + sourceAssignments.putAll(node.getSource().getOutputVariables().stream() + .collect(toImmutableMap(identity(), identity()))); + sourceAssignments.put(randomizedSourceKey, sourceRandomExpression); + + ProjectNode newSource = new ProjectNode( + node.getSource().getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getSource().getStatsEquivalentPlanNode(), + node.getSource(), + sourceAssignments.build(), + LOCAL); + + RowExpression newFilterExpression = call("CAST", functionAndTypeManager.lookupCast(CAST, node.getFilteringSourceJoinVariable().getType(), VARCHAR), VARCHAR, node.getFilteringSourceJoinVariable()); + VariableReferenceExpression newFilteringSourceKey = context.getVariableAllocator().newVariable(newFilterExpression); + + Assignments.Builder filteringSourceAssignments = Assignments.builder(); + filteringSourceAssignments.putAll(node.getFilteringSource().getOutputVariables().stream() + .collect(toImmutableMap(identity(), identity()))); + filteringSourceAssignments.put(newFilteringSourceKey, newFilterExpression); + + ProjectNode newFilteringSource = new ProjectNode( + node.getFilteringSource().getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getFilteringSource().getStatsEquivalentPlanNode(), + node.getFilteringSource(), + filteringSourceAssignments.build(), + LOCAL); + + VariableReferenceExpression newSemiJoinOutput = context.getVariableAllocator().newVariable(node.getSemiJoinOutput()); + SemiJoinNode newSemiJoin = new SemiJoinNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getStatsEquivalentPlanNode(), + newSource, + newFilteringSource, + randomizedSourceKey, + newFilteringSourceKey, + newSemiJoinOutput, + Optional.empty(), + Optional.empty(), + node.getDistributionType(), + ImmutableMap.of()); + RowExpression outputExpression = LogicalRowExpressions.or( + newSemiJoinOutput, + new SpecialFormExpression( + node.getSemiJoinOutput().getSourceLocation(), + SpecialFormExpression.Form.IF, + BOOLEAN, + new SpecialFormExpression( + SpecialFormExpression.Form.IS_NULL, + BOOLEAN, + node.getSourceJoinVariable()), + constantNull(BOOLEAN), + LogicalRowExpressions.FALSE_CONSTANT)); + Assignments.Builder outputAssignments = Assignments.builder(); + outputAssignments.putAll(node.getOutputVariables().stream().collect(toImmutableMap(identity(), x -> x.equals(node.getSemiJoinOutput()) ? outputExpression : x))); + return Result.ofPlanNode(new ProjectNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getStatsEquivalentPlanNode(), + newSemiJoin, + outputAssignments.build(), + LOCAL)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java index 428640b5fc556..7915157f14650 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveMapCastRule.java @@ -14,36 +14,50 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.RowExpressionRewriter; import com.facebook.presto.expressions.RowExpressionTreeRewriter; +import com.facebook.presto.metadata.CastType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import java.util.List; import java.util.Set; import static com.facebook.presto.SystemSessionProperties.isRemoveMapCastEnabled; import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.castToInteger; +import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.tryCast; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; /** - * Remove cast on map if possible. Currently it only supports subscript and element_at function, and only works when map key is of type integer and index is bigint. For example: - * Input: cast(feature as map)[key], where feature is of type map and key is of type bigint + * Remove cast on map if possible. Currently it only supports subscript, element_at and map_subset function, and only works when map key is of type integer and index is bigint. For example: + * Input: cast(feature as map)[key], where feature is of type map and key is of type bigint * Output: feature[cast(key as integer)] - * - * Input: element_at(cast(feature as map), key), where feature is of type map and key is of type bigint + *

    + * Input: element_at(cast(feature as map), key), where feature is of type map and key is of type bigint * Output: element_at(feature, try_cast(key as integer)) - * + *

    + * Input: map_subset(cast(feature as map), array[k1, k2]) where feature is of type map and key is of type array + * Output: cast(map_subset(feature, array[try_cast(k1 as integer), try_cast(k2 as integer)]) as map) + * When k1, or k2 is out of integer range, try_cast will return NULL, where map_subset will not return values for this key, which is the same behavior for both input and output + *

    * Notice that here when it's accessing the map using subscript function, we use CAST function in index, and when it's element_at function, we use TRY_CAST function, so that * when the key is out of integer range, for feature[key] it will fail both with and without optimization, fail with map key not exists before optimization and with cast failure after optimization * when the key is out of integer range, for element_at(feature, key) it will return NULL both before and after optimization @@ -101,7 +115,8 @@ private RemoveMapCastRewriter(FunctionAndTypeManager functionAndTypeManager) @Override public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter treeRewriter) { - if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle())) && node.getArguments().get(0) instanceof CallExpression + if ((functionResolution.isSubscriptFunction(node.getFunctionHandle()) || functionResolution.isElementAtFunction(node.getFunctionHandle()) || functionResolution.isMapSubSetFunction(node.getFunctionHandle())) + && node.getArguments().get(0) instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) node.getArguments().get(0)).getFunctionHandle()) && ((CallExpression) node.getArguments().get(0)).getArguments().get(0).getType() instanceof MapType) { CallExpression castExpression = (CallExpression) node.getArguments().get(0); @@ -116,18 +131,42 @@ public RowExpression rewriteCall(CallExpression node, Void context, RowExpressio RowExpression newIndex = castToInteger(functionAndTypeManager, node.getArguments().get(1)); return call(SUBSCRIPT.name(), functionResolution.subscriptFunction(castInput.getType(), newIndex.getType()), node.getType(), castInput, newIndex); } - else { + else if (functionResolution.isElementAtFunction(node.getFunctionHandle())) { RowExpression newIndex = tryCast(functionAndTypeManager, node.getArguments().get(1), INTEGER); return call(functionAndTypeManager, "element_at", node.getType(), castInput, newIndex); } + else if (functionResolution.isMapSubSetFunction(node.getFunctionHandle())) { + RowExpression newKeyArray = null; + if (node.getArguments().get(1) instanceof CallExpression && functionResolution.isArrayConstructor(((CallExpression) node.getArguments().get(1)).getFunctionHandle())) { + CallExpression arrayConstruct = (CallExpression) node.getArguments().get(1); + List newArguments = arrayConstruct.getArguments().stream().map(x -> tryCast(functionAndTypeManager, x, INTEGER)).collect(toImmutableList()); + newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), newArguments); + } + else if (node.getArguments().get(1) instanceof ConstantExpression) { + ConstantExpression constantArray = (ConstantExpression) node.getArguments().get(1); + checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType); + Block arrayValue = (Block) constantArray.getValue(); + Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType(); + ImmutableList.Builder arguments = ImmutableList.builder(); + for (int i = 0; i < arrayValue.getPositionCount(); ++i) { + ConstantExpression mapKey = constant(readNativeValue(arrayElementType, arrayValue, i), arrayElementType); + arguments.add(tryCast(functionAndTypeManager, mapKey, INTEGER)); + } + newKeyArray = call(functionAndTypeManager, "array_constructor", new ArrayType(INTEGER), arguments.build()); + } + if (newKeyArray != null) { + CallExpression mapSubset = call(functionAndTypeManager, "map_subset", castInput.getType(), castInput, newKeyArray); + return call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, mapSubset.getType(), node.getType()), node.getType(), mapSubset); + } + } } } return null; } - private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type indexType) + private static boolean canRemoveMapCast(Type fromKeyType, Type fromValueType, Type toKeyType, Type toValueType, Type subsetKeysType) { - return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && indexType.equals(BIGINT); + return fromValueType.equals(toValueType) && fromKeyType.equals(INTEGER) && toKeyType.equals(BIGINT) && (subsetKeysType.equals(BIGINT) || (subsetKeysType instanceof ArrayType && ((ArrayType) subsetKeysType).getElementType().equals(BIGINT))); } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..48f58bce5952a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableList; + +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMost; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; + +/** + * Table function can take multiple table arguments. Each argument is either "prune when empty" or "keep when empty". + * "Prune when empty" means that if this argument has no rows, the function result is empty, so the function can be + * removed from the plan, and replaced with empty values. + * "Keep when empty" means that even if the argument has no rows, the function should still be executed, and it can + * return a non-empty result. + * All the table arguments are combined into a single source of a TableFunctionProcessorNode. If either argument is + * "prune when empty", the overall result is "prune when empty". This rule removes a redundant TableFunctionProcessorNode + * based on the "prune when empty" property. + */ +public class RemoveRedundantTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (node.isPruneWhenEmpty() && node.getSource().isPresent()) { + if (isAtMost(node.getSource().orElseThrow(NoSuchElementException::new), context.getLookup(), 0)) { + return Result.ofPlanNode( + new ValuesNode(node.getSourceLocation(), + node.getId(), + node.getOutputVariables(), + ImmutableList.of(), + Optional.empty())); + } + } + + return Result.empty(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index 818450a21ca2c..a19747039e57a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -443,7 +443,7 @@ JoinCondition extractJoinConditions(List joinPredicates, for (RowExpression predicate : joinPredicates) { if (predicate instanceof CallExpression - && functionResolution.isEqualFunction(((CallExpression) predicate).getFunctionHandle()) + && functionResolution.isEqualsFunction(((CallExpression) predicate).getFunctionHandle()) && ((CallExpression) predicate).getArguments().size() == 2) { RowExpression argument0 = ((CallExpression) predicate).getArguments().get(0); RowExpression argument1 = ((CallExpression) predicate).getArguments().get(1); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java new file mode 100644 index 0000000000000..6079ed2bd45b7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReplaceConditionalApproxDistinct.java @@ -0,0 +1,233 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Map.Entry; + +import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalApproxDistinctEnabled; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.constantNull; +import static com.facebook.presto.sql.relational.Expressions.isNull; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * elimination of approx distinct on conditional constant values. + *

    + * depending on the inner conditional, the expression is converted + * to its equivalent arbitrary() expression. + * + * - approx_distinct(if(..., non-null)) -> arbitrary(if(..., 1, NULL)) + * - approx_distinct(if(..., null, non-null)) -> arbitrary(if(..., NULL, 1)) + * - approx_distinct(if(..., null, null)) -> arbitrary(0) + * + * An intermediate projection is inserted to convert any NULL arbitrary output + * to zero values. + */ +public class ReplaceConditionalApproxDistinct + implements Rule +{ + private static final Capture SOURCE = Capture.newCapture(); + + private static final Pattern PATTERN = aggregation() + .with(source().matching(project().capturedAs(SOURCE))); + + private final StandardFunctionResolution functionResolution; + + private static final String ARBITRARY = "arbitrary"; + + public ReplaceConditionalApproxDistinct(FunctionAndTypeManager functionAndTypeManager) + { + requireNonNull(functionAndTypeManager, "functionManager is null"); + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public boolean isEnabled(Session session) + { + return isOptimizeConditionalApproxDistinctEnabled(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode parent, Captures captures, Context context) + { + VariableAllocator variableAllocator = context.getVariableAllocator(); + boolean changed = false; + ProjectNode project = captures.get(SOURCE); + Assignments.Builder outputs = Assignments.builder(); + Assignments.Builder inputs = Assignments.builder(); + + ImmutableMap.Builder aggregations = ImmutableMap.builder(); + for (Entry entry : parent.getAggregations().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); + AggregationNode.Aggregation aggregation = entry.getValue(); + SpecialFormExpression replaced; + VariableReferenceExpression intermediate; + VariableReferenceExpression expression; + + if (!isApproxDistinct(aggregation) || !aggregationIsReplaceable(aggregation, project.getAssignments())) { + aggregations.put(variable, aggregation); + outputs.put(variable, variable); + continue; + } + changed = true; + replaced = (SpecialFormExpression) project.getAssignments().get( + (VariableReferenceExpression) aggregation.getArguments().get(0)); + + expression = variableAllocator.newVariable("expression", BIGINT); + inputs.put(expression, replaceIfExpression(replaced)); + + intermediate = variableAllocator.newVariable("intermediate", BIGINT); + aggregations.put(intermediate, new AggregationNode.Aggregation( + new CallExpression( + aggregation.getCall().getSourceLocation(), + ARBITRARY, + functionResolution.arbitraryFunction(BIGINT), + BIGINT, + ImmutableList.of(expression)), + aggregation.getFilter(), + aggregation.getOrderBy(), + aggregation.isDistinct(), + aggregation.getMask())); + + outputs.put(variable, new SpecialFormExpression( + COALESCE, + BIGINT, + ImmutableList.of( + intermediate, + constant(0L, BIGINT)))); + } + + if (!changed) { + return Result.empty(); + } + + ProjectNode child = new ProjectNode( + project.getSourceLocation(), + context.getIdAllocator().getNextId(), + project.getSource(), + inputs.putAll(project.getAssignments()).build(), + project.getLocality()); + + AggregationNode aggregation = new AggregationNode( + parent.getSourceLocation(), + context.getIdAllocator().getNextId(), + child, + aggregations.build(), + parent.getGroupingSets(), + ImmutableList.of(), + parent.getStep(), + parent.getHashVariable(), + parent.getGroupIdVariable(), + parent.getAggregationId()); + + aggregation.getHashVariable().ifPresent(hashvariable -> outputs.put(hashvariable, hashvariable)); + aggregation.getGroupingSets().getGroupingKeys().forEach(groupingKey -> outputs.put(groupingKey, groupingKey)); + return Result.ofPlanNode(new ProjectNode( + context.getIdAllocator().getNextId(), + aggregation, + outputs.build())); + } + + private boolean isApproxDistinct(AggregationNode.Aggregation aggregation) + { + return functionResolution.isApproximateCountDistinctFunction(aggregation.getFunctionHandle()); + } + + private ConstantExpression convertConstant(ConstantExpression expression) + { + return isNull(expression) ? constantNull(BIGINT) : constant(1L, BIGINT); + } + + private RowExpression replaceIfExpression(SpecialFormExpression ifCondition) + { + ConstantExpression trueThen = (ConstantExpression) ifCondition.getArguments().get(1); + ConstantExpression falseThen = (ConstantExpression) ifCondition.getArguments().get(2); + RowExpression replace; + + if ((isNull(trueThen) && !isNull(falseThen)) || (!isNull(trueThen) && isNull(falseThen))) { + // if(..., null, non-null) or if(..., non-null, null) + replace = new SpecialFormExpression( + ifCondition.getSourceLocation(), + IF, + BIGINT, + ImmutableList.of( + ifCondition.getArguments().get(0), + convertConstant(trueThen), + convertConstant(falseThen))); + } + else { + // if(..., null, null) + checkState(isNull(trueThen) && isNull(falseThen), + "expected true (%s) and false (%s) predicates to be null", + trueThen, falseThen); + replace = convertConstant(trueThen); + } + return replace; + } + + private boolean aggregationIsReplaceable(AggregationNode.Aggregation aggregation, Assignments inputs) + { + RowExpression argument = aggregation.getArguments().get(0); + RowExpression ifCondition = null; + RowExpression trueThen = null; + RowExpression falseThen = null; + + if (argument instanceof VariableReferenceExpression) { + ifCondition = inputs.get((VariableReferenceExpression) argument); + } + + if (ifCondition instanceof SpecialFormExpression && ((SpecialFormExpression) ifCondition).getForm() == IF) { + trueThen = ((SpecialFormExpression) ifCondition).getArguments().get(1); + falseThen = ((SpecialFormExpression) ifCondition).getArguments().get(2); + } + + return trueThen instanceof ConstantExpression && + falseThen instanceof ConstantExpression && + (isNull(trueThen) || isNull(falseThen)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteConstantArrayContainsToInExpression.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteConstantArrayContainsToInExpression.java index 9fefbaa902169..14796c655117b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteConstantArrayContainsToInExpression.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteConstantArrayContainsToInExpression.java @@ -32,7 +32,7 @@ import java.util.Set; -import static com.facebook.presto.SystemSessionProperties.isRwriteConstantArrayContainsToInExpressionEnabled; +import static com.facebook.presto.SystemSessionProperties.isRewriteConstantArrayContainsToInExpressionEnabled; import static com.facebook.presto.common.type.DateType.DATE; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP_MICROSECONDS; @@ -53,7 +53,7 @@ public RewriteConstantArrayContainsToInExpression(FunctionAndTypeManager functio @Override public boolean isRewriterEnabled(Session session) { - return isRwriteConstantArrayContainsToInExpressionEnabled(session); + return isRewriteConstantArrayContainsToInExpressionEnabled(session); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..1857dafa9493c --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.NoSuchElementException; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Iterators.getOnlyElement; +/** + * Rewrite a TableFunctionProcessorNode into a Project node if the table function is exclude_columns. + *

    + * - TableFunctionProcessorNode
    + *   propperOutputs=[A, B]
    + *   passthroughColumns=[C, D]
    + *   - (input) plan which produces symbols [A, B, C, D]
    + * 
    + * into + *
    + * - Project
    + *   assignments={A, B, C, D}
    + *   - (input) plan which produces symbols [A, B, C, D]
    + * 
    + */ +public class RewriteExcludeColumnsFunctionToProjection + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!(node.getHandle().getFunctionHandle() instanceof ExcludeColumnsFunctionHandle)) { + return Result.empty(); + } + + List inputSymbols = getOnlyElement(node.getRequiredVariables().iterator()); + List outputSymbols = node.getOutputVariables(); + + checkState(inputSymbols.size() == outputSymbols.size(), "inputSymbols size differs from outputSymbols size"); + Assignments.Builder assignments = Assignments.builder(); + for (int i = 0; i < outputSymbols.size(); i++) { + assignments.put(outputSymbols.get(i), inputSymbols.get(i)); + } + + return Result.ofPlanNode(new ProjectNode( + node.getId(), + node.getSource().orElseThrow(NoSuchElementException::new), + assignments.build())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteRowExpressions.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteRowExpressions.java new file mode 100644 index 0000000000000..cfd7a421de0f2 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteRowExpressions.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.annotations.VisibleForTesting; + +import static com.facebook.presto.SystemSessionProperties.getExpressionOptimizerInRowExpressionRewrite; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static java.util.Objects.requireNonNull; + +/** + * A rule set that rewrites row expressions using a custom expression optimizer. + */ +public class RewriteRowExpressions + extends RowExpressionRewriteRuleSet +{ + public RewriteRowExpressions(ExpressionOptimizerManager expressionOptimizerManager) + { + super(new Rewriter(expressionOptimizerManager)); + } + + @Override + public boolean isRewriterEnabled(Session session) + { + return !getExpressionOptimizerInRowExpressionRewrite(session).isEmpty(); + } + + private static class Rewriter + implements PlanRowExpressionRewriter + { + private final ExpressionOptimizerManager expressionOptimizerManager; + + public Rewriter(ExpressionOptimizerManager expressionOptimizerManager) + { + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); + } + + @Override + public RowExpression rewrite(RowExpression expression, Rule.Context context) + { + return rewrite(expression, context.getSession()); + } + + private RowExpression rewrite(RowExpression expression, Session session) + { + if (getExpressionOptimizerInRowExpressionRewrite(session).isEmpty()) { + return expression; + } + ExpressionOptimizer optimizer = expressionOptimizerManager.getExpressionOptimizer(getExpressionOptimizerInRowExpressionRewrite(session)); + return optimizer.optimize(expression, OPTIMIZED, session.toConnectorSession()); + } + } + + @VisibleForTesting + public static RowExpression rewrite(RowExpression expression, Session session, ExpressionOptimizerManager expressionOptimizerManager) + { + return new Rewriter(expressionOptimizerManager).rewrite(expression, session); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java index e969326ecfff5..9938f9ea4c740 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RowExpressionRewriteRuleSet.java @@ -218,6 +218,9 @@ public Result apply(SpatialJoinNode spatialJoinNode, Captures captures, Context spatialJoinNode.getLeft(), spatialJoinNode.getRight(), spatialJoinNode.getOutputVariables(), + spatialJoinNode.getProbeGeometryVariable(), + spatialJoinNode.getBuildGeometryVariable(), + spatialJoinNode.getRadiusVariable(), rewritten, spatialJoinNode.getLeftPartitionVariable(), spatialJoinNode.getRightPartitionVariable(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java new file mode 100644 index 0000000000000..6cbf4378b334b --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionProcessorToTableScan.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.matching.Pattern.empty; +import static com.facebook.presto.sql.planner.plan.Patterns.sources; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/* + * This rule converts connector-resolvable TableFunctionProcessorNodes into equivalent + * TableScanNodes by invoking the connector's applyTableFunction() method during query planning. + * + * It enables table-valued functions whose results can be represented as a ConnectorTableHandle + * to be treated like regular table scans, allowing them to benefit from standard scan optimizations. + * + * Example: + * Before Transformation: + * TableFunction(my_function(arg1, arg2)) + * + * After Transformation: + * TableScan(my_function(arg1, arg2)) + * assignments: { + * outputVar1 -> my_function(arg1, arg2)_colHandle1, + * outputVar2 -> my_function(arg1, arg2)_colHandle2 + * } + */ +public class TransformTableFunctionProcessorToTableScan + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor() + .with(empty(sources())); + + private final Metadata metadata; + + public TransformTableFunctionProcessorToTableScan(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + Optional> result = metadata.applyTableFunction(context.getSession(), node.getHandle()); + + if (!result.isPresent()) { + return Result.empty(); + } + + List columnHandles = result.get().getColumnHandles(); + checkState(node.getOutputVariables().size() == columnHandles.size(), + "Connector returned %s columns but TableFunctionProcessorNode expects %s outputs", + columnHandles.size(), node.getOutputVariables().size()); + ImmutableMap.Builder assignments = ImmutableMap.builder(); + for (int i = 0; i < columnHandles.size(); i++) { + assignments.put(node.getOutputVariables().get(i), columnHandles.get(i)); + } + + return Result.ofPlanNode(new TableScanNode( + node.getSourceLocation(), + node.getId(), + result.get().getTableHandle(), + node.getOutputVariables(), + assignments.buildOrThrow(), + TupleDomain.all(), + TupleDomain.all(), Optional.empty())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..8d143ea0d006e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/TransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1032 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.WindowNode.Frame; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.ROWS; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.relational.Expressions.coalesce; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

    + * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

    + * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

    + * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

    + * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

    + * - TableFunction foo
    + *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
    + *      - source T2(a2, b2) PARTITION BY a2
    + * 
    + * Is transformed into: + *
    + * - TableFunctionDataProcessor foo
    + *      PARTITION BY (a1, a2), ORDER BY combined_row_number
    + *      - Project
    + *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
    + *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
    + *          - Project
    + *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
    + *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
    + *              - FULL Join
    + *                  [table1_row_number = table2_row_number OR
    + *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
    + *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
    + *                  - Window [PARTITION BY a1 ORDER BY b1]
    + *                      table1_row_number <= row_number()
    + *                      table1_partition_size <= count()
    + *                          - source T1(a1, b1)
    + *                  - Window [PARTITION BY a2]
    + *                      table2_row_number <= row_number()
    + *                      table2_partition_size <= count()
    + *                          - source T2(a2, b2)
    + * 
    + */ +public class TransformTableFunctionToTableFunctionProcessor + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public TransformTableFunctionToTableFunctionProcessor(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + + // Create call expression for row_number + FunctionHandle rowNumberFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("row_number")), + ImmutableList.of()); + + FunctionMetadata rowNumberFunctionMetadata = functionAndTypeManager.getFunctionMetadata(rowNumberFunctionHandle); + CallExpression rowNumberFunction = new CallExpression("row_number", rowNumberFunctionHandle, functionAndTypeManager.getType(rowNumberFunctionMetadata.getReturnType()), ImmutableList.of()); + + // Create call expression for count + FunctionHandle countFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("count")), + ImmutableList.of()); + + FunctionMetadata countFunctionMetadata = functionAndTypeManager.getFunctionMetadata(countFunctionHandle); + CallExpression countFunction = new CallExpression("count", countFunctionHandle, functionAndTypeManager.getType(countFunctionMetadata.getReturnType()), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context, metadata)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithVariables finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithVariables first = intermediateResultSources.get(0); + NodeWithVariables second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context, metadata); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithVariables joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context, metadata); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + VariableReferenceExpression finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context, metadata); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.variableToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + ImmutableList.Builder newOrderings = ImmutableList.builder(); + newOrderings.add(new Ordering(finalRowNumberSymbol, ASC_NULLS_LAST)); + Optional finalOrderBy = Optional.of(new OrderingScheme(newOrderings.build())); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredVariables = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithVariables planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + VariableReferenceExpression rowNumber = context.getVariableAllocator().newVariable(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputVariables().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + VariableReferenceExpression partitionSize = context.getVariableAllocator().newVariable(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode innerWindow = new WindowNode( + source.getSourceLocation(), + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + PlanNode window = new WindowNode( + innerWindow.getSourceLocation(), + context.getIdAllocator().getNextId(), + innerWindow, + specification, + ImmutableMap.of( + partitionSize, new WindowNode.Function(countFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithVariables(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithVariables copartition( + List sourceList, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context, + Metadata metadata) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithVariables first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithVariables second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context, metadata); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithVariables copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + NodeWithVariables next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context, metadata); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + } + + private static JoinedNodes copartition(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + + Optional copartitionConjuncts = Streams.zip( + left.partitionBy.stream(), + right.partitionBy.stream(), + (leftColumn, rightColumn) -> new CallExpression("NOT", + functionResolution.notFunction(), + BOOLEAN, + ImmutableList.of( + new CallExpression(IS_DISTINCT_FROM.name(), + functionResolution.comparisonFunction(IS_DISTINCT_FROM, leftColumn.getType(), rightColumn.getType()), + BOOLEAN, + ImmutableList.of(leftColumn, rightColumn))))) + .map(expr -> expr) + .reduce((expr, conjunct) -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(expr, conjunct))); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + + SpecialFormExpression orExpression = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + RowExpression joinCondition = copartitionConjuncts.map( + conjunct -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(conjunct, orExpression))) + .orElse(orExpression); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context, + Metadata metadata) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftRowNumber(), + copartitionedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftPartitionSize(), + copartitionedNodes.rightPartitionSize())); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + VariableReferenceExpression leftColumn = copartitionedNodes.leftPartitionBy().get(i); + VariableReferenceExpression rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getVariableAllocator().getVariables().get(leftColumn.getName()); + + VariableReferenceExpression joinedColumn = context.getVariableAllocator().newVariable("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, coalesce(leftColumn, rightColumn)); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putAll( + copartitionedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + RowExpression joinCondition = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context, Metadata metadata) + { + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftRowNumber(), + joinedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftPartitionSize(), + joinedNodes.rightPartitionSize())); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putAll( + joinedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set variables, VariableReferenceExpression referenceSymbol, Context context, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll( + node.getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))); + + ImmutableMap.Builder variablesToMarkers = ImmutableMap.builder(); + + for (VariableReferenceExpression variable : variables) { + VariableReferenceExpression marker = context.getVariableAllocator().newVariable("marker", BIGINT); + variablesToMarkers.put(variable, marker); + RowExpression ifExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + EQUAL.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.EQUAL, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of(variable, referenceSymbol)), + variable, + new ConstantExpression(null, BIGINT))); + assignments.put(marker, ifExpression); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, variablesToMarkers.buildOrThrow()); + } + + private static class SourceWithProperties + { + private final PlanNode source; + private final TableArgumentProperties properties; + + public SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + this.source = requireNonNull(source, "source is null"); + this.properties = requireNonNull(properties, "properties is null"); + } + + public PlanNode source() + { + return source; + } + + public TableArgumentProperties properties() + { + return properties; + } + } + + public static final class NodeWithVariables + { + private final PlanNode node; + private final VariableReferenceExpression rowNumber; + private final VariableReferenceExpression partitionSize; + private final List partitionBy; + private final boolean pruneWhenEmpty; + private final Map rowNumberSymbolsMapping; + + public NodeWithVariables(PlanNode node, VariableReferenceExpression rowNumber, VariableReferenceExpression partitionSize, + List partitionBy, boolean pruneWhenEmpty, + Map rowNumberSymbolsMapping) + { + this.node = requireNonNull(node, "node is null"); + this.rowNumber = requireNonNull(rowNumber, "rowNumber is null"); + this.partitionSize = requireNonNull(partitionSize, "partitionSize is null"); + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + + public PlanNode node() + { + return node; + } + + public VariableReferenceExpression rowNumber() + { + return rowNumber; + } + + public VariableReferenceExpression partitionSize() + { + return partitionSize; + } + + public List partitionBy() + { + return partitionBy; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public Map rowNumberSymbolsMapping() + { + return rowNumberSymbolsMapping; + } + } + + private static class JoinedNodes + { + private final PlanNode joinedNode; + private final VariableReferenceExpression leftRowNumber; + private final VariableReferenceExpression leftPartitionSize; + private final List leftPartitionBy; + private final boolean leftPruneWhenEmpty; + private final Map leftRowNumberSymbolsMapping; + private final VariableReferenceExpression rightRowNumber; + private final VariableReferenceExpression rightPartitionSize; + private final List rightPartitionBy; + private final boolean rightPruneWhenEmpty; + private final Map rightRowNumberSymbolsMapping; + + public JoinedNodes( + PlanNode joinedNode, + VariableReferenceExpression leftRowNumber, + VariableReferenceExpression leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + VariableReferenceExpression rightRowNumber, + VariableReferenceExpression rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + this.joinedNode = requireNonNull(joinedNode, "joinedNode is null"); + this.leftRowNumber = requireNonNull(leftRowNumber, "leftRowNumber is null"); + this.leftPartitionSize = requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + this.leftPartitionBy = ImmutableList.copyOf(requireNonNull(leftPartitionBy, "leftPartitionBy is null")); + this.leftPruneWhenEmpty = leftPruneWhenEmpty; + this.leftRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(leftRowNumberSymbolsMapping, "leftRowNumberSymbolsMapping is null")); + this.rightRowNumber = requireNonNull(rightRowNumber, "rightRowNumber is null"); + this.rightPartitionSize = requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + this.rightPartitionBy = ImmutableList.copyOf(requireNonNull(rightPartitionBy, "rightPartitionBy is null")); + this.rightPruneWhenEmpty = rightPruneWhenEmpty; + this.rightRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(rightRowNumberSymbolsMapping, "rightRowNumberSymbolsMapping is null")); + } + + public PlanNode joinedNode() + { + return joinedNode; + } + public VariableReferenceExpression leftRowNumber() + { + return leftRowNumber; + } + public VariableReferenceExpression leftPartitionSize() + { + return leftPartitionSize; + } + public List leftPartitionBy() + { + return leftPartitionBy; + } + public boolean leftPruneWhenEmpty() + { + return leftPruneWhenEmpty; + } + public Map leftRowNumberSymbolsMapping() + { + return leftRowNumberSymbolsMapping; + } + public VariableReferenceExpression rightRowNumber() + { + return rightRowNumber; + } + public VariableReferenceExpression rightPartitionSize() + { + return rightPartitionSize; + } + public List rightPartitionBy() + { + return rightPartitionBy; + } + public boolean rightPruneWhenEmpty() + { + return rightPruneWhenEmpty; + } + public Map rightRowNumberSymbolsMapping() + { + return rightRowNumberSymbolsMapping; + } + } + + private static class NodeWithMarkers + { + private final PlanNode node; + private final Map variableToMarker; + + public NodeWithMarkers(PlanNode node, Map variableToMarker) + { + this.node = requireNonNull(node, "node is null"); + this.variableToMarker = ImmutableMap.copyOf(requireNonNull(variableToMarker, "symbolToMarker is null")); + } + + public PlanNode node() + { + return node; + } + + public Map variableToMarker() + { + return variableToMarker; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java index 9660296d02931..f562d79b50c32 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ActualProperties.java @@ -25,8 +25,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Collection; import java.util.HashMap; @@ -50,6 +49,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.transform; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class ActualProperties @@ -57,11 +57,22 @@ public class ActualProperties private final Global global; private final List> localProperties; private final Map constants; + // Used to track the properties of the unique row_id + private final Optional propertiesFromUniqueColumn; private ActualProperties( Global global, List> localProperties, Map constants) + { + this(global, localProperties, constants, Optional.empty()); + } + + private ActualProperties( + Global global, + List> localProperties, + Map constants, + Optional propertiesFromUniqueColumn) { requireNonNull(global, "globalProperties is null"); requireNonNull(localProperties, "localProperties is null"); @@ -86,6 +97,8 @@ private ActualProperties( this.localProperties = ImmutableList.copyOf(updatedLocalProperties); this.constants = ImmutableMap.copyOf(constants); + propertiesFromUniqueColumn.ifPresent(actualProperties -> checkArgument(!actualProperties.getPropertiesFromUniqueColumn().isPresent())); + this.propertiesFromUniqueColumn = propertiesFromUniqueColumn; } public boolean isCoordinatorOnly() @@ -93,6 +106,11 @@ public boolean isCoordinatorOnly() return global.isCoordinatorOnly(); } + public Optional getPropertiesFromUniqueColumn() + { + return propertiesFromUniqueColumn; + } + /** * @return true if the plan will only execute on a single node */ @@ -121,6 +139,16 @@ public boolean isStreamPartitionedOn(Collection col } } + public boolean isStreamPartitionedOnAdditionalProperty(Collection columns, boolean exactly) + { + if (exactly) { + return propertiesFromUniqueColumn.isPresent() && propertiesFromUniqueColumn.get().global.isStreamPartitionedOnExactly(columns, ImmutableSet.of(), false); + } + else { + return propertiesFromUniqueColumn.isPresent() && propertiesFromUniqueColumn.get().global.isStreamPartitionedOn(columns, ImmutableSet.of(), false); + } + } + public boolean isNodePartitionedOn(Collection columns, boolean exactly) { return isNodePartitionedOn(columns, false, exactly); @@ -136,6 +164,16 @@ public boolean isNodePartitionedOn(Collection colum } } + public boolean isNodePartitionedOnAdditionalProperty(Collection columns, boolean exactly) + { + if (exactly) { + return propertiesFromUniqueColumn.isPresent() && propertiesFromUniqueColumn.get().global.isNodePartitionedOnExactly(columns, ImmutableSet.of(), false); + } + else { + return propertiesFromUniqueColumn.isPresent() && propertiesFromUniqueColumn.get().global.isNodePartitionedOn(columns, ImmutableSet.of(), false); + } + } + @Deprecated public boolean isCompatibleTablePartitioningWith(Partitioning partitioning, boolean nullsAndAnyReplicated, Metadata metadata, Session session) { @@ -195,6 +233,13 @@ public ActualProperties translateVariable(Function newAdditionalProperty = Optional.empty(); + if (propertiesFromUniqueColumn.isPresent()) { + ActualProperties translatedAdditionalProperty = propertiesFromUniqueColumn.get().translateVariable(translator); + if (!translatedAdditionalProperty.getLocalProperties().isEmpty()) { + newAdditionalProperty = Optional.of(translatedAdditionalProperty); + } + } return builder() .global(global.translateVariableToRowExpression(variable -> { Optional translated = translator.apply(variable).map(RowExpression.class::cast); @@ -205,6 +250,7 @@ public ActualProperties translateVariable(Function !inputToOutputVariables.containsKey(entry.getKey())) .forEach(inputToOutputMappings::put); + + Optional newAdditionalProperty = Optional.empty(); + if (propertiesFromUniqueColumn.isPresent()) { + ActualProperties translatedAdditionalProperty = propertiesFromUniqueColumn.get().translateRowExpression(assignments); + if (!translatedAdditionalProperty.getLocalProperties().isEmpty()) { + newAdditionalProperty = Optional.of(translatedAdditionalProperty); + } + } + return builder() .global(global.translateRowExpression(inputToOutputMappings.build(), assignments)) .local(LocalProperties.translate(localProperties, variable -> Optional.ofNullable(inputToOutputVariables.get(variable)))) .constants(translatedConstants) + .propertiesFromUniqueColumn(newAdditionalProperty) .build(); } @@ -275,6 +331,7 @@ public static class Builder private List> localProperties; private Map constants; private boolean unordered; + private Optional propertiesFromUniqueColumn; public Builder() { @@ -282,10 +339,16 @@ public Builder() } public Builder(Global global, List> localProperties, Map constants) + { + this(global, localProperties, constants, Optional.empty()); + } + + public Builder(Global global, List> localProperties, Map constants, Optional propertiesFromUniqueColumn) { this.global = requireNonNull(global, "global is null"); this.localProperties = ImmutableList.copyOf(localProperties); this.constants = ImmutableMap.copyOf(constants); + this.propertiesFromUniqueColumn = propertiesFromUniqueColumn; } public Builder global(Global global) @@ -318,6 +381,18 @@ public Builder unordered(boolean unordered) return this; } + public Builder propertiesFromUniqueColumn(Optional propertiesFromUniqueColumn) + { + if (propertiesFromUniqueColumn.isPresent() && !propertiesFromUniqueColumn.get().getLocalProperties().isEmpty()) { + checkArgument(propertiesFromUniqueColumn.get().getLocalProperties().size() == 1); + this.propertiesFromUniqueColumn = propertiesFromUniqueColumn; + } + else { + this.propertiesFromUniqueColumn = Optional.empty(); + } + return this; + } + public ActualProperties build() { List> localProperties = this.localProperties; @@ -333,14 +408,19 @@ public ActualProperties build() } localProperties = newLocalProperties.build(); } - return new ActualProperties(global, localProperties, constants); + if (propertiesFromUniqueColumn.isPresent() && unordered) { + propertiesFromUniqueColumn = Optional.of(ActualProperties.builderFrom(propertiesFromUniqueColumn.get()) + .unordered(unordered) + .build()); + } + return new ActualProperties(global, localProperties, constants, propertiesFromUniqueColumn); } } @Override public int hashCode() { - return Objects.hash(global, localProperties, constants.keySet()); + return Objects.hash(global, localProperties, constants.keySet(), propertiesFromUniqueColumn); } @Override @@ -355,7 +435,8 @@ public boolean equals(Object obj) final ActualProperties other = (ActualProperties) obj; return Objects.equals(this.global, other.global) && Objects.equals(this.localProperties, other.localProperties) - && Objects.equals(this.constants.keySet(), other.constants.keySet()); + && Objects.equals(this.constants.keySet(), other.constants.keySet()) + && Objects.equals(this.propertiesFromUniqueColumn, other.propertiesFromUniqueColumn); } @Override @@ -365,6 +446,7 @@ public String toString() .add("globalProperties", global) .add("localProperties", localProperties) .add("constants", constants) + .add("propertiesFromUniqueColumn", propertiesFromUniqueColumn) .toString(); } @@ -392,7 +474,7 @@ private Global(Optional nodePartitioning, Optional s || !streamPartitioning.isPresent() || nodePartitioning.get().getVariableReferences().containsAll(streamPartitioning.get().getVariableReferences()) || streamPartitioning.get().getVariableReferences().containsAll(nodePartitioning.get().getVariableReferences()), - "Global stream partitioning columns should match node partitioning columns"); + format("Global stream partitioning columns should match node partitioning columns, nodePartitioning: %s, streamPartitioning: %s", nodePartitioning, streamPartitioning)); this.nodePartitioning = requireNonNull(nodePartitioning, "nodePartitioning is null"); this.streamPartitioning = requireNonNull(streamPartitioning, "streamPartitioning is null"); this.nullsAndAnyReplicated = nullsAndAnyReplicated; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 7fe4f5d08937a..083b88040db61 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -18,7 +18,9 @@ import com.facebook.presto.connector.system.GlobalSystemConnector; import com.facebook.presto.execution.QueryManagerConfig.ExchangeMaterializationStrategy; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.PrestoException; @@ -32,12 +34,14 @@ import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.InputDistribution; import com.facebook.presto.spi.plan.JoinDistributionType; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -53,30 +57,36 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.statistics.TableStatistics; import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.PartialMergePushdownStrategy; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.PreferredProperties.PartitioningProperties; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ChildReplacer; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExchangeNode.Scope; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.UpdateNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; @@ -96,9 +106,12 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; import java.util.stream.Stream; import static com.facebook.presto.SystemSessionProperties.getAggregationPartitioningMergingStrategy; @@ -106,6 +119,10 @@ import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; import static com.facebook.presto.SystemSessionProperties.getPartialMergePushdownStrategy; import static com.facebook.presto.SystemSessionProperties.getPartitioningProviderCatalog; +import static com.facebook.presto.SystemSessionProperties.getRemoteFunctionFixedParallelismTaskCount; +import static com.facebook.presto.SystemSessionProperties.getRemoteFunctionNamesForFixedParallelism; +import static com.facebook.presto.SystemSessionProperties.getTableScanShuffleParallelismThreshold; +import static com.facebook.presto.SystemSessionProperties.getTableScanShuffleStrategy; import static com.facebook.presto.SystemSessionProperties.getTaskPartitionedWriterCount; import static com.facebook.presto.SystemSessionProperties.isAddPartialNodeForRowNumberWithLimit; import static com.facebook.presto.SystemSessionProperties.isColocatedJoinEnabled; @@ -125,6 +142,9 @@ import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR; import static com.facebook.presto.spi.plan.LimitNode.Step.PARTIAL; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShuffleForTableScanStrategy.ALWAYS_ENABLED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShuffleForTableScanStrategy.COST_BASED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.ShuffleForTableScanStrategy.DISABLED; import static com.facebook.presto.sql.planner.FragmentTableScanCounter.getNumberOfTableScans; import static com.facebook.presto.sql.planner.FragmentTableScanCounter.hasMultipleTableScans; import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; @@ -207,6 +227,7 @@ private class Rewriter private final ExchangeMaterializationStrategy exchangeMaterializationStrategy; private final PartitioningProviderManager partitioningProviderManager; private final boolean nativeExecution; + private boolean isDeleteOrUpdateQuery; public Rewriter( PlanNodeIdAllocator idAllocator, @@ -243,8 +264,31 @@ public PlanWithProperties visitProject(ProjectNode node, PreferredProperties pre { Map identities = computeIdentityTranslations(node.getAssignments()); PreferredProperties translatedPreferred = preferredProperties.translate(symbol -> Optional.ofNullable(identities.get(symbol))); + PlanWithProperties planWithProperties = planChild(node, translatedPreferred); + + if (node.getLocality().equals(ProjectNode.Locality.REMOTE)) { + String functionNameRegex = getRemoteFunctionNamesForFixedParallelism(session); + if (!functionNameRegex.isEmpty()) { + Pattern pattern; + try { + pattern = Pattern.compile(functionNameRegex); + } + catch (PatternSyntaxException e) { + return rebaseAndDeriveProperties(node, planWithProperties); + } + if (node.getAssignments().getExpressions().stream().filter(x -> x instanceof CallExpression) + .anyMatch(x -> pattern.matcher(((CallExpression) x).getFunctionHandle().getName()).matches())) { + int taskCount = getRemoteFunctionFixedParallelismTaskCount(session); + checkState(taskCount > 0, "taskCount should be larger than 0"); + PlanNode newNode = roundRobinExchange(idAllocator.getNextId(), REMOTE_STREAMING, planWithProperties.getNode(), taskCount); + newNode = ChildReplacer.replaceChildren(node, ImmutableList.of(newNode)); + newNode = roundRobinExchange(idAllocator.getNextId(), REMOTE_STREAMING, newNode); + return new PlanWithProperties(newNode, derivePropertiesRecursively(newNode)); + } + } + } - return rebaseAndDeriveProperties(node, planChild(node, translatedPreferred)); + return rebaseAndDeriveProperties(node, planWithProperties); } @Override @@ -310,7 +354,8 @@ public PlanWithProperties visitAggregation(AggregationNode node, PreferredProper child.getProperties()); } else if (hasMixedGroupingSets - || !isStreamPartitionedOn(child.getProperties(), partitioningRequirement) && !isNodePartitionedOn(child.getProperties(), partitioningRequirement)) { + || !isStreamPartitionedOn(child.getProperties(), partitioningRequirement) && !isNodePartitionedOn(child.getProperties(), partitioningRequirement) + && !isNodePartitionedOnAdditionalProperty(child.getProperties(), partitioningRequirement) && !isStreamPartitionedOnAdditionalProperty(child.getProperties(), partitioningRequirement)) { child = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), @@ -409,6 +454,64 @@ public PlanWithProperties visitWindow(WindowNode node, PreferredProperties prefe return rebaseAndDeriveProperties(node, child); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredProperties preferredProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + if (!node.getSource().isPresent()) { + return new PlanWithProperties(node, deriveProperties(node, ImmutableList.of())); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. A single source with row semantics can be distributed arbitrarily. + PlanWithProperties child = planChild(node, PreferredProperties.any()); + return rebaseAndDeriveProperties(node, child); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().orElseThrow(NoSuchElementException::new) + .getOrderingScheme() + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + + PlanWithProperties child = planChild(node, PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(partitionBy), desiredProperties)); + + // TODO do not gather if already gathered + if (!node.isPruneWhenEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else if (!isStreamPartitionedOn(child.getProperties(), partitionBy) && + !isNodePartitionedOn(child.getProperties(), partitionBy)) { + if (partitionBy.isEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else { + child = withDerivedProperties( + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode(), Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionBy), node.getHashSymbol()), + child.getProperties()); + } + } + + return rebaseAndDeriveProperties(node, child); + } + @Override public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties preferredProperties) { @@ -540,9 +643,17 @@ public PlanWithProperties visitTopN(TopNNode node, PreferredProperties preferred return rebaseAndDeriveProperties(node, child); } + @Override + public PlanWithProperties visitUpdate(UpdateNode node, PreferredProperties context) + { + isDeleteOrUpdateQuery = true; + return visitPlan(node, context); + } + @Override public PlanWithProperties visitDelete(DeleteNode node, PreferredProperties preferredProperties) { + isDeleteOrUpdateQuery = true; if (!node.getInputDistribution().isPresent()) { return visitPlan(node, preferredProperties); } @@ -732,13 +843,44 @@ public PlanWithProperties visitTableScan(TableScanNode node, PreferredProperties return planTableScan(node, TRUE_CONSTANT); } + @Override + public PlanWithProperties visitMetadataDelete(MetadataDeleteNode node, PreferredProperties preferredProperties) + { + // MetadataDeleteNode is a leaf node that runs on coordinator + return new PlanWithProperties(node); + } + + @Override + public PlanWithProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, PreferredProperties preferredProperties) + { + Optional partitioningScheme = node.getPartitioningScheme(); + boolean isSingleWriterPerPartitionRequired = partitioningScheme.isPresent(); + return getTableWriterPlanWithProperties(node, preferredProperties, partitioningScheme, isSingleWriterPerPartitionRequired); + } + + @Override + public PlanWithProperties visitMergeWriter(MergeWriterNode node, PreferredProperties preferredProperties) + { + return getTableWriterPlanWithProperties(node, preferredProperties, Optional.empty(), false); + } + @Override public PlanWithProperties visitTableWriter(TableWriterNode node, PreferredProperties preferredProperties) { - PlanWithProperties source = accept(node.getSource(), preferredProperties); + return getTableWriterPlanWithProperties(node, preferredProperties, node.getTablePartitioningScheme(), node.isSingleWriterPerPartitionRequired()); + } + + private PlanWithProperties getTableWriterPlanWithProperties( + PlanNode node, + PreferredProperties preferredProperties, + Optional nodeTablePartitioningScheme, + boolean isSingleWriterPerPartitionRequired) + { + checkArgument(node instanceof TableWriterNode || node instanceof CallDistributedProcedureNode || node instanceof MergeWriterNode); + PlanWithProperties source = accept(node.getSources().get(0), preferredProperties); - Optional shufflePartitioningScheme = node.getTablePartitioningScheme(); - if (!node.isSingleWriterPerPartitionRequired()) { + Optional shufflePartitioningScheme = nodeTablePartitioningScheme; + if (!isSingleWriterPerPartitionRequired) { // prefer scale writers if single writer per partition is not required // TODO: take into account partitioning scheme in scale writer tasks implementation if (scaleWriters) { @@ -752,15 +894,16 @@ else if (redistributeWrites) { } } + PlanWithProperties newSource = source; if (shufflePartitioningScheme.isPresent() && // TODO: Deprecate compatible table partitioning - !source.getProperties().isCompatibleTablePartitioningWith(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && - !(source.getProperties().isRefinedPartitioningOver(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && - canPushdownPartialMerge(source.getNode(), partialMergePushdownStrategy))) { + !newSource.getProperties().isCompatibleTablePartitioningWith(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && + !(newSource.getProperties().isRefinedPartitioningOver(shufflePartitioningScheme.get().getPartitioning(), false, metadata, session) && + canPushdownPartialMerge(newSource.getNode(), partialMergePushdownStrategy))) { PartitioningScheme exchangePartitioningScheme = shufflePartitioningScheme.get(); - if (node.getTablePartitioningScheme().isPresent() && isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(session)) { + if (nodeTablePartitioningScheme.isPresent() && isPrestoSparkAssignBucketToPartitionForPartitionedTableWriteEnabled(session)) { int writerThreadsPerNode = getTaskPartitionedWriterCount(session); - int bucketCount = getBucketCount(node.getTablePartitioningScheme().get().getPartitioning().getHandle()); + int bucketCount = getBucketCount(nodeTablePartitioningScheme.get().getPartitioning().getHandle()); int[] bucketToPartition = new int[bucketCount]; for (int i = 0; i < bucketCount; i++) { bucketToPartition[i] = i / writerThreadsPerNode; @@ -768,15 +911,15 @@ else if (redistributeWrites) { exchangePartitioningScheme = exchangePartitioningScheme.withBucketToPartition(Optional.of(bucketToPartition)); } - source = withDerivedProperties( + newSource = withDerivedProperties( partitionedExchange( idAllocator.getNextId(), REMOTE_STREAMING, - source.getNode(), + newSource.getNode(), exchangePartitioningScheme), - source.getProperties()); + newSource.getProperties()); } - return rebaseAndDeriveProperties(node, source); + return rebaseAndDeriveProperties(node, newSource); } private int getBucketCount(PartitioningHandle partitioning) @@ -803,6 +946,18 @@ private PlanWithProperties planTableScan(TableScanNode node, RowExpression predi if (nativeExecution && containsSystemTableScan(plan)) { plan = gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, plan); } + else if (!getTableScanShuffleStrategy(session).equals(DISABLED) && !isDeleteOrUpdateQuery) { + if (getTableScanShuffleStrategy(session).equals(ALWAYS_ENABLED)) { + plan = roundRobinExchange(idAllocator.getNextId(), REMOTE_STREAMING, plan); + } + else if (getTableScanShuffleStrategy(session).equals(COST_BASED)) { + Constraint constraint = new Constraint<>(node.getCurrentConstraint()); + TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), ImmutableList.copyOf(node.getAssignments().values()), constraint); + if (!tableStatistics.getParallelismFactor().isUnknown() && tableStatistics.getParallelismFactor().getValue() < getTableScanShuffleParallelismThreshold(session)) { + plan = roundRobinExchange(idAllocator.getNextId(), REMOTE_STREAMING, plan); + } + } + } // TODO: Support selecting layout with best local property once connector can participate in query optimization. return new PlanWithProperties(plan, derivePropertiesRecursively(plan)); } @@ -1616,11 +1771,21 @@ private boolean isNodePartitionedOn(ActualProperties properties, Collection columns) + { + return properties.isNodePartitionedOnAdditionalProperty(columns, isExactPartitioningPreferred(session)); + } + private boolean isStreamPartitionedOn(ActualProperties properties, Collection columns) { return properties.isStreamPartitionedOn(columns, isExactPartitioningPreferred(session)); } + private boolean isStreamPartitionedOnAdditionalProperty(ActualProperties properties, Collection columns) + { + return properties.isStreamPartitionedOnAdditionalProperty(columns, isExactPartitioningPreferred(session)); + } + private boolean shouldAggregationMergePartitionPreferences(AggregationPartitioningMergingStrategy aggregationPartitioningMergingStrategy) { if (isExactPartitioningPreferred(session)) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index b1be3d3964611..46d17de0d459c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -22,13 +22,16 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.InputDistribution; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MergeJoinNode; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; @@ -48,14 +51,17 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.collect.ImmutableList; @@ -64,6 +70,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; @@ -108,6 +115,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -307,11 +315,10 @@ public PlanWithProperties visitLimit(LimitNode node, StreamPreferredProperties p } // final limit requires that all data be in one stream - // also, a final changes the input organization completely, so we do not pass through parent preferences return planAndEnforceChildren( node, singleStream(), - defaultParallelism(session)); + parentPreferences.withDefaultParallelism(session)); } @Override @@ -398,7 +405,8 @@ public PlanWithProperties visitAggregation(AggregationNode node, StreamPreferred // [A, B] [(A, C)] -> List.of(Optional.of(GroupingProperty(C))) // [A, B] [(D, A, C)] -> List.of(Optional.of(GroupingProperty(D, C))) List>> matchResult = LocalProperties.match(child.getProperties().getLocalProperties(), LocalProperties.grouped(groupingKeys)); - if (!matchResult.get(0).isPresent()) { + List>> matchResultForAdditional = LocalProperties.match(child.getProperties().getAdditionalLocalProperties(), LocalProperties.grouped(groupingKeys)); + if (!matchResult.get(0).isPresent() || !matchResultForAdditional.get(0).isPresent()) { // !isPresent() indicates the property was satisfied completely preGroupedSymbols = groupingKeys; } @@ -497,6 +505,87 @@ public PlanWithProperties visitDelete(DeleteNode node, StreamPreferredProperties return deriveProperties(result, child.getProperties()); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, StreamPreferredProperties parentPreferences) + { + if (!node.getSource().isPresent()) { + return deriveProperties(node, ImmutableList.of()); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. Source's properties do not hold after the TableFunctionProcessorNode + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), StreamPreferredProperties.any(), StreamPreferredProperties.any()); + return rebaseAndDeriveProperties(node, ImmutableList.of(child)); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + StreamPreferredProperties childRequirements; + if (!node.isPruneWhenEmpty()) { + childRequirements = singleStream(); + } + else { + childRequirements = parentPreferences + .constrainTo(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()) + .withDefaultParallelism(session) + .withPartitioning(partitionBy); + } + + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), childRequirements, childRequirements); + + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification() + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + + Set prePartitionedInputs = ImmutableSet.of(); + if (!partitionBy.isEmpty()) { + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + prePartitionedInputs = partitionBy.stream() + .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .collect(toImmutableSet()); + } + + int preSortedOrderPrefix = 0; + if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) { + while (matchIterator.hasNext() && !matchIterator.next().isPresent()) { + preSortedOrderPrefix++; + } + } + + TableFunctionProcessorNode result = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child.getNode()), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + prePartitionedInputs, + preSortedOrderPrefix, + node.getHashSymbol(), + node.getHandle()); + + return deriveProperties(result, child.getProperties()); + } + @Override public PlanWithProperties visitMarkDistinct(MarkDistinctNode node, StreamPreferredProperties parentPreferences) { @@ -595,6 +684,20 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, StreamPrefe return planAndEnforceChildren(node, requiredProperties, requiredProperties); } + @Override + public PlanWithProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, StreamPreferredProperties parentPreferences) + { + if (node.getPartitioningScheme().isPresent() && getTaskPartitionedWriterCount(session) == 1) { + return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); + } + + if (!node.getPartitioningScheme().isPresent() && getTaskWriterCount(session) == 1) { + return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); + } + + return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); + } + // // Table Writer // @@ -717,6 +820,24 @@ private PlanWithProperties planTableWriteWithTableWriteMerge(TableWriterNode tab gatherExchangeWithProperties.getProperties()); } + private PlanWithProperties visitPartitionedWriter(PlanNode node) + { + if (getTaskWriterCount(session) == 1) { + return planAndEnforceChildren(node, singleStream(), defaultParallelism(session)); + } + return planAndEnforceChildren(node, fixedParallelism(), fixedParallelism()); + } + + // + // Merge + // + + @Override + public PlanWithProperties visitMergeWriter(MergeWriterNode node, StreamPreferredProperties parentPreferences) + { + return visitPartitionedWriter(node); + } + @Override public PlanWithProperties visitTableWriteMerge(TableWriterMergeNode node, StreamPreferredProperties context) { @@ -868,7 +989,8 @@ public PlanWithProperties visitSemiJoin(SemiJoinNode node, StreamPreferredProper parentPreferences.constrainTo(node.getSource().getOutputVariables()).withDefaultParallelism(session)); // this filter source consumes the input completely, so we do not pass through parent preferences - PlanWithProperties filteringSource = planAndEnforce(node.getFilteringSource(), singleStream(), singleStream()); + StreamPreferredProperties filteringPreference = nativeExecution ? defaultParallelism(session) : singleStream(); + PlanWithProperties filteringSource = planAndEnforce(node.getFilteringSource(), filteringPreference, filteringPreference); return rebaseAndDeriveProperties(node, ImmutableList.of(source, filteringSource)); } @@ -887,6 +1009,16 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, StreamPreferred return rebaseAndDeriveProperties(node, ImmutableList.of(probe, build)); } + @Override + public PlanWithProperties visitMergeJoin(MergeJoinNode node, StreamPreferredProperties parentPreferences) + { + // The optimizer rule MergeJoinForSortedInputOptimizer and SortMergeJoinOptimizer which add the merge join node is responsible to ensure the input of the merge join is sorted. + // Here we use `any().withOrderSensitivity()` meaning respect the input distribution of the input and keep the input order. + PlanWithProperties probe = planAndEnforce(node.getLeft(), any().withOrderSensitivity(), any().withOrderSensitivity()); + PlanWithProperties build = planAndEnforce(node.getRight(), any().withOrderSensitivity(), any().withOrderSensitivity()); + return rebaseAndDeriveProperties(node, ImmutableList.of(probe, build)); + } + @Override public PlanWithProperties visitIndexJoin(IndexJoinNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java index 4e8863af458e5..6e32ac3ad33ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AggregationNodeUtils.java @@ -13,18 +13,11 @@ */ package com.facebook.presto.sql.planner.optimizations; -import com.facebook.presto.Session; -import com.facebook.presto.cost.CachingStatsProvider; -import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.StatsCalculator; -import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.plan.AggregationNode; -import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; @@ -33,10 +26,8 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.LOW; import static com.google.common.collect.ImmutableList.toImmutableList; public class AggregationNodeUtils @@ -73,20 +64,6 @@ private static List extractAll(RowExpression expres .collect(toImmutableList()); } - public static boolean isAllLowCardinalityGroupByKeys(AggregationNode aggregationNode, TableScanNode scanNode, Session session, StatsCalculator statsCalculator, TypeProvider types, long count) - { - List groupbyKeys = aggregationNode.getGroupingSets().getGroupingKeys().stream().collect(Collectors.toList()); - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); - PlanNodeStatsEstimate estimate = statsProvider.getStats(scanNode); - if (estimate.confidenceLevel() == LOW) { - // For safety, we assume they are low card if not confident - // TODO(kaikalur) : maybe return low card only for partition keys if/when we can detect that - return true; - } - - return groupbyKeys.stream().noneMatch(x -> estimate.getVariableStatistics(x).getDistinctValuesCount() >= count); - } - public static AggregationNode.Aggregation removeFilterAndMask(AggregationNode.Aggregation aggregation) { Optional filter = aggregation.getFilter(); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java index 3834d137199d4..dcd9435290f36 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ApplyConnectorOptimization.java @@ -17,6 +17,7 @@ import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorPlanOptimizer; +import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; @@ -27,11 +28,13 @@ import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MaterializedViewScanNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -42,6 +45,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.sql.planner.TypeProvider; import com.google.common.base.Supplier; @@ -53,11 +57,13 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Queue; import java.util.Set; +import static com.facebook.presto.SystemSessionProperties.isEmptyConnectorOptimizerEnabled; import static com.facebook.presto.SystemSessionProperties.isIncludeValuesNodeInConnectorOptimizer; import static com.facebook.presto.common.RuntimeUnit.NANO; import static com.facebook.presto.sql.OptimizerRuntimeTrackUtil.getOptimizerNameForLog; @@ -85,17 +91,20 @@ public class ApplyConnectorOptimization ProjectNode.class, AggregationNode.class, MarkDistinctNode.class, + MaterializedViewScanNode.class, UnionNode.class, IntersectNode.class, ExceptNode.class, SemiJoinNode.class, JoinNode.class, + IndexJoinNode.class, + UnnestNode.class, TableWriterNode.class, TableFinishNode.class, DeleteNode.class); // for a leaf node that does not belong to any connector (e.g., ValuesNode) - private static final ConnectorId EMPTY_CONNECTOR_ID = new ConnectorId("$internal$" + ApplyConnectorOptimization.class + "_CONNECTOR"); + private static final ConnectorId EMPTY_CONNECTOR_ID = new ConnectorId("$internal$ApplyConnectorOptimization_EMPTY_CONNECTOR"); private final Supplier>> connectorOptimizersSupplier; @@ -122,6 +131,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider // retrieve all the connectors ImmutableSet.Builder connectorIds = ImmutableSet.builder(); getAllConnectorIds(plan, connectorIds); + Set connectorIdSet = connectorIds.build(); // for each connector, retrieve the set of subplans to optimize // TODO: what if a new connector is added by an existing one @@ -129,79 +139,118 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider // create a UNION_ALL to federate data sources from both C1 and C2 (regardless of the classloader issue). // For such case, it is dangerous to re-calculate the "max closure" given the fixpoint property will be broken. // In order to preserve the fixpoint, we will "pretend" the newly added C2 table scan is part of C1's job to maintain. - for (ConnectorId connectorId : connectorIds.build()) { - Set optimizers = connectorOptimizers.get(connectorId); - if (optimizers == null) { + for (ConnectorId connectorId : connectorIdSet) { + Set optimizers; + if (isEmptyConnectorOptimizerEnabled(session) && connectorIdSet.stream() + .allMatch(x -> x.equals(EMPTY_CONNECTOR_ID)) && session.getCatalog().isPresent()) { + ConnectorId queryConnectorId = new ConnectorId(session.getCatalog().get()); + optimizers = connectorOptimizers.get(queryConnectorId) == null ? null + : connectorOptimizers.get(queryConnectorId).stream() + .filter(x -> x.getSupportedConnectorIds().size() == 1 + && x.getSupportedConnectorIds().get(0).equals(EMPTY_CONNECTOR_ID)) + .collect( + toImmutableSet()); + } + else { + optimizers = connectorOptimizers.get(connectorId); + } + if (optimizers == null || optimizers.isEmpty()) { continue; } - ImmutableMap.Builder contextMapBuilder = ImmutableMap.builder(); - buildConnectorPlanNodeContext(plan, null, contextMapBuilder); - Map contextMap = contextMapBuilder.build(); - - // keep track of changed nodes; the keys are original nodes and the values are the new nodes - Map updates = new HashMap<>(); - - // process connector optimizers - for (PlanNode node : contextMap.keySet()) { - // For a subtree with root `node` to be a max closure, the following conditions must hold: - // * The subtree with root `node` is a closure. - // * `node` has no parent, or the subtree with root as `node`'s parent is not a closure. - ConnectorPlanNodeContext context = contextMap.get(node); - if (!context.isClosure(connectorId, session) || - !context.getParent().isPresent() || - contextMap.get(context.getParent().get()).isClosure(connectorId, session)) { - continue; + ImmutableMap.Builder, Set> optimizersWithConnectorRange = ImmutableMap.builder(); + List currentConnectors = null; + ImmutableSet.Builder currentGroup = null; + for (ConnectorPlanOptimizer optimizer : optimizers) { + List supportedConnectors = optimizer.getSupportedConnectorIds().isEmpty() + ? ImmutableList.of(connectorId) + : optimizer.getSupportedConnectorIds(); + + if (!supportedConnectors.equals(currentConnectors)) { + if (currentGroup != null) { + optimizersWithConnectorRange.put(currentConnectors, currentGroup.build()); + } + currentConnectors = supportedConnectors; + currentGroup = ImmutableSet.builder(); } + currentGroup.add(optimizer); + } + optimizersWithConnectorRange.put(currentConnectors, currentGroup.build()); + + for (Map.Entry, Set> entry : optimizersWithConnectorRange.build().entrySet()) { + // keep track of changed nodes; the keys are original nodes and the values are the new nodes + Map updates = new HashMap<>(); + + ImmutableMap.Builder contextMapBuilder = ImmutableMap.builder(); + buildConnectorPlanNodeContext(plan, null, contextMapBuilder); + Map contextMap = contextMapBuilder.build(); + + // process connector optimizers + for (PlanNode node : contextMap.keySet()) { + // For a subtree with root `node` to be a max closure, the following conditions must hold: + // * The subtree with root `node` is a closure. + // * `node` has no parent, or the subtree with root as `node`'s parent is not a closure. + ConnectorPlanNodeContext context = contextMap.get(node); + if (!context.isClosure(connectorId, session, entry.getKey()) || + !context.getParent().isPresent() || + contextMap.get(context.getParent().get()).isClosure(connectorId, session, entry.getKey())) { + continue; + } - PlanNode newNode = node; - - // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here) - for (ConnectorPlanOptimizer optimizer : optimizers) { - long start = System.nanoTime(); - newNode = optimizer.optimize(newNode, session.toConnectorSession(connectorId), variableAllocator, idAllocator); - if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) { - session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start); + PlanNode newNode = node; + + // the returned node is still a max closure (only if there is no new connector added, which does happen but ignored here) + for (ConnectorPlanOptimizer optimizer : entry.getValue()) { + long start = System.nanoTime(); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + if (isEmptyConnectorOptimizerEnabled(session) && connectorId.equals(EMPTY_CONNECTOR_ID) && session.getCatalog().isPresent()) { + connectorSession = session.toConnectorSession(new ConnectorId(session.getCatalog().get())); + } + checkState(connectorSession.getConnectorId().isPresent()); + newNode = optimizer.optimize(newNode, connectorSession, variableAllocator, idAllocator); + if (enableVerboseRuntimeStats || trackOptimizerRuntime(session, optimizer)) { + session.getRuntimeStats().addMetricValue(String.format("optimizer%sTimeNanos", getOptimizerNameForLog(optimizer)), NANO, System.nanoTime() - start); + } } - } - if (node != newNode) { - // the optimizer has allocated a new PlanNode - checkState( - containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()), - "the connector optimizer from %s returns a node that does not cover all output before optimization", - connectorId); + if (node != newNode) { + // the optimizer has allocated a new PlanNode + checkState( + containsAll(ImmutableSet.copyOf(newNode.getOutputVariables()), node.getOutputVariables()), + "the connector optimizer from %s returns a node that does not cover all output before optimization", + connectorId); - updates.put(node, newNode); - } - } - // up to this point, we have a set of updated nodes; need to recursively update their parents - - // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm) - // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes - Queue originalNodes = new LinkedList<>(updates.keySet()); - while (!originalNodes.isEmpty()) { - PlanNode originalNode = originalNodes.poll(); - - if (!contextMap.get(originalNode).getParent().isPresent()) { - // originalNode must be the root; update the plan - plan = updates.get(originalNode); - continue; + updates.put(node, newNode); + } } + // up to this point, we have a set of updated nodes; need to recursively update their parents + + // alter the plan with a bottom-up approach (but does not have to be strict bottom-up to guarantee the correctness of the algorithm) + // use "original nodes" to keep track of the plan structure and "updates" to keep track of the new nodes + Queue originalNodes = new LinkedList<>(updates.keySet()); + while (!originalNodes.isEmpty()) { + PlanNode originalNode = originalNodes.poll(); + + if (!contextMap.get(originalNode).getParent().isPresent()) { + // originalNode must be the root; update the plan + plan = updates.get(originalNode); + continue; + } - PlanNode originalParent = contextMap.get(originalNode).getParent().get(); + PlanNode originalParent = contextMap.get(originalNode).getParent().get(); - // need to create a new parent given the child has changed; the new parent needs to point to the new child. - // if a node has been updated, it will occur in `updates`; otherwise, just use the original node - ImmutableList.Builder newChildren = ImmutableList.builder(); - originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child))); - PlanNode newParent = originalParent.replaceChildren(newChildren.build()); + // need to create a new parent given the child has changed; the new parent needs to point to the new child. + // if a node has been updated, it will occur in `updates`; otherwise, just use the original node + ImmutableList.Builder newChildren = ImmutableList.builder(); + originalParent.getSources().forEach(child -> newChildren.add(updates.getOrDefault(child, child))); + PlanNode newParent = originalParent.replaceChildren(newChildren.build()); - // mark the new parent as updated - updates.put(originalParent, newParent); + // mark the new parent as updated + updates.put(originalParent, newParent); - // enqueue the parent node in order to recursively update its ancestors - originalNodes.add(originalParent); + // enqueue the parent node in order to recursively update its ancestors + originalNodes.add(originalParent); + } } } @@ -304,17 +353,19 @@ public Set> getReachablePlanNodeTypes() return reachablePlanNodeTypes; } - boolean isClosure(ConnectorId connectorId, Session session) + boolean isClosure(ConnectorId connectorId, Session session, List supportedConnectorId) { + if (isEmptyConnectorOptimizerEnabled(session) && reachableConnectors.stream().allMatch(x -> x.equals(EMPTY_CONNECTOR_ID)) && supportedConnectorId.size() == 1 && supportedConnectorId.get(0).equals(EMPTY_CONNECTOR_ID)) { + return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes); + } // check if all children can reach the only connector boolean includeValuesNode = isIncludeValuesNodeInConnectorOptimizer(session); Set connectorIds = includeValuesNode ? reachableConnectors.stream().filter(x -> !x.equals(EMPTY_CONNECTOR_ID)).collect(toImmutableSet()) : reachableConnectors; - if (connectorIds.size() != 1 || !connectorIds.contains(connectorId)) { - return false; + if (connectorIds.contains(connectorId) && new HashSet<>(supportedConnectorId).containsAll(connectorIds) && supportedConnectorId.size() == connectorIds.size()) { + // check if all children are accessible by connectors + return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes); } - - // check if all children are accessible by connectors - return containsAll(CONNECTOR_ACCESSIBLE_PLAN_NODES, reachablePlanNodeTypes); + return false; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 75e1368958b58..84f42ef283d6b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.MarkDistinctNode; @@ -34,6 +35,7 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; @@ -43,13 +45,11 @@ import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.BiMap; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; @@ -506,7 +506,8 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, HashComputationSet node.getCriteria(), node.getFilter(), Optional.of(probeHashVariable), - Optional.of(indexHashVariable)), + Optional.of(indexHashVariable), + node.getLookupVariables()), allHashVariables); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java index 7b7f9f9e6bf1c..2091c17a9a5a0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/IndexJoinOptimizer.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.metadata.Metadata; @@ -22,9 +21,12 @@ import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; @@ -34,11 +36,13 @@ import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.DomainTranslator.ExtractionResult; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SimplePlanVisitor; import com.facebook.presto.sql.planner.TypeProvider; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.relational.FunctionResolution; @@ -51,19 +55,23 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Predicates.in; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; @@ -94,11 +102,11 @@ public PlanOptimizerResult optimize( requireNonNull(idAllocator, "idAllocator is null"); IndexJoinRewriter rewriter; - if (SystemSessionProperties.isNativeExecutionEnabled(session)) { - rewriter = new NativeRewriter(idAllocator, metadata, session); + if (isNativeExecutionEnabled(session)) { + rewriter = new NativeIndexJoinRewriter(idAllocator, metadata, session); } else { - rewriter = new DefaultRewriter(idAllocator, metadata, session); + rewriter = new DefaultIndexJoinRewriter(idAllocator, metadata, session); } PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); @@ -142,6 +150,7 @@ protected static PlanNode createIndexJoinWithExpectedOutputs( PlanNode index, List equiJoinClause, Optional filter, + List lookupVariables, PlanNodeIdAllocator idAllocator) { PlanNode result = new IndexJoinNode( @@ -153,7 +162,8 @@ protected static PlanNode createIndexJoinWithExpectedOutputs( equiJoinClause, filter, Optional.empty(), - Optional.empty()); + Optional.empty(), + lookupVariables); if (!result.getOutputVariables().equals(expectedOutputs)) { result = new ProjectNode( idAllocator.getNextId(), @@ -164,10 +174,10 @@ protected static PlanNode createIndexJoinWithExpectedOutputs( } } - private static class DefaultRewriter + private static class DefaultIndexJoinRewriter extends IndexJoinRewriter { - private DefaultRewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) + private DefaultIndexJoinRewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { super(idAllocator, metadata, session); } @@ -192,7 +202,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Sanity check that we can trace the path for the index lookup key Map trace = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinVariables)); - checkState(!trace.isEmpty() && leftJoinVariables.containsAll(trace.keySet())); + checkState(!trace.isEmpty() && ImmutableSet.copyOf(leftJoinVariables).containsAll(trace.keySet())); } Optional rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex( @@ -205,7 +215,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) // Sanity check that we can trace the path for the index lookup key Map trace = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinVariables)); - checkState(!trace.isEmpty() && rightJoinVariables.containsAll(trace.keySet())); + checkState(!trace.isEmpty() && ImmutableSet.copyOf(rightJoinVariables).containsAll(trace.keySet())); } switch (node.getType()) { @@ -222,7 +232,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) createEquiJoinClause(leftJoinVariables, rightJoinVariables), node.getFilter(), Optional.empty(), - Optional.empty()); + Optional.empty(), + ImmutableList.copyOf(rightJoinVariables)); } else if (leftIndexCandidate.isPresent()) { indexJoinNode = new IndexJoinNode( @@ -234,7 +245,8 @@ else if (leftIndexCandidate.isPresent()) { createEquiJoinClause(rightJoinVariables, leftJoinVariables), node.getFilter(), Optional.empty(), - Optional.empty()); + Optional.empty(), + ImmutableList.copyOf(leftJoinVariables)); } if (indexJoinNode != null) { @@ -264,6 +276,7 @@ else if (leftIndexCandidate.isPresent()) { rightIndexCandidate.get(), createEquiJoinClause(leftJoinVariables, rightJoinVariables), node.getFilter(), + ImmutableList.copyOf(rightJoinVariables), idAllocator); } break; @@ -278,6 +291,7 @@ else if (leftIndexCandidate.isPresent()) { leftIndexCandidate.get(), createEquiJoinClause(rightJoinVariables, leftJoinVariables), node.getFilter(), + ImmutableList.copyOf(leftJoinVariables), idAllocator); } break; @@ -310,10 +324,10 @@ else if (leftIndexCandidate.isPresent()) { } } - private static class NativeRewriter + private static class NativeIndexJoinRewriter extends IndexJoinRewriter { - private NativeRewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) + private NativeIndexJoinRewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) { super(idAllocator, metadata, session); } @@ -324,111 +338,166 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode leftRewritten = context.rewrite(node.getLeft()); PlanNode rightRewritten = context.rewrite(node.getRight()); - if (!node.getCriteria().isEmpty()) { // Index join only possible with JOIN criteria - List leftJoinVariables = Lists.transform(node.getCriteria(), EquiJoinClause::getLeft); - List rightJoinVariables = Lists.transform(node.getCriteria(), EquiJoinClause::getRight); + StandardFunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + LookupVariableExtractor.Context leftExtractorContext = new LookupVariableExtractor.Context(new HashSet<>(), functionResolution); + LookupVariableExtractor.Context rightExtractorContext = new LookupVariableExtractor.Context(new HashSet<>(), functionResolution); + + Set leftLookupVariables = leftExtractorContext.getLookupVariables(); + Set rightLookupVariables = rightExtractorContext.getLookupVariables(); + + // Extract non-equal join keys. + if (node.getFilter().isPresent()) { + LookupVariableExtractor.Context filterExtractorContext = new LookupVariableExtractor.Context(new HashSet<>(), functionResolution); + LookupVariableExtractor.extractFromFilter(node.getFilter().get(), filterExtractorContext); + if (filterExtractorContext.isEligible()) { + for (VariableReferenceExpression variableExpression : filterExtractorContext.getLookupVariables()) { + if (node.getLeft().getOutputVariables().contains(variableExpression)) { + leftLookupVariables.add(variableExpression); + } + if (node.getRight().getOutputVariables().contains(variableExpression)) { + rightLookupVariables.add(variableExpression); + } + } + } + else { + return node; + } + } - Optional leftIndexCandidate = IndexSourceRewriter.rewriteWithIndex( + // Extract equal Join keys. + List leftEqualJoinVariables; + List rightEqualJoinVariables; + if (node.getCriteria().isEmpty()) { + leftEqualJoinVariables = ImmutableList.of(); + rightEqualJoinVariables = ImmutableList.of(); + } + else { + leftEqualJoinVariables = node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); + rightEqualJoinVariables = node.getCriteria().stream().map(EquiJoinClause::getRight).collect(toImmutableList()); + leftLookupVariables.addAll(leftEqualJoinVariables); + rightLookupVariables.addAll(rightEqualJoinVariables); + } + + // Extract Join keys from pushed-down filters. + LookupVariableExtractor.extractFromSubPlan(node.getLeft(), leftExtractorContext); + LookupVariableExtractor.extractFromSubPlan(node.getRight(), rightExtractorContext); + + Optional leftIndexCandidate; + if (leftExtractorContext.isEligible() && !leftExtractorContext.getLookupVariables().isEmpty()) { + leftIndexCandidate = IndexSourceRewriter.rewriteWithIndex( leftRewritten, - ImmutableSet.copyOf(leftJoinVariables), + leftLookupVariables, idAllocator, metadata, session); - if (leftIndexCandidate.isPresent()) { - // Sanity check that we can trace the path for the index lookup key - Map trace - = IndexKeyTracer.trace(leftIndexCandidate.get(), ImmutableSet.copyOf(leftJoinVariables)); - checkState(!trace.isEmpty() && leftJoinVariables.containsAll(trace.keySet())); - } + } + else { + leftIndexCandidate = Optional.empty(); + } - Optional rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex( + if (leftIndexCandidate.isPresent()) { + // Sanity check that we can trace the path for the index lookup key + Map trace = IndexKeyTracer.trace(leftIndexCandidate.get(), leftLookupVariables); + checkState(!trace.isEmpty() && leftLookupVariables.containsAll(trace.keySet())); + } + + Optional rightIndexCandidate; + if (rightExtractorContext.isEligible() && !rightExtractorContext.getLookupVariables().isEmpty()) { + rightIndexCandidate = IndexSourceRewriter.rewriteWithIndex( rightRewritten, - ImmutableSet.copyOf(rightJoinVariables), + rightLookupVariables, idAllocator, metadata, session); - if (rightIndexCandidate.isPresent()) { - // Sanity check that we can trace the path for the index lookup key - Map trace - = IndexKeyTracer.trace(rightIndexCandidate.get(), ImmutableSet.copyOf(rightJoinVariables)); - checkState(!trace.isEmpty() && rightJoinVariables.containsAll(trace.keySet())); - } - - switch (node.getType()) { - // Only INNER and LEFT joins are supported on native. - case INNER: - // Prefer the right candidate over the left candidate - PlanNode indexJoinNode = null; - if (rightIndexCandidate.isPresent()) { - indexJoinNode = new IndexJoinNode( - leftRewritten.getSourceLocation(), - idAllocator.getNextId(), - JoinType.INNER, - leftRewritten, - rightIndexCandidate.get(), - createEquiJoinClause(leftJoinVariables, rightJoinVariables), - node.getFilter(), - Optional.empty(), - Optional.empty()); - } - else if (leftIndexCandidate.isPresent()) { - indexJoinNode = new IndexJoinNode( - rightRewritten.getSourceLocation(), - idAllocator.getNextId(), - JoinType.INNER, - rightRewritten, - leftIndexCandidate.get(), - createEquiJoinClause(rightJoinVariables, leftJoinVariables), - node.getFilter(), - Optional.empty(), - Optional.empty()); - } - - if (indexJoinNode != null) { - if (!indexJoinNode.getOutputVariables().equals(node.getOutputVariables())) { - indexJoinNode = new ProjectNode( - idAllocator.getNextId(), - indexJoinNode, - identityAssignments(node.getOutputVariables())); - } - - planChanged = true; - return indexJoinNode; - } - break; + } + else { + rightIndexCandidate = Optional.empty(); + } - case LEFT: - if (rightIndexCandidate.isPresent()) { - planChanged = true; - return createIndexJoinWithExpectedOutputs( - node.getOutputVariables(), - leftRewritten, - rightIndexCandidate.get(), - createEquiJoinClause(leftJoinVariables, rightJoinVariables), - node.getFilter(), - idAllocator); - } - break; + if (rightIndexCandidate.isPresent()) { + // Sanity check that we can trace the path for the index lookup key + Map trace = IndexKeyTracer.trace(rightIndexCandidate.get(), rightLookupVariables); + checkState(!trace.isEmpty() && rightLookupVariables.containsAll(trace.keySet())); + } - case RIGHT: - if (leftIndexCandidate.isPresent()) { - planChanged = true; - return createIndexJoinWithExpectedOutputs( - node.getOutputVariables(), - rightRewritten, - leftIndexCandidate.get(), - createEquiJoinClause(rightJoinVariables, leftJoinVariables), - node.getFilter(), - idAllocator); + switch (node.getType()) { + // Only INNER and LEFT joins are supported on native. + case INNER: + // Prefer the right candidate over the left candidate + PlanNode indexJoinNode = null; + if (rightIndexCandidate.isPresent()) { + indexJoinNode = new IndexJoinNode( + leftRewritten.getSourceLocation(), + idAllocator.getNextId(), + JoinType.INNER, + leftRewritten, + rightIndexCandidate.get(), + createEquiJoinClause(leftEqualJoinVariables, rightEqualJoinVariables), + node.getFilter(), + Optional.empty(), + Optional.empty(), + ImmutableList.copyOf(rightLookupVariables)); + } + else if (leftIndexCandidate.isPresent()) { + indexJoinNode = new IndexJoinNode( + rightRewritten.getSourceLocation(), + idAllocator.getNextId(), + JoinType.INNER, + rightRewritten, + leftIndexCandidate.get(), + createEquiJoinClause(rightEqualJoinVariables, leftEqualJoinVariables), + node.getFilter(), + Optional.empty(), + Optional.empty(), + ImmutableList.copyOf(leftLookupVariables)); + } + + if (indexJoinNode != null) { + if (!indexJoinNode.getOutputVariables().equals(node.getOutputVariables())) { + indexJoinNode = new ProjectNode( + idAllocator.getNextId(), + indexJoinNode, + identityAssignments(node.getOutputVariables())); } - break; - case FULL: - break; - - default: - throw new IllegalArgumentException("Unknown type: " + node.getType()); - } + planChanged = true; + return indexJoinNode; + } + break; + + case LEFT: + if (rightIndexCandidate.isPresent()) { + planChanged = true; + return createIndexJoinWithExpectedOutputs( + node.getOutputVariables(), + leftRewritten, + rightIndexCandidate.get(), + createEquiJoinClause(leftEqualJoinVariables, rightEqualJoinVariables), + node.getFilter(), + ImmutableList.copyOf(rightLookupVariables), + idAllocator); + } + break; + + case RIGHT: + if (leftIndexCandidate.isPresent()) { + planChanged = true; + return createIndexJoinWithExpectedOutputs( + node.getOutputVariables(), + rightRewritten, + leftIndexCandidate.get(), + createEquiJoinClause(rightEqualJoinVariables, leftEqualJoinVariables), + node.getFilter(), + ImmutableList.copyOf(leftLookupVariables), + idAllocator); + } + break; + + case FULL: + break; + + default: + throw new IllegalArgumentException("Unknown type: " + node.getType()); } if (leftRewritten != node.getLeft() || rightRewritten != node.getRight()) { @@ -512,7 +581,10 @@ private PlanNode planTableScan(TableScanNode node, RowExpression predicate, Cont .transform(variableName -> node.getAssignments().get(variableName)) .intersect(node.getEnforcedConstraint()); - checkState(node.getOutputVariables().containsAll(context.getLookupVariables())); + if (!ImmutableSet.copyOf(node.getOutputVariables()).containsAll(context.getLookupVariables())) { + // Lookup variable is not present in the plan. Index join is not possible. + return node; + } Set lookupColumns = context.getLookupVariables().stream() .map(variable -> node.getAssignments().get(variable)) @@ -545,7 +617,7 @@ private PlanNode planTableScan(TableScanNode node, RowExpression predicate, Cont decomposedPredicate.getRemainingExpression()); if (!resultingPredicate.equals(TRUE_CONSTANT)) { - // todo it is likely we end up with redundant filters here because the predicate push down has already been run... the fix is to run predicate push down again + // TODO: it is likely we end up with redundant filters here because the predicate push down has already been run... the fix is to run predicate push down again source = new FilterNode(source.getSourceLocation(), idAllocator.getNextId(), source, resultingPredicate); } context.markSuccess(); @@ -555,11 +627,34 @@ private PlanNode planTableScan(TableScanNode node, RowExpression predicate, Cont @Override public PlanNode visitProject(ProjectNode node, RewriteContext context) { + if (isNativeExecutionEnabled(session)) { + // Preserve the lookup variables for native execution. + ProjectNode rewrittenNode = (ProjectNode) context.defaultRewrite(node, context.get()); + Set directVariables = Maps.filterValues(node.getAssignments().getMap(), IndexJoinOptimizer::isVariable).keySet(); + if (!directVariables.containsAll(context.get().getLookupVariables())) { + Assignments.Builder newAssignments = Assignments.builder(); + for (VariableReferenceExpression variable : context.get().getLookupVariables()) { + newAssignments.put(variable, variable); + } + for (Map.Entry entry : node.getAssignments().entrySet()) { + if (!context.get().lookupVariables.contains(entry.getKey())) { + newAssignments.put(entry); + } + } + return new ProjectNode(rewrittenNode.getSourceLocation(), + rewrittenNode.getId(), + rewrittenNode.getStatsEquivalentPlanNode(), + rewrittenNode.getSource(), + newAssignments.build(), + rewrittenNode.getLocality()); + } + return rewrittenNode; + } // Rewrite the lookup variables in terms of only the pre-projected variables that have direct translations ImmutableSet.Builder newLookupVariablesBuilder = ImmutableSet.builder(); for (VariableReferenceExpression variable : context.get().getLookupVariables()) { RowExpression expression = node.getAssignments().get(variable); - if (expression instanceof VariableReferenceExpression) { + if (isVariable(expression)) { newLookupVariablesBuilder.add((VariableReferenceExpression) expression); } } @@ -645,7 +740,8 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext conte node.getCriteria(), node.getFilter(), node.getProbeHashVariable(), - node.getIndexHashVariable()); + node.getIndexHashVariable(), + node.getLookupVariables()); } return source; @@ -703,6 +799,130 @@ public void markSuccess() } } + public static class LookupVariableExtractor + extends SimplePlanVisitor + { + public static class Context + { + private final Set lookupVariables; + private final StandardFunctionResolution standardFunctionResolution; + private boolean eligible = true; + + public Context(Set lookupVariables, StandardFunctionResolution standardFunctionResolution) + { + this.lookupVariables = requireNonNull(lookupVariables, "lookupVariables is null"); + this.standardFunctionResolution = requireNonNull(standardFunctionResolution, "standardFunctionResolution is null"); + } + + public Set getLookupVariables() + { + return lookupVariables; + } + + public StandardFunctionResolution getStandardFunctionResolution() + { + return standardFunctionResolution; + } + + public boolean isEligible() + { + return eligible; + } + + public void markIneligible() + { + eligible = false; + } + + @Override + public String toString() + { + return "Context{" + + "lookupVariables=" + lookupVariables + + ", eligible=" + eligible + + '}'; + } + } + + // Traverse the non-equal join condition and extract the lookup variables. + private static void extractFromFilter(RowExpression expression, Context context) + { + List conjuncts = extractConjuncts(expression); + for (RowExpression conjunct : conjuncts) { + // Index lookup condition only supports Equal, BETWEEN and CONTAINS. + if (!(conjunct instanceof CallExpression)) { + continue; + } + + CallExpression callExpression = (CallExpression) conjunct; + if (context.getStandardFunctionResolution().isEqualsFunction(callExpression.getFunctionHandle()) + && callExpression.getArguments().size() == 2) { + RowExpression leftArg = callExpression.getArguments().get(0); + RowExpression rightArg = callExpression.getArguments().get(1); + + VariableReferenceExpression variable = null; + // Check for pattern: constant = variable or variable = constant. + if (isConstant(leftArg) && isVariable(rightArg)) { + variable = (VariableReferenceExpression) rightArg; + } + else if (isVariable(leftArg) && isConstant(rightArg)) { + variable = (VariableReferenceExpression) leftArg; + } + + if (variable != null) { + // It is a lookup equal condition only when it's variable=constant. + context.getLookupVariables().add(variable); + } + } + else if (context.getStandardFunctionResolution().isBetweenFunction(callExpression.getFunctionHandle()) + && isVariable(callExpression.getArguments().get(0))) { + context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(0)); + } + else if (callExpression.getDisplayName().equalsIgnoreCase("CONTAINS") + && callExpression.getArguments().size() == 2 + && isVariable(callExpression.getArguments().get(1))) { + context.getLookupVariables().add((VariableReferenceExpression) callExpression.getArguments().get(1)); + } + } + } + + public static void extractFromSubPlan(PlanNode node, Context context) + { + node.accept(new LookupVariableExtractor(), context); + } + + @Override + public Void visitPlan(PlanNode node, Context context) + { + // Node isn't expected to be part of the index pipeline. + context.markIneligible(); + return null; + } + + @Override + public Void visitProject(ProjectNode node, Context context) + { + node.getSource().accept(this, context); + return null; + } + + @Override + public Void visitFilter(FilterNode node, Context context) + { + if (node.getSource() instanceof TableScanNode) { + extractFromFilter(node.getPredicate(), context); + return null; + } + return node.getSource().accept(this, context); + } + + @Override + public Void visitTableScan(TableScanNode node, Context context) + { + return null; + } + } + /** * Identify the mapping from the lookup variables used at the top of the index plan to * the actual variables produced by the IndexSource. Note that multiple top-level lookup variables may share the same @@ -730,7 +950,7 @@ public Map visitProjec { // Map from output Variables to source Variables Map directVariableTranslationOutputMap = Maps.transformValues( - Maps.filterValues(node.getAssignments().getMap(), IndexKeyTracer::isVariable), + Maps.filterValues(node.getAssignments().getMap(), IndexJoinOptimizer::isVariable), VariableReferenceExpression.class::cast); Map outputToSourceMap = lookupVariables.stream() .filter(directVariableTranslationOutputMap.keySet()::contains) @@ -796,10 +1016,15 @@ public Map visitIndexS return lookupVariables.stream().collect(toImmutableMap(identity(), identity())); } } + } - private static boolean isVariable(RowExpression expression) - { - return expression instanceof VariableReferenceExpression; - } + private static boolean isVariable(RowExpression expression) + { + return expression instanceof VariableReferenceExpression; + } + + private static boolean isConstant(RowExpression expression) + { + return expression instanceof ConstantExpression; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java index c590bf6a6cb1f..1f3581d29233f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; @@ -27,7 +26,6 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.SemiJoinNode; -import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; @@ -42,7 +40,6 @@ import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isJoinPrefilterEnabled; -import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; @@ -50,13 +47,11 @@ import static com.facebook.presto.spi.plan.JoinType.LEFT; import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode; +import static com.facebook.presto.sql.planner.PlannerUtils.getVariableHash; import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject; -import static com.facebook.presto.sql.planner.PlannerUtils.orNullHashCode; import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions; import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; -import static com.facebook.presto.sql.relational.Expressions.call; -import static com.facebook.presto.sql.relational.Expressions.callOperator; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -208,7 +203,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode leftKeys = clonePlanNode(rewrittenLeft, session, metadata, idAllocator, leftKeyList, leftVarMap); ImmutableList.Builder expressionsToProject = ImmutableList.builder(); if (hashJoinKey) { - RowExpression hashExpression = getVariableHash(leftKeyList); + RowExpression hashExpression = getVariableHash(leftKeyList, functionAndTypeManager); expressionsToProject.add(hashExpression); } else { @@ -218,7 +213,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) VariableReferenceExpression rightKeyToFilter = rightKeyList.get(0); if (hashJoinKey) { - RowExpression hashExpression = getVariableHash(rightKeyList); + RowExpression hashExpression = getVariableHash(rightKeyList, functionAndTypeManager); rightKeyToFilter = variableAllocator.newVariable(hashExpression); rewrittenRight = addProjections(rewrittenRight, idAllocator, ImmutableMap.of(rightKeyToFilter, hashExpression)); } @@ -273,19 +268,5 @@ public boolean isPlanChanged() { return planChanged; } - - private RowExpression getVariableHash(List inputVariables) - { - List hashExpressionList = inputVariables.stream().map(keyVariable -> - callOperator(functionAndTypeManager.getFunctionAndTypeResolver(), OperatorType.XX_HASH_64, BIGINT, keyVariable)).collect(toImmutableList()); - RowExpression hashExpression = hashExpressionList.get(0); - if (hashExpressionList.size() > 1) { - hashExpression = orNullHashCode(hashExpression); - for (int i = 1; i < hashExpressionList.size(); ++i) { - hashExpression = call(functionAndTypeManager, "combine_hash", BIGINT, hashExpression, orNullHashCode(hashExpressionList.get(i))); - } - } - return hashExpression; - } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java index f34359aeba6af..1c7856a62caea 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/KeyBasedSampler.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.Varchars; @@ -55,7 +54,6 @@ import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardWarningCode.SAMPLED_FIELDS; @@ -150,7 +148,7 @@ private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional List> sorted(Collection columns, SortOrder return columns.stream().map(column -> new SortingProperty<>(column, order)).collect(toImmutableList()); } + public static List> unique(T column) + { + return ImmutableList.of(new UniqueProperty<>(column)); + } + public static List> stripLeadingConstants(List> properties) { PeekingIterator> iterator = peekingIterator(properties.iterator()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java index fc1735f7e7e70..f8edfeb37e69b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java @@ -20,19 +20,24 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Optional; import static com.facebook.presto.SystemSessionProperties.isGroupedExecutionEnabled; import static com.facebook.presto.SystemSessionProperties.isSingleNodeExecutionEnabled; import static com.facebook.presto.SystemSessionProperties.preferMergeJoinForSortedInputs; import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; -import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -41,12 +46,14 @@ public class MergeJoinForSortedInputOptimizer { private final Metadata metadata; private final boolean nativeExecution; + private final boolean prestoOnSpark; private boolean isEnabledForTesting; - public MergeJoinForSortedInputOptimizer(Metadata metadata, boolean nativeExecution) + public MergeJoinForSortedInputOptimizer(Metadata metadata, boolean nativeExecution, boolean prestoOnSpark) { this.metadata = requireNonNull(metadata, "metadata is null"); this.nativeExecution = nativeExecution; + this.prestoOnSpark = prestoOnSpark; } @Override @@ -58,7 +65,7 @@ public void setEnabledForTesting(boolean isSet) @Override public boolean isEnabled(Session session) { - return isEnabledForTesting || isGroupedExecutionEnabled(session) && preferMergeJoinForSortedInputs(session) && !isSingleNodeExecutionEnabled(session); + return isEnabledForTesting || nativeExecution && (isGroupedExecutionEnabled(session) || prestoOnSpark) && preferMergeJoinForSortedInputs(session) && !isSingleNodeExecutionEnabled(session); } @Override @@ -70,7 +77,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider requireNonNull(idAllocator, "idAllocator is null"); if (isEnabled(session)) { - Rewriter rewriter = new MergeJoinForSortedInputOptimizer.Rewriter(variableAllocator, idAllocator, metadata, session); + Rewriter rewriter = new MergeJoinForSortedInputOptimizer.Rewriter(idAllocator, metadata, session, prestoOnSpark); PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); } @@ -83,15 +90,15 @@ private class Rewriter private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final Session session; - private final TypeProvider types; + private final boolean prestoOnSpark; private boolean planChanged; - private Rewriter(VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) + private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session, boolean prestoOnSpark) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); - this.types = TypeProvider.viewOf(variableAllocator.getVariables()); + this.prestoOnSpark = prestoOnSpark; } public boolean isPlanChanged() @@ -102,62 +109,63 @@ public boolean isPlanChanged() @Override public PlanNode visitJoin(JoinNode node, RewriteContext context) { - // As of now, we only support inner join for merge join - if (node.getType() != INNER) { - return node; + PlanNode rewrittenLeft = node.getLeft().accept(this, context); + PlanNode rewrittenRight = node.getRight().accept(this, context); + + boolean leftInputSorted = isPlanOutputSortedByColumns(rewrittenLeft, node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList())); + boolean rightInputSorted = isPlanOutputSortedByColumns(rewrittenRight, node.getCriteria().stream().map(EquiJoinClause::getRight).collect(toImmutableList())); + + if ((!leftInputSorted && !rightInputSorted) || (!prestoOnSpark && (!leftInputSorted || !rightInputSorted))) { + return replaceChildren(node, ImmutableList.of(rewrittenLeft, rewrittenRight)); } - // Fast path merge join optimization (no sort, no local merge) - - // For example: when we have a plan that looks like: - // JoinNode - //- TableScanA - //- TableScanB - - // We check the data properties of TableScanA and TableScanB to see if they meet requirements for merge join: - // 1. If so, we replace the JoinNode to MergeJoinNode - // MergeJoinNode - //- TableScanA - //- TableScanB - - // 2. If not, we don't optimize - if (meetsDataRequirement(node.getLeft(), node.getRight(), node)) { - planChanged = true; - return new MergeJoinNode( - node.getSourceLocation(), - node.getId(), - node.getType(), - node.getLeft(), - node.getRight(), - node.getCriteria(), - node.getOutputVariables(), - node.getFilter(), - node.getLeftHashVariable(), - node.getRightHashVariable()); + if (!leftInputSorted) { + List leftOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getLeft(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + rewrittenLeft = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + rewrittenLeft, + new OrderingScheme(leftOrdering), + true, + ImmutableList.of()); } - return node; + if (!rightInputSorted) { + List rightOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getRight(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + rewrittenRight = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + rewrittenRight, + new OrderingScheme(rightOrdering), + true, + ImmutableList.of()); + } + planChanged = true; + return new MergeJoinNode( + node.getSourceLocation(), + node.getId(), + node.getType(), + rewrittenLeft, + rewrittenRight, + node.getCriteria(), + node.getOutputVariables(), + node.getFilter(), + Optional.empty(), + Optional.empty()); } - private boolean meetsDataRequirement(PlanNode left, PlanNode right, JoinNode node) + private boolean isPlanOutputSortedByColumns(PlanNode plan, List columns) { - // Acquire data properties for both left and right side - StreamPropertyDerivations.StreamProperties leftProperties = StreamPropertyDerivations.derivePropertiesRecursively(left, metadata, session, nativeExecution); - StreamPropertyDerivations.StreamProperties rightProperties = StreamPropertyDerivations.derivePropertiesRecursively(right, metadata, session, nativeExecution); - - List leftJoinColumns = node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); - List rightJoinColumns = node.getCriteria().stream() - .map(EquiJoinClause::getRight) - .collect(toImmutableList()); - - // Check if both the left side and right side's partitioning columns (bucketed-by columns [B]) are a subset of join columns [J] - // B = subset (J) - if (!verifyStreamProperties(leftProperties, leftJoinColumns) || !verifyStreamProperties(rightProperties, rightJoinColumns)) { + StreamPropertyDerivations.StreamProperties properties = StreamPropertyDerivations.derivePropertiesRecursively(plan, metadata, session, nativeExecution); + + if (!verifyStreamProperties(properties, columns)) { return false; } - // Check if the left side and right side are both ordered by the join columns - return !LocalProperties.match(rightProperties.getLocalProperties(), LocalProperties.sorted(rightJoinColumns, ASC_NULLS_FIRST)).get(0).isPresent() && - !LocalProperties.match(leftProperties.getLocalProperties(), LocalProperties.sorted(leftJoinColumns, ASC_NULLS_FIRST)).get(0).isPresent(); + return !LocalProperties.match(properties.getLocalProperties(), LocalProperties.sorted(columns, ASC_NULLS_FIRST)).get(0).isPresent(); } private boolean verifyStreamProperties(StreamPropertyDerivations.StreamProperties streamProperties, List joinColumns) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java index 1c6b974853eda..3db7dc66cc6e4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MergePartialAggregationsWithFilter.java @@ -36,12 +36,16 @@ import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter; +import static com.facebook.presto.expressions.LogicalRowExpressions.or; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL; @@ -123,11 +127,13 @@ private static class Context { private final Map partialResultToMask; private final Map partialOutputMapping; + private final List newAggregationOutput; public Context() { partialResultToMask = new HashMap<>(); partialOutputMapping = new HashMap<>(); + newAggregationOutput = new LinkedList<>(); } public boolean isEmpty() @@ -139,6 +145,7 @@ public void clear() { partialResultToMask.clear(); partialOutputMapping.clear(); + newAggregationOutput.clear(); } public Map getPartialOutputMapping() @@ -150,6 +157,11 @@ public Map getPartialR { return partialResultToMask; } + + public List getNewAggregationOutput() + { + return newAggregationOutput; + } } private static class Rewriter @@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) { private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext context) { checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node"); + Map aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream() .filter(x -> !x.getValue().getMask().isPresent()) - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey(), (a, b) -> a)); + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a)); Map aggregationsToMergeOutput = node.getAggregations().entrySet().stream() .filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue()))) - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey())); + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + + ImmutableMap.Builder partialAggregationToOutputBuilder = ImmutableMap.builder(); + partialAggregationToOutputBuilder.putAll(aggregationsToMergeOutput.keySet().stream().collect(toImmutableMap(Function.identity(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x))))); + + List> candidateAggregationsWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue) + .filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x)) + .collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values() + .stream().filter(x -> x.size() > 1).collect(toImmutableList()); + + Map aggregationsWithMaskToMerge = node.getAggregations().entrySet().stream() + .filter(x -> aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue()))) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); + ImmutableMap.Builder newMaskAssignmentsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder aggregationsAddedBuilder = ImmutableMap.builder(); + List newAggregationAdded = candidateAggregationsWithMaskNotMatched.stream() + .map(aggregations -> + { + List maskVariables = aggregations.stream().map(x -> x.getMask().get()).collect(toImmutableList()); + RowExpression orMaskVariables = or(maskVariables); + VariableReferenceExpression newMaskVariable = variableAllocator.newVariable(orMaskVariables); + newMaskAssignmentsBuilder.put(newMaskVariable, orMaskVariables); + AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation( + aggregations.get(0).getCall(), + Optional.empty(), + aggregations.get(0).getOrderBy(), + aggregations.get(0).isDistinct(), + Optional.of(newMaskVariable)); + VariableReferenceExpression newAggregationVariable = variableAllocator.newVariable(newAggregation.getCall()); + aggregationsAddedBuilder.put(newAggregationVariable, newAggregation); + aggregations.forEach(x -> partialAggregationToOutputBuilder.put(x, newAggregationVariable)); + return newAggregation; + }) + .collect(toImmutableList()); + Map newMaskAssignments = newMaskAssignmentsBuilder.build(); + Map aggregationsAdded = aggregationsAddedBuilder.build(); + Map partialAggregationToOutput = partialAggregationToOutputBuilder.build(); + + Map aggregationsToMergeOutputCombined = + node.getAggregations().entrySet().stream() + .filter(x -> x.getValue().getMask().isPresent() && aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue()))) + .collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey)); - context.get().getPartialResultToMask().putAll(aggregationsToMergeOutput.entrySet().stream() - .collect(toImmutableMap(x -> x.getValue(), x -> x.getKey().getMask().get()))); - context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutput.entrySet().stream() - .collect(toImmutableMap(x -> x.getValue(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x.getKey()))))); + context.get().getNewAggregationOutput().addAll(aggregationsAdded.keySet()); + context.get().getPartialResultToMask().putAll(aggregationsWithMaskToMerge.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get()))); + context.get().getPartialOutputMapping().putAll(aggregationsWithMaskToMerge.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getValue, x -> partialAggregationToOutput.get(x.getKey())))); Set maskVariables = new HashSet<>(context.get().getPartialResultToMask().values()); if (maskVariables.isEmpty()) { @@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor( groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets()); - Set partialResultToMerge = new HashSet<>(aggregationsToMergeOutput.values()); - Map newAggregations = node.getAggregations().entrySet().stream() + Set partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values()); + Map aggregationsRemained = node.getAggregations().entrySet().stream() .filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + Map newAggregations = ImmutableMap.builder() + .putAll(aggregationsRemained).putAll(aggregationsAdded).build(); + + PlanNode newChild = rewrittenSource; + if (!newMaskAssignments.isEmpty()) { + newChild = addProjections(newChild, planNodeIdAllocator, newMaskAssignments); + } return new AggregationNode( node.getSourceLocation(), node.getId(), - rewrittenSource, + newChild, newAggregations, partialGroupingSetDescriptor, node.getPreGroupedVariables(), @@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource)); } List intermediateVariables = node.getAggregations().values().stream() - .map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(Collectors.toList()); + .map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(toImmutableList()); checkState(intermediateVariables.containsAll(context.get().partialResultToMask.keySet())); ImmutableList.Builder projectionsFromPartialAgg = ImmutableList.builder(); @@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); assignments.putAll(excludeMergedAssignments); assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values())); + assignments.putAll(identityAssignments(context.get().getNewAggregationOutput())); return new ProjectNode( node.getSourceLocation(), node.getId(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java index 2d132f80d6842..58ab23f719c4d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataDeleteOptimizer.java @@ -15,22 +15,25 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.DeleteNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.google.common.collect.Iterables; import java.util.List; import java.util.Optional; +import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static java.util.Objects.requireNonNull; /** @@ -101,6 +104,13 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont return context.defaultRewrite(node); } + // Check for remaining predicates that require row-level filtering. + // The remainingPredicate contains filters that couldn't be pushed down to the connector, + // such as non-deterministic functions (RAND(), UUID(), etc.) or complex expressions. + if (hasRemainingPredicates(tableScanNode)) { + return context.defaultRewrite(node); + } + planChanged = true; return new MetadataDeleteNode( node.getSourceLocation(), @@ -109,6 +119,22 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont Iterables.getOnlyElement(node.getOutputVariables())); } + private boolean hasRemainingPredicates(TableScanNode tableScanNode) + { + if (!tableScanNode.getTable().getLayout().isPresent()) { + return false; + } + + TableLayout tableLayout = metadata.getLayout(session, tableScanNode.getTable()); + + Optional remainingPredicate = tableLayout.getRemainingPredicate(); + if (remainingPredicate.isPresent() && !TRUE_CONSTANT.equals(remainingPredicate.get())) { + return true; + } + + return false; + } + private static Optional findNode(PlanNode source, Class clazz) { while (true) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java index d3737dd73be2a..52890f8a31d89 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -46,6 +46,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; @@ -64,9 +65,10 @@ import java.util.Set; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; -import static com.facebook.presto.sql.planner.RowExpressionInterpreter.evaluateConstantRowExpression; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED; import static com.facebook.presto.sql.relational.Expressions.call; import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -81,12 +83,13 @@ public class MetadataQueryOptimizer private final Set allowedFunctions; private final Map aggregationScalarMapping; private final Metadata metadata; + private final ExpressionOptimizerManager expressionOptimizerManager; - public MetadataQueryOptimizer(Metadata metadata) + public MetadataQueryOptimizer(Metadata metadata, ExpressionOptimizerManager expressionOptimizerManager) { - requireNonNull(metadata, "metadata is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.expressionOptimizerManager = requireNonNull(expressionOptimizerManager, "expressionOptimizerManager is null"); - this.metadata = metadata; CatalogSchemaName defaultNamespace = metadata.getFunctionAndTypeManager().getDefaultNamespace(); this.allowedFunctions = ImmutableSet.of( QualifiedObjectName.valueOf(defaultNamespace, "max"), @@ -104,7 +107,7 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider if (!SystemSessionProperties.isOptimizeMetadataQueries(session) && !SystemSessionProperties.isOptimizeMetadataQueriesIgnoreStats(session)) { return PlanOptimizerResult.optimizerResult(plan, false); } - Optimizer optimizer = new Optimizer(session, metadata, idAllocator); + Optimizer optimizer = new Optimizer(session, metadata, idAllocator, expressionOptimizerManager); PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(optimizer, plan, null); return PlanOptimizerResult.optimizerResult(rewrittenPlan, optimizer.isPlanChanged()); } @@ -130,8 +133,9 @@ private static class Optimizer private final int metastoreCallNumThreshold; private boolean planChanged; private final MetadataQueryOptimizer metadataQueryOptimizer; + private final ExpressionOptimizerManager expressionOptimizerManager; - private Optimizer(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator) + private Optimizer(Session session, Metadata metadata, PlanNodeIdAllocator idAllocator, ExpressionOptimizerManager expressionOptimizerManager) { this.session = session; this.metadata = metadata; @@ -139,7 +143,8 @@ private Optimizer(Session session, Metadata metadata, PlanNodeIdAllocator idAllo this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); this.ignoreMetadataStats = SystemSessionProperties.isOptimizeMetadataQueriesIgnoreStats(session); this.metastoreCallNumThreshold = SystemSessionProperties.getOptimizeMetadataQueriesCallThreshold(session); - this.metadataQueryOptimizer = new MetadataQueryOptimizer(metadata); + this.metadataQueryOptimizer = new MetadataQueryOptimizer(metadata, expressionOptimizerManager); + this.expressionOptimizerManager = expressionOptimizerManager; } public boolean isPlanChanged() @@ -374,15 +379,17 @@ private RowExpression evaluateMinMax(FunctionMetadata aggregationFunctionMetadat List reducedArguments = new ArrayList<>(); // We fold for every 100 values because GREATEST/LEAST has argument count limit for (List partitionedArguments : Lists.partition(arguments, 100)) { - Object reducedValue = evaluateConstantRowExpression( + RowExpression expression = expressionOptimizerManager.getExpressionOptimizer(connectorSession).optimize( call( metadata.getFunctionAndTypeManager(), scalarFunctionName, returnType, partitionedArguments), - metadata.getFunctionAndTypeManager(), - connectorSession); - reducedArguments.add(constant(reducedValue, returnType)); + EVALUATED, + connectorSession, + i -> i); + verify(expression instanceof ConstantExpression, "Expected constant expression"); + reducedArguments.add(expression); } arguments = reducedArguments; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 2cb459394a84c..1ef7e2642d2dc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -14,10 +14,10 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; @@ -51,7 +51,6 @@ import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; @@ -214,10 +213,10 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext variableToColumnMap = - assignTemporaryTableColumnNames(actualSource.getOutputVariables()); + assignTemporaryTableColumnNames(metadata, session, partitioningProviderCatalog, actualSource.getOutputVariables()); TableHandle temporaryTableHandle; try { temporaryTableHandle = metadata.createTemporaryTable( diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java index b5a4edaf0351f..88bda8083ab12 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PredicatePushDown.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; @@ -61,7 +62,6 @@ import com.facebook.presto.sql.planner.plan.GroupIdNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.Expressions; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; @@ -905,15 +905,18 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext scanNode = getTableScanNodeWithOnlyFilterAndProject(aggregationNode.getSource()); // Since we duplicate the source of the aggregation - we want to restrict it to simple scan/filter/project // so we can do this opportunistic optimization without too much latency/cpu overhead to support common BI usecases - if (scanNode.isPresent() && - !isAllLowCardinalityGroupByKeys(aggregationNode, scanNode.get(), session, statsCalculator, types, limitNode.getCount())) { + if (scanNode.isPresent()) { PlanNode rewrittenAggregation = addPrefilter(aggregationNode, limitNode.getCount()); if (rewrittenAggregation != aggregationNode) { planChanged = true; @@ -225,8 +223,8 @@ private PlanNode addPrefilter(AggregationNode aggregationNode, long count) SystemSessionProperties.getPrefilterForGroupbyLimitTimeoutMS(session)); FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); - RowExpression leftHashExpression = getHashExpression(functionAndTypeManager, keys).get(); - RowExpression rightHashExpression = getHashExpression(functionAndTypeManager, timedDistinctLimitNode.getOutputVariables()).get(); + RowExpression leftHashExpression = getVariableHash(keys, functionAndTypeManager); + RowExpression rightHashExpression = getVariableHash(timedDistinctLimitNode.getOutputVariables(), functionAndTypeManager); Type mapType = createMapType(functionAndTypeManager, BIGINT, BOOLEAN); PlanNode rightProjectNode = projectExpressions(timedDistinctLimitNode, idAllocator, variableAllocator, ImmutableList.of(rightHashExpression, constant(TRUE, BOOLEAN)), ImmutableList.of()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index d62c9e17bf1fc..7dc181ca661e0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -23,17 +23,21 @@ import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; +import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PartitioningHandle; @@ -46,6 +50,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.ConstantExpression; @@ -55,21 +60,24 @@ import com.facebook.presto.sql.planner.optimizations.ActualProperties.Global; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.relational.RowExpressionDomainTranslator; import com.google.common.collect.ImmutableBiMap; @@ -89,6 +97,7 @@ import static com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled; import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; +import static com.facebook.presto.SystemSessionProperties.isUtilizeUniquePropertyInQueryPlanningEnabled; import static com.facebook.presto.SystemSessionProperties.planWithTableNodePartitioning; import static com.facebook.presto.common.predicate.TupleDomain.toLinkedMap; import static com.facebook.presto.spi.relation.DomainTranslator.BASIC_COLUMN_EXTRACTOR; @@ -105,6 +114,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; public class PropertyDerivations @@ -141,6 +151,22 @@ public static ActualProperties streamBackdoorDeriveProperties(PlanNode node, Lis return node.accept(new Visitor(metadata, session), inputProperties); } + public static Optional uniqueToGroupProperties(ActualProperties properties) + { + // We only call uniqueToGroupProperties on derived properties from propertiesFromUniqueColumn, which can have one local property if the column is preserved in a node + // output, or no local property if the column is not preserved in a node output + checkArgument(properties.getLocalProperties().size() <= 1); + if (properties.getLocalProperties().isEmpty()) { + return Optional.empty(); + } + LocalProperty localProperty = Iterables.getOnlyElement(properties.getLocalProperties()); + if (localProperty instanceof UniqueProperty) { + return Optional.of(ActualProperties.builderFrom(properties).local(ImmutableList.of(new GroupingProperty<>(ImmutableList.of(((UniqueProperty) localProperty).getColumn())))).build()); + } + checkState(localProperty instanceof GroupingProperty, "returned actual properties should have grouping property"); + return Optional.of(properties); + } + private static class Visitor extends InternalPlanVisitor> { @@ -195,12 +221,14 @@ public ActualProperties visitAssignUniqueId(AssignUniqueId node, List inpu .build(); } + @Override + public ActualProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + ImmutableList.Builder> localProperties = ImmutableList.builder(); + + if (node.getSource().isPresent()) { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + // Only the partitioning properties of the source are passed-through, because the pass-through mechanism preserves the partitioning values. + // Sorting properties might be broken because input rows can be shuffled or nulls can be inserted as the result of pass-through. + // Constant properties might be broken because nulls can be inserted as the result of pass-through. + if (!node.getPrePartitioned().isEmpty()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitioned()); + for (LocalProperty localProperty : properties.getLocalProperties()) { + if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { + break; + } + localProperties.add(localProperty); + } + } + } + + List partitionBy = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .orElse(ImmutableList.of()); + if (!partitionBy.isEmpty()) { + localProperties.add(new GroupingProperty<>(partitionBy)); + } + + // TODO add global single stream property when there's Specification present with no partitioning columns + + return ActualProperties.builder() + .local(localProperties.build()) + .build() + // Crop properties to output columns. + .translateVariable(variable -> node.getOutputVariables().contains(variable) ? Optional.of(variable) : Optional.empty()); + } + @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { @@ -281,7 +353,10 @@ public ActualProperties visitGroupId(GroupIdNode node, List in inputToOutputMappings.putIfAbsent(argument, argument); } - return Iterables.getOnlyElement(inputProperties).translateVariable(column -> Optional.ofNullable(inputToOutputMappings.get(column))); + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + return ActualProperties.builderFrom(properties.translateVariable(column -> Optional.ofNullable(inputToOutputMappings.get(column)))) + .propertiesFromUniqueColumn(properties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x.translateVariable(column -> Optional.ofNullable(inputToOutputMappings.get(column)))))) + .build(); } @Override @@ -290,10 +365,9 @@ public ActualProperties visitAggregation(AggregationNode node, List node.getGroupingKeys().contains(variable) ? Optional.of(variable) : Optional.empty()); - return ActualProperties.builderFrom(translated) .local(LocalProperties.grouped(node.getGroupingKeys())) - .build(); + .propertiesFromUniqueColumn(uniqueProperties(translated.getPropertiesFromUniqueColumn())).build(); } @Override @@ -302,6 +376,14 @@ public ActualProperties visitRowNumber(RowNumberNode node, List uniqueProperties(Optional properties) + { + if (properties.isPresent() && properties.get().getLocalProperties().size() == 1 && properties.get().getLocalProperties().get(0) instanceof UniqueProperty) { + return properties; + } + return Optional.empty(); + } + @Override public ActualProperties visitTopNRowNumber(TopNRowNumberNode node, List inputProperties) { @@ -315,6 +397,7 @@ public ActualProperties visitTopNRowNumber(TopNRowNumberNode node, List inputPro return ActualProperties.builderFrom(properties) .local(localProperties) + .propertiesFromUniqueColumn(uniqueProperties(properties.getPropertiesFromUniqueColumn())) .build(); } @@ -343,6 +427,7 @@ public ActualProperties visitSort(SortNode node, List inputPro return ActualProperties.builderFrom(properties) .local(localProperties) + .propertiesFromUniqueColumn(uniqueProperties(properties.getPropertiesFromUniqueColumn())) .build(); } @@ -359,6 +444,7 @@ public ActualProperties visitDistinctLimit(DistinctLimitNode node, List inpu return Iterables.getOnlyElement(inputProperties).translateVariable(symbol -> Optional.empty()); } + @Override + public ActualProperties visitMergeWriter(MergeWriterNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + @Override + public ActualProperties visitMergeProcessor(MergeProcessorNode node, List inputProperties) + { + return Iterables.getOnlyElement(inputProperties).translateVariable(symbol -> Optional.empty()); + } + + @Override + public ActualProperties visitMetadataDelete(MetadataDeleteNode node, List inputProperties) + { + // MetadataDeleteNode runs on coordinator and produces row count + return ActualProperties.builder() + .global(coordinatorSingleStreamPartition()) + .build(); + } + @Override public ActualProperties visitJoin(JoinNode node, List inputProperties) { @@ -421,10 +528,12 @@ public ActualProperties visitJoin(JoinNode node, List inputPro return ActualProperties.builderFrom(probeProperties) .constants(constants) + .propertiesFromUniqueColumn(probeProperties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x.translateVariable(column -> filterOrRewrite(outputVariableReferences, node.getCriteria(), column))))) .unordered(unordered) .build(); case LEFT: return ActualProperties.builderFrom(probeProperties.translateVariable(column -> filterIfMissing(outputVariableReferences, column))) + .propertiesFromUniqueColumn(probeProperties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x.translateVariable(column -> filterIfMissing(outputVariableReferences, column))))) .unordered(unordered) .build(); case RIGHT: @@ -508,10 +617,12 @@ public ActualProperties visitIndexJoin(IndexJoinNode node, List uniqueToGroupProperties(x))) .build(); case SOURCE_OUTER: return ActualProperties.builderFrom(probeProperties) .constants(probeProperties.getConstants()) + .propertiesFromUniqueColumn(probeProperties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x))) .build(); default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); @@ -543,16 +654,19 @@ public ActualProperties visitMergeJoin(MergeJoinNode node, List uniqueToGroupProperties(x.translateVariable(column -> filterOrRewrite(outputVariableReferences, node.getCriteria(), column))))) .constants(constants) .build(); case LEFT: return ActualProperties.builderFrom(leftProperties.translateVariable(column -> filterIfMissing(outputVariableReferences, column))) + .propertiesFromUniqueColumn(leftProperties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x.translateVariable(column -> filterIfMissing(outputVariableReferences, column))))) .build(); case RIGHT: rightProperties = rightProperties.translateVariable(column -> filterIfMissing(node.getOutputVariables(), column)); return ActualProperties.builderFrom(rightProperties.translateVariable(column -> filterIfMissing(outputVariableReferences, column))) .local(ImmutableList.of()) + .propertiesFromUniqueColumn(rightProperties.getPropertiesFromUniqueColumn().flatMap(x -> uniqueToGroupProperties(x.translateVariable(column -> filterIfMissing(outputVariableReferences, column))))) .unordered(true) .build(); case FULL: @@ -588,6 +702,7 @@ public ActualProperties visitExchange(ExchangeNode node, List checkArgument(!node.getScope().isRemote() || inputProperties.stream().noneMatch(ActualProperties::isNullsAndAnyReplicated), "Null-and-any replicated inputs should not be remotely exchanged"); Set> entries = null; + ActualProperties translated = null; for (int sourceIndex = 0; sourceIndex < node.getSources().size(); sourceIndex++) { List inputVariables = node.getInputs().get(sourceIndex); Map inputToOutput = new HashMap<>(); @@ -595,7 +710,7 @@ public ActualProperties visitExchange(ExchangeNode node, List inputToOutput.put(inputVariables.get(i), node.getOutputVariables().get(i)); } - ActualProperties translated = inputProperties.get(sourceIndex).translateVariable(variable -> Optional.ofNullable(inputToOutput.get(variable))); + translated = inputProperties.get(sourceIndex).translateVariable(variable -> Optional.ofNullable(inputToOutput.get(variable))); entries = (entries == null) ? translated.getConstants().entrySet() : Sets.intersection(entries, translated.getConstants().entrySet()); } @@ -611,6 +726,8 @@ public ActualProperties visitExchange(ExchangeNode node, List .forEach(localProperties::add); } + boolean additionalPropertyIsUnique = inputProperties.size() == 1 && uniqueProperties(translated.getPropertiesFromUniqueColumn()).isPresent() && !node.getType().equals(ExchangeNode.Type.REPLICATE); + // Local exchanges are only created in AddLocalExchanges, at the end of optimization, and // local exchanges do not produce all global properties as represented by ActualProperties. // This is acceptable because AddLocalExchanges does not use global properties and is only @@ -630,6 +747,10 @@ else if (inputProperties.stream().anyMatch(ActualProperties::isSingleNode)) { builder.global(coordinatorSingleStreamPartition()); } + if (additionalPropertyIsUnique) { + builder.propertiesFromUniqueColumn(translated.getPropertiesFromUniqueColumn()); + } + return builder.build(); } @@ -640,6 +761,7 @@ else if (inputProperties.stream().anyMatch(ActualProperties::isSingleNode)) { .global(coordinatorOnly ? coordinatorSingleStreamPartition() : singleStreamPartition()) .local(localProperties.build()) .constants(constants) + .propertiesFromUniqueColumn(additionalPropertyIsUnique ? translated.getPropertiesFromUniqueColumn() : Optional.empty()) .build(); case REPARTITION: { Global globalPartitioning; @@ -656,6 +778,7 @@ else if (inputProperties.stream().anyMatch(ActualProperties::isSingleNode)) { return ActualProperties.builder() .global(globalPartitioning) .constants(constants) + .propertiesFromUniqueColumn(additionalPropertyIsUnique ? translated.getPropertiesFromUniqueColumn() : Optional.empty()) .build(); } case REPLICATE: @@ -681,6 +804,7 @@ public ActualProperties visitFilter(FilterNode node, List inpu return ActualProperties.builderFrom(properties) .constants(constants) + .propertiesFromUniqueColumn(properties.getPropertiesFromUniqueColumn()) .build(); } @@ -718,11 +842,32 @@ else if (!(value instanceof RowExpression)) { return ActualProperties.builderFrom(translatedProperties) .constants(constants) + .propertiesFromUniqueColumn(properties.getPropertiesFromUniqueColumn().map(x -> x.translateRowExpression(node.getAssignments().getMap()))) .build(); } @Override public ActualProperties visitTableWriter(TableWriterNode node, List inputProperties) + { + return visitPartitionedWriter(inputProperties); + } + + private ActualProperties visitPartitionedWriter(List inputProperties) + { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + if (properties.isCoordinatorOnly()) { + return ActualProperties.builder() + .global(coordinatorSingleStreamPartition()) + .build(); + } + return ActualProperties.builder() + .global(properties.isSingleNode() ? singleStreamPartition() : arbitraryPartition()) + .build(); + } + + @Override + public ActualProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, List inputProperties) { ActualProperties properties = Iterables.getOnlyElement(inputProperties); @@ -804,6 +949,13 @@ public ActualProperties visitTableScan(TableScanNode node, List Optional.ofNullable(assignments.get(column)))); + if (isUtilizeUniquePropertyInQueryPlanningEnabled(session) && layout.getUniqueColumn().isPresent() && assignments.containsKey(layout.getUniqueColumn().get())) { + VariableReferenceExpression uniqueVariable = assignments.get(layout.getUniqueColumn().get()); + ActualProperties.Builder propertiesFromUniqueColumn = ActualProperties.builder(); + propertiesFromUniqueColumn.global(partitionedOn(ARBITRARY_DISTRIBUTION, ImmutableList.of(uniqueVariable), Optional.empty())); + propertiesFromUniqueColumn.local(LocalProperties.unique(uniqueVariable)); + properties.propertiesFromUniqueColumn(Optional.of(propertiesFromUniqueColumn.build())); + } return properties.build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index f5ed851a7e089..201ec219823af 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; @@ -47,6 +48,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; @@ -55,18 +57,20 @@ import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; @@ -320,18 +324,40 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext> context) { + Set expectedFilterInputs = new HashSet<>(); + if (node.getFilter().isPresent()) { + expectedFilterInputs = ImmutableSet.builder() + .addAll(VariablesExtractor.extractUnique(node.getFilter().get())) + .build(); + } + ImmutableSet.Builder probeInputsBuilder = ImmutableSet.builder(); probeInputsBuilder.addAll(context.get()) .addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe)); if (node.getProbeHashVariable().isPresent()) { probeInputsBuilder.add(node.getProbeHashVariable().get()); } + probeInputsBuilder.addAll(expectedFilterInputs); Set probeInputs = probeInputsBuilder.build(); ImmutableSet.Builder indexInputBuilder = ImmutableSet.builder(); @@ -340,6 +366,9 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext indexInputs = indexInputBuilder.build(); PlanNode probeSource = context.rewrite(node.getProbeSource(), probeInputs); @@ -355,7 +384,8 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext> context) + { + Set expectedInputs = ImmutableSet.builder() + .addAll(node.getMergeProcessorProjectedVariables()) + .build(); + + PlanNode source = context.rewrite(node.getSource(), expectedInputs); + + return new MergeWriterNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + source, + node.getTarget(), + node.getMergeProcessorProjectedVariables(), + node.getOutputVariables()); + } + + @Override + public PlanNode visitMergeProcessor(MergeProcessorNode node, RewriteContext> context) + { + Set expectedInputs = ImmutableSet.builder() + .add(node.getTargetTableRowIdColumnVariable()) + .add(node.getMergeRowVariable()) + .build(); + + PlanNode source = context.rewrite(node.getSource(), expectedInputs); + + return new MergeProcessorNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + source, + node.getTarget(), + node.getTargetTableRowIdColumnVariable(), + node.getMergeRowVariable(), + node.getTargetColumnVariables(), + node.getOutputVariables()); + } + @Override public PlanNode visitFilter(FilterNode node, RewriteContext> context) { @@ -781,6 +852,25 @@ public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext> context) + { + PlanNode source = context.rewrite(node.getSource(), ImmutableSet.copyOf(node.getSource().getOutputVariables())); + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + source, + node.getTarget(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + node.getColumns(), + node.getColumnNames(), + node.getNotNullColumnVariables(), + node.getPartitioningScheme()); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext> context) { @@ -816,7 +906,7 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context) { ImmutableSet.Builder builder = ImmutableSet.builder(); - builder.add(node.getRowId()); + node.getRowId().ifPresent(r -> builder.add(r)); if (node.getInputDistribution().isPresent()) { builder.addAll(node.getInputDistribution().get().getInputVariables()); } @@ -995,5 +1085,25 @@ public PlanNode visitLateralJoin(LateralJoinNode node, RewriteContext> context) + { + return node.getSource().map(source -> new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(context.rewrite(source, ImmutableSet.copyOf(source.getOutputVariables()))), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle() + )).orElse(node); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java index 6a0a5d37c79d1..9121a6ca44844 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PushdownSubfields.java @@ -18,9 +18,11 @@ import com.facebook.presto.common.Subfield; import com.facebook.presto.common.Subfield.NestedField; import com.facebook.presto.common.Subfield.PathElement; +import com.facebook.presto.common.block.Block; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.MapType; import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.Metadata; @@ -39,6 +41,8 @@ import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; +import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.OrderingScheme; @@ -53,6 +57,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; @@ -64,13 +69,12 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.QualifiedName; import com.google.common.collect.ImmutableList; @@ -89,13 +93,18 @@ import java.util.stream.IntStream; import static com.facebook.presto.SystemSessionProperties.isLegacyUnnest; +import static com.facebook.presto.SystemSessionProperties.isPushSubfieldsForCardinalityEnabled; +import static com.facebook.presto.SystemSessionProperties.isPushSubfieldsForMapFunctionsEnabled; import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsEnabled; import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsFromArrayLambdasEnabled; import static com.facebook.presto.common.Subfield.allSubscripts; import static com.facebook.presto.common.Subfield.noSubfield; +import static com.facebook.presto.common.Subfield.structureOnly; +import static com.facebook.presto.common.type.TypeUtils.readNativeValue; import static com.facebook.presto.common.type.Varchars.isVarcharType; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -151,7 +160,7 @@ private static class Rewriter { private final Session session; private final Metadata metadata; - private final StandardFunctionResolution functionResolution; + private final FunctionResolution functionResolution; private final ExpressionOptimizer expressionOptimizer; private final SubfieldExtractor subfieldExtractor; private static final QualifiedObjectName ARBITRARY_AGGREGATE_FUNCTION = QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "arbitrary"); @@ -169,7 +178,7 @@ public Rewriter(Session session, Metadata metadata, ExpressionOptimizerProvider expressionOptimizer, session.toConnectorSession(), metadata.getFunctionAndTypeManager(), - isPushdownSubfieldsFromArrayLambdasEnabled(session)); + session); } public boolean isPlanChanged() @@ -255,6 +264,9 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext conte node.getCriteria().stream() .map(IndexJoinNode.EquiJoinClause::getIndex) .forEach(context.get().variables::add); + node.getFilter() + .ifPresent(expression -> expression.accept(subfieldExtractor, context.get())); + context.get().variables.addAll(node.getLookupVariables()); return context.defaultRewrite(node, context.get()); } @@ -307,9 +319,9 @@ public PlanNode visitProject(ProjectNode node, RewriteContext context) continue; } - Optional subfield = toSubfield(expression, functionResolution, expressionOptimizer, session.toConnectorSession(), metadata.getFunctionAndTypeManager()); + Optional> subfield = toSubfield(expression, functionResolution, expressionOptimizer, session.toConnectorSession(), metadata.getFunctionAndTypeManager(), isPushSubfieldsForMapFunctionsEnabled(session)); if (subfield.isPresent()) { - context.get().addAssignment(variable, subfield.get()); + subfield.get().forEach(element -> context.get().addAssignment(variable, element)); continue; } @@ -403,6 +415,68 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext conte node.getEnforcedConstraint(), node.getCteMaterializationInfo()); } + @Override + public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext context) + { + if (context.get().subfields.isEmpty()) { + return node; + } + + ImmutableMap.Builder newAssignments = ImmutableMap.builder(); + + for (Map.Entry entry : node.getAssignments().entrySet()) { + VariableReferenceExpression variable = entry.getKey(); + if (context.get().variables.contains(variable)) { + newAssignments.put(entry); + continue; + } + + List subfields = context.get().findSubfields(variable.getName()); + + verify(!subfields.isEmpty(), "Missing variable: " + variable); + + String columnName = getColumnName(session, metadata, node.getTableHandle(), entry.getValue()); + + List subfieldsWithoutNoSubfield = subfields.stream().filter(subfield -> !containsNoSubfieldPathElement(subfield)).collect(toList()); + List subfieldsWithNoSubfield = subfields.stream().filter(subfield -> containsNoSubfieldPathElement(subfield)).collect(toList()); + + // Prune subfields: if one subfield is a prefix of another subfield, keep the shortest one. + // Example: {a.b.c, a.b} -> {a.b} + List columnSubfields = subfieldsWithoutNoSubfield.stream() + .filter(subfield -> !prefixExists(subfield, subfieldsWithoutNoSubfield)) + .map(Subfield::getPath) + .map(path -> new Subfield(columnName, path)) + .collect(toList()); + + columnSubfields.addAll(subfieldsWithNoSubfield.stream() + .filter(subfield -> !isPrefixOf(dropNoSubfield(subfield), subfieldsWithoutNoSubfield)) + .map(Subfield::getPath) + .map(path -> new Subfield(columnName, path)) + .collect(toList())); + + planChanged = true; + newAssignments.put(variable, entry.getValue().withRequiredSubfields(ImmutableList.copyOf(columnSubfields))); + } + + return new IndexSourceNode( + node.getSourceLocation(), + node.getId(), + node.getStatsEquivalentPlanNode(), + node.getIndexHandle(), + node.getTableHandle(), + node.getLookupVariables(), + node.getOutputVariables(), + newAssignments.build(), + node.getCurrentConstraint()); + } + + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext context) + { + context.get().variables.addAll(node.getColumns()); + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { @@ -416,7 +490,7 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context) if (node.getInputDistribution().isPresent()) { context.get().variables.addAll(node.getInputDistribution().get().getInputVariables()); } - context.get().variables.add(node.getRowId()); + node.getRowId().ifPresent(r -> context.get().variables.add(r)); return context.defaultRewrite(node, context.get()); } @@ -570,17 +644,29 @@ private static String getColumnName(Session session, Metadata metadata, TableHan return metadata.getColumnMetadata(session, tableHandle, columnHandle).getName(); } - private static Optional toSubfield( + private static Optional> toSubfield( RowExpression expression, - StandardFunctionResolution functionResolution, + FunctionResolution functionResolution, ExpressionOptimizer expressionOptimizer, ConnectorSession connectorSession, - FunctionAndTypeManager functionAndTypeManager) + FunctionAndTypeManager functionAndTypeManager, + boolean isPushdownSubfieldsForMapFunctionsEnabled) { ImmutableList.Builder elements = ImmutableList.builder(); while (true) { if (expression instanceof VariableReferenceExpression) { - return Optional.of(new Subfield(((VariableReferenceExpression) expression).getName(), elements.build().reverse())); + return Optional.of(ImmutableList.of(new Subfield(((VariableReferenceExpression) expression).getName(), elements.build().reverse()))); + } + if (expression instanceof CallExpression) { + ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getDescriptor(); + Optional pushdownSubfieldArgIndex = functionDescriptor.getPushdownSubfieldArgIndex(); + if (pushdownSubfieldArgIndex.isPresent() && + ((CallExpression) expression).getArguments().size() > pushdownSubfieldArgIndex.get() && + ((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get()).getType() instanceof RowType + && !elements.build().isEmpty()) { // ensures pushdown only happens when a subfield is read from a column + expression = ((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get()); + continue; + } } if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == DEREFERENCE) { @@ -623,7 +709,7 @@ private static Optional toSubfield( if (index instanceof Number) { //Fix for issue https://github.com/prestodb/presto/issues/22690 //Avoid negative index pushdown - if (((Number) index).longValue() < 0) { + if (((Number) index).longValue() < 0 && arguments.get(0).getType() instanceof ArrayType) { return Optional.empty(); } @@ -640,11 +726,88 @@ private static Optional toSubfield( } return Optional.empty(); } + // map_subset(feature, constant_array) is only accessing fields specified in feature map. + // For example map_subset(feature, array[1, 2]) is equivalent to calling element_at(feature, 1) and element_at(feature, 2) for subfield extraction + if (isPushdownSubfieldsForMapFunctionsEnabled && expression instanceof CallExpression && isMapSubSetWithConstantArray((CallExpression) expression, functionResolution)) { + CallExpression call = (CallExpression) expression; + ConstantExpression constantArray = (ConstantExpression) call.getArguments().get(1); + return extractSubfieldsFromArray(constantArray, (VariableReferenceExpression) call.getArguments().get(0)); + } + // map_filter(feature, (k, v) -> k in (1, 2, 3)), map_filter(feature, (k, v) -> contains(array[1, 2, 3], k)), map_filter(feature, (k, v) -> k = 2) only access specified elements + if (isPushdownSubfieldsForMapFunctionsEnabled && expression instanceof CallExpression && isMapFilterWithConstantFilterInMapKey((CallExpression) expression, functionResolution)) { + CallExpression call = (CallExpression) expression; + VariableReferenceExpression mapVariable = (VariableReferenceExpression) call.getArguments().get(0); + ImmutableList.Builder arguments = ImmutableList.builder(); + if (((LambdaDefinitionExpression) call.getArguments().get(1)).getBody() instanceof SpecialFormExpression) { + List mapKeys = ((SpecialFormExpression) ((LambdaDefinitionExpression) call.getArguments().get(1)).getBody()).getArguments().stream().skip(1).collect(toImmutableList()); + for (RowExpression mapKey : mapKeys) { + Optional mapKeySubfield = extractSubfieldsFromSingleValue((ConstantExpression) mapKey, mapVariable); + if (!mapKeySubfield.isPresent()) { + return Optional.empty(); + } + arguments.add(mapKeySubfield.get()); + } + return Optional.of(arguments.build()); + } + else if (((LambdaDefinitionExpression) call.getArguments().get(1)).getBody() instanceof CallExpression) { + CallExpression callExpression = (CallExpression) ((LambdaDefinitionExpression) call.getArguments().get(1)).getBody(); + if (functionResolution.isArrayContainsFunction(callExpression.getFunctionHandle())) { + return extractSubfieldsFromArray((ConstantExpression) callExpression.getArguments().get(0), mapVariable); + } + else if (functionResolution.isEqualsFunction(callExpression.getFunctionHandle())) { + ConstantExpression mapKey; + if (callExpression.getArguments().get(0) instanceof ConstantExpression) { + mapKey = (ConstantExpression) callExpression.getArguments().get(0); + } + else { + mapKey = (ConstantExpression) callExpression.getArguments().get(1); + } + Optional mapKeySubfield = extractSubfieldsFromSingleValue(mapKey, mapVariable); + return mapKeySubfield.map(ImmutableList::of); + } + } + } return Optional.empty(); } } + private static Optional> extractSubfieldsFromArray(ConstantExpression constantArray, VariableReferenceExpression mapVariable) + { + ImmutableList.Builder arguments = ImmutableList.builder(); + checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType); + Block arrayValue = (Block) constantArray.getValue(); + Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType(); + for (int i = 0; i < arrayValue.getPositionCount(); ++i) { + Object mapKey = readNativeValue(arrayElementType, arrayValue, i); + if (mapKey == null) { + return Optional.empty(); + } + if (mapKey instanceof Number) { + arguments.add(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.LongSubscript(((Number) mapKey).longValue())))); + } + if (isVarcharType(arrayElementType)) { + arguments.add(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.StringSubscript(((Slice) mapKey).toStringUtf8())))); + } + } + return Optional.of(arguments.build()); + } + + private static Optional extractSubfieldsFromSingleValue(ConstantExpression mapKey, VariableReferenceExpression mapVariable) + { + Object value = mapKey.getValue(); + if (value == null) { + return Optional.empty(); + } + if (value instanceof Number) { + return Optional.of(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.LongSubscript(((Number) value).longValue())))); + } + if (isVarcharType(mapKey.getType())) { + return Optional.of(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.StringSubscript(((Slice) value).toStringUtf8())))); + } + return Optional.empty(); + } + private static NestedField nestedField(String name) { return new NestedField(name.toLowerCase(Locale.ENGLISH)); @@ -653,38 +816,57 @@ private static NestedField nestedField(String name) private static final class SubfieldExtractor extends DefaultRowExpressionTraversalVisitor { - private final StandardFunctionResolution functionResolution; + private final FunctionResolution functionResolution; private final ExpressionOptimizer expressionOptimizer; private final ConnectorSession connectorSession; private final FunctionAndTypeManager functionAndTypeManager; private final boolean isPushDownSubfieldsFromLambdasEnabled; + private final boolean isPushdownSubfieldsForMapFunctionsEnabled; + private final boolean isPushdownSubfieldsForCardinalityEnabled; private SubfieldExtractor( - StandardFunctionResolution functionResolution, + FunctionResolution functionResolution, ExpressionOptimizer expressionOptimizer, ConnectorSession connectorSession, FunctionAndTypeManager functionAndTypeManager, - boolean isPushDownSubfieldsFromLambdasEnabled) + Session session) { this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null"); this.connectorSession = connectorSession; this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); - this.isPushDownSubfieldsFromLambdasEnabled = isPushDownSubfieldsFromLambdasEnabled; + requireNonNull(session); + this.isPushDownSubfieldsFromLambdasEnabled = isPushdownSubfieldsFromArrayLambdasEnabled(session); + this.isPushdownSubfieldsForMapFunctionsEnabled = isPushSubfieldsForMapFunctionsEnabled(session); + this.isPushdownSubfieldsForCardinalityEnabled = isPushSubfieldsForCardinalityEnabled(session); } @Override public Void visitCall(CallExpression call, Context context) { + if (isPushdownSubfieldsForCardinalityEnabled && functionResolution.isCardinalityFunction(call.getFunctionHandle()) && call.getArguments().size() == 1) { + RowExpression argument = call.getArguments().get(0); + if (argument instanceof VariableReferenceExpression) { + Type argumentType = argument.getType(); + if (argumentType instanceof MapType || argumentType instanceof ArrayType) { + VariableReferenceExpression variable = (VariableReferenceExpression) argument; + Subfield cardinalitySubfield = new Subfield( + variable.getName(), + ImmutableList.of(structureOnly())); + context.subfields.add(cardinalitySubfield); + return null; + } + } + } ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getDescriptor(); - if (isSubscriptOrElementAtFunction(call, functionResolution, functionAndTypeManager)) { - Optional subfield = toSubfield(call, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager); + if (isSubscriptOrElementAtFunction(call, functionResolution, functionAndTypeManager) || isMapSubSetWithConstantArray(call, functionResolution) || isMapFilterWithConstantFilterInMapKey(call, functionResolution)) { + Optional> subfield = toSubfield(call, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled); if (subfield.isPresent()) { if (context.isPruningLambdaSubfieldsPossible()) { - addRequiredLambdaSubfields(context, subfield.get()); + subfield.get().forEach(item -> addRequiredLambdaSubfields(context, item)); } else { - context.subfields.add(subfield.get()); + context.subfields.addAll(subfield.get()); } } else { @@ -837,14 +1019,14 @@ else if (specialForm.getForm() != DEREFERENCE) { return null; } - Optional subfield = toSubfield(specialForm, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager); + Optional> subfield = toSubfield(specialForm, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled); if (subfield.isPresent()) { if (context.isPruningLambdaSubfieldsPossible()) { - addRequiredLambdaSubfields(context, subfield.get()); + subfield.get().forEach(item -> addRequiredLambdaSubfields(context, item)); } else { - context.subfields.add(subfield.get()); + context.subfields.addAll(subfield.get()); } } else { @@ -877,7 +1059,7 @@ private void addRequiredLambdaSubfields(Context context, Subfield input) public Void visitVariableReference(VariableReferenceExpression reference, Context context) { if (context.isPruningLambdaSubfieldsPossible()) { - addRequiredLambdaSubfields(context, toSubfield(reference, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager).get()); + toSubfield(reference, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled).get().forEach(item -> addRequiredLambdaSubfields(context, item)); return null; } context.variables.add(reference); @@ -968,4 +1150,42 @@ private static boolean isSubscriptOrElementAtFunction(CallExpression expression, functionAndTypeManager.getFunctionAndTypeResolver().getFunctionMetadata(expression.getFunctionHandle()).getName() .equals(functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("element_at"))); } + + private static boolean isMapSubSetWithConstantArray(CallExpression expression, FunctionResolution functionResolution) + { + return functionResolution.isMapSubSetFunction(expression.getFunctionHandle()) + && expression.getArguments().get(0) instanceof VariableReferenceExpression + && expression.getArguments().get(1) instanceof ConstantExpression; + } + + private static boolean isMapFilterWithConstantFilterInMapKey(CallExpression expression, FunctionResolution functionResolution) + { + if (functionResolution.isMapFilterFunction(expression.getFunctionHandle()) + && expression.getArguments().get(0) instanceof VariableReferenceExpression && expression.getArguments().get(1) instanceof LambdaDefinitionExpression) { + LambdaDefinitionExpression lambdaDefinitionExpression = (LambdaDefinitionExpression) expression.getArguments().get(1); + if (lambdaDefinitionExpression.getBody() instanceof SpecialFormExpression) { + SpecialFormExpression specialFormExpression = (SpecialFormExpression) lambdaDefinitionExpression.getBody(); + if (specialFormExpression.getForm().equals(IN) && specialFormExpression.getArguments().get(0) instanceof VariableReferenceExpression + && ((VariableReferenceExpression) specialFormExpression.getArguments().get(0)).getName().equals(lambdaDefinitionExpression.getArguments().get(0))) { + return specialFormExpression.getArguments().stream().skip(1).allMatch(x -> x instanceof ConstantExpression); + } + } + else if (lambdaDefinitionExpression.getBody() instanceof CallExpression) { + CallExpression callExpression = (CallExpression) lambdaDefinitionExpression.getBody(); + if (functionResolution.isArrayContainsFunction(callExpression.getFunctionHandle())) { + return callExpression.getArguments().get(0) instanceof ConstantExpression && callExpression.getArguments().get(1) instanceof VariableReferenceExpression + && ((VariableReferenceExpression) callExpression.getArguments().get(1)).getName().equals(lambdaDefinitionExpression.getArguments().get(0)); + } + else if (functionResolution.isEqualsFunction(callExpression.getFunctionHandle())) { + return (callExpression.getArguments().get(0) instanceof VariableReferenceExpression + && ((VariableReferenceExpression) callExpression.getArguments().get(0)).getName().equals(lambdaDefinitionExpression.getArguments().get(0)) + && callExpression.getArguments().get(1) instanceof ConstantExpression) + || (callExpression.getArguments().get(1) instanceof VariableReferenceExpression + && ((VariableReferenceExpression) callExpression.getArguments().get(1)).getName().equals(lambdaDefinitionExpression.getArguments().get(0)) + && callExpression.getArguments().get(0) instanceof ConstantExpression); + } + } + } + return false; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java index ffd4806665c2c..cd4d5207fccc8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; @@ -102,6 +103,12 @@ public Range visitEnforceSingleRow(EnforceSingleRowNode node, Void context return Range.singleton(1L); } + @Override + public Range visitWindow(WindowNode node, Void context) + { + return node.getSource().accept(this, null); + } + @Override public Range visitAggregation(AggregationNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RandomizeNullKeyInOuterJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RandomizeNullKeyInOuterJoin.java index 27ab650e7af46..057c84a3514cc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RandomizeNullKeyInOuterJoin.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RandomizeNullKeyInOuterJoin.java @@ -29,14 +29,14 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy; +import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import io.airlift.slice.Slices; import java.util.HashMap; import java.util.HashSet; @@ -46,30 +46,23 @@ import java.util.stream.IntStream; import java.util.stream.Stream; -import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount; import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getRandomizeOuterJoinNullKeyNullRatioThreshold; import static com.facebook.presto.SystemSessionProperties.getRandomizeOuterJoinNullKeyStrategy; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; -import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.metadata.CastType.CAST; import static com.facebook.presto.spi.plan.JoinDistributionType.PARTITIONED; import static com.facebook.presto.spi.plan.JoinType.FULL; import static com.facebook.presto.spi.plan.JoinType.LEFT; import static com.facebook.presto.spi.plan.JoinType.RIGHT; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; -import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; -import static com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy; import static com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy.ALWAYS; import static com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy.COST_BASED; import static com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy.DISABLED; import static com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy.KEY_FROM_OUTER_JOIN; import static com.facebook.presto.sql.planner.plan.ChildReplacer.replaceChildren; -import static com.facebook.presto.sql.relational.Expressions.call; -import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.specialForm; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -350,8 +343,8 @@ public PlanNode visitJoin(JoinNode joinNode, RewriteContext joinOutputBuilder = ImmutableList.builder(); joinOutputBuilder.addAll(leftKeyRandomVariableMap.keySet()); // Input from left side should be before input from right side in join output @@ -390,28 +383,6 @@ private boolean hasNullSkew(JoinNodeStatsEstimate joinEstimate) && joinEstimate.getNullJoinProbeKeyCount() / joinEstimate.getJoinProbeKeyCount() > getRandomizeOuterJoinNullKeyNullRatioThreshold(session))); } - private RowExpression randomizeJoinKey(RowExpression keyExpression, String prefix) - { - int partitionCount = getHashPartitionCount(session); - RowExpression randomNumber = call( - functionAndTypeManager, - "random", - BIGINT, - constant((long) partitionCount, BIGINT)); - RowExpression randomNumberVarchar = call("CAST", functionAndTypeManager.lookupCast(CAST, randomNumber.getType(), VARCHAR), VARCHAR, randomNumber); - RowExpression concatExpression = call(functionAndTypeManager, - "concat", - VARCHAR, - ImmutableList.of(constant(Slices.utf8Slice(prefix), VARCHAR), randomNumberVarchar)); - - RowExpression castToVarchar = keyExpression; - // Only do cast if keyExpression is not VARCHAR type. - if (!(keyExpression.getType() instanceof VarcharType)) { - castToVarchar = call("CAST", functionAndTypeManager.lookupCast(CAST, keyExpression.getType(), VARCHAR), VARCHAR, keyExpression); - } - return new SpecialFormExpression(COALESCE, VARCHAR, ImmutableList.of(castToVarchar, concatExpression)); - } - // Do not need to generate randomized variable if the joinKey 1) has already been randomized with the same prefix 2) included in the output of the source node private boolean isAlreadyRandomized(PlanNode source, VariableReferenceExpression joinKey, String prefix) { @@ -425,7 +396,7 @@ private boolean enabledForJoinKeyFromOuterJoin(Set private Map generateRandomKeyMap(List joinKeys, String prefix) { - List randomExpressions = joinKeys.stream().map(x -> randomizeJoinKey(x, prefix)).collect(toImmutableList()); + List randomExpressions = joinKeys.stream().map(x -> PlannerUtils.randomizeJoinKey(session, functionAndTypeManager, x, prefix)).collect(toImmutableList()); List randomVariable = randomExpressions.stream() .map(x -> planVariableAllocator.newVariable(x, RandomizeNullKeyInOuterJoin.class.getSimpleName())) .collect(toImmutableList()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplaceConstantVariableReferencesWithConstants.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplaceConstantVariableReferencesWithConstants.java index 91bc3277795a5..e0c64a01f9373 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplaceConstantVariableReferencesWithConstants.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplaceConstantVariableReferencesWithConstants.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; @@ -44,7 +45,6 @@ import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.SampleNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -202,7 +202,7 @@ public PlanNodeWithConstant visitFilter(FilterNode node, Void context) if (newConstantMap.containsKey(variable) && !newConstantMap.get(variable).equals(constant)) { return new PlanNodeWithConstant(replaceChildren(node, ImmutableList.of(rewrittenChild.getPlanNode())), ImmutableMap.of()); } - if (!constant.isNull()) { + if (!constant.isNull() && variable.getType().equals(constant.getType())) { planChanged = true; newConstantMap.put(variable, constant); } @@ -235,7 +235,7 @@ public PlanNodeWithConstant visitProject(ProjectNode node, Void context) for (Map.Entry entry : newProjectNode.getAssignments().getMap().entrySet()) { if (entry.getValue() instanceof ConstantExpression && isSupportedType(entry.getKey()) && isSupportedType(entry.getValue())) { ConstantExpression constantExpression = (ConstantExpression) entry.getValue(); - if (!constantExpression.isNull()) { + if (!constantExpression.isNull() && entry.getKey().getType().equals(constantExpression.getType())) { planChanged = true; newConstantMap.put(entry.getKey(), constantExpression); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java index 3a2f0664c4747..9ed461c75214f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ReplicateSemiJoinInDelete.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import static com.facebook.presto.SystemSessionProperties.isBroadcastSemiJoinForDeleteEnabled; import static com.facebook.presto.spi.plan.SemiJoinNode.DistributionType.REPLICATED; import static java.util.Objects.requireNonNull; @@ -32,10 +33,13 @@ public class ReplicateSemiJoinInDelete @Override public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) { - requireNonNull(plan, "plan is null"); - Rewriter rewriter = new Rewriter(); - PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan); - return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + if (isBroadcastSemiJoinForDeleteEnabled(session)) { + requireNonNull(plan, "plan is null"); + Rewriter rewriter = new Rewriter(); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + return PlanOptimizerResult.optimizerResult(plan, false); } private static class Rewriter diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java new file mode 100644 index 0000000000000..3198cf84c5755 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/RewriteWriterTarget.java @@ -0,0 +1,226 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.TableMetadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; +import com.facebook.presto.spi.plan.TableWriterNode.WriterTarget; +import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessDeniedException; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter.RewriteContext; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.stream.Collectors.toSet; + +public class RewriteWriterTarget + implements PlanOptimizer +{ + private final Metadata metadata; + private final AccessControl accessControl; + public RewriteWriterTarget(Metadata metadata, AccessControl accessControl) + { + this.metadata = metadata; + this.accessControl = accessControl; + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + Rewriter rewriter = new Rewriter(session, metadata, accessControl); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, Optional.empty()); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + + private class Rewriter + extends SimplePlanRewriter> + { + private final Session session; + private final Metadata metadata; + private final AccessControl accessControl; + private boolean planChanged; + + public Rewriter(Session session, Metadata metadata, AccessControl accessControl) + { + this.session = session; + this.metadata = metadata; + this.accessControl = accessControl; + } + + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext> context) + { + CallDistributedProcedureTarget callDistributedProcedureTarget = (CallDistributedProcedureTarget) getContextTarget(context); + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + node.getSource(), + Optional.of(callDistributedProcedureTarget), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + node.getColumns(), + node.getColumnNames(), + node.getNotNullColumnVariables(), + node.getPartitioningScheme()); + } + + @Override + public PlanNode visitTableFinish(TableFinishNode node, RewriteContext> context) + { + PlanNode child = node.getSource(); + + Optional newTarget = getWriterTarget(child); + if (!newTarget.isPresent()) { + return node; + } + + planChanged = true; + child = context.rewrite(child, newTarget); + + return new TableFinishNode( + node.getSourceLocation(), + node.getId(), + child, + newTarget, + node.getRowCountVariable(), + node.getStatisticsAggregation(), + node.getStatisticsAggregationDescriptor(), + Optional.empty()); + } + + public Optional getWriterTarget(PlanNode node) + { + if (node instanceof CallDistributedProcedureNode) { + Optional tableHandle = findTableHandleForCallDistributedProcedure(((CallDistributedProcedureNode) node).getSource()); + Optional callDistributedProcedureTarget = ((CallDistributedProcedureNode) node).getTarget(); + return !tableHandle.isPresent() ? + callDistributedProcedureTarget.map(target -> new CallDistributedProcedureTarget( + target.getProcedureName(), + target.getProcedureArguments(), + target.getSourceHandle(), + target.getSchemaTableName(), + true)) : + callDistributedProcedureTarget.map(target -> new CallDistributedProcedureTarget( + target.getProcedureName(), + target.getProcedureArguments(), + tableHandle, + target.getSchemaTableName(), + false)); + } + + if (node instanceof ExchangeNode || node instanceof UnionNode) { + Set> writerTargets = node.getSources().stream() + .map(this::getWriterTarget) + .collect(toSet()); + return getOnlyElement(writerTargets); + } + + return Optional.empty(); + } + + private Optional findTableHandleForCallDistributedProcedure(PlanNode startNode) + { + List tableScanNodes = PlanNodeSearcher.searchFrom(startNode) + .where(node -> node instanceof TableScanNode) + .findAll(); + + if (tableScanNodes.size() == 1) { + TableHandle tableHandle = ((TableScanNode) tableScanNodes.get(0)).getTable(); + checkFullDataAccessControl(tableHandle); + return Optional.of(tableHandle); + } + + List valuesNodes = PlanNodeSearcher.searchFrom(startNode) + .where(node -> node instanceof ValuesNode) + .findAll(); + + if (valuesNodes.size() == 1) { + return Optional.empty(); + } + + throw new IllegalArgumentException("Expected to find exactly one update target TableScanNode in plan but found: " + tableScanNodes); + } + + public boolean isPlanChanged() + { + return planChanged; + } + + private void checkFullDataAccessControl(TableHandle tableHandle) + { + TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle); + QualifiedObjectName baseTable = new QualifiedObjectName(tableMetadata.getConnectorId().getCatalogName(), + tableMetadata.getTable().getSchemaName(), tableMetadata.getTable().getTableName()); + String errorMessage = "Full data access is restricted by row filters and column masks for table: " + baseTable; + + // Check for row filters on this target table + List rowFilters = accessControl.getRowFilters( + session.getRequiredTransactionId(), + session.getIdentity(), + session.getAccessControlContext(), + baseTable); + + if (!rowFilters.isEmpty()) { + throw new AccessDeniedException(errorMessage); + } + + // Check for column masks on this target table + Map columnHandles = metadata.getColumnHandles(session, tableHandle); + List columnsMetadata = columnHandles.values().stream() + .map(handle -> metadata.getColumnMetadata(session, tableHandle, handle)) + .collect(toImmutableList()); + + Map columnMasks = accessControl.getColumnMasks( + session.getRequiredTransactionId(), + session.getIdentity(), + session.getAccessControlContext(), + baseTable, + columnsMetadata); + + if (!columnMasks.isEmpty()) { + throw new AccessDeniedException(errorMessage); + } + } + } + + private static WriterTarget getContextTarget(RewriteContext> context) + { + return context.get().orElseThrow(() -> new IllegalStateException("WriterTarget not present")); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ShardJoins.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ShardJoins.java index 75f8d9ec8ee23..7af27d8551863 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ShardJoins.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/ShardJoins.java @@ -24,12 +24,12 @@ import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.PlannerUtils; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyPlanWithEmptyInput.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyPlanWithEmptyInput.java index bba51cacd00cc..c6c0077747caa 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyPlanWithEmptyInput.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SimplifyPlanWithEmptyInput.java @@ -32,6 +32,7 @@ import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -43,7 +44,6 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableList; import java.util.ArrayList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java new file mode 100644 index 0000000000000..f97398529a6f5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.EquiJoinClause; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.preferSortMergeJoin; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class SortMergeJoinOptimizer + implements PlanOptimizer +{ + private final Metadata metadata; + private final boolean nativeExecution; + private boolean isEnabledForTesting; + + public SortMergeJoinOptimizer(Metadata metadata, boolean nativeExecution) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.nativeExecution = nativeExecution; + } + + @Override + public void setEnabledForTesting(boolean isSet) + { + isEnabledForTesting = isSet; + } + + @Override + public boolean isEnabled(Session session) + { + // TODO: Consider group execution and single node execution. + return isEnabledForTesting || preferSortMergeJoin(session); + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider type, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + + if (isEnabled(session)) { + Rewriter rewriter = new SortMergeJoinOptimizer.Rewriter(idAllocator, metadata, session); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + return PlanOptimizerResult.optimizerResult(plan, false); + } + + /** + * @param joinNode + * @return returns true if merge join is supported for the given join node. + */ + public boolean isMergeJoinEligible(JoinNode joinNode) + { + return (joinNode.getType() == JoinType.INNER || joinNode.getType() == JoinType.LEFT || joinNode.getType() == JoinType.RIGHT) + && !joinNode.isCrossJoin(); + } + + private class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final Session session; + private boolean planChanged; + + private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); + } + + public boolean isPlanChanged() + { + return planChanged; + } + + @Override + public PlanNode visitJoin(JoinNode node, RewriteContext context) + { + if (!isMergeJoinEligible(node)) { + return node; + } + + PlanNode left = node.getLeft(); + PlanNode right = node.getRight(); + + List leftJoinColumns = node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); + + if (!isPlanOutputSortedByColumns(left, leftJoinColumns)) { + List leftOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getLeft(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + left = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + left, + new OrderingScheme(leftOrdering), + true, + ImmutableList.of()); + } + + List rightJoinColumns = node.getCriteria().stream() + .map(EquiJoinClause::getRight) + .collect(toImmutableList()); + if (!isPlanOutputSortedByColumns(right, rightJoinColumns)) { + List rightOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getRight(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + right = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + right, + new OrderingScheme(rightOrdering), + true, + ImmutableList.of()); + } + + planChanged = true; + return new MergeJoinNode( + Optional.empty(), + node.getId(), + node.getType(), + left, + right, + node.getCriteria(), + node.getOutputVariables(), + node.getFilter(), + node.getLeftHashVariable(), + node.getRightHashVariable()); + } + + private boolean isPlanOutputSortedByColumns(PlanNode plan, List columns) + { + StreamPropertyDerivations.StreamProperties properties = StreamPropertyDerivations.derivePropertiesRecursively(plan, metadata, session, nativeExecution); + + // Check if partitioning columns (bucketed-by columns [B]) are a subset of join columns [J] + // B = subset (J) + if (!verifyStreamProperties(properties, columns)) { + return false; + } + + // Check if the output of the subplan is ordered by the join columns + return !LocalProperties.match(properties.getLocalProperties(), LocalProperties.sorted(columns, ASC_NULLS_FIRST)).get(0).isPresent(); + } + + private boolean verifyStreamProperties(StreamPropertyDerivations.StreamProperties streamProperties, List joinColumns) + { + if (!streamProperties.getPartitioningColumns().isPresent()) { + return false; + } + List partitioningColumns = streamProperties.getPartitioningColumns().get(); + return partitioningColumns.size() <= joinColumns.size() && joinColumns.containsAll(partitioningColumns); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java new file mode 100644 index 0000000000000..c3c892516949f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortedExchangeRule.java @@ -0,0 +1,207 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isSortedExchangeEnabled; +import static java.util.Objects.requireNonNull; + +/** + * Optimizer rule that pushes sort operations down to exchange nodes where possible. + * This optimization is beneficial for distributed queries where data needs to be sorted + * after shuffling, as it allows sorting to happen during the shuffle operation itself + * rather than requiring an explicit SortNode afterward. + * + * The rule looks for SortNode → ExchangeNode patterns and attempts to merge them into + * a single sorted exchange node when: + * - The exchange is a REMOTE exchange (REMOTE_STREAMING or REMOTE_MATERIALIZED) + * - The exchange is either REPARTITION or GATHER (merging) type + * - REPLICATE exchanges are excluded as sorting is not beneficial for broadcast operations + * - The exchange doesn't already have an ordering scheme + * - All ordering variables are available in the exchange output + */ +public class SortedExchangeRule + implements PlanOptimizer +{ + private final boolean isPrestoSparkExecution; + private boolean isEnabledForTesting; + + /** + * Constructor that accepts a flag indicating whether this is a Presto Spark execution environment. + * + * @param isPrestoSparkExecution true if running in Presto Spark execution environment + */ + public SortedExchangeRule(boolean isPrestoSparkExecution) + { + this.isPrestoSparkExecution = isPrestoSparkExecution; + } + + @Override + public void setEnabledForTesting(boolean isSet) + { + isEnabledForTesting = isSet; + } + + @Override + public boolean isEnabled(Session session) + { + return (isSortedExchangeEnabled(session) && isPrestoSparkExecution) || isEnabledForTesting; + } + + @Override + public PlanOptimizerResult optimize( + PlanNode plan, + Session session, + TypeProvider types, + VariableAllocator variableAllocator, + PlanNodeIdAllocator idAllocator, + WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(types, "types is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(warningCollector, "warningCollector is null"); + + if (isEnabled(session)) { + Rewriter rewriter = new Rewriter(idAllocator); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + return PlanOptimizerResult.optimizerResult(plan, false); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private boolean planChanged; + + public Rewriter(PlanNodeIdAllocator idAllocator) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + } + + public boolean isPlanChanged() + { + return planChanged; + } + + @Override + public PlanNode visitSort(SortNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + + // Try to push sort down to exchange if we can find one in the source tree + Optional sortedExchange = pushSortToExchangeIfPossible(source, node.getOrderingScheme()); + if (sortedExchange.isPresent()) { + planChanged = true; + return sortedExchange.get(); + } + + // If we can't optimize, return a new SortNode with the rewritten source + // to preserve any optimizations that happened deeper in the tree + if (source != node.getSource()) { + return new SortNode( + node.getSourceLocation(), + node.getId(), + source, + node.getOrderingScheme(), + node.isPartial(), + node.getPartitionBy()); + } + + return node; + } + + /** + * Attempts to push the sorting operation down to the Exchange node if the plan structure allows it. + * This is beneficial for distributed queries where we can sort during the shuffle operation instead of + * adding an explicit SortNode. + * + * IMPORTANT: This only optimizes if the immediate child is an ExchangeNode. We do NOT look through + * intermediate nodes to find exchanges deeper in the tree. + * + * @param plan The plan node that needs sorting (must be immediate child of SortNode) + * @param orderingScheme The required ordering scheme + * @return Optional containing the sorted exchange if push-down is possible, empty otherwise + */ + private Optional pushSortToExchangeIfPossible(PlanNode plan, OrderingScheme orderingScheme) + { + // Only optimize if the immediate child is an exchange node + if (!(plan instanceof ExchangeNode)) { + return Optional.empty(); + } + + ExchangeNode exchangeNode = (ExchangeNode) plan; + + // TODO: Future work. Support rewrite for exchange node with + // multiple sources. + if (exchangeNode.getSources().size() > 1) { + return Optional.empty(); + } + + // Only push sort down to exchanges in remote scope + // These are the exchanges that involve shuffling data between executors + if (!exchangeNode.getScope().isRemote()) { + return Optional.empty(); + } + + // Only push sort down to REPARTITION and GATHER (merging) exchanges + // Do not support REPLICATED exchanges + if (exchangeNode.getType() == ExchangeNode.Type.REPLICATE) { + return Optional.empty(); + } + + // Validate that all ordering variables are present in the source's output + // This ensures we don't create invalid sorted exchanges + // TODO: Translate variable names in multiple node scenario + // tracking issue - https://github.com/prestodb/presto/issues/26602 + PlanNode source = exchangeNode.getSources().get(0); + List sourceOutputVariables = source.getOutputVariables(); + for (VariableReferenceExpression orderingVariable : orderingScheme.getOrderByVariables()) { + if (!sourceOutputVariables.contains(orderingVariable)) { + // Cannot push sort down if ordering references variables not in source output + return Optional.empty(); + } + } + + // Create a new sorted exchange node + ExchangeNode sortedExchange = ExchangeNode.sortedPartitionedExchange( + idAllocator.getNextId(), + exchangeNode.getScope(), + exchangeNode.getSources().get(0), + exchangeNode.getPartitioningScheme().getPartitioning(), + exchangeNode.getPartitioningScheme().getHashColumn(), + orderingScheme); + + return Optional.of(sortedExchange); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java index 221d224ea6d40..b1b44d241e6d6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPreferredProperties.java @@ -105,6 +105,16 @@ public StreamPreferredProperties withFixedParallelism() return fixedParallelism(); } + public static StreamPreferredProperties partitionedOn(Collection partitionSymbols) + { + if (partitionSymbols.isEmpty()) { + return singleStream(); + } + + // Prefer partitioning on given partitioning symbols. Partition hash can be evaluated in any order. + return new StreamPreferredProperties(Optional.of(FIXED), false, Optional.of(ImmutableSet.copyOf(partitionSymbols)), false); + } + public static StreamPreferredProperties exactlyPartitionedOn(Collection partitionVariables) { if (partitionVariables.isEmpty()) { @@ -188,9 +198,11 @@ else if (actualProperties.getDistribution() == SINGLE) { // is there a preference for a specific partitioning scheme? if (partitioningColumns.isPresent()) { if (exactColumnOrder) { - return actualProperties.isExactlyPartitionedOn(partitioningColumns.get()); + return actualProperties.isExactlyPartitionedOn(partitioningColumns.get()) + || actualProperties.getStreamPropertiesFromUniqueColumn().isPresent() && actualProperties.getStreamPropertiesFromUniqueColumn().get().isExactlyPartitionedOn(partitioningColumns.get()); } - return actualProperties.isPartitionedOn(partitioningColumns.get()); + return actualProperties.isPartitionedOn(partitioningColumns.get()) + || actualProperties.getStreamPropertiesFromUniqueColumn().isPresent() && actualProperties.getStreamPropertiesFromUniqueColumn().get().isPartitionedOn(partitioningColumns.get()); } return true; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 9bb82623312ca..3effedf6638fb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -18,15 +18,18 @@ import com.facebook.presto.metadata.TableLayout; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.LocalProperty; +import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; @@ -38,6 +41,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; @@ -45,38 +49,44 @@ import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import static com.facebook.presto.SystemSessionProperties.isUtilizeUniquePropertyInQueryPlanningEnabled; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.optimizations.PropertyDerivations.extractFixedValuesToConstantExpressions; import static com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.FIXED; @@ -122,7 +132,10 @@ public static StreamProperties deriveProperties(PlanNode node, List properties.otherActualProperties) + .map(properties -> { + checkState(properties.otherActualProperties.isPresent(), "otherActualProperties should always exist"); + return properties.otherActualProperties.get(); + }) .collect(toImmutableList()), metadata, session); @@ -224,11 +237,8 @@ public StreamProperties visitIndexJoin(IndexJoinNode node, List streamPropertiesFromUniqueColumn = Optional.empty(); + if (isUtilizeUniquePropertyInQueryPlanningEnabled(session) && layout.getUniqueColumn().isPresent() && assignments.containsKey(layout.getUniqueColumn().get())) { + streamPropertiesFromUniqueColumn = Optional.of(new StreamProperties(streamDistribution, Optional.of(ImmutableList.of(assignments.get(layout.getUniqueColumn().get()))), false)); + } + // if we are partitioned on empty set, we must say multiple of unknown partitioning, because // the connector does not guarantee a single split in this case (since it might not understand // that the value is a constant). if (streamPartitionSymbols.isPresent() && streamPartitionSymbols.get().isEmpty()) { - return new StreamProperties(streamDistribution, Optional.empty(), false); + return new StreamProperties(streamDistribution, Optional.empty(), false, Optional.empty(), streamPropertiesFromUniqueColumn); } - return new StreamProperties(streamDistribution, streamPartitionSymbols, false); + return new StreamProperties(streamDistribution, streamPartitionSymbols, false, Optional.empty(), streamPropertiesFromUniqueColumn); } private Optional> getNonConstantVariables(Set columnHandles, Map assignments, Set globalConstants) @@ -336,20 +351,32 @@ public StreamProperties visitExchange(ExchangeNode node, List return new StreamProperties(MULTIPLE, Optional.empty(), false); } + Optional additionalUniqueProperty = Optional.empty(); + if (inputProperties.size() == 1 && inputProperties.get(0).hasUniqueProperties() && !node.getType().equals(ExchangeNode.Type.REPLICATE)) { + List inputVariables = node.getInputs().get(0); + Map inputToOutput = new HashMap<>(); + for (int i = 0; i < node.getOutputVariables().size(); i++) { + inputToOutput.put(inputVariables.get(i), node.getOutputVariables().get(i)); + } + checkState(inputProperties.get(0).getStreamPropertiesFromUniqueColumn().isPresent(), + "when unique columns exists, the stream is also partitioned by the unique column and should be represented in the streamPropertiesFromUniqueColumn field"); + additionalUniqueProperty = Optional.of(inputProperties.get(0).getStreamPropertiesFromUniqueColumn().get().translate(column -> Optional.ofNullable(inputToOutput.get(column)))); + } + if (node.getScope().isRemote()) { // TODO: correctly determine if stream is parallelised // based on session properties - return StreamProperties.fixedStreams(); + return StreamProperties.fixedStreams().withStreamPropertiesFromUniqueColumn(additionalUniqueProperty); } switch (node.getType()) { case GATHER: - return StreamProperties.singleStream(); + return StreamProperties.singleStream().withStreamPropertiesFromUniqueColumn(additionalUniqueProperty); case REPARTITION: if (node.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION) || // no strict partitioning guarantees when multiple writers per partitions are allows (scaled writers) node.getPartitioningScheme().isScaleWriters()) { - return new StreamProperties(FIXED, Optional.empty(), false); + return new StreamProperties(FIXED, Optional.empty(), false).withStreamPropertiesFromUniqueColumn(additionalUniqueProperty); } checkArgument( node.getPartitioningScheme().getPartitioning().getArguments().stream().allMatch(VariableReferenceExpression.class::isInstance), @@ -358,7 +385,7 @@ public StreamProperties visitExchange(ExchangeNode node, List FIXED, Optional.of(node.getPartitioningScheme().getPartitioning().getArguments().stream() .map(VariableReferenceExpression.class::cast) - .collect(toImmutableList())), false); + .collect(toImmutableList())), false).withStreamPropertiesFromUniqueColumn(additionalUniqueProperty); case REPLICATE: return new StreamProperties(MULTIPLE, Optional.empty(), false); } @@ -447,6 +474,14 @@ public StreamProperties visitDelete(DeleteNode node, List inpu return properties.withUnspecifiedPartitioning(); } + @Override + public StreamProperties visitCallDistributedProcedure(CallDistributedProcedureNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + // call distributed procedure only outputs the row count + return properties.withUnspecifiedPartitioning(); + } + @Override public StreamProperties visitTableWriter(TableWriterNode node, List inputProperties) { @@ -462,6 +497,27 @@ public StreamProperties visitUpdate(UpdateNode node, List inpu return properties.withUnspecifiedPartitioning(); } + @Override + public StreamProperties visitMergeWriter(MergeWriterNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + + @Override + public StreamProperties visitMergeProcessor(MergeProcessorNode node, List inputProperties) + { + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + return properties.withUnspecifiedPartitioning(); + } + + @Override + public StreamProperties visitMetadataDelete(MetadataDeleteNode node, List inputProperties) + { + // MetadataDeleteNode runs on coordinator and outputs a single row count + return StreamProperties.singleStream(); + } + @Override public StreamProperties visitTableWriteMerge(TableWriterMergeNode node, List inputProperties) { @@ -555,6 +611,33 @@ public StreamProperties visitWindow(WindowNode node, List inpu return Iterables.getOnlyElement(inputProperties); } + @Override + public StreamProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public StreamProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + if (!node.getSource().isPresent()) { + return StreamProperties.singleStream(); // TODO allow multiple; return partitioning properties + } + + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + + Set passThroughInputs = Sets.intersection(ImmutableSet.copyOf(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()), ImmutableSet.copyOf(node.getOutputVariables())); + StreamProperties translatedProperties = properties.translate(column -> { + if (passThroughInputs.contains(column)) { + return Optional.of(column); + } + return Optional.empty(); + }); + // Mark as unordered since table functions have opaque logic that may reorder, generate, or filter rows + // even though partitioning properties are preserved for pass-through columns + return translatedProperties.unordered(true); + } + @Override public StreamProperties visitRowNumber(RowNumberNode node, List inputProperties) { @@ -660,21 +743,33 @@ public enum StreamDistribution // We are only interested in the local properties, but PropertyDerivations requires input // ActualProperties, so we hold on to the whole object - private final ActualProperties otherActualProperties; + private final Optional otherActualProperties; // NOTE: Partitioning on zero columns (or effectively zero columns if the columns are constant) indicates that all // the rows will be partitioned into a single stream. + private final Optional streamPropertiesFromUniqueColumn; + private StreamProperties(StreamDistribution distribution, Optional> partitioningColumns, boolean ordered) { - this(distribution, partitioningColumns, ordered, null); + this(distribution, partitioningColumns, ordered, Optional.empty()); } private StreamProperties( StreamDistribution distribution, Optional> partitioningColumns, boolean ordered, - ActualProperties otherActualProperties) + Optional otherActualProperties) + { + this(distribution, partitioningColumns, ordered, otherActualProperties, Optional.empty()); + } + + private StreamProperties( + StreamDistribution distribution, + Optional> partitioningColumns, + boolean ordered, + Optional otherActualProperties, + Optional streamPropertiesFromUniqueColumn) { this.distribution = requireNonNull(distribution, "distribution is null"); @@ -689,13 +784,38 @@ private StreamProperties( this.ordered = ordered; checkArgument(!ordered || distribution == SINGLE, "Ordered must be a single stream"); - this.otherActualProperties = otherActualProperties; + this.otherActualProperties = requireNonNull(otherActualProperties); + requireNonNull(streamPropertiesFromUniqueColumn).ifPresent(properties -> checkArgument(!properties.streamPropertiesFromUniqueColumn.isPresent())); + // When unique properties exists, the stream is also partitioned on the unique column + if (otherActualProperties.isPresent() && otherActualProperties.get().getPropertiesFromUniqueColumn().isPresent()) { + ActualProperties propertiesFromUniqueColumn = otherActualProperties.get().getPropertiesFromUniqueColumn().get(); + if (Iterables.getOnlyElement(propertiesFromUniqueColumn.getLocalProperties()) instanceof UniqueProperty) { + VariableReferenceExpression uniqueVariable = (VariableReferenceExpression) ((UniqueProperty) Iterables.getOnlyElement(propertiesFromUniqueColumn.getLocalProperties())).getColumn(); + checkState(streamPropertiesFromUniqueColumn.isPresent() && streamPropertiesFromUniqueColumn.get().partitioningColumns.isPresent() + && Iterables.getOnlyElement(streamPropertiesFromUniqueColumn.get().partitioningColumns.get()).equals(uniqueVariable)); + } + } + this.streamPropertiesFromUniqueColumn = streamPropertiesFromUniqueColumn; } public List> getLocalProperties() { - checkState(otherActualProperties != null, "otherActualProperties not set"); - return otherActualProperties.getLocalProperties(); + checkState(otherActualProperties.isPresent(), "otherActualProperties not set"); + return otherActualProperties.get().getLocalProperties(); + } + + public List> getAdditionalLocalProperties() + { + checkState(otherActualProperties.isPresent(), "otherActualProperties not set"); + if (!otherActualProperties.get().getPropertiesFromUniqueColumn().isPresent()) { + return ImmutableList.of(); + } + return otherActualProperties.get().getPropertiesFromUniqueColumn().get().getLocalProperties(); + } + + public Optional getStreamPropertiesFromUniqueColumn() + { + return streamPropertiesFromUniqueColumn; } private static StreamProperties singleStream() @@ -717,8 +837,9 @@ private StreamProperties unordered(boolean unordered) { if (unordered) { ActualProperties updatedProperties = null; - if (otherActualProperties != null) { - updatedProperties = ActualProperties.builderFrom(otherActualProperties) + if (otherActualProperties.isPresent()) { + updatedProperties = ActualProperties.builderFrom(otherActualProperties.get()) + .propertiesFromUniqueColumn(otherActualProperties.get().getPropertiesFromUniqueColumn().map(x -> ActualProperties.builderFrom(x).unordered(true).build())) .unordered(true) .build(); } @@ -726,11 +847,33 @@ private StreamProperties unordered(boolean unordered) distribution, partitioningColumns, false, - updatedProperties); + Optional.ofNullable(updatedProperties), + streamPropertiesFromUniqueColumn.map(x -> x.unordered(true))); + } + return this; + } + + public StreamProperties uniqueToGroupProperties() + { + if (otherActualProperties.isPresent() && otherActualProperties.get().getPropertiesFromUniqueColumn().isPresent()) { + if (Iterables.getOnlyElement(otherActualProperties.get().getPropertiesFromUniqueColumn().get().getLocalProperties()) instanceof UniqueProperty) { + Optional groupedProperties = PropertyDerivations.uniqueToGroupProperties(otherActualProperties.get().getPropertiesFromUniqueColumn().get()); + return new StreamProperties(distribution, partitioningColumns, ordered, + otherActualProperties.map(x -> ActualProperties.builderFrom(x).propertiesFromUniqueColumn(groupedProperties).build()), + streamPropertiesFromUniqueColumn.map(StreamProperties::uniqueToGroupProperties)); + } } return this; } + public boolean hasUniqueProperties() + { + if (otherActualProperties.isPresent() && otherActualProperties.get().getPropertiesFromUniqueColumn().isPresent()) { + return Iterables.getOnlyElement(otherActualProperties.get().getPropertiesFromUniqueColumn().get().getLocalProperties()) instanceof UniqueProperty; + } + return false; + } + public boolean isSingleStream() { return distribution == SINGLE; @@ -775,7 +918,12 @@ private StreamProperties withUnspecifiedPartitioning() private StreamProperties withOtherActualProperties(ActualProperties actualProperties) { - return new StreamProperties(distribution, partitioningColumns, ordered, actualProperties); + return new StreamProperties(distribution, partitioningColumns, ordered, Optional.ofNullable(actualProperties), streamPropertiesFromUniqueColumn); + } + + private StreamProperties withStreamPropertiesFromUniqueColumn(Optional streamPropertiesFromUniqueColumn) + { + return new StreamProperties(distribution, partitioningColumns, ordered, otherActualProperties, streamPropertiesFromUniqueColumn); } public StreamProperties translate(Function> translator) @@ -793,7 +941,8 @@ public StreamProperties translate(Function x.translateVariable(translator)), + streamPropertiesFromUniqueColumn.map(x -> x.translate(translator))); } public Optional> getPartitioningColumns() @@ -804,7 +953,7 @@ public Optional> getPartitioningColumns() @Override public int hashCode() { - return Objects.hash(distribution, partitioningColumns); + return Objects.hash(distribution, partitioningColumns, ordered, otherActualProperties, streamPropertiesFromUniqueColumn); } @Override @@ -818,7 +967,10 @@ public boolean equals(Object obj) } StreamProperties other = (StreamProperties) obj; return Objects.equals(this.distribution, other.distribution) && - Objects.equals(this.partitioningColumns, other.partitioningColumns); + Objects.equals(this.partitioningColumns, other.partitioningColumns) && + this.ordered == other.ordered && + Objects.equals(this.otherActualProperties, other.otherActualProperties) && + Objects.equals(this.streamPropertiesFromUniqueColumn, other.streamPropertiesFromUniqueColumn); } @Override @@ -827,6 +979,9 @@ public String toString() return toStringHelper(this) .add("distribution", distribution) .add("partitioningColumns", partitioningColumns) + .add("ordered", ordered) + .add("otherActualProperties", otherActualProperties) + .add("streamPropertiesFromUniqueColumn", streamPropertiesFromUniqueColumn) .toString(); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9805efad17939..824c392ce9619 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,8 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.ExchangeEncoding; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -36,7 +38,12 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -51,12 +58,14 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -110,6 +119,13 @@ public VariableReferenceExpression map(VariableReferenceExpression variable) return new VariableReferenceExpression(variable.getSourceLocation(), canonical, types.get(new SymbolReference(getNodeLocation(variable.getSourceLocation()), canonical))); } + public List map(List variableReferenceExpressions) + { + return variableReferenceExpressions.stream() + .map(this::map) + .collect(toImmutableList()); + } + public Expression map(Expression value) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() @@ -135,6 +151,27 @@ public RowExpression rewriteVariableReference(VariableReferenceExpression variab }, value); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newOrderings = ImmutableList.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + VariableReferenceExpression variable = orderingScheme.getOrderBy().get(i).getVariable(); + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + newOrderings.add(new Ordering(canonical, orderingScheme.getOrdering(variable))); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newOrderings.build()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { // SymbolMapper inlines symbol with multiple level reference (SymbolInliner only inline single level). @@ -262,6 +299,29 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new node.getIsTemporaryTableWriter()); } + public CallDistributedProcedureNode map(CallDistributedProcedureNode node, PlanNode source) + { + ImmutableList columns = node.getColumns().stream() + .map(this::map) + .collect(toImmutableList()); + Set notNullColumnVariables = node.getNotNullColumnVariables().stream() + .map(this::map) + .collect(toImmutableSet()); + + return new CallDistributedProcedureNode( + node.getSourceLocation(), + node.getId(), + source, + node.getTarget(), + node.getRowCountVariable(), + node.getFragmentVariable(), + node.getTableCommitContextVariable(), + columns, + columns.stream().map(VariableReferenceExpression::getName).collect(toImmutableList()), + notNullColumnVariables, + node.getPartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source))); + } + public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) { return new StatisticsWriterNode( @@ -274,6 +334,61 @@ public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) node.getDescriptor().map(this::map)); } + public MergeWriterNode map(MergeWriterNode node, PlanNode source) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + List newOutputs = map(node.getOutputVariables()); + + return new MergeWriterNode( + source.getSourceLocation(), + node.getId(), + source, + node.getTarget(), + map(node.getMergeProcessorProjectedVariables()), + newOutputs); + } + + public MergeWriterNode map(MergeWriterNode node, PlanNode source, PlanNodeId newId) + { + // Intentionally does not use mapAndDistinct on columns as that would remove columns + List newOutputs = map(node.getOutputVariables()); + + return new MergeWriterNode( + source.getSourceLocation(), + newId, + source, + node.getTarget(), + map(node.getMergeProcessorProjectedVariables()), + newOutputs); + } + + public MergeProcessorNode map(MergeProcessorNode node, PlanNode source) + { + List newOutputs = map(node.getOutputVariables()); + + return new MergeProcessorNode( + source.getSourceLocation(), + node.getId(), + source, + node.getTarget(), + map(node.getTargetTableRowIdColumnVariable()), + map(node.getMergeRowVariable()), + map(node.getTargetColumnVariables()), + newOutputs); + } + + public PartitioningScheme map(PartitioningScheme scheme, List sourceLayout) + { + return new PartitioningScheme( + translateVariable(scheme.getPartitioning(), this::map), + mapAndDistinctVariable(sourceLayout), + scheme.getHashColumn().map(this::map), + scheme.isReplicateNullsAndAny(), + scheme.isScaleWriters(), + ExchangeEncoding.COLUMNAR, + scheme.getBucketToPartition()); + } + public TableFinishNode map(TableFinishNode node, PlanNode source) { return new TableFinishNode( @@ -299,6 +414,68 @@ public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source) node.getStatisticsAggregation().map(this::map)); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + // rewrite and deduplicate pass-through specifications + // note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten + // to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences. + // For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism + // is more efficient for partitioning columns which are guaranteed to be constant within partition. + // TODO choose a partitioning column to be retrieved while deduplicating + ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder(); + Set newPassThroughVariables = new HashSet<>(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + ImmutableList.Builder newColumns = ImmutableList.builder(); + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + VariableReferenceExpression newVariable = map(column.getOutputVariables()); + if (newPassThroughVariables.add(newVariable)) { + newColumns.add(new TableFunctionNode.PassThroughColumn(newVariable, column.isPartitioningColumn())); + } + } + newPassThroughSpecifications.add(new TableFunctionNode.PassThroughSpecification(specification.isDeclaredAsPassThrough(), newColumns.build())); + } + + // rewrite required symbols without deduplication. the table function expects specific input layout + List> newRequiredVariables = node.getRequiredVariables().stream() + .map(list -> list.stream() + .map(this::map) + .collect(toImmutableList())) + .collect(toImmutableList()); + + // rewrite and deduplicate marker mapping + Optional> newMarkerVariables = node.getMarkerVariables() + .map(mapping -> mapping.entrySet().stream() + .collect(toImmutableMap( + entry -> map(entry.getKey()), + entry -> map(entry.getValue()), + (first, second) -> { + checkState(first.equals(second), "Ambiguous marker symbols: %s and %s", first, second); + return first; + }))); + + // rewrite and deduplicate specification + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs().stream() + .map(this::map) + .collect(toImmutableList()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications.build(), + newRequiredVariables, + newMarkerVariables, + newSpecification.map(SpecificationWithPreSortedPrefix::getSpecification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::getPreSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source) { return new PartitioningScheme(translateVariable(scheme.getPartitioning(), this::map), @@ -348,6 +525,25 @@ private List mapAndDistinctVariable(List newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getOrderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getPreSorted).orElse(preSorted)); + } + + DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + { + return new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + specification.getOrderingScheme().map(this::map)); + } + public static SymbolMapper.Builder builder(WarningCollector warningCollector) { return new Builder(warningCollector); @@ -379,4 +575,48 @@ public void put(VariableReferenceExpression from, VariableReferenceExpression to mappingsBuilder.put(from, to); } } + + private static class OrderingSchemeWithPreSortedPrefix + { + private final OrderingScheme orderingScheme; + private final int preSorted; + + public OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + + public OrderingScheme getOrderingScheme() + { + return orderingScheme; + } + + public int getPreSorted() + { + return preSorted; + } + } + + private static class SpecificationWithPreSortedPrefix + { + private final DataOrganizationSpecification specification; + private final int preSorted; + + public SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + + public DataOrganizationSpecification getSpecification() + { + return specification; + } + + public int getPreSorted() + { + return preSorted; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 51fa468a822a8..1c5e2e506f282 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -29,11 +29,13 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; @@ -50,6 +52,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; @@ -59,12 +62,14 @@ import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; @@ -72,15 +77,17 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -156,6 +163,11 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag this.warningCollector = warningCollector; } + public Map getMapping() + { + return mapping; + } + @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { @@ -460,6 +472,22 @@ public PlanNode visitUpdate(UpdateNode node, RewriteContext context) return new UpdateNode(node.getSourceLocation(), node.getId(), node.getSource(), canonicalize(node.getRowId()), node.getColumnValueAndRowIdSymbols(), node.getOutputVariables()); } + @Override + public PlanNode visitMergeWriter(MergeWriterNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return mapper.map(node, source); + } + + @Override + public PlanNode visitMergeProcessor(MergeProcessorNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return mapper.map(node, source); + } + @Override public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, RewriteContext context) { @@ -476,6 +504,81 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont return mapper.map(node, source); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map) + .collect(toImmutableList()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode newSource = context.rewrite(node.getSources().get(i)); + newSources.add(newSource); + + // Use the mapping state from after processing the source for the input properties + SymbolMapper inputMapper = new SymbolMapper(mapping, types, warningCollector); + + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + TableFunctionNode.PassThroughSpecification newPassThroughSpecification = new TableFunctionNode.PassThroughSpecification( + properties.getPassThroughSpecification().isDeclaredAsPassThrough(), + properties.getPassThroughSpecification().getColumns().stream() + .map(column -> new TableFunctionNode.PassThroughColumn( + inputMapper.map(column.getOutputVariables()), + column.isPartitioningColumn())) + .collect(toImmutableList())); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( + properties.getArgumentName(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + newPassThroughSpecification, + inputMapper.map(properties.getRequiredColumns()), + newSpecification)); + } + + return new TableFunctionNode( + node.getId(), + node.getName(), + node.getArguments(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), + node.getHandle()); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()); + } + PlanNode rewrittenSource = context.rewrite(node.getSource().get()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + + return mapper.map(node, rewrittenSource); + } + @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { @@ -636,7 +739,20 @@ public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext cont PlanNode left = context.rewrite(node.getLeft()); PlanNode right = context.rewrite(node.getRight()); - return new SpatialJoinNode(node.getSourceLocation(), node.getId(), node.getType(), left, right, canonicalizeAndDistinct(node.getOutputVariables()), canonicalize(node.getFilter()), canonicalize(node.getLeftPartitionVariable()), canonicalize(node.getRightPartitionVariable()), node.getKdbTree()); + return new SpatialJoinNode( + node.getSourceLocation(), + node.getId(), + node.getType(), + left, + right, + canonicalizeAndDistinct(node.getOutputVariables()), + canonicalize(node.getProbeGeometryVariable()), + canonicalize(node.getBuildGeometryVariable()), + canonicalize(node.getRadiusVariable()), + canonicalize(node.getFilter()), + canonicalize(node.getLeftPartitionVariable()), + canonicalize(node.getRightPartitionVariable()), + node.getKdbTree()); } @Override @@ -660,7 +776,8 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) canonicalizeIndexJoinCriteria(node.getCriteria()), node.getFilter().map(this::canonicalize), canonicalize(node.getProbeHashVariable()), - canonicalize(node.getIndexHashVariable())); + canonicalize(node.getIndexHashVariable()), + node.getLookupVariables()); } @Override @@ -690,6 +807,14 @@ private static ImmutableList.Builder rewriteSources(SetOperationNode n return rewrittenSources; } + @Override + public PlanNode visitCallDistributedProcedure(CallDistributedProcedureNode node, RewriteContext context) + { + PlanNode source = context.rewrite(node.getSource()); + SymbolMapper mapper = new SymbolMapper(mapping, types, warningCollector); + return mapper.map(node, source); + } + @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { @@ -706,6 +831,13 @@ public PlanNode visitTableWriteMerge(TableWriterMergeNode node, RewriteContext context) + { + // MetadataDeleteNode has no symbols to unalias, so return unchanged + return node; + } + @Override public PlanNode visitPlan(PlanNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java index 5d168fa6c7aeb..d2e67ca86d858 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ApplyNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java new file mode 100644 index 0000000000000..b07ec4ca4c2e6 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/CallDistributedProcedureNode.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class CallDistributedProcedureNode + extends InternalPlanNode +{ + private final PlanNode source; + private final Optional target; + private final VariableReferenceExpression rowCountVariable; + private final VariableReferenceExpression fragmentVariable; + private final VariableReferenceExpression tableCommitContextVariable; + private final List columns; + private final List columnNames; + private final Set notNullColumnVariables; + private final Optional partitioningScheme; + private final List outputs; + + @JsonCreator + public CallDistributedProcedureNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") Optional target, + @JsonProperty("rowCountVariable") VariableReferenceExpression rowCountVariable, + @JsonProperty("fragmentVariable") VariableReferenceExpression fragmentVariable, + @JsonProperty("tableCommitContextVariable") VariableReferenceExpression tableCommitContextVariable, + @JsonProperty("columns") List columns, + @JsonProperty("columnNames") List columnNames, + @JsonProperty("notNullColumnVariables") Set notNullColumnVariables, + @JsonProperty("partitioningScheme") Optional partitioningScheme) + { + this(sourceLocation, id, Optional.empty(), source, target, rowCountVariable, fragmentVariable, tableCommitContextVariable, columns, columnNames, notNullColumnVariables, partitioningScheme); + } + + public CallDistributedProcedureNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + Optional target, + VariableReferenceExpression rowCountVariable, + VariableReferenceExpression fragmentVariable, + VariableReferenceExpression tableCommitContextVariable, + List columns, + List columnNames, + Set notNullColumnVariables, + Optional partitioningScheme) + { + super(sourceLocation, id, statsEquivalentPlanNode); + + requireNonNull(columns, "columns is null"); + requireNonNull(columnNames, "columnNames is null"); + checkArgument(columns.size() == columnNames.size(), "columns and columnNames sizes don't match"); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.rowCountVariable = requireNonNull(rowCountVariable, "rowCountVariable is null"); + this.fragmentVariable = requireNonNull(fragmentVariable, "fragmentVariable is null"); + this.tableCommitContextVariable = requireNonNull(tableCommitContextVariable, "tableCommitContextVariable is null"); + this.columns = ImmutableList.copyOf(columns); + this.columnNames = ImmutableList.copyOf(columnNames); + this.notNullColumnVariables = ImmutableSet.copyOf(requireNonNull(notNullColumnVariables, "notNullColumns is null")); + this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + + ImmutableList.Builder outputs = ImmutableList.builder() + .add(rowCountVariable) + .add(fragmentVariable) + .add(tableCommitContextVariable); + this.outputs = outputs.build(); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonIgnore + public Optional getTarget() + { + return target; + } + + @JsonProperty + public VariableReferenceExpression getRowCountVariable() + { + return rowCountVariable; + } + + @JsonProperty + public VariableReferenceExpression getFragmentVariable() + { + return fragmentVariable; + } + + @JsonProperty + public VariableReferenceExpression getTableCommitContextVariable() + { + return tableCommitContextVariable; + } + + @JsonProperty + public Optional getPartitioningScheme() + { + return partitioningScheme; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + + @JsonProperty + public List getColumnNames() + { + return columnNames; + } + + @JsonProperty + public Set getNotNullColumnVariables() + { + return notNullColumnVariables; + } + + @JsonProperty + public List getOutputs() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public List getOutputVariables() + { + return outputs; + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitCallDistributedProcedure(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new CallDistributedProcedureNode( + this.getSourceLocation(), + getId(), + this.getStatsEquivalentPlanNode(), + Iterables.getOnlyElement(newChildren), + target, + rowCountVariable, + fragmentVariable, + tableCommitContextVariable, + columns, + columnNames, + notNullColumnVariables, + partitioningScheme); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new CallDistributedProcedureNode( + this.getSourceLocation(), + getId(), + statsEquivalentPlanNode, + source, + target, + rowCountVariable, + fragmentVariable, + tableCommitContextVariable, + columns, + columnNames, + notNullColumnVariables, + partitioningScheme); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java index bcf0c3f13fd2a..33318f3cd0ffd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/EnforceSingleRowNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java index ed2b2307dccbe..311076e818315 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExchangeNode.java @@ -21,11 +21,13 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.SystemPartitioningHandle; +import com.facebook.presto.sql.planner.SystemPartitioningHandle.SystemPartitionFunction; +import com.facebook.presto.sql.planner.SystemPartitioningHandle.SystemPartitioning; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; @@ -146,7 +148,6 @@ public ExchangeNode( orderingScheme.ifPresent(ordering -> { PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); - checkArgument(!scope.isRemote() || partitioningHandle.equals(SINGLE_DISTRIBUTION), "remote merging exchange requires single distribution"); checkArgument(!scope.isLocal() || partitioningHandle.equals(FIXED_PASSTHROUGH_DISTRIBUTION), "local merging exchange requires passthrough distribution"); checkArgument(partitioningScheme.getOutputLayout().containsAll(ordering.getOrderByVariables()), "Partitioning scheme does not supply all required ordering symbols"); }); @@ -261,6 +262,16 @@ public static ExchangeNode roundRobinExchange(PlanNodeId id, Scope scope, PlanNo new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), child.getOutputVariables())); } + public static ExchangeNode roundRobinExchange(PlanNodeId id, Scope scope, PlanNode child, int partitionCount) + { + checkArgument(partitionCount > 0, "partitionCount must be positive"); + return partitionedExchange( + id, + scope, + child, + new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.createSystemPartitioning(SystemPartitioning.FIXED, SystemPartitionFunction.ROUND_ROBIN, partitionCount), ImmutableList.of()), child.getOutputVariables())); + } + public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode child, OrderingScheme orderingScheme) { PartitioningHandle partitioningHandle = scope.isLocal() ? FIXED_PASSTHROUGH_DISTRIBUTION : SINGLE_DISTRIBUTION; @@ -276,6 +287,24 @@ public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode Optional.of(orderingScheme)); } + /** + * Creates an exchange node that performs sorting during the shuffle operation. + * This is used for merge joins where we want to push down sorting to the exchange layer. + */ + public static ExchangeNode sortedPartitionedExchange(PlanNodeId id, Scope scope, PlanNode child, Partitioning partitioning, Optional hashColumn, OrderingScheme sortOrder) + { + return new ExchangeNode( + child.getSourceLocation(), + id, + REPARTITION, + scope, + new PartitioningScheme(partitioning, child.getOutputVariables(), hashColumn, false, false, COLUMNAR, Optional.empty()), + ImmutableList.of(child), + ImmutableList.of(child.getOutputVariables()), + true, // Ensure source ordering since we're sorting + Optional.of(sortOrder)); + } + @JsonProperty public Type getType() { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java index 26c3260154e01..a438fb48c3079 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/ExplainAnalyzeNode.java @@ -22,8 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java index b77a1fd2e7194..0038b80d6da77 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/GroupIdNode.java @@ -24,8 +24,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.Collection; import java.util.HashSet; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index 43608a98ec271..6bfa05d32a9a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.plan; +import com.facebook.presto.spi.plan.MergeJoinNode; import com.facebook.presto.spi.plan.PlanVisitor; import com.facebook.presto.sql.planner.CanonicalJoinNode; import com.facebook.presto.sql.planner.CanonicalTableScanNode; @@ -37,22 +38,27 @@ public R visitExplainAnalyze(ExplainAnalyzeNode node, C context) return visitPlan(node, context); } - public R visitIndexJoin(IndexJoinNode node, C context) + public R visitMergeJoin(MergeJoinNode node, C context) { return visitPlan(node, context); } - public R visitOffset(OffsetNode node, C context) + public R visitMergeWriter(MergeWriterNode node, C context) { return visitPlan(node, context); } - public R visitTableWriteMerge(TableWriterMergeNode node, C context) + public R visitMergeProcessor(MergeProcessorNode node, C context) + { + return visitPlan(node, context); + } + + public R visitOffset(OffsetNode node, C context) { return visitPlan(node, context); } - public R visitMetadataDelete(MetadataDeleteNode node, C context) + public R visitTableWriteMerge(TableWriterMergeNode node, C context) { return visitPlan(node, context); } @@ -67,7 +73,7 @@ public R visitStatisticsWriterNode(StatisticsWriterNode node, C context) return visitPlan(node, context); } - public R visitUnnest(UnnestNode node, C context) + public R visitCallDistributedProcedure(CallDistributedProcedureNode node, C context) { return visitPlan(node, context); } @@ -136,4 +142,14 @@ public R visitSequence(SequenceNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunction(TableFunctionNode node, C context) + { + return visitPlan(node, context); + } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java index c9105843adbee..6163d6737d93c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/LateralJoinNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeProcessorNode.java new file mode 100644 index 0000000000000..8f4352a4a8bb2 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeProcessorNode.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.MergeTarget; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * The node processes the result of the Searched CASE and RIGHT JOIN + * derived from a MERGE statement. + */ +public class MergeProcessorNode + extends InternalPlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final VariableReferenceExpression targetTableRowIdColumnVariable; + private final VariableReferenceExpression mergeRowVariable; + private final List targetColumnVariables; + private final List outputs; + + @JsonCreator + public MergeProcessorNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("targetTableRowIdColumnVariable") VariableReferenceExpression targetTableRowIdColumnVariable, + @JsonProperty("mergeRowVariable") VariableReferenceExpression mergeRowVariable, + @JsonProperty("targetColumnVariables") List targetColumnVariables, + @JsonProperty("outputs") List outputs) + { + this(sourceLocation, id, Optional.empty(), source, target, targetTableRowIdColumnVariable, mergeRowVariable, targetColumnVariables, outputs); + } + + public MergeProcessorNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + MergeTarget target, + VariableReferenceExpression targetTableRowIdColumnVariable, + VariableReferenceExpression mergeRowVariable, + List targetColumnVariables, + List outputs) + { + super(sourceLocation, id, statsEquivalentPlanNode); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.mergeRowVariable = requireNonNull(mergeRowVariable, "mergeRowVariable is null"); + this.targetTableRowIdColumnVariable = requireNonNull(targetTableRowIdColumnVariable, "targetTableRowIdColumnVariable is null"); + this.targetColumnVariables = requireNonNull(targetColumnVariables, "targetColumnVariables is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty + public VariableReferenceExpression getMergeRowVariable() + { + return mergeRowVariable; + } + + @JsonProperty + public VariableReferenceExpression getTargetTableRowIdColumnVariable() + { + return targetTableRowIdColumnVariable; + } + + @JsonProperty + public List getTargetColumnVariables() + { + return targetColumnVariables; + } + + @JsonProperty("outputs") + @Override + public List getOutputVariables() + { + return outputs; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitMergeProcessor(this, context); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new MergeProcessorNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), + target, targetTableRowIdColumnVariable, mergeRowVariable, targetColumnVariables, outputs); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new MergeProcessorNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, target, + targetTableRowIdColumnVariable, mergeRowVariable, targetColumnVariables, outputs); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeWriterNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeWriterNode.java new file mode 100644 index 0000000000000..73580e7a546cb --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/MergeWriterNode.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.TableWriterNode.MergeTarget; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.errorprone.annotations.Immutable; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@Immutable +public class MergeWriterNode + extends InternalPlanNode +{ + private final PlanNode source; + private final MergeTarget target; + private final List mergeProcessorProjectedVariables; + private final List outputs; + + @JsonCreator + public MergeWriterNode( + Optional sourceLocation, + @JsonProperty("id") PlanNodeId id, + @JsonProperty("source") PlanNode source, + @JsonProperty("target") MergeTarget target, + @JsonProperty("mergeProcessorProjectedVariables") List mergeProcessorProjectedVariables, + @JsonProperty("outputs") List outputs) + { + this(sourceLocation, id, Optional.empty(), source, target, mergeProcessorProjectedVariables, outputs); + } + + public MergeWriterNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + PlanNode source, + MergeTarget target, + List mergeProcessorProjectedVariables, + List outputs) + { + super(sourceLocation, id, statsEquivalentPlanNode); + + this.source = requireNonNull(source, "source is null"); + this.target = requireNonNull(target, "target is null"); + this.mergeProcessorProjectedVariables = requireNonNull(mergeProcessorProjectedVariables, "mergeProcessorProjectedVariables is null"); + this.outputs = ImmutableList.copyOf(requireNonNull(outputs, "outputs is null")); + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public MergeTarget getTarget() + { + return target; + } + + @JsonProperty + public List getMergeProcessorProjectedVariables() + { + return mergeProcessorProjectedVariables; + } + + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + @JsonProperty("outputs") + public List getOutputVariables() + { + return outputs; + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitMergeWriter(this, context); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new MergeWriterNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, target, + mergeProcessorProjectedVariables, outputs); + } + + @Override + public PlanNode replaceChildren(List newChildren) + { + return new MergeWriterNode(getSourceLocation(), getId(), Iterables.getOnlyElement(newChildren), + target, mergeProcessorProjectedVariables, outputs); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/OffsetNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/OffsetNode.java index b785078b1497c..2d9a3060d27a1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/OffsetNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/OffsetNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index 21a54b9b325a6..1db8a8b817eb5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MaterializedViewScanNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; @@ -38,6 +39,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; @@ -133,6 +135,11 @@ public static Pattern markDistinct() return typeOf(MarkDistinctNode.class); } + public static Pattern materializedViewScan() + { + return typeOf(MaterializedViewScanNode.class); + } + public static Pattern output() { return typeOf(OutputNode.class); @@ -203,6 +210,11 @@ public static Pattern tableWriterMergeNode() return typeOf(TableWriterMergeNode.class); } + public static Pattern mergeWriter() + { + return typeOf(MergeWriterNode.class); + } + public static Pattern topN() { return typeOf(TopNNode.class); @@ -228,6 +240,16 @@ public static Pattern window() return typeOf(WindowNode.class); } + public static Pattern tableFunction() + { + return typeOf(TableFunctionNode.class); + } + + public static Pattern tableFunctionProcessor() + { + return typeOf(TableFunctionProcessorNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java index f5e980a4470b7..b9de1a0673733 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RemoteSourceNode.java @@ -23,8 +23,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java index 29b8781c7ad65..ec97ad42ee26c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/RowNumberNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java index 636754a57e372..ca75ac588d083 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SampleNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java new file mode 100644 index 0000000000000..8838e82b48c91 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -0,0 +1,289 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.SourceLocation; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.Immutable; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +@Immutable +public class TableFunctionNode + extends InternalPlanNode +{ + private final String name; + private final Map arguments; + private final List outputVariables; + private final List sources; + private final List tableArgumentProperties; + private final List> copartitioningLists; + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("arguments") Map arguments, + @JsonProperty("outputVariables") List outputVariables, + @JsonProperty("sources") List sources, + @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, + @JsonProperty("handle") TableFunctionHandle handle) + { + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); + } + + public TableFunctionNode( + Optional sourceLocation, + PlanNodeId id, + Optional statsEquivalentPlanNode, + String name, + Map arguments, + List outputVariables, + List sources, + List tableArgumentProperties, + List> copartitioningLists, + TableFunctionHandle handle) + { + super(sourceLocation, id, statsEquivalentPlanNode); + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Map getArguments() + { + return arguments; + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + variables.addAll(outputVariables); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() + { + return outputVariables; + } + + @JsonProperty + public List getTableArgumentProperties() + { + return tableArgumentProperties; + } + + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return sources; + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunction(this, context); + } + + @Override + public PlanNode replaceChildren(List newSources) + { + checkArgument(sources.size() == newSources.size(), "wrong number of new children"); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); + } + + public static class TableArgumentProperties + { + private final String argumentName; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; + private final Optional specification; + + @JsonCreator + public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, + @JsonProperty("rowSemantics") boolean rowSemantics, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, + @JsonProperty("specification") Optional specification) + { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); + this.specification = requireNonNull(specification, "specification is null"); + } + + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + + @JsonProperty + public boolean isRowSemantics() + { + return rowSemantics; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() + { + return requiredColumns; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + } + + /** + * Specifies how columns from source tables are passed through to the output of a table function. + * This class manages both explicitly declared pass-through columns and partitioning columns + * that must be preserved in the output. + */ + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression outputVariables; + private final boolean isPartitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("outputVariables") VariableReferenceExpression outputVariables, + @JsonProperty("partitioningColumn") boolean isPartitioningColumn) + { + this.outputVariables = requireNonNull(outputVariables, "symbol is null"); + this.isPartitioningColumn = isPartitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getOutputVariables() + { + return outputVariables; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 0000000000000..851f776a2c90f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends InternalPlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredVariables; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + // + // Example: + // Given two input tables T1(a,b) PARTITION BY a and T2(c, d) PARTITION BY c + // T1 partitions: T2 partitions: + // a | b c | d + // ---+--- ---+--- + // 1 | 10 5 | 50 + // 1 | 20 5 | 60 + // 1 | 30 6 | 90 + // 2 | 40 6 | 100 + // 2 | 50 6 | 110 + // + // TransformTableFunctionToTableFunctionProcessor creates a join that produces a cartesian product of partitions from each table, resulting in 4 partitions: + // + // Partition (a=1, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 5 | 50 | null (filler row for T2, real row 3 from T1) + // + // Partition (a=1, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 1 | 10 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 1 | 20 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 1 | 30 | 3 | 6 | 110 | 3 (row 3 from both partitions) + // + // Partition (a=2, c=5): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 5 | 50 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 5 | 60 | 2 (row 2 from both partitions) + // + // Partition (a=2, c=6): + // a | b | marker_1 | c | d | marker_2 + // -----+------+----------+----+-----+---------- + // 2 | 40 | 1 | 6 | 90 | 1 (row 1 from both partitions) + // 2 | 50 | 2 | 6 | 100 | 2 (row 2 from both partitions) + // 2 | 40 | null | 6 | 110 | 3 (filler row for T1, real row 3 from T2) + // + // markerVariables map: + // { + // VariableReferenceExpression(a) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(b) -> VariableReferenceExpression(marker_1), + // VariableReferenceExpression(c) -> VariableReferenceExpression(marker_2), + // VariableReferenceExpression(d) -> VariableReferenceExpression(marker_2) + // } + // + // When marker_1 is null, columns a and b should not be processed or passed-through. + // When marker_2 is null, columns c and d should not be processed or passed-through. + + private final Optional> markerVariables; + + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredVariables") List> requiredVariables, + @JsonProperty("markerVariables") Optional> markerVariables, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(Optional.empty(), id, Optional.empty()); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredVariables = requiredVariables.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerVariables = markerVariables.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredVariables() + { + return requiredVariables; + } + + @JsonProperty + public Optional> getMarkerVariables() + { + return markerVariables; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return this; + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredVariables, markerVariables, specification, prePartitioned, preSorted, hashSymbol, handle); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index d7e66afb664e4..f5fc3bde590a6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -23,8 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/UpdateNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/UpdateNode.java index 7a50e3aa0da05..2494cd0547054 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/UpdateNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/UpdateNode.java @@ -21,8 +21,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; - -import javax.annotation.concurrent.Immutable; +import com.google.errorprone.annotations.Immutable; import java.util.List; import java.util.Optional; @@ -34,7 +33,7 @@ public class UpdateNode extends InternalPlanNode { private final PlanNode source; - private final VariableReferenceExpression rowId; + private final Optional rowId; private final List columnValueAndRowIdSymbols; private final List outputVariables; @@ -43,7 +42,7 @@ public UpdateNode( Optional sourceLocation, @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, - @JsonProperty("rowId") VariableReferenceExpression rowId, + @JsonProperty("rowId") Optional rowId, @JsonProperty("columnValueAndRowIdSymbols") List columnValueAndRowIdSymbols, @JsonProperty("outputVariables") List outputVariables) { @@ -55,7 +54,7 @@ public UpdateNode( PlanNodeId id, Optional statsEquivalentPlanNode, PlanNode source, - VariableReferenceExpression rowId, + Optional rowId, List columnValueAndRowIdSymbols, List outputVariables) { @@ -74,7 +73,7 @@ public PlanNode getSource() } @JsonProperty - public VariableReferenceExpression getRowId() + public Optional getRowId() { return rowId; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/HashCollisionPlanNodeStats.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/HashCollisionPlanNodeStats.java index fc2d099ef5189..6fff4cf2b15a4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/HashCollisionPlanNodeStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/HashCollisionPlanNodeStats.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.operator.DynamicFilterStats; import com.facebook.presto.spi.plan.PlanNodeId; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.Map; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/IOPlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/IOPlanPrinter.java index cae2a0f85f93b..a16d72f732d0c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/IOPlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/IOPlanPrinter.java @@ -20,8 +20,11 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; import com.facebook.presto.common.type.IntegerType; import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TimestampWithTimeZoneType; import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; @@ -55,6 +58,9 @@ import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.common.predicate.Marker.Bound.EXACTLY; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.util.DateTimeUtils.printDate; +import static com.facebook.presto.util.DateTimeUtils.printTimestampWithTimeZone; +import static com.facebook.presto.util.DateTimeUtils.printTimestampWithoutTimeZone; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -565,6 +571,17 @@ private String getVarcharValue(Type type, Object value) if (type instanceof BooleanType) { return ((Boolean) value).toString(); } + if (type instanceof TimestampType) { + TimestampType timestampType = (TimestampType) type; + long timestampValue = timestampType.getPrecision().toMillis((Long) value); + return printTimestampWithoutTimeZone(timestampValue); + } + if (type instanceof TimestampWithTimeZoneType) { + return printTimestampWithTimeZone((Long) value); + } + if (type instanceof DateType) { + return printDate(((Long) value).intValue()); + } throw new PrestoException(NOT_SUPPORTED, format("Unsupported data type in EXPLAIN (TYPE IO): %s", type.getDisplayName())); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java index 676231c237053..0a61a39ffc0fe 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStats.java @@ -13,22 +13,22 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.operator.DynamicFilterStats; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.util.Mergeable; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.Collections; import java.util.Map; import java.util.Optional; import java.util.Set; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.util.MoreMaps.mergeMaps; import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.max; import static java.lang.Math.sqrt; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java index ec8d8d8e63fdc..17b8fce3ebf48 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanNodeStatsSummarizer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.StageInfo; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.operator.DynamicFilterStats; @@ -23,7 +24,6 @@ import com.facebook.presto.operator.WindowInfo; import com.facebook.presto.spi.plan.PlanNodeId; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -33,11 +33,11 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.succinctDataSize; import static com.facebook.presto.util.MoreMaps.mergeMaps; import static com.google.common.collect.Iterables.getLast; import static com.google.common.collect.Lists.reverse; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.succinctDataSize; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.stream.Collectors.toList; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 4532ce5d1b807..c35f28525bb09 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.predicate.Domain; import com.facebook.presto.common.predicate.Range; @@ -34,6 +35,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -45,12 +49,14 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; @@ -67,8 +73,10 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; @@ -82,22 +90,25 @@ import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; @@ -107,31 +118,35 @@ import com.google.common.base.Functions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Streams; import io.airlift.slice.Slice; -import io.airlift.units.Duration; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SystemSessionProperties.isVerboseOptimizerInfoEnabled; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -145,11 +160,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.succinctBytes; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -434,6 +450,9 @@ private static String formatFragment( Joiner.on(", ").join(partitioningScheme.getPartitioning().getArguments()), formatHash(partitioningScheme.getHashColumn()))); } + if (fragment.getOutputOrderingScheme().isPresent()) { + builder.append(indentString(1)).append(format("Output ordering: %s%n", fragment.getOutputOrderingScheme())); + } builder.append(indentString(1)).append(format("Output encoding: %s%n", fragment.getPartitioningScheme().getEncoding())); builder.append(indentString(1)).append(format("Stage Execution Strategy: %s%n", fragment.getStageExecutionDescriptor().getStageExecutionStrategy())); @@ -465,6 +484,7 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types, Stat SINGLE_DISTRIBUTION, ImmutableList.of(plan.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(estimatedStatsAndCosts), @@ -587,6 +607,10 @@ public Void visitIndexSource(IndexSourceNode node, Void context) nodeOutput.appendDetailsLine("TableHandle: %s", node.getTableHandle().getConnectorHandle().toString()); + Optional tableLayoutHandle = node.getTableHandle().getLayout(); + tableLayoutHandle.ifPresent( + connectorTableLayoutHandle -> nodeOutput.appendDetailsLine("TableHandleLayout: %s", connectorTableLayoutHandle.toString())); + for (Map.Entry entry : node.getAssignments().entrySet()) { if (node.getOutputVariables().contains(entry.getKey())) { nodeOutput.appendDetailsLine("%s := %s%s", entry.getKey(), entry.getValue(), formatSourceLocation(entry.getKey().getSourceLocation())); @@ -1126,9 +1150,15 @@ public Void visitSort(SortNode node, Void context) @Override public Void visitRemoteSource(RemoteSourceNode node, Void context) { + String nodeName = "RemoteSource"; + String orderingSchemStr = ""; + if (node.getOrderingScheme().isPresent()) { + orderingSchemStr = node.getOrderingScheme().toString(); + nodeName = "RemoteMerge"; + } addNode(node, - format("Remote%s", node.getOrderingScheme().isPresent() ? "Merge" : "Source"), - format("[%s]", Joiner.on(',').join(node.getSourceFragmentIds())), + format("%s", nodeName), + format("[%s] %s", Joiner.on(',').join(node.getSourceFragmentIds()), orderingSchemStr), ImmutableList.of(), ImmutableList.of(), node.getSourceFragmentIds()); @@ -1193,6 +1223,13 @@ public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) return processChildren(node, context); } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + addNode(node, "CallDistributedProcedure", format("[%s]", node.getTarget().map(CallDistributedProcedureTarget::getProcedureName).orElse(null))); + return processChildren(node, context); + } + @Override public Void visitTableFinish(TableFinishNode node, Void context) { @@ -1267,6 +1304,27 @@ public Void visitMetadataDelete(MetadataDeleteNode node, Void context) return processChildren(node, context); } + @Override + public Void visitMergeWriter(MergeWriterNode node, Void context) + { + addNode(node, "MergeWriter", format("table: %s", node.getTarget().toString())); + return processChildren(node, context); + } + + @Override + public Void visitMergeProcessor(MergeProcessorNode node, Void context) + { + String identifier = format("[target: %s, output: %s]", node.getTarget(), node.getOutputVariables()); + + NodeRepresentation nodeOutput = addNode(node, "MergeProcessor", identifier); + nodeOutput.appendDetails("target: %s", node.getTarget()); + nodeOutput.appendDetails("merge row column: %s", node.getMergeRowVariable()); + nodeOutput.appendDetails("target table row id column: %s", node.getTargetTableRowIdColumnVariable()); + nodeOutput.appendDetails("data columns: %s", node.getTargetColumnVariables()); + + return processChildren(node, context); + } + @Override public Void visitEnforceSingleRow(EnforceSingleRowNode node, Void context) { @@ -1308,6 +1366,189 @@ public Void visitLateralJoin(LateralJoinNode node, Void context) return processChildren(node, context); } + @Override + public Void visitTableFunction(TableFunctionNode node, Void context) + { + NodeRepresentation nodeOutput = addNode( + node, + "TableFunction", + node.getName()); + + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); + + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetailsLine(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetailsLine(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", ", "Co-partition: [", "] "))); + } + } + + processChildren(node, context); + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(Collectors.joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + List properties = new ArrayList<>(); + + if (argumentProperties.isRowSemantics()) { + properties.add("row semantics "); + } + argumentProperties.getSpecification().ifPresent(specification -> { + StringBuilder specificationBuilder = new StringBuilder(); + specificationBuilder + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + specificationBuilder + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + properties.add(specificationBuilder.toString()); + }); + + properties.add("required columns: [" + + Joiner.on(", ").join(argumentProperties.getRequiredColumns()) + "]"); + + if (argumentProperties.isPruneWhenEmpty()) { + properties.add("prune when empty"); + } + + if (argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + properties.add("pass through columns"); + } + + return format("%s => TableArgument{%s}", argumentName, Joiner.on(", ").join(properties)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); + } + + public String formatCollection(Collection collection, Function formatter) + { + return collection.stream() + .map(formatter) + .collect(Collectors.joining(", ", "[", "]")); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(node.getProperOutputs()))); + + String specs = node.getPassThroughSpecifications().stream() + .map(spec -> spec.getColumns().stream() + .map(col -> col.getOutputVariables().toString()) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ")); + descriptor.put("passThroughSymbols", format("[%s]", specs)); + + String requiredSymbols = node.getRequiredVariables().stream() + .map(vars -> vars.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ", "[", "]")); + descriptor.put("requiredSymbols", format("[%s]", requiredSymbols)); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(prePartitioned.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(notPrePartitioned)); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionProcessorNode" + descriptor.build()); + + return processChildren(node, context); + } + @Override public Void visitPlan(PlanNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanRepresentation.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanRepresentation.java index 3e1d41e99a1f4..167a2a15a8800 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanRepresentation.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanRepresentation.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.eventlistener.CTEInformation; import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.optimizations.OptimizerResult; -import io.airlift.units.Duration; import java.util.HashMap; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java index 1f7330a742153..2809b8145831a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/TextRenderer.java @@ -29,8 +29,8 @@ import java.util.Optional; import java.util.Set; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; import static java.lang.String.format; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowPlanNodeStats.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowPlanNodeStats.java index f373525c004f3..512a580c597d1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowPlanNodeStats.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/WindowPlanNodeStats.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.sql.planner.planPrinter; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.operator.DynamicFilterStats; import com.facebook.presto.spi.plan.PlanNodeId; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.Map; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java new file mode 100644 index 0000000000000..084cb6e7048b7 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CallDistributedProcedureValidator.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.OutputNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableFinishNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; +import com.facebook.presto.sql.planner.plan.ExchangeNode; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; + +public final class CallDistributedProcedureValidator + implements PlanChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, WarningCollector warningCollector) + { + Optional callDistributedProcedureNode = searchFrom(planNode) + .where(node -> node instanceof CallDistributedProcedureNode) + .findFirst(); + + if (!callDistributedProcedureNode.isPresent()) { + // not a call distributed procedure plan + return; + } + + searchFrom(planNode) + .findAll() + .forEach(node -> { + if (!isAllowedNode(node)) { + throw new IllegalStateException("Unexpected " + node.getClass().getSimpleName() + " found in plan; probably connector was not able to handle provided WHERE expression"); + } + }); + } + + private boolean isAllowedNode(PlanNode node) + { + return node instanceof TableScanNode + || node instanceof ValuesNode + || node instanceof ProjectNode + || node instanceof CallDistributedProcedureNode + || node instanceof OutputNode + || node instanceof ExchangeNode + || node instanceof TableFinishNode; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java new file mode 100644 index 0000000000000..a2abbe91c84a9 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.SimplePlanVisitor; +import com.facebook.presto.sql.planner.plan.ExchangeNode; + +import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * Validates that there are no filter or projection nodes containing non-Java functions + * (which must be evaluated on native nodes) within the same fragment as a system table scan + * (which must be evaluated on the coordinator). + */ +public class CheckNoIneligibleFunctionsInCoordinatorFragments + implements PlanChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, WarningCollector warningCollector) + { + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + // Validate each fragment independently + validateFragment(planNode, functionAndTypeManager); + } + + private void validateFragment(PlanNode root, FunctionAndTypeManager functionAndTypeManager) + { + // First, collect information about this fragment + FragmentValidator validator = new FragmentValidator(functionAndTypeManager); + root.accept(validator, null); + + // Check if this fragment violates the constraint + checkState( + !(validator.hasSystemTableScan() && validator.hasNonCoordinatorEligibleFunction()), + "Fragment contains both system table scan and non-Java functions. " + + "System table scans must execute on the coordinator while non-Java functions must execute on native nodes. " + + "These operations must be in separate fragments separated by an exchange."); + + // Recursively validate child fragments + ChildFragmentVisitor childVisitor = new ChildFragmentVisitor(functionAndTypeManager); + root.accept(childVisitor, null); + } + + /** + * Visits nodes within a single fragment to collect information about + * system table scans and non-coordinator-eligible functions. + * Stops at exchange boundaries. + */ + private static class FragmentValidator + extends SimplePlanVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + private boolean hasSystemTableScan; + private boolean hasNonCoordinatorEligibleFunction; + + public FragmentValidator(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + public boolean hasSystemTableScan() + { + return hasSystemTableScan; + } + + public boolean hasNonCoordinatorEligibleFunction() + { + return hasNonCoordinatorEligibleFunction; + } + + @Override + public Void visitExchange(ExchangeNode node, Void context) + { + // Don't traverse into exchange sources - they are different fragments + return null; + } + + @Override + public Void visitTableScan(TableScanNode node, Void context) + { + if (containsSystemTableScan(node)) { + hasSystemTableScan = true; + } + return null; + } + + @Override + public Void visitFilter(FilterNode node, Void context) + { + RowExpression predicate = node.getPredicate(); + if (containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, predicate)) { + hasNonCoordinatorEligibleFunction = true; + } + return visitPlan(node, context); + } + + @Override + public Void visitProject(ProjectNode node, Void context) + { + boolean hasIneligibleProjections = node.getAssignments().getExpressions().stream() + .anyMatch(expression -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, expression)); + + if (hasIneligibleProjections) { + hasNonCoordinatorEligibleFunction = true; + } + return visitPlan(node, context); + } + + @Override + public Void visitPlan(PlanNode node, Void context) + { + for (PlanNode source : node.getSources()) { + source.accept(this, context); + } + return null; + } + } + + /** + * Visits nodes to find and validate child fragments (those below exchanges). + */ + private class ChildFragmentVisitor + extends SimplePlanVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + + public ChildFragmentVisitor(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + @Override + public Void visitExchange(ExchangeNode node, Void context) + { + // Each source of an exchange is a separate fragment + for (PlanNode source : node.getSources()) { + validateFragment(source, functionAndTypeManager); + } + return null; + } + + @Override + public Void visitPlan(PlanNode node, Void context) + { + for (PlanNode source : node.getSources()) { + source.accept(this, context); + } + return null; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java index 310a4bb09e0da..0f38868fc5c4d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java @@ -23,8 +23,7 @@ import com.facebook.presto.sql.planner.PlanFragment; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.Multimap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.requireNonNull; @@ -74,10 +73,14 @@ public PlanChecker(FeaturesConfig featuresConfig, boolean noExchange, PlanChecke new VerifyNoIntermediateFormExpression(), new VerifyProjectionLocality(), new DynamicFiltersChecker(), - new WarnOnScanWithoutPartitionPredicate(featuresConfig)); - if (featuresConfig.isNativeExecutionEnabled() && (featuresConfig.isDisableTimeStampWithTimeZoneForNative() || - featuresConfig.isDisableIPAddressForNative())) { - builder.put(Stage.INTERMEDIATE, new CheckUnsupportedPrestissimoTypes(featuresConfig)); + new WarnOnScanWithoutPartitionPredicate(featuresConfig), + new CallDistributedProcedureValidator()); + if (featuresConfig.isNativeExecutionEnabled()) { + if (featuresConfig.isDisableTimeStampWithTimeZoneForNative() || + featuresConfig.isDisableIPAddressForNative()) { + builder.put(Stage.INTERMEDIATE, new CheckUnsupportedPrestissimoTypes(featuresConfig)); + } + builder.put(Stage.FINAL, new CheckNoIneligibleFunctionsInCoordinatorFragments()); } checkers = builder.build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManager.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManager.java index 1349a1dcd301c..fe08425d3ffd6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManager.java @@ -31,6 +31,7 @@ import java.util.concurrent.CopyOnWriteArrayList; import static com.facebook.presto.util.PropertiesUtil.loadProperties; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Strings.isNullOrEmpty; import static java.lang.String.format; @@ -65,29 +66,38 @@ public void addPlanCheckerProviderFactory(PlanCheckerProviderFactory planChecker public void loadPlanCheckerProviders(NodeManager nodeManager) throws IOException { - PlanCheckerProviderContext planCheckerProviderContext = new PlanCheckerProviderContext(simplePlanFragmentSerde, nodeManager); - for (File file : listFiles(configDirectory)) { if (file.isFile() && file.getName().endsWith(".properties")) { // unlike function namespaces and connectors, we don't have a concept of catalog // name here (conventionally config file name without the extension) // because plan checkers are never referenced by name. Map properties = new HashMap<>(loadProperties(file)); - checkState(!isNullOrEmpty(properties.get(PLAN_CHECKER_PROVIDER_NAME)), + String planCheckerProviderName = properties.remove(PLAN_CHECKER_PROVIDER_NAME); + checkState(!isNullOrEmpty(planCheckerProviderName), "Plan checker configuration %s does not contain %s", file.getAbsoluteFile(), PLAN_CHECKER_PROVIDER_NAME); - String planCheckerProviderName = properties.remove(PLAN_CHECKER_PROVIDER_NAME); - log.info("-- Loading plan checker provider [%s] --", planCheckerProviderName); - PlanCheckerProviderFactory providerFactory = providerFactories.get(planCheckerProviderName); - checkState(providerFactory != null, - "No planCheckerProviderFactory found for '%s'. Available factories were %s", planCheckerProviderName, providerFactories.keySet()); - providers.addIfAbsent(providerFactory.create(properties, planCheckerProviderContext)); - log.info("-- Added plan checker provider [%s] --", planCheckerProviderName); + loadPlanCheckerProvider(planCheckerProviderName, properties, nodeManager); } } } + public void loadPlanCheckerProvider(String planCheckerProviderName, Map properties, NodeManager nodeManager) + { + checkArgument(!isNullOrEmpty(planCheckerProviderName), "Plan checker provider name is null or empty"); + requireNonNull(properties, "properties is null"); + requireNonNull(nodeManager, "nodeManager is null"); + + PlanCheckerProviderContext planCheckerProviderContext = new PlanCheckerProviderContext(simplePlanFragmentSerde, nodeManager); + + log.info("-- Loading plan checker provider [%s] --", planCheckerProviderName); + PlanCheckerProviderFactory providerFactory = providerFactories.get(planCheckerProviderName); + checkState(providerFactory != null, + "No planCheckerProviderFactory found for '%s'. Available factories were %s", planCheckerProviderName, providerFactories.keySet()); + providers.addIfAbsent(providerFactory.create(properties, planCheckerProviderContext)); + log.info("-- Added plan checker provider [%s] --", planCheckerProviderName); + } + private static List listFiles(File dir) { if (dir != null && dir.isDirectory()) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManagerConfig.java index d356f433d75ef..99d5a46e59126 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanCheckerProviderManagerConfig.java @@ -15,8 +15,7 @@ package com.facebook.presto.sql.planner.sanity; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; import java.io.File; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index ecc264a2518af..1a88f259c882e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -26,12 +26,15 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MaterializedViewScanNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; @@ -46,6 +49,7 @@ import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; @@ -54,23 +58,26 @@ import com.facebook.presto.sql.planner.optimizations.WindowNodeUtil; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -83,6 +90,7 @@ import static com.facebook.presto.spi.plan.JoinNode.checkLeftOutputVariablesBeforeRight; import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.extractAggregationUniqueVariables; import static com.facebook.presto.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -116,6 +124,123 @@ public Void visitPlan(PlanNode node, Set boundVaria throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName()); } + @Override + public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) + { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getRequiredColumns(), + source.getOutputVariables()); + argumentProperties.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + Set passThroughVariable = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughVariable, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughVariable, + source.getOutputVariables()); + } + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundVariables) + { + if (!node.getSource().isPresent()) { + return null; + } + + PlanNode source = node.getSource().get(); + source.accept(this, boundVariables); + + Set inputs = createInputs(source, boundVariables); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputVariables()); + + Set requiredSymbols = node.getRequiredVariables().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputVariables()); + + node.getMarkerVariables().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputVariables()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputVariables()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + + return null; + } + @Override public Void visitExplainAnalyze(ExplainAnalyzeNode node, Set boundVariables) { @@ -360,6 +485,47 @@ public Void visitDistinctLimit(DistinctLimitNode node, Set boundVariables) + { + PlanNode dataTablePlan = node.getSources().get(0); + PlanNode viewQueryPlan = node.getSources().get(1); + + dataTablePlan.accept(this, boundVariables); + viewQueryPlan.accept(this, boundVariables); + + Set dataTableOutputs = ImmutableSet.copyOf(dataTablePlan.getOutputVariables()); + Set viewQueryOutputs = ImmutableSet.copyOf(viewQueryPlan.getOutputVariables()); + + for (VariableReferenceExpression outputVariable : node.getOutputVariables()) { + VariableReferenceExpression dataTableVariable = node.getDataTableMappings().get(outputVariable); + checkArgument( + dataTableVariable != null, + "Output variable %s has no mapping in dataTableMappings", + outputVariable); + checkArgument( + dataTableOutputs.contains(dataTableVariable), + "Data table mapping variable %s for output %s not in data table plan output (%s)", + dataTableVariable, + outputVariable, + dataTableOutputs); + + VariableReferenceExpression viewQueryVariable = node.getViewQueryMappings().get(outputVariable); + checkArgument( + viewQueryVariable != null, + "Output variable %s has no mapping in viewQueryMappings", + outputVariable); + checkArgument( + viewQueryOutputs.contains(viewQueryVariable), + "View query mapping variable %s for output %s not in view query plan output (%s)", + viewQueryVariable, + outputVariable, + viewQueryOutputs); + } + + return null; + } + @Override public Void visitJoin(JoinNode node, Set boundVariables) { @@ -472,9 +638,7 @@ public Void visitIndexJoin(IndexJoinNode node, Set checkArgument(indexSourceInputs.contains(clause.getIndex()), "Index variable from index join clause (%s) not in index source (%s)", clause.getIndex(), node.getIndexSource().getOutputVariables()); } - Set lookupVariables = node.getCriteria().stream() - .map(IndexJoinNode.EquiJoinClause::getIndex) - .collect(toImmutableSet()); + Set lookupVariables = ImmutableSet.copyOf(node.getLookupVariables()); Map trace = IndexKeyTracer.trace(node.getIndexSource(), lookupVariables); checkArgument(!trace.isEmpty() && lookupVariables.containsAll(trace.keySet()), "Index lookup symbols are not traceable to index source: %s", @@ -590,6 +754,15 @@ public Void visitExchange(ExchangeNode node, Set bo return null; } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Set boundVariables) + { + PlanNode source = node.getSource(); + source.accept(this, boundVariables); // visit child + + return null; + } + @Override public Void visitTableWriter(TableWriterNode node, Set boundVariables) { @@ -608,13 +781,39 @@ public Void visitTableWriteMerge(TableWriterMergeNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + return null; + } + + @Override + public Void visitMergeProcessor(MergeProcessorNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); // visit child + + checkArgument(source.getOutputVariables().contains(node.getTargetTableRowIdColumnVariable()), + "Invalid node. rowId symbol (%s) is not in source plan output (%s)", + node.getTargetTableRowIdColumnVariable(), node.getSource().getOutputVariables()); + checkArgument(source.getOutputVariables().contains(node.getMergeRowVariable()), + "Invalid node. Merge row symbol (%s) is not in source plan output (%s)", + node.getMergeRowVariable(), node.getSource().getOutputVariables()); + + return null; + } + @Override public Void visitDelete(DeleteNode node, Set boundVariables) { PlanNode source = node.getSource(); source.accept(this, boundVariables); // visit child - checkArgument(source.getOutputVariables().contains(node.getRowId()), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputVariables()); + node.getRowId().ifPresent(rowid -> + checkArgument(source.getOutputVariables().contains(rowid), + "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", rowid, node.getSource().getOutputVariables())); return null; } @@ -624,7 +823,8 @@ public Void visitUpdate(UpdateNode node, Set boundV { PlanNode source = node.getSource(); source.accept(this, boundVariables); // visit child - checkArgument(source.getOutputVariables().contains(node.getRowId()), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputVariables()); + node.getRowId().ifPresent(r -> + checkArgument(source.getOutputVariables().contains(r), "Invalid node. Row ID symbol (%s) is not in source plan output (%s)", node.getRowId(), node.getSource().getOutputVariables())); checkArgument(source.getOutputVariables().containsAll(node.getColumnValueAndRowIdSymbols()), "Invalid node. Some UPDATE SET expression symbols (%s) are not contained in the outputSymbols (%s)", node.getColumnValueAndRowIdSymbols(), source.getOutputVariables()); return null; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java index cfe040826bdcf..f5ca4ab176918 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java @@ -86,8 +86,10 @@ public Void visitAggregation(AggregationNode node, Void context) List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedVariables())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); + Iterator>> additionalMatchIterator = LocalProperties.match(properties.getAdditionalLocalProperties(), desiredProperties).iterator(); Optional> unsatisfiedRequirement = Iterators.getOnlyElement(matchIterator); - checkArgument(!unsatisfiedRequirement.isPresent(), "Streaming aggregation with input not grouped on the grouping keys"); + Optional> additionalUnsatisfiedRequirement = Iterators.getOnlyElement(additionalMatchIterator); + checkArgument(!unsatisfiedRequirement.isPresent() || !additionalUnsatisfiedRequirement.isPresent(), "Streaming aggregation with input not grouped on the grouping keys"); return null; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/Expressions.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/Expressions.java index 3f4340ffdd445..373ce2bfd65a2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/Expressions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/Expressions.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.sql.relational; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.CastType; @@ -154,12 +153,6 @@ public static CallExpression call(FunctionAndTypeManager functionAndTypeManager, return call(name, functionHandle, returnType, arguments); } - public static CallExpression call(FunctionAndTypeManager functionAndTypeManager, QualifiedObjectName qualifiedObjectName, Type returnType, List arguments) - { - FunctionHandle functionHandle = functionAndTypeManager.lookupFunction(qualifiedObjectName, fromTypes(arguments.stream().map(RowExpression::getType).collect(toImmutableList()))); - return call(String.valueOf(qualifiedObjectName), functionHandle, returnType, arguments); - } - public static CallExpression call(FunctionAndTypeResolver functionAndTypeResolver, String name, Type returnType, RowExpression... arguments) { FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(name, fromTypes(Arrays.stream(arguments).map(RowExpression::getType).collect(toImmutableList()))); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index 860d1e5ee3927..22a4e5d54473b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -148,6 +148,12 @@ public boolean isCastFunction(FunctionHandle functionHandle) return functionAndTypeResolver.getFunctionMetadata(functionHandle).getOperatorType().equals(Optional.of(OperatorType.CAST)); } + @Override + public FunctionHandle lookupCast(String castType, Type fromType, Type toType) + { + return functionAndTypeResolver.lookupCast(castType, fromType, toType); + } + public boolean isTryCastFunction(FunctionHandle functionHandle) { return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "TRY_CAST")); @@ -275,14 +281,13 @@ public boolean isComparisonFunction(FunctionHandle functionHandle) public boolean isEqualsFunction(FunctionHandle functionHandle) { - Optional operatorType = functionAndTypeResolver.getFunctionMetadata(functionHandle).getOperatorType(); - return operatorType.isPresent() && operatorType.get().getOperator().equals(EQUAL.getOperator()); + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getOperatorType().map(EQUAL::equals).orElse(false); } @Override public FunctionHandle subscriptFunction(Type baseType, Type indexType) { - return functionAndTypeResolver.lookupFunction(SUBSCRIPT.getFunctionName().getObjectName(), fromTypes(baseType, indexType)); + return functionAndTypeResolver.resolveOperator(SUBSCRIPT, fromTypes(baseType, indexType)); } @Override @@ -301,9 +306,12 @@ public boolean isTryFunction(FunctionHandle functionHandle) return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().getObjectName().equals("$internal$try"); } - public boolean isFailFunction(FunctionHandle functionHandle) + public boolean isJavaBuiltInFailFunction(FunctionHandle functionHandle) { - return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("fail"))); + // todo: Revert this hack once constant folding support lands in C++. + // For now, we always use the presto.default.fail function even when the default namespace is switched. + // This is done for consistency since the BuiltInNamespaceRewriter rewrites the functionHandles to presto.default functionHandles. + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "fail")); } @Override @@ -312,6 +320,12 @@ public boolean isCountFunction(FunctionHandle functionHandle) return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("count"))); } + @Override + public boolean isCountIfFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("count_if"))); + } + @Override public FunctionHandle countFunction() { @@ -324,6 +338,22 @@ public FunctionHandle countFunction(Type valueType) return functionAndTypeResolver.lookupFunction("count", fromTypes(valueType)); } + public boolean isMaxByFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("max_by"))); + } + + public boolean isMinByFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by"))); + } + + @Override + public FunctionHandle arbitraryFunction(Type valueType) + { + return functionAndTypeResolver.lookupFunction("arbitrary", fromTypes(valueType)); + } + @Override public boolean isMaxFunction(FunctionHandle functionHandle) { @@ -384,11 +414,6 @@ public FunctionHandle approximateSetFunction(Type valueType) return functionAndTypeResolver.lookupFunction("approx_set", fromTypes(valueType)); } - public boolean isEqualFunction(FunctionHandle functionHandle) - { - return functionAndTypeResolver.getFunctionMetadata(functionHandle).getOperatorType().map(EQUAL::equals).orElse(false); - } - public boolean isArrayContainsFunction(FunctionHandle functionHandle) { return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("contains"))); @@ -404,9 +429,30 @@ public boolean isWindowValueFunction(FunctionHandle functionHandle) return windowValueFunctions.contains(functionAndTypeResolver.getFunctionMetadata(functionHandle).getName()); } + public boolean isMapSubSetFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("map_subset"))); + } + + public boolean isMapFilterFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("map_filter"))); + } + + public boolean isCardinalityFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("cardinality"))); + } + @Override public FunctionHandle lookupBuiltInFunction(String functionName, List inputTypes) { return functionAndTypeResolver.lookupFunction(functionName, fromTypes(inputTypes)); } + + @Override + public FunctionHandle lookupFunction(String catalog, String schema, String functionName, List inputTypes) + { + return functionAndTypeResolver.resolveFunction(Optional.empty(), Optional.empty(), QualifiedObjectName.valueOf(catalog, schema, functionName), fromTypes(inputTypes)); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDeterminismEvaluator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDeterminismEvaluator.java index 7c55d7cb9af75..da485dd86f4fc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDeterminismEvaluator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDeterminismEvaluator.java @@ -27,8 +27,7 @@ import com.facebook.presto.spi.relation.RowExpressionVisitor; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDomainTranslator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDomainTranslator.java index 8b0561ff73074..028fccaebf8a4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDomainTranslator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionDomainTranslator.java @@ -50,9 +50,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.PeekingIterator; - -import javax.annotation.Nullable; -import javax.inject.Inject; +import jakarta.annotation.Nullable; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.List; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java index 173b1c996e0ea..811fe88c20c10 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionOptimizer.java @@ -39,7 +39,7 @@ public RowExpressionOptimizer(Metadata metadata) public RowExpressionOptimizer(FunctionAndTypeManager functionAndTypeManager) { - this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionMetadataManager is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java new file mode 100644 index 0000000000000..531b25e715de5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; + +import static java.util.Objects.requireNonNull; + +public class RowExpressionUtils +{ + private RowExpressionUtils() {} + + public static boolean containsNonCoordinatorEligibleCallExpression(FunctionAndTypeManager functionAndTypeManager, RowExpression expression) + { + return expression.accept(new ContainsNonCoordinatorEligibleCallExpressionVisitor(functionAndTypeManager), null); + } + + private static class ContainsNonCoordinatorEligibleCallExpressionVisitor + implements RowExpressionVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + + public ContainsNonCoordinatorEligibleCallExpressionVisitor(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + @Override + public Boolean visitCall(CallExpression call, Void context) + { + // If the call is not a Java function, we return true to indicate that we found a non-Java expression + FunctionHandle functionHandle = call.getFunctionHandle(); + FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(functionHandle); + if (!functionMetadata.getImplementationType().canBeEvaluatedInCoordinator()) { + return true; + } + for (RowExpression argument : call.getArguments()) { + if (argument.accept(this, context)) { + return true; // Found a non-Java expression in arguments + } + } + return false; + } + + @Override + public Boolean visitExpression(RowExpression expression, Void context) + { + for (RowExpression child : expression.getChildren()) { + if (child.accept(this, context)) { + return true; // Found a non-Java expression + } + } + return false; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DefaultTreeRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DefaultTreeRewriter.java deleted file mode 100644 index 2b3009a64dde9..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DefaultTreeRewriter.java +++ /dev/null @@ -1,745 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.rewrite; - -import com.facebook.presto.sql.tree.AddColumn; -import com.facebook.presto.sql.tree.AliasedRelation; -import com.facebook.presto.sql.tree.Analyze; -import com.facebook.presto.sql.tree.AstVisitor; -import com.facebook.presto.sql.tree.Call; -import com.facebook.presto.sql.tree.CallArgument; -import com.facebook.presto.sql.tree.ColumnDefinition; -import com.facebook.presto.sql.tree.CreateMaterializedView; -import com.facebook.presto.sql.tree.CreateSchema; -import com.facebook.presto.sql.tree.CreateTable; -import com.facebook.presto.sql.tree.CreateTableAsSelect; -import com.facebook.presto.sql.tree.CreateView; -import com.facebook.presto.sql.tree.Cube; -import com.facebook.presto.sql.tree.Deallocate; -import com.facebook.presto.sql.tree.Delete; -import com.facebook.presto.sql.tree.Except; -import com.facebook.presto.sql.tree.Execute; -import com.facebook.presto.sql.tree.Explain; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FrameBound; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.GroupBy; -import com.facebook.presto.sql.tree.GroupingElement; -import com.facebook.presto.sql.tree.GroupingSets; -import com.facebook.presto.sql.tree.Identifier; -import com.facebook.presto.sql.tree.Insert; -import com.facebook.presto.sql.tree.Intersect; -import com.facebook.presto.sql.tree.Join; -import com.facebook.presto.sql.tree.JoinCriteria; -import com.facebook.presto.sql.tree.JoinOn; -import com.facebook.presto.sql.tree.Lateral; -import com.facebook.presto.sql.tree.Node; -import com.facebook.presto.sql.tree.OrderBy; -import com.facebook.presto.sql.tree.Prepare; -import com.facebook.presto.sql.tree.Property; -import com.facebook.presto.sql.tree.Query; -import com.facebook.presto.sql.tree.QueryBody; -import com.facebook.presto.sql.tree.QuerySpecification; -import com.facebook.presto.sql.tree.RefreshMaterializedView; -import com.facebook.presto.sql.tree.Relation; -import com.facebook.presto.sql.tree.Return; -import com.facebook.presto.sql.tree.Rollup; -import com.facebook.presto.sql.tree.Row; -import com.facebook.presto.sql.tree.SampledRelation; -import com.facebook.presto.sql.tree.Select; -import com.facebook.presto.sql.tree.SelectItem; -import com.facebook.presto.sql.tree.ShowStats; -import com.facebook.presto.sql.tree.SimpleGroupBy; -import com.facebook.presto.sql.tree.SingleColumn; -import com.facebook.presto.sql.tree.SortItem; -import com.facebook.presto.sql.tree.Statement; -import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.sql.tree.TableElement; -import com.facebook.presto.sql.tree.TableSubquery; -import com.facebook.presto.sql.tree.Union; -import com.facebook.presto.sql.tree.Unnest; -import com.facebook.presto.sql.tree.Values; -import com.facebook.presto.sql.tree.Window; -import com.facebook.presto.sql.tree.WindowFrame; -import com.facebook.presto.sql.tree.With; -import com.facebook.presto.sql.tree.WithQuery; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; - -import java.util.Iterator; -import java.util.List; -import java.util.Optional; - -/** - * A default implementation of {@link AstVisitor} that reconstructs a node if any of its children is reconstructed. - * Expression node reconstruction is not supported and left to the users. - */ -public class DefaultTreeRewriter - extends AstVisitor -{ - @Override - protected Node visitNode(Node node, C context) - { - return node; - } - - @Override - protected Node visitExpression(Expression node, C context) - { - throw new UnsupportedOperationException("Not yet implemented: " + getClass().getSimpleName() + " for " + node.getClass().getName()); - } - - @Override - protected Node visitAddColumn(AddColumn node, C context) - { - Node column = process(node.getColumn(), context); - if (node.getColumn() == column) { - return node; - } - - return new AddColumn(node.getName(), (ColumnDefinition) column, node.isTableExists(), node.isColumnNotExists()); - } - - @Override - protected Node visitAliasedRelation(AliasedRelation node, C context) - { - Node relation = process(node.getRelation(), context); - Node alias = process(node.getAlias(), context); - List columnNames = process(node.getColumnNames(), context); - if (node.getRelation() == relation && node.getAlias() == alias && sameElements(node.getColumnNames(), columnNames)) { - return node; - } - - return new AliasedRelation((Relation) relation, (Identifier) alias, columnNames); - } - - @Override - protected Node visitAnalyze(Analyze node, C context) - { - List properties = process(node.getProperties(), context); - if (sameElements(node.getProperties(), properties)) { - return node; - } - - return new Analyze(node.getTableName(), properties); - } - - @Override - protected Node visitCall(Call node, C context) - { - List arguments = process(node.getArguments(), context); - if (sameElements(node.getArguments(), arguments)) { - return node; - } - - return new Call(node.getName(), arguments); - } - - @Override - protected Node visitCallArgument(CallArgument node, C context) - { - Node value = process(node.getValue(), context); - if (node.getValue() == value) { - return node; - } - - return node.getName().isPresent() ? new CallArgument(node.getName().get(), (Expression) value) : new CallArgument((Expression) value); - } - - @Override - protected Node visitColumnDefinition(ColumnDefinition node, C context) - { - Node name = process(node.getName(), context); - List properties = process(node.getProperties(), context); - if (node.getName() == name && sameElements(node.getProperties(), properties)) { - return node; - } - - return new ColumnDefinition((Identifier) name, node.getType(), node.isNullable(), properties, node.getComment()); - } - - @Override - protected Node visitCreateMaterializedView(CreateMaterializedView node, C context) - { - Node query = process(node.getQuery(), context); - List properties = process(node.getProperties(), context); - if (node.getQuery() == query && node.getProperties() == properties) { - return node; - } - - return new CreateMaterializedView(node.getName(), (Query) query, node.isNotExists(), properties, node.getComment()); - } - - @Override - protected Node visitCreateSchema(CreateSchema node, C context) - { - List properties = process(node.getProperties(), context); - if (sameElements(node.getProperties(), properties)) { - return node; - } - - return new CreateSchema(node.getSchemaName(), node.isNotExists(), properties); - } - - @Override - protected Node visitCreateTable(CreateTable node, C context) - { - List elements = process(node.getElements(), context); - List properties = process(node.getProperties(), context); - if (sameElements(node.getElements(), elements) && sameElements(node.getProperties(), properties)) { - return node; - } - - return new CreateTable(node.getName(), elements, node.isNotExists(), properties, node.getComment()); - } - - @Override - protected Node visitCreateTableAsSelect(CreateTableAsSelect node, C context) - { - Node query = process(node.getQuery(), context); - List properties = process(node.getProperties(), context); - Optional> columnAliases = node.getColumnAliases().map(aliases -> process(aliases, context)); - if (node.getQuery() == query && node.getProperties() == properties && (!columnAliases.isPresent() || sameElements(node.getColumnAliases().get(), columnAliases.get()))) { - return node; - } - - return new CreateTableAsSelect(node.getName(), (Query) query, node.isNotExists(), properties, node.isWithData(), node.getColumnAliases(), node.getComment()); - } - - @Override - protected Node visitCreateView(CreateView node, C context) - { - Node query = process(node.getQuery(), context); - if (node.getQuery() == query) { - return node; - } - - return new CreateView(node.getName(), (Query) query, node.isReplace(), node.getSecurity()); - } - - @Override - protected Node visitCube(Cube node, C context) - { - List expressions = process(node.getExpressions(), context); - if (sameElements(node.getExpressions(), expressions)) { - return node; - } - - return new Cube(expressions); - } - - @Override - protected Node visitDeallocate(Deallocate node, C context) - { - Node name = process(node.getName(), context); - if (node.getName() == name) { - return node; - } - - return new Deallocate((Identifier) name); - } - - @Override - protected Node visitDelete(Delete node, C context) - { - Node table = process(node.getTable(), context); - Optional where = process(node.getWhere(), context); - if (node.getTable() == table && sameElement(node.getWhere(), where)) { - return node; - } - - return new Delete((Table) table, where); - } - - @Override - protected Node visitExcept(Except node, C context) - { - Node left = process(node.getLeft(), context); - Node right = process(node.getRight(), context); - if (node.getLeft() == left && node.getRight() == right) { - return node; - } - - return new Except((Relation) left, (Relation) right, node.isDistinct()); - } - - @Override - protected Node visitExecute(Execute node, C context) - { - Node name = process(node.getName(), context); - List parameters = process(node.getParameters(), context); - if (node.getName() == name && sameElements(node.getParameters(), parameters)) { - return node; - } - - return new Execute(node.getName(), parameters); - } - - @Override - protected Node visitExplain(Explain node, C context) - { - Node statement = process(node.getStatement(), context); - if (node.getStatement() == statement) { - return node; - } - - return new Explain((Statement) statement, node.isAnalyze(), node.isVerbose(), node.getOptions()); - } - - @Override - protected Node visitFrameBound(FrameBound node, C context) - { - Optional value = process(node.getValue(), context); - if (sameElement(node.getValue(), value)) { - return node; - } - - return value.map(expression -> new FrameBound(node.getType(), expression)).orElseGet(() -> new FrameBound(node.getType())); - } - - @Override - protected Node visitFunctionCall(FunctionCall node, C context) - { - Optional window = process(node.getWindow(), context); - Optional filter = process(node.getFilter(), context); - Optional orderBy = process(node.getOrderBy(), context); - List arguments = process(node.getArguments(), context); - if (sameElement(node.getWindow(), window) && sameElement(node.getFilter(), filter) && sameElement(node.getOrderBy(), orderBy) && sameElements(node.getArguments(), arguments)) { - return node; - } - - return node.getLocation().isPresent() ? new FunctionCall(node.getLocation().get(), node.getName(), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), node.getArguments()) : - new FunctionCall(node.getName(), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), node.getArguments()); - } - - @Override - protected Node visitGroupBy(GroupBy node, C context) - { - List groupingElements = process(node.getGroupingElements(), context); - if (sameElements(node.getGroupingElements(), groupingElements)) { - return node; - } - - return new GroupBy(node.isDistinct(), groupingElements); - } - - @Override - protected Node visitGroupingSets(GroupingSets node, C context) - { - List> sets = node.getSets().stream().map(expressionList -> process(expressionList, context)).collect(ImmutableList.toImmutableList()); - - if (sameElements(node.getSets(), sets)) { - return node; - } - - return new GroupingSets(sets); - } - - @Override - protected Node visitInsert(Insert node, C context) - { - Node query = process(node.getQuery(), context); - Optional> columns = node.getColumns().map(columnList -> process(columnList, context)); - if (node.getQuery() == query && (!columns.isPresent() || sameElements(node.getColumns().get(), columns.get()))) { - return node; - } - - return new Insert(node.getTarget(), columns, (Query) query); - } - - @Override - protected Node visitIntersect(Intersect node, C context) - { - List relations = process(node.getRelations(), context); - if (sameElements(node.getRelations(), relations)) { - return node; - } - - return new Intersect(relations, node.isDistinct()); - } - - @Override - protected Node visitJoin(Join node, C context) - { - Node left = process(node.getLeft(), context); - Node right = process(node.getRight(), context); - Optional joinCriteria = node.getCriteria() - .map(criteria -> { - if (criteria instanceof JoinOn) { - Node expression = process(((JoinOn) criteria).getExpression(), context); - if (((JoinOn) criteria).getExpression() == expression) { - return criteria; - } - return new JoinOn((Expression) expression); - } - return criteria; - }); - if (node.getLeft() == left && node.getRight() == right && node.getCriteria() == joinCriteria) { - return node; - } - - return new Join(node.getType(), (Relation) left, (Relation) right, joinCriteria); - } - - @Override - protected Node visitLateral(Lateral node, C context) - { - Node query = process(node.getQuery(), context); - if (node.getQuery() == query) { - return node; - } - - return new Lateral((Query) query); - } - - @Override - protected Node visitOrderBy(OrderBy node, C context) - { - List sortItems = process(node.getSortItems(), context); - if (sameElements(node.getSortItems(), sortItems)) { - return node; - } - - return new OrderBy(sortItems); - } - - @Override - protected Node visitPrepare(Prepare node, C context) - { - Node statement = process(node.getStatement(), context); - if (node.getStatement() == statement) { - return node; - } - - return new Prepare(node.getName(), (Statement) statement); - } - - @Override - protected Node visitProperty(Property node, C context) - { - Node name = process(node.getName(), context); - Node value = process(node.getValue(), context); - if (node.getName() == name && node.getValue() == value) { - return node; - } - - return new Property((Identifier) name, (Expression) value); - } - - @Override - protected Node visitQuery(Query node, C context) - { - Optional with = process(node.getWith(), context); - Node queryBody = process(node.getQueryBody(), context); - Optional orderBy = process(node.getOrderBy(), context); - if (node.getQueryBody() == queryBody && sameElement(node.getWith(), with) && sameElement(node.getOrderBy(), orderBy)) { - return node; - } - - return new Query(with, (QueryBody) queryBody, orderBy, node.getOffset(), node.getLimit()); - } - - @Override - protected Node visitQuerySpecification(QuerySpecification node, C context) - { - Node select = process(node.getSelect(), context); - Optional from = process(node.getFrom(), context); - Optional where = process(node.getWhere(), context); - Optional groupBy = process(node.getGroupBy(), context); - Optional having = process(node.getHaving(), context); - Optional orderBy = process(node.getOrderBy(), context); - if (node.getSelect() == - select && sameElement(node.getFrom(), from) && sameElement(node.getWhere(), where) && sameElement(node.getGroupBy(), groupBy) && sameElement(node.getHaving(), - having) && sameElement(node.getOrderBy(), orderBy)) { - return node; - } - - return new QuerySpecification( - (Select) select, - from, - where, - groupBy, - having, - orderBy, - node.getOffset(), - node.getLimit()); - } - - @Override - protected Node visitRefreshMaterializedView(RefreshMaterializedView node, C context) - { - Node table = process(node.getTarget(), context); - Node where = process(node.getWhere(), context); - if (node.getTarget() == table && node.getWhere() == where) { - return node; - } - - return new RefreshMaterializedView((Table) table, (Expression) where); - } - - @Override - protected Node visitReturn(Return node, C context) - { - Node expression = process(node.getExpression(), context); - if (node.getExpression() == expression) { - return node; - } - - return new Return((Expression) expression); - } - - @Override - protected Node visitRollup(Rollup node, C context) - { - List expressions = process(node.getExpressions(), context); - if (sameElements(node.getExpressions(), expressions)) { - return node; - } - - return new Rollup(expressions); - } - - @Override - protected Node visitRow(Row node, C context) - { - List items = process(node.getItems(), context); - if (sameElements(node.getItems(), items)) { - return node; - } - - return new Row(items); - } - - @Override - protected Node visitSampledRelation(SampledRelation node, C context) - { - Node relation = process(node.getRelation(), context); - Node samplePercentage = process(node.getSamplePercentage(), context); - if (node.getRelation() == relation && node.getSamplePercentage() == samplePercentage) { - return node; - } - - return new SampledRelation((Relation) relation, node.getType(), (Expression) samplePercentage); - } - - @Override - protected Node visitSelect(Select node, C context) - { - List selectItems = process(node.getSelectItems(), context); - if (sameElements(node.getSelectItems(), selectItems)) { - return node; - } - - return new Select(node.isDistinct(), selectItems); - } - - @Override - protected Node visitShowStats(ShowStats node, C context) - { - Node relation = process(node.getRelation(), context); - if (node.getRelation() == relation) { - return node; - } - - return new ShowStats((Relation) relation); - } - - @Override - protected Node visitSimpleGroupBy(SimpleGroupBy node, C context) - { - List columns = process(node.getExpressions(), context); - if (sameElements(node.getExpressions(), columns)) { - return node; - } - - return new SimpleGroupBy(columns); - } - - @Override - protected Node visitSingleColumn(SingleColumn node, C context) - { - Node expression = process(node.getExpression(), context); - if (node.getExpression() == expression) { - return node; - } - - return new SingleColumn((Expression) expression, node.getAlias()); - } - - @Override - protected Node visitSortItem(SortItem node, C context) - { - Node sortKey = process(node.getSortKey(), context); - if (node.getSortKey() == sortKey) { - return node; - } - - return new SortItem((Expression) sortKey, node.getOrdering(), node.getNullOrdering()); - } - - @Override - protected Node visitTableSubquery(TableSubquery node, C context) - { - Node query = process(node.getQuery(), context); - if (node.getQuery() == query) { - return node; - } - - return new TableSubquery((Query) query); - } - - @Override - protected Node visitUnion(Union node, C context) - { - List relations = process(node.getRelations(), context); - if (sameElements(node.getRelations(), relations)) { - return node; - } - - return new Union(relations, node.isDistinct()); - } - - @Override - protected Node visitUnnest(Unnest node, C context) - { - List expressions = process(node.getExpressions(), context); - if (sameElements(node.getExpressions(), expressions)) { - return node; - } - - return new Unnest(expressions, node.isWithOrdinality()); - } - - @Override - protected Node visitValues(Values node, C context) - { - List expressions = process(node.getRows(), context); - if (sameElements(node.getRows(), expressions)) { - return node; - } - - return new Values(expressions); - } - - @Override - protected Node visitWindow(Window node, C context) - { - List partitionBy = process(node.getPartitionBy(), context); - Optional orderBy = process(node.getOrderBy(), context); - Optional frame = process(node.getFrame(), context); - if (sameElements(node.getPartitionBy(), partitionBy) && sameElement(node.getOrderBy(), orderBy) && sameElement(node.getFrame(), frame)) { - return node; - } - - return new Window(partitionBy, orderBy, frame); - } - - @Override - protected Node visitWindowFrame(WindowFrame node, C context) - { - Node start = process(node.getStart(), context); - Optional end = process(node.getEnd(), context); - if (node.getStart() == start && sameElement(node.getEnd(), end)) { - return node; - } - - return new WindowFrame(node.getType(), (FrameBound) start, end); - } - - @Override - protected Node visitWith(With node, C context) - { - List queries = process(node.getQueries(), context); - if (sameElements(node.getQueries(), queries)) { - return node; - } - - return new With(node.isRecursive(), queries); - } - - @Override - protected Node visitWithQuery(WithQuery node, C context) - { - Node name = process(node.getName(), context); - Node query = process(node.getQuery(), context); - Optional> columnNames = node.getColumnNames().map(columnNamesList -> process(columnNamesList, context)); - if (node.getName() == name && node.getQuery() == query && sameElement(node.getColumnNames(), columnNames)) { - return node; - } - - return new WithQuery(node.getName(), (Query) query, node.getColumnNames()); - } - - private List process(List elements, C context) - { - if (elements == null) { - return null; - } - List result = elements.stream().map(element -> (T) process(element, context)).collect(ImmutableList.toImmutableList()); - return sameElements(elements, result) ? elements : result; - } - - private Optional process(Optional element, C context) - { - if (element == null) { - return null; - } - Optional result = element.map(e -> (T) process(e, context)); - return sameElement(element, result) ? element : result; - } - - private static boolean sameElement(Optional a, Optional b) - { - if (a == null && b == null) { - return true; - } - if (a == null || b == null) { - return false; - } - if (!a.isPresent() && !b.isPresent()) { - return true; - } - else if (a.isPresent() != b.isPresent()) { - return false; - } - - return a.get() == b.get(); - } - - @SuppressWarnings("ObjectEquality") - private static boolean sameElements(Iterable a, Iterable b) - { - if (a == null && b == null) { - return true; - } - if (a == null || b == null) { - return false; - } - - if (Iterables.size(a) != Iterables.size(b)) { - return false; - } - - Iterator first = a.iterator(); - Iterator second = b.iterator(); - - while (first.hasNext() && second.hasNext()) { - if (first.next() != second.next()) { - return false; - } - } - - return true; - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeInputRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeInputRewrite.java index b3ee4b42a543c..5f36afa14d7e4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeInputRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeInputRewrite.java @@ -17,6 +17,8 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; @@ -50,6 +52,7 @@ import static com.facebook.presto.sql.QueryUtil.simpleQuery; import static com.facebook.presto.sql.QueryUtil.values; import static com.facebook.presto.sql.analyzer.utils.ParameterExtractor.getParameters; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissionsForTablesAndColumns; import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static java.util.Objects.requireNonNull; @@ -67,9 +70,10 @@ public Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { - return (Statement) new Visitor(session, parser, metadata, queryExplainer, parameters, parameterLookup, accessControl, warningCollector, query).process(node, null); + return (Statement) new Visitor(session, parser, metadata, queryExplainer, parameters, parameterLookup, accessControl, warningCollector, query, viewDefinitionReferences).process(node, null); } private static final class Visitor @@ -84,6 +88,7 @@ private static final class Visitor private final AccessControl accessControl; private final WarningCollector warningCollector; private final String query; + private final ViewDefinitionReferences viewDefinitionReferences; public Visitor( Session session, @@ -94,7 +99,8 @@ public Visitor( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { this.session = requireNonNull(session, "session is null"); this.parser = parser; @@ -105,6 +111,7 @@ public Visitor( this.parameterLookup = parameterLookup; this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.query = requireNonNull(query, "query is null"); + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } @Override @@ -115,8 +122,10 @@ protected Node visitDescribeInput(DescribeInput node, Void context) Statement statement = parser.createStatement(sqlString, createParsingOptions(session, warningCollector)); // create analysis for the query we are describing. - Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, query); - Analysis analysis = analyzer.analyze(statement, true); + Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, query, viewDefinitionReferences); + Analysis analysis = analyzer.analyzeSemantic(statement, true); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissionsForTablesAndColumns(accessControlReferences); // get all parameters in query List parameters = getParameters(statement); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeOutputRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeOutputRewrite.java index f8fc8c19ef5ec..b8108619ceb03 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeOutputRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/DescribeOutputRewrite.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.type.FixedWidthType; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.Analysis; import com.facebook.presto.sql.analyzer.Analyzer; @@ -48,6 +50,7 @@ import static com.facebook.presto.sql.QueryUtil.selectList; import static com.facebook.presto.sql.QueryUtil.simpleQuery; import static com.facebook.presto.sql.QueryUtil.values; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissionsForTablesAndColumns; import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static java.util.Objects.requireNonNull; @@ -65,9 +68,10 @@ public Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { - return (Statement) new Visitor(session, parser, metadata, queryExplainer, parameters, parameterLookup, accessControl, warningCollector, query).process(node, null); + return (Statement) new Visitor(session, parser, metadata, queryExplainer, parameters, parameterLookup, accessControl, warningCollector, query, viewDefinitionReferences).process(node, null); } private static final class Visitor @@ -82,6 +86,7 @@ private static final class Visitor private final AccessControl accessControl; private final WarningCollector warningCollector; private final String query; + private final ViewDefinitionReferences viewDefinitionReferences; public Visitor( Session session, @@ -92,7 +97,8 @@ public Visitor( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { this.session = requireNonNull(session, "session is null"); this.parser = parser; @@ -103,6 +109,7 @@ public Visitor( this.accessControl = accessControl; this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.query = requireNonNull(query, "query is null"); + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } @Override @@ -111,8 +118,10 @@ protected Node visitDescribeOutput(DescribeOutput node, Void context) String sqlString = session.getPreparedStatement(node.getName().getValue()); Statement statement = parser.createStatement(sqlString, createParsingOptions(session, warningCollector)); - Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, query); - Analysis analysis = analyzer.analyze(statement, true); + Analyzer analyzer = new Analyzer(session, metadata, parser, accessControl, queryExplainer, parameters, parameterLookup, warningCollector, query, viewDefinitionReferences); + Analysis analysis = analyzer.analyzeSemantic(statement, true); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissionsForTablesAndColumns(accessControlReferences); Optional limit = Optional.empty(); Row[] rows = analysis.getRootScope().getRelationType().getVisibleFields().stream().map(field -> createDescribeOutputRow(field, analysis)).toArray(Row[]::new); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java index 1b5ee3c1a47dd..3dcfb983b21f0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ExplainRewrite.java @@ -17,6 +17,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AnalyzerOptions; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer; import com.facebook.presto.sql.analyzer.BuiltInQueryPreparer.BuiltInPreparedQuery; @@ -61,9 +63,11 @@ public Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { - return (Statement) new Visitor(session, parser, queryExplainer, warningCollector, query).process(node, null); + return (Statement) new Visitor(session, parser, queryExplainer, metadata.getProcedureRegistry(), warningCollector, query, viewDefinitionReferences) + .process(node, null); } private static final class Visitor @@ -74,19 +78,23 @@ private static final class Visitor private final Optional queryExplainer; private final WarningCollector warningCollector; private final String query; + private final ViewDefinitionReferences viewDefinitionReferences; public Visitor( Session session, SqlParser parser, Optional queryExplainer, + ProcedureRegistry procedureRegistry, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { this.session = requireNonNull(session, "session is null"); - this.queryPreparer = new BuiltInQueryPreparer(requireNonNull(parser, "queryPreparer is null")); + this.queryPreparer = new BuiltInQueryPreparer(requireNonNull(parser, "queryPreparer is null"), procedureRegistry); this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.query = requireNonNull(query, "query is null"); + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } @Override @@ -135,13 +143,13 @@ private Node getQueryPlan(Explain node, ExplainType.Type planType, ExplainFormat String plan; switch (planFormat) { case GRAPHVIZ: - plan = queryExplainer.get().getGraphvizPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), warningCollector, query); + plan = queryExplainer.get().getGraphvizPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), warningCollector, query, viewDefinitionReferences); break; case JSON: - plan = queryExplainer.get().getJsonPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), warningCollector, query); + plan = queryExplainer.get().getJsonPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), warningCollector, query, viewDefinitionReferences); break; case TEXT: - plan = queryExplainer.get().getPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), node.isVerbose(), warningCollector, query); + plan = queryExplainer.get().getPlan(session, preparedQuery.getStatement(), planType, preparedQuery.getParameters(), node.isVerbose(), warningCollector, query, viewDefinitionReferences); break; default: throw new IllegalArgumentException("Invalid Explain Format: " + planFormat.toString()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/MaterializedViewOptimizationRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/MaterializedViewOptimizationRewrite.java index 9e72b2653447c..9221821205f04 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/MaterializedViewOptimizationRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/MaterializedViewOptimizationRewrite.java @@ -17,6 +17,7 @@ import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.MaterializedViewQueryOptimizer; import com.facebook.presto.sql.analyzer.QueryExplainer; @@ -52,7 +53,8 @@ public Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { return (Statement) new MaterializedViewOptimizationRewrite .Visitor(metadata, session, parser, accessControl) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java deleted file mode 100644 index ebd9e3d31a543..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.rewrite; - -import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; -import com.facebook.presto.UnknownTypeException; -import com.facebook.presto.common.type.BigintEnumType; -import com.facebook.presto.common.type.EnumType; -import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.TypeWithName; -import com.facebook.presto.common.type.VarcharEnumType; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.security.AccessControl; -import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver; -import com.facebook.presto.sql.analyzer.QueryExplainer; -import com.facebook.presto.sql.analyzer.SemanticException; -import com.facebook.presto.sql.parser.SqlParser; -import com.facebook.presto.sql.tree.ArrayConstructor; -import com.facebook.presto.sql.tree.Cast; -import com.facebook.presto.sql.tree.Expression; -import com.facebook.presto.sql.tree.FunctionCall; -import com.facebook.presto.sql.tree.LongLiteral; -import com.facebook.presto.sql.tree.Node; -import com.facebook.presto.sql.tree.NodeRef; -import com.facebook.presto.sql.tree.Parameter; -import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.sql.tree.Statement; -import com.facebook.presto.sql.tree.StringLiteral; -import com.google.common.collect.ImmutableList; - -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.facebook.presto.common.type.StandardTypes.BIGINT; -import static com.facebook.presto.common.type.StandardTypes.BIGINT_ENUM; -import static com.facebook.presto.common.type.StandardTypes.VARCHAR; -import static com.facebook.presto.common.type.StandardTypes.VARCHAR_ENUM; -import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH; -import static java.util.Objects.requireNonNull; - -/** - * Queries can fail on native worker due to following missing support in Velox:

    - * 1. Named types: Presto supports {@link TypeWithName} which Velox is not able to parse.

    - * 2. {@link EnumType}: Velox does not support EnumTypes as well as its companion function {@code ENUM_KEY}.

    - * - * This rewrite addresses the above issues by resolving the type or function in coordinator for native execution:

    - * 1. Peel {@link TypeWithName} and only preserve the actual base type.

    - * 2. Rewrite {@code CAST(col AS EnumType)} -> {@code CAST(col AS )}.

    TODO: preserve the original type information for `typeof`.

    - * 3. Since enum can be treated as a map, rewrite {@code ENUM_KEY(EnumType)} -> {@code ELEMENT_AT(MAP(, VARCHAR))}.

    - */ -final class NativeExecutionTypeRewrite - implements StatementRewrite.Rewrite -{ - private static final String FUNCTION_ENUM_KEY = "enum_key"; - private static final String FUNCTION_ELEMENT_AT = "element_at"; - private static final String FUNCTION_MAP = "map"; - - @Override - public Statement rewrite( - Session session, - Metadata metadata, - SqlParser parser, - Optional queryExplainer, - Statement node, - List parameters, - Map, Expression> parameterLookup, - AccessControl accessControl, - WarningCollector warningCollector, - String query) - { - if (SystemSessionProperties.isNativeExecutionEnabled(session) - && SystemSessionProperties.isNativeExecutionTypeRewriteEnabled(session)) { - return (Statement) new Rewriter(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()).process(node, null); - } - return node; - } - - private static final class Rewriter - extends DefaultTreeRewriter - { - private final FunctionAndTypeResolver functionAndTypeResolver; - - public Rewriter(FunctionAndTypeResolver functionAndTypeResolver) - { - this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null"); - } - - @Override - protected Node visitCast(Cast node, Void context) - { - try { - Type type = functionAndTypeResolver.getType(parseTypeSignature(node.getType())); - if (type instanceof TypeWithName) { - // Peel user defined type name. - type = ((TypeWithName) type).getType(); - switch (type.getTypeSignature().getBase()) { - case BIGINT_ENUM: - return new Cast(node.getLocation(), node.getExpression(), BIGINT, node.isSafe(), node.isTypeOnly()); - case VARCHAR_ENUM: - return new Cast(node.getLocation(), node.getExpression(), VARCHAR, node.isSafe(), node.isTypeOnly()); - default: - return new Cast(node.getLocation(), node.getExpression(), type.getTypeSignature().getBase(), node.isSafe(), node.isTypeOnly()); - } - } - } - catch (IllegalArgumentException | UnknownTypeException e) { - throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType()); - } - return node; - } - - @Override - protected Node visitFunctionCall(FunctionCall node, Void context) - { - if (isValidEnumKeyFunctionCall(node)) { - Cast argument = (Cast) node.getArguments().get(0); - Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(argument.getType())); - if (argumentType instanceof TypeWithName) { - // Peel user defined type name. - argumentType = ((TypeWithName) argumentType).getType(); - } - if (argumentType instanceof EnumType) { - // Convert enum_key to element_at. - List arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument.getExpression()); - return node.getLocation().isPresent() - ? new FunctionCall(node.getLocation().get(), QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) - : new FunctionCall(QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); - } - } - return super.visitFunctionCall(node, context); - } - - @Override - protected Node visitExpression(Expression node, Void context) - { - return node; - } - - private boolean isValidEnumKeyFunctionCall(FunctionCall node) - { - return node.getName().equals(QualifiedName.of(FUNCTION_ENUM_KEY)) - && node.getArguments().size() == 1 - && node.getArguments().get(0) instanceof Cast; - } - - private Expression convertEnumTypeToMapExpression(Type type) - { - ImmutableList.Builder keys = ImmutableList.builder(); - ImmutableList.Builder values = ImmutableList.builder(); - switch (type.getTypeSignature().getBase()) { - case BIGINT_ENUM: - for (Map.Entry entry : ((BigintEnumType) type).getEnumMap().entrySet()) { - keys.add(new LongLiteral(entry.getValue().toString())); - values.add(new StringLiteral(entry.getKey())); - } - break; - case VARCHAR_ENUM: - for (Map.Entry entry : ((VarcharEnumType) type).getEnumMap().entrySet()) { - keys.add(new StringLiteral(entry.getValue())); - values.add(new StringLiteral(entry.getKey())); - } - break; - default: - throw new SemanticException(TYPE_MISMATCH, "Unknown type: " + type); - } - return new FunctionCall(QualifiedName.of(FUNCTION_MAP), - ImmutableList.of( - new ArrayConstructor(keys.build()), - new ArrayConstructor(values.build()))); - } - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java index ea58bb3ec537c..e0b5214a689e7 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowQueriesRewrite.java @@ -14,11 +14,13 @@ package com.facebook.presto.sql.rewrite; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.SessionPropertyManager.SessionPropertyValue; +import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.MaterializedViewDefinition; @@ -29,6 +31,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.constraints.NotNullConstraint; import com.facebook.presto.spi.constraints.PrimaryKeyConstraint; import com.facebook.presto.spi.constraints.UniqueConstraint; @@ -39,6 +42,7 @@ import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.PrincipalType; +import com.facebook.presto.spi.security.ViewSecurity; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.sql.QueryUtil; import com.facebook.presto.sql.analyzer.QueryExplainer; @@ -94,6 +98,7 @@ import com.google.common.collect.ImmutableSortedMap; import com.google.common.primitives.Primitives; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -112,7 +117,6 @@ import static com.facebook.presto.metadata.MetadataListing.listCatalogs; import static com.facebook.presto.metadata.MetadataListing.listSchemas; import static com.facebook.presto.metadata.MetadataUtil.createCatalogSchemaName; -import static com.facebook.presto.metadata.MetadataUtil.createQualifiedName; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; @@ -121,6 +125,8 @@ import static com.facebook.presto.spi.StandardErrorCode.INVALID_COLUMN_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_SCHEMA_PROPERTY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_TABLE_PROPERTY; +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; import static com.facebook.presto.sql.QueryUtil.aliased; import static com.facebook.presto.sql.QueryUtil.aliasedName; @@ -161,6 +167,7 @@ import static com.google.common.base.Strings.nullToEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getLast; import static java.lang.String.format; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; @@ -181,7 +188,8 @@ public Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { return (Statement) new Visitor(metadata, parser, session, parameters, accessControl, queryExplainer, warningCollector).process(node, null); } @@ -225,7 +233,7 @@ protected Node visitExplain(Explain node, Void context) @Override protected Node visitShowTables(ShowTables showTables, Void context) { - CatalogSchemaName schema = createCatalogSchemaName(session, showTables, showTables.getSchema()); + CatalogSchemaName schema = createCatalogSchemaName(session, showTables, showTables.getSchema(), metadata); accessControl.checkCanShowTablesMetadata(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), schema); @@ -263,7 +271,7 @@ protected Node visitShowGrants(ShowGrants showGrants, Void context) Optional tableName = showGrants.getTableName(); if (tableName.isPresent()) { - QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, showGrants, tableName.get()); + QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, showGrants, tableName.get(), metadata); if (!metadataResolver.getView(qualifiedTableName).isPresent() && !metadataResolver.getTableHandle(qualifiedTableName).isPresent()) { @@ -403,19 +411,28 @@ protected Node visitShowCatalogs(ShowCatalogs node, Void context) @Override protected Node visitShowColumns(ShowColumns showColumns, Void context) { - QualifiedObjectName tableName = createQualifiedObjectName(session, showColumns, showColumns.getTable()); + QualifiedObjectName tableName = createQualifiedObjectName(session, showColumns, showColumns.getTable(), metadata); if (!metadataResolver.getView(tableName).isPresent() && !metadataResolver.getTableHandle(tableName).isPresent()) { throw new SemanticException(MISSING_TABLE, showColumns, "Table '%s' does not exist", tableName); } + accessControl.checkCanShowColumnsMetadata( + session.getRequiredTransactionId(), + session.getIdentity(), + session.getAccessControlContext(), + tableName); + return simpleQuery( selectList( aliasedName("column_name", "Column"), aliasedName("data_type", "Type"), aliasedNullToEmpty("extra_info", "Extra"), - aliasedNullToEmpty("comment", "Comment")), + aliasedNullToEmpty("comment", "Comment"), + aliasedName("precision", "Precision"), + aliasedName("scale", "Scale"), + aliasedName("length", "Length")), from(tableName.getCatalogName(), TABLE_COLUMNS), logicalAnd( equal(identifier("table_schema"), new StringLiteral(tableName.getSchemaName())), @@ -462,7 +479,7 @@ private static Expression toExpression(Object value) protected Node visitShowCreate(ShowCreate node, Void context) { if (node.getType() == SCHEMA) { - CatalogSchemaName catalogSchemaName = createCatalogSchemaName(session, node, Optional.of(node.getName())); + CatalogSchemaName catalogSchemaName = createCatalogSchemaName(session, node, Optional.of(node.getName()), metadata); if (!metadataResolver.schemaExists(catalogSchemaName)) { throw new SemanticException(MISSING_SCHEMA, node, "Schema '%s' does not exist", catalogSchemaName); } @@ -477,7 +494,7 @@ protected Node visitShowCreate(ShowCreate node, Void context) return singleValueQuery("Create Schema", formatSql(createSchema, Optional.of(parameters)).trim()); } - QualifiedObjectName objectName = createQualifiedObjectName(session, node, node.getName()); + QualifiedObjectName objectName = createQualifiedObjectName(session, node, node.getName(), metadata); Optional viewDefinition = metadataResolver.getView(objectName); Optional materializedViewDefinition = metadataResolver.getMaterializedView(objectName); @@ -493,8 +510,11 @@ protected Node visitShowCreate(ShowCreate node, Void context) } Query query = parseView(viewDefinition.get().getOriginalSql(), objectName, node); - CreateView.Security security = (viewDefinition.get().isRunAsInvoker()) ? CreateView.Security.INVOKER : CreateView.Security.DEFINER; - String sql = formatSql(new CreateView(createQualifiedName(objectName), query, false, Optional.of(security)), Optional.of(parameters)).trim(); + + accessControl.checkCanShowCreateTable(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), objectName); + + ViewSecurity security = (viewDefinition.get().isRunAsInvoker()) ? INVOKER : DEFINER; + String sql = formatSql(new CreateView(getQualifiedName(node, objectName), query, false, Optional.of(security)), Optional.of(parameters)).trim(); return singleValueQuery("Create View", sql); } @@ -513,16 +533,23 @@ protected Node visitShowCreate(ShowCreate node, Void context) Query query = parseView(materializedViewDefinition.get().getOriginalSql(), objectName, node); + accessControl.checkCanShowCreateTable(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), objectName); + ConnectorTableMetadata connectorTableMetadata = metadata.getTableMetadata(session, tableHandle.get()).getMetadata(); Map properties = connectorTableMetadata.getProperties(); Map> allTableProperties = metadata.getTablePropertyManager().getAllProperties().get(tableHandle.get().getConnectorId()); List propertyNodes = buildProperties("materialized view " + objectName, INVALID_TABLE_PROPERTY, properties, allTableProperties); + Optional security = SystemSessionProperties.isLegacyMaterializedViews(session) + ? Optional.empty() + : materializedViewDefinition.get().getSecurityMode(); + CreateMaterializedView createMaterializedView = new CreateMaterializedView( Optional.empty(), - createQualifiedName(objectName), + getQualifiedName(node, objectName), query, false, + security, propertyNodes, connectorTableMetadata.getComment()); return singleValueQuery("Create Materialized View", formatSql(createMaterializedView, Optional.of(parameters)).trim()); @@ -541,6 +568,8 @@ protected Node visitShowCreate(ShowCreate node, Void context) throw new SemanticException(MISSING_TABLE, node, "Table '%s' does not exist", objectName); } + accessControl.checkCanShowCreateTable(session.getRequiredTransactionId(), session.getIdentity(), session.getAccessControlContext(), objectName); + ConnectorTableMetadata connectorTableMetadata = metadata.getTableMetadata(session, tableHandle.get()).getMetadata(); Set notNullColumns = connectorTableMetadata.getTableConstraintsHolder().getTableConstraints() @@ -550,7 +579,16 @@ protected Node visitShowCreate(ShowCreate node, Void context) .collect(toImmutableSet()); Map> allColumnProperties = metadata.getColumnPropertyManager().getAllProperties().get(tableHandle.get().getConnectorId()); - List columns = connectorTableMetadata.getColumns().stream() + + List allowedColumns = new ArrayList<>(); + allowedColumns = accessControl.filterColumns( + session.getRequiredTransactionId(), + session.getIdentity(), + session.getAccessControlContext(), + objectName, + connectorTableMetadata.getColumns()); + + List columns = allowedColumns.stream() .filter(column -> !column.isHidden()) .map(column -> { List propertyNodes = buildProperties(toQualifiedName(objectName, Optional.of(column.getName())), INVALID_COLUMN_PROPERTY, column.getProperties(), allColumnProperties); @@ -665,6 +703,16 @@ protected Node visitShowCreateFunction(ShowCreateFunction node, Void context) ordering(ascending("argument_types"))); } + private QualifiedName getQualifiedName(ShowCreate node, QualifiedObjectName objectName) + { + List parts = node.getName().getOriginalParts(); + Identifier tableName = getLast(parts); + Identifier schemaName = parts.size() > 1 ? parts.get(parts.size() - 2) : new Identifier(objectName.getSchemaName()); + Identifier catalogName = (parts.size() > 2) ? parts.get(0) : new Identifier(objectName.getCatalogName()); + + return QualifiedName.of(ImmutableList.of(catalogName, schemaName, tableName)); + } + private List buildProperties( Object objectName, StandardErrorCode errorCode, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java index 7d163c68bd3f5..fbd848eff2376 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/ShowStatsRewrite.java @@ -23,6 +23,7 @@ import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.SqlTime; import com.facebook.presto.common.type.SqlTimestamp; +import com.facebook.presto.common.type.SqlTimestampWithTimeZone; import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; @@ -32,6 +33,7 @@ import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.security.AccessControl; @@ -82,7 +84,9 @@ import static com.facebook.presto.common.type.StandardTypes.DOUBLE; import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; import static com.facebook.presto.sql.QueryUtil.aliased; import static com.facebook.presto.sql.QueryUtil.selectAll; @@ -113,9 +117,10 @@ public Statement rewrite(Session session, Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { - return (Statement) new Visitor(metadata, session, parameters, queryExplainer, warningCollector, query).process(node, null); + return (Statement) new Visitor(metadata, session, parameters, queryExplainer, warningCollector, query, viewDefinitionReferences).process(node, null); } private static class Visitor @@ -127,8 +132,9 @@ private static class Visitor private final Optional queryExplainer; private final WarningCollector warningCollector; private final String sqlString; + private final ViewDefinitionReferences viewDefinitionReferences; - public Visitor(Metadata metadata, Session session, List parameters, Optional queryExplainer, WarningCollector warningCollector, String sqlString) + public Visitor(Metadata metadata, Session session, List parameters, Optional queryExplainer, WarningCollector warningCollector, String sqlString, ViewDefinitionReferences viewDefinitionReferences) { this.metadata = requireNonNull(metadata, "metadata is null"); this.session = requireNonNull(session, "session is null"); @@ -136,6 +142,7 @@ public Visitor(Metadata metadata, Session session, List parameters, this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.sqlString = requireNonNull(sqlString, "sqlString is null"); + this.viewDefinitionReferences = requireNonNull(viewDefinitionReferences, "viewDefinitionReferences is null"); } @Override @@ -146,7 +153,7 @@ protected Node visitShowStats(ShowStats node, Void context) if (node.getRelation() instanceof TableSubquery) { Query query = ((TableSubquery) node.getRelation()).getQuery(); QuerySpecification specification = (QuerySpecification) query.getQueryBody(); - Plan plan = queryExplainer.get().getLogicalPlan(session, new Query(Optional.empty(), specification, Optional.empty(), Optional.empty(), Optional.empty()), parameters, warningCollector, sqlString); + Plan plan = queryExplainer.get().getLogicalPlan(session, new Query(Optional.empty(), specification, Optional.empty(), Optional.empty(), Optional.empty()), parameters, warningCollector, sqlString, viewDefinitionReferences); Set columns = validateShowStatsSubquery(node, query, specification, plan); Table table = (Table) specification.getFrom().get(); Constraint constraint = getConstraint(plan); @@ -255,7 +262,7 @@ private Constraint getConstraint(Plan plan) private TableHandle getTableHandle(ShowStats node, QualifiedName table) { - QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, node, table); + QualifiedObjectName qualifiedTableName = createQualifiedObjectName(session, node, table, metadata); return metadata.getMetadataResolver(session).getTableHandle(qualifiedTableName) .orElseThrow(() -> new SemanticException(MISSING_TABLE, node, "Table %s not found", table)); } @@ -381,6 +388,9 @@ private Expression toStringLiteral(Type type, double value) if (type.equals(TIME)) { return new StringLiteral(new SqlTime(round(value)).toString()); } + if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return new StringLiteral(new SqlTimestampWithTimeZone(round(value) / MICROSECONDS_PER_MILLISECOND, UTC_KEY).toString()); + } throw new IllegalArgumentException("Unexpected type: " + type); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java index 7b7ea1f8c97e6..f65ff4af88243 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/StatementRewrite.java @@ -16,6 +16,7 @@ import com.facebook.presto.Session; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.parser.SqlParser; @@ -39,8 +40,7 @@ public final class StatementRewrite new ShowQueriesRewrite(), new ShowStatsRewrite(), new ExplainRewrite(), - new MaterializedViewOptimizationRewrite(), - new NativeExecutionTypeRewrite()); + new MaterializedViewOptimizationRewrite()); private StatementRewrite() {} @@ -54,10 +54,11 @@ public static Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query) + String query, + ViewDefinitionReferences viewDefinitionReferences) { for (Rewrite rewrite : REWRITES) { - node = requireNonNull(rewrite.rewrite(session, metadata, parser, queryExplainer, node, parameters, parameterLookup, accessControl, warningCollector, query), "Statement rewrite returned null"); + node = requireNonNull(rewrite.rewrite(session, metadata, parser, queryExplainer, node, parameters, parameterLookup, accessControl, warningCollector, query, viewDefinitionReferences), "Statement rewrite returned null"); } return node; } @@ -74,6 +75,7 @@ Statement rewrite( Map, Expression> parameterLookup, AccessControl accessControl, WarningCollector warningCollector, - String query); + String query, + ViewDefinitionReferences viewDefinitionReferences); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/storage/TempStorageManager.java b/presto-main-base/src/main/java/com/facebook/presto/storage/TempStorageManager.java index 5ad2b1f48ac52..5a11dc02fa4de 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/storage/TempStorageManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/storage/TempStorageManager.java @@ -27,8 +27,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.io.File; import java.io.IOException; diff --git a/presto-main-base/src/main/java/com/facebook/presto/tdigest/Centroid.java b/presto-main-base/src/main/java/com/facebook/presto/tdigest/Centroid.java index 6849a1baa050a..66923132494dd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/tdigest/Centroid.java +++ b/presto-main-base/src/main/java/com/facebook/presto/tdigest/Centroid.java @@ -31,7 +31,7 @@ package com.facebook.presto.tdigest; -import javax.annotation.concurrent.NotThreadSafe; +import com.facebook.airlift.concurrent.NotThreadSafe; import java.io.Serializable; import java.util.concurrent.atomic.AtomicInteger; diff --git a/presto-main-base/src/main/java/com/facebook/presto/tdigest/TDigest.java b/presto-main-base/src/main/java/com/facebook/presto/tdigest/TDigest.java index 2b15d45500d86..0a7977cd6fa67 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/tdigest/TDigest.java +++ b/presto-main-base/src/main/java/com/facebook/presto/tdigest/TDigest.java @@ -31,6 +31,7 @@ package com.facebook.presto.tdigest; +import com.facebook.airlift.concurrent.NotThreadSafe; import io.airlift.slice.BasicSliceInput; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; @@ -38,8 +39,6 @@ import io.airlift.slice.SliceOutput; import org.openjdk.jol.info.ClassLayout; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.AbstractCollection; import java.util.ArrayList; import java.util.Arrays; @@ -204,6 +203,13 @@ public static TDigest createTDigest(Slice slice) r.mean = new double[r.activeCentroids]; sliceInput.readBytes(wrappedDoubleArray(r.weight), r.activeCentroids * SIZE_OF_DOUBLE); sliceInput.readBytes(wrappedDoubleArray(r.mean), r.activeCentroids * SIZE_OF_DOUBLE); + + // Validate deserialized TDigest data + for (int i = 0; i < r.activeCentroids; i++) { + checkArgument(!isNaN(r.mean[i]), "Deserialized t-digest contains NaN mean value"); + checkArgument(r.weight[i] > 0, "weight must be > 0"); + } + sliceInput.close(); return r; } @@ -713,6 +719,12 @@ public long estimatedInMemorySizeInBytes() public Slice serialize() { + // Validate data before serialization + for (int i = 0; i < activeCentroids; i++) { + checkArgument(!isNaN(mean[i]), "Cannot serialize t-digest with NaN mean value"); + checkArgument(weight[i] > 0, "Cannot serialize t-digest with non-positive weight"); + } + SliceOutput sliceOutput = new DynamicSliceOutput(toIntExact(estimatedSerializedSizeInBytes())); sliceOutput.writeByte(1); // version 1 of T-Digest serialization diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index e6a1e2495ab06..3a745229b7521 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -14,6 +14,8 @@ package com.facebook.presto.testing; import com.facebook.airlift.node.NodeInfo; +import com.facebook.airlift.units.Duration; +import com.facebook.drift.codec.ThriftCodecManager; import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.PagesIndexPageSorter; @@ -26,8 +28,8 @@ import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.connector.ConnectorCodecManager; import com.facebook.presto.connector.ConnectorManager; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.connector.system.AnalyzePropertiesSystemTable; import com.facebook.presto.connector.system.CatalogSystemTable; import com.facebook.presto.connector.system.ColumnPropertiesSystemTable; @@ -96,18 +98,20 @@ import com.facebook.presto.memory.MemoryManagerConfig; import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.MaterializedViewPropertyManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.MetadataUtil; import com.facebook.presto.metadata.QualifiedTablePrefix; import com.facebook.presto.metadata.SchemaPropertyManager; import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.Driver; @@ -122,7 +126,7 @@ import com.facebook.presto.operator.TableCommitContext; import com.facebook.presto.operator.TaskContext; import com.facebook.presto.operator.index.IndexJoinLookupStats; -import com.facebook.presto.server.ConnectorMetadataUpdateHandleJsonSerde; +import com.facebook.presto.operator.table.ExcludeColumns; import com.facebook.presto.server.NodeStatusNotificationManager; import com.facebook.presto.server.PluginManager; import com.facebook.presto.server.PluginManagerConfig; @@ -150,6 +154,8 @@ import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.StageExecutionDescriptor; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; import com.facebook.presto.spiller.GenericPartitioningSpillerFactory; import com.facebook.presto.spiller.GenericSpillerFactory; @@ -176,6 +182,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -240,7 +247,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Closer; -import io.airlift.units.Duration; import org.intellij.lang.annotations.Language; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.testing.TestingMBeanServer; @@ -317,6 +323,7 @@ public class LocalQueryRunner private final PageSorter pageSorter; private final PageIndexerFactory pageIndexerFactory; private final MetadataManager metadata; + private final ProcedureRegistry procedureRegistry; private final ScalarStatsCalculator scalarStatsCalculator; private final StatsNormalizer statsNormalizer; private final FilterStatsCalculator filterStatsCalculator; @@ -332,8 +339,6 @@ public class LocalQueryRunner private final PartitioningProviderManager partitioningProviderManager; private final NodePartitioningManager nodePartitioningManager; private final ConnectorPlanOptimizerManager planOptimizerManager; - private final ConnectorMetadataUpdaterManager distributedMetadataManager; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; private final PageSinkManager pageSinkManager; private final TransactionManager transactionManager; private final FileSingleStreamSpillerFactory singleStreamSpillerFactory; @@ -431,14 +436,14 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.partitioningProviderManager = new PartitioningProviderManager(); this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, partitioningProviderManager, new NodeSelectionStats()); this.planOptimizerManager = new ConnectorPlanOptimizerManager(); - this.distributedMetadataManager = new ConnectorMetadataUpdaterManager(); - this.connectorTypeSerdeManager = new ConnectorTypeSerdeManager(new ConnectorMetadataUpdateHandleJsonSerde()); this.blockEncodingManager = new BlockEncodingManager(); featuresConfig.setIgnoreStatsCalculatorFailures(false); + FunctionAndTypeManager functionAndTypeManager = new FunctionAndTypeManager(transactionManager, new TableFunctionRegistry(), blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()); + this.procedureRegistry = new BuiltInProcedureRegistry(functionAndTypeManager); this.metadata = new MetadataManager( - new FunctionAndTypeManager(transactionManager, blockEncodingManager, featuresConfig, functionsConfig, new HandleResolver(), ImmutableSet.of()), + functionAndTypeManager, blockEncodingManager, createTestingSessionPropertyManager( new SystemSessionProperties( @@ -458,9 +463,11 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, nodeSpillConfig), new SchemaPropertyManager(), new TablePropertyManager(), + new MaterializedViewPropertyManager(), new ColumnPropertyManager(), new AnalyzePropertyManager(), - transactionManager); + transactionManager, + procedureRegistry); this.splitManager = new SplitManager(metadata, new QueryManagerConfig(), nodeSchedulerConfig); this.planCheckerProviderManager = new PlanCheckerProviderManager(new JsonCodecSimplePlanFragmentSerde(jsonCodec(SimplePlanFragment.class)), new PlanCheckerProviderManagerConfig()); this.distributedPlanChecker = new PlanChecker(featuresConfig, false, planCheckerProviderManager); @@ -470,7 +477,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.pageIndexerFactory = new GroupByHashPageIndexerFactory(joinCompiler); NodeInfo nodeInfo = new NodeInfo("test"); - expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager()); + expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()), getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); this.accessControl = new TestingAccessControlManager(transactionManager); this.statsNormalizer = new StatsNormalizer(); @@ -478,7 +485,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, this.filterStatsCalculator = new FilterStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer); this.historyBasedPlanStatisticsManager = new HistoryBasedPlanStatisticsManager(objectMapper, createTestingSessionPropertyManager(), metadata, new HistoryBasedOptimizationConfig(), featuresConfig, new NodeVersion("1")); this.fragmentStatsProvider = new FragmentStatsProvider(); - this.statsCalculator = createNewStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer, filterStatsCalculator, historyBasedPlanStatisticsManager, fragmentStatsProvider); + this.statsCalculator = createNewStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer, filterStatsCalculator, historyBasedPlanStatisticsManager, fragmentStatsProvider, expressionOptimizerManager); this.taskCountEstimator = new TaskCountEstimator(() -> nodeCountForStats); this.costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, taskCountEstimator); @@ -499,13 +506,12 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, indexManager, partitioningProviderManager, planOptimizerManager, - distributedMetadataManager, - connectorTypeSerdeManager, pageSinkManager, new HandleResolver(), nodeManager, nodeInfo, metadata.getFunctionAndTypeManager(), + procedureRegistry, pageSorter, pageIndexerFactory, transactionManager, @@ -515,7 +521,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()), new FilterStatsCalculator(metadata, scalarStatsCalculator, statsNormalizer), blockEncodingManager, - featuresConfig); + featuresConfig, + new ConnectorCodecManager(ThriftCodecManager::new)); GlobalSystemConnectorFactory globalSystemConnectorFactory = new GlobalSystemConnectorFactory(ImmutableSet.of( new NodeSystemTable(nodeManager), @@ -525,11 +532,12 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ColumnPropertiesSystemTable(transactionManager, metadata), new AnalyzePropertiesSystemTable(transactionManager, metadata), new TransactionsSystemTable(metadata.getFunctionAndTypeManager(), transactionManager)), - ImmutableSet.of()); + ImmutableSet.of(), + ImmutableSet.of(new ExcludeColumns.ExcludeColumnsFunction())); BuiltInQueryAnalyzer queryAnalyzer = new BuiltInQueryAnalyzer(metadata, sqlParser, accessControl, Optional.empty(), metadataExtractorExecutor); BuiltInAnalyzerProvider analyzerProvider = new BuiltInAnalyzerProvider(queryAnalyzer); - BuiltInQueryPreparer queryPreparer = new BuiltInQueryPreparer(sqlParser); + BuiltInQueryPreparer queryPreparer = new BuiltInQueryPreparer(sqlParser, procedureRegistry); BuiltInQueryPreparerProvider queryPreparerProvider = new BuiltInQueryPreparerProvider(queryPreparer); this.pluginManager = new PluginManager( @@ -777,7 +785,8 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override @@ -918,7 +927,7 @@ private MaterializedResultWithPlan executeInternal(Session session, @Language("S private MaterializedResultWithPlan executeExplainTypeValidate(String sql, Session session, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -939,7 +948,7 @@ private MaterializedResultWithPlan executeExplainTypeValidate(String sql, Sessio AnalyzerContext analyzerContext = getAnalyzerContext(queryAnalyzer, metadata.getMetadataResolver(session), idAllocator, new VariableAllocator(), session, sql); QueryAnalysis queryAnalysis = queryAnalyzer.analyze(analyzerContext, preparedQuery); - checkAccessPermissions(queryAnalysis.getAccessControlReferences(), sql); + checkAccessPermissions(queryAnalysis.getAccessControlReferences(), queryAnalysis.getViewDefinitionReferences(), sql, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); MaterializedResult result = MaterializedResult.resultBuilder(session, BooleanType.BOOLEAN) .row(true) @@ -950,7 +959,7 @@ private MaterializedResultWithPlan executeExplainTypeValidate(String sql, Sessio private boolean isExplainTypeValidate(String sql, Session session, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - PreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + PreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); return preparedQuery.isExplainTypeValidate(); } @@ -990,7 +999,6 @@ private List createDrivers(Session session, Plan plan, OutputFactory out partitioningProviderManager, nodePartitioningManager, pageSinkManager, - distributedMetadataManager, expressionCompiler, pageFunctionCompiler, joinFilterFunctionCompiler, @@ -1124,12 +1132,17 @@ public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.P } public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.PlanStage stage, boolean noExchange, WarningCollector warningCollector) + { + return createPlan(session, sql, stage, noExchange, false, warningCollector); + } + + public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.PlanStage stage, boolean noExchange, boolean nativeExecutionEnabled, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); - return createPlan(session, sql, getPlanOptimizers(noExchange), stage, warningCollector); + return createPlan(session, sql, getPlanOptimizers(noExchange, nativeExecutionEnabled), stage, warningCollector); } public void setAdditionalOptimizer(List additionalOptimizer) @@ -1138,10 +1151,16 @@ public void setAdditionalOptimizer(List additionalOptimizer) } public List getPlanOptimizers(boolean noExchange) + { + return getPlanOptimizers(noExchange, false); + } + + public List getPlanOptimizers(boolean noExchange, boolean nativeExecutionEnabled) { FeaturesConfig featuresConfig = new FeaturesConfig() .setDistributedIndexJoinsEnabled(false) - .setOptimizeHashGeneration(true); + .setOptimizeHashGeneration(true) + .setNativeExecutionEnabled(nativeExecutionEnabled); ImmutableList.Builder planOptimizers = ImmutableList.builder(); if (!additionalOptimizer.isEmpty()) { planOptimizers.addAll(additionalOptimizer); @@ -1162,7 +1181,8 @@ public List getPlanOptimizers(boolean noExchange) partitioningProviderManager, featuresConfig, expressionOptimizerManager, - taskManagerConfig).getPlanningTimeOptimizers()); + taskManagerConfig, + accessControl).getPlanningTimeOptimizers()); return planOptimizers.build(); } @@ -1174,7 +1194,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List optimizers, Optimizer.PlanStage stage, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); - BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); + BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser, procedureRegistry).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); @@ -1195,7 +1215,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List types; private final Map setSessionProperties; private final Set resetSessionProperties; - private final Optional updateType; + private final Optional updateInfo; private final OptionalLong updateCount; + private final Optional startedTransactionId; + private final boolean clearTransactionId; private final List warnings; public MaterializedResult(List rows, List types) { - this(rows, types, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of()); + this(rows, types, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), Optional.empty(), false, ImmutableList.of()); } public MaterializedResult( @@ -111,16 +115,20 @@ public MaterializedResult( List types, Map setSessionProperties, Set resetSessionProperties, - Optional updateType, + Optional updateInfo, OptionalLong updateCount, + Optional startedTransactionId, + boolean clearTransactionId, List warnings) { this.rows = ImmutableList.copyOf(requireNonNull(rows, "rows is null")); this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); this.setSessionProperties = ImmutableMap.copyOf(requireNonNull(setSessionProperties, "setSessionProperties is null")); this.resetSessionProperties = ImmutableSet.copyOf(requireNonNull(resetSessionProperties, "resetSessionProperties is null")); - this.updateType = requireNonNull(updateType, "updateType is null"); + this.updateInfo = requireNonNull(updateInfo, "updateInfo is null"); this.updateCount = requireNonNull(updateCount, "updateCount is null"); + this.startedTransactionId = requireNonNull(startedTransactionId, "startedTransactionId is null"); + this.clearTransactionId = clearTransactionId; this.warnings = requireNonNull(warnings, "warnings is null"); } @@ -155,9 +163,9 @@ public Set getResetSessionProperties() return resetSessionProperties; } - public Optional getUpdateType() + public Optional getUpdateInfo() { - return updateType; + return updateInfo; } public OptionalLong getUpdateCount() @@ -165,6 +173,16 @@ public OptionalLong getUpdateCount() return updateCount; } + public Optional getStartedTransactionId() + { + return startedTransactionId; + } + + public boolean isClearTransactionId() + { + return clearTransactionId; + } + public List getWarnings() { return warnings; @@ -184,14 +202,16 @@ public boolean equals(Object obj) Objects.equals(rows, o.rows) && Objects.equals(setSessionProperties, o.setSessionProperties) && Objects.equals(resetSessionProperties, o.resetSessionProperties) && - Objects.equals(updateType, o.updateType) && - Objects.equals(updateCount, o.updateCount); + Objects.equals(updateInfo, o.updateInfo) && + Objects.equals(updateCount, o.updateCount) && + Objects.equals(startedTransactionId, o.startedTransactionId) && + Objects.equals(clearTransactionId, o.clearTransactionId); } @Override public int hashCode() { - return Objects.hash(rows, types, setSessionProperties, resetSessionProperties, updateType, updateCount); + return Objects.hash(rows, types, setSessionProperties, resetSessionProperties, updateInfo, updateCount, startedTransactionId, clearTransactionId); } @Override @@ -202,8 +222,10 @@ public String toString() .add("types", types) .add("setSessionProperties", setSessionProperties) .add("resetSessionProperties", resetSessionProperties) - .add("updateType", updateType.orElse(null)) + .add("updateInfo", updateInfo.orElse(null)) .add("updateCount", updateCount.isPresent() ? updateCount.getAsLong() : null) + .add("startedTransactionId", startedTransactionId.orElse(null)) + .add("clearTransactionId", clearTransactionId) .omitNullValues() .toString(); } @@ -358,8 +380,10 @@ public MaterializedResult toTestTypes() types, setSessionProperties, resetSessionProperties, - updateType, + updateInfo, updateCount, + startedTransactionId, + clearTransactionId, warnings); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java index 7132cbea4bc8f..f6c1a0b9b1405 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/QueryRunner.java @@ -109,7 +109,7 @@ default void installCoordinatorPlugin(CoordinatorPlugin plugin) void loadFunctionNamespaceManager(String functionNamespaceManagerName, String catalogName, Map properties); - default void loadSessionPropertyProvider(String sessionPropertyProviderName) + default void loadSessionPropertyProvider(String sessionPropertyProviderName, Map properties) { throw new UnsupportedOperationException(); } @@ -121,6 +121,16 @@ default void loadTypeManager(String typeManagerName) throw new UnsupportedOperationException(); } + default void loadPlanCheckerProviderManager(String planCheckerProviderName, Map properties) + { + throw new UnsupportedOperationException(); + } + + default void triggerConflictCheckWithBuiltInFunctions() + { + throw new UnsupportedOperationException(); + } + class MaterializedResultWithPlan { private final MaterializedResult materializedResult; diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java new file mode 100644 index 0000000000000..722adad98c453 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestProcedureRegistry.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.testing; + +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorProcedureContext; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.ProcedureRegistry; + +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.spi.StandardErrorCode.PROCEDURE_NOT_FOUND; +import static java.util.Objects.requireNonNull; + +public class TestProcedureRegistry + implements ProcedureRegistry +{ + private final Map>> connectorProcedures = new ConcurrentHashMap<>(); + + @Override + public void addProcedures(ConnectorId connectorId, Collection> procedures) + { + requireNonNull(connectorId, "connectorId is null"); + requireNonNull(procedures, "procedures is null"); + + Map> proceduresByName = procedures.stream().collect(Collectors.toMap( + procedure -> new SchemaTableName(procedure.getSchema(), procedure.getName()), + Function.identity())); + if (connectorProcedures.putIfAbsent(connectorId, proceduresByName) != null) { + throw new IllegalStateException("Procedures already registered for connector: " + connectorId); + } + } + + @Override + public void removeProcedures(ConnectorId connectorId) + { + connectorProcedures.remove(connectorId); + } + + @Override + public BaseProcedure resolve(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + BaseProcedure procedure = procedures.get(name); + if (procedure != null) { + return procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Procedure not registered: " + name); + } + + @Override + public DistributedProcedure resolveDistributed(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + if (procedures != null) { + BaseProcedure procedure = procedures.get(name); + if (procedure instanceof DistributedProcedure) { + return (DistributedProcedure) procedure; + } + } + throw new PrestoException(PROCEDURE_NOT_FOUND, "Distributed procedure not registered: " + name); + } + + @Override + public boolean isDistributedProcedure(ConnectorId connectorId, SchemaTableName name) + { + Map> procedures = connectorProcedures.get(connectorId); + return procedures != null && + procedures.containsKey(name) && + procedures.get(name) instanceof DistributedProcedure; + } + + public static class TestProcedureContext + implements ConnectorProcedureContext + {} +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java index 9a7be061f767b..7ccec9adafbc0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingAccessControlManager.java @@ -26,8 +26,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.security.Principal; import java.util.ArrayList; @@ -42,6 +41,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denyAddColumn; import static com.facebook.presto.spi.security.AccessDeniedException.denyAddConstraint; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCallProcedure; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateSchema; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyCreateView; @@ -62,6 +62,7 @@ import static com.facebook.presto.spi.security.AccessDeniedException.denySetSystemSessionProperty; import static com.facebook.presto.spi.security.AccessDeniedException.denySetTableProperties; import static com.facebook.presto.spi.security.AccessDeniedException.denySetUser; +import static com.facebook.presto.spi.security.AccessDeniedException.denyShowCreateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyTruncateTable; import static com.facebook.presto.spi.security.AccessDeniedException.denyUpdateTableColumns; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.ADD_COLUMN; @@ -76,6 +77,7 @@ import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_SCHEMA; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_VIEW; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.EXECUTE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.RENAME_COLUMN; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.RENAME_SCHEMA; @@ -85,6 +87,7 @@ import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SET_SESSION; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SET_TABLE_PROPERTIES; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SET_USER; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SHOW_CREATE_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.TRUNCATE_TABLE; import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.UPDATE_TABLE; import static com.google.common.base.MoreObjects.toStringHelper; @@ -100,9 +103,16 @@ public class TestingAccessControlManager @Inject public TestingAccessControlManager(TransactionManager transactionManager) + { + this(transactionManager, true); + } + + public TestingAccessControlManager(TransactionManager transactionManager, boolean loadDefaultSystemAccessControl) { super(transactionManager); - setSystemAccessControl(AllowAllSystemAccessControl.NAME, ImmutableMap.of()); + if (loadDefaultSystemAccessControl) { + setSystemAccessControl(AllowAllSystemAccessControl.NAME, ImmutableMap.of()); + } } public static TestingPrivilege privilege(String entityName, TestingPrivilegeType type) @@ -182,6 +192,17 @@ public void checkCanRenameSchema(TransactionId transactionId, Identity identity, } } + @Override + public void checkCanShowCreateTable(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) + { + if (shouldDenyPrivilege(identity.getUser(), tableName.getObjectName(), SHOW_CREATE_TABLE)) { + denyShowCreateTable(tableName.toString()); + } + if (denyPrivileges.isEmpty()) { + super.checkCanShowCreateTable(transactionId, identity, context, tableName); + } + } + @Override public void checkCanCreateTable(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) { @@ -381,6 +402,15 @@ public void checkCanSelectFromColumns(TransactionId transactionId, Identity iden } } + @Override + public void checkCanCallProcedure(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName procedureName) + { + if (shouldDenyPrivilege(identity.getUser(), procedureName.getObjectName(), EXECUTE)) { + denyCallProcedure(procedureName.toString()); + } + super.checkCanCallProcedure(transactionId, identity, context, procedureName); + } + @Override public void checkCanDropConstraint(TransactionId transactionId, Identity identity, AccessControlContext context, QualifiedObjectName tableName) { @@ -437,11 +467,13 @@ public enum TestingPrivilegeType { SET_USER, CREATE_SCHEMA, DROP_SCHEMA, RENAME_SCHEMA, - CREATE_TABLE, DROP_TABLE, RENAME_TABLE, INSERT_TABLE, DELETE_TABLE, TRUNCATE_TABLE, UPDATE_TABLE, + SHOW_CREATE_TABLE, CREATE_TABLE, DROP_TABLE, RENAME_TABLE, INSERT_TABLE, DELETE_TABLE, TRUNCATE_TABLE, UPDATE_TABLE, ADD_COLUMN, DROP_COLUMN, RENAME_COLUMN, SELECT_COLUMN, + DROP_BRANCH, DROP_TAG, ADD_CONSTRAINT, DROP_CONSTRAINT, CREATE_VIEW, RENAME_VIEW, DROP_VIEW, CREATE_VIEW_WITH_SELECT_COLUMNS, SET_TABLE_PROPERTIES, - SET_SESSION + SET_SESSION, + EXECUTE } public static class TestingPrivilege diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingConnectorSession.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingConnectorSession.java index fcf3e95974b6b..9cffb7efc7ae8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingConnectorSession.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingConnectorSession.java @@ -13,8 +13,10 @@ */ package com.facebook.presto.testing; +import com.facebook.presto.FullConnectorSession; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.function.SqlFunctionProperties; +import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.execution.QueryIdGenerator; import com.facebook.presto.spi.ConnectorId; @@ -39,12 +41,13 @@ import static com.facebook.presto.common.type.TimeZoneKey.UTC_KEY; import static com.facebook.presto.spi.StandardErrorCode.INVALID_SESSION_PROPERTY; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; public class TestingConnectorSession - implements ConnectorSession + extends FullConnectorSession { private static final QueryIdGenerator queryIdGenerator = new QueryIdGenerator(); public static final ConnectorSession SESSION = new TestingConnectorSession(ImmutableList.of()); @@ -105,6 +108,7 @@ public TestingConnectorSession( Optional schema, Map sessionFunctions) { + super(testSessionBuilder().build(), identity); this.queryId = queryIdGenerator.createNextQueryId().toString(); this.identity = requireNonNull(identity, "identity is null"); this.source = requireNonNull(source, "source is null"); @@ -225,6 +229,12 @@ public RuntimeStats getRuntimeStats() return new RuntimeStats(); } + @Override + public Optional getQueryType() + { + return Optional.of(QueryType.SELECT); + } + @Override public ConnectorSession forConnectorId(ConnectorId connectorId) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java index 34de904e9068e..40294ba5dcb0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandle.java @@ -13,12 +13,13 @@ */ package com.facebook.presto.testing; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorInsertTableHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorTableLayoutHandle; public enum TestingHandle - implements ConnectorOutputTableHandle, ConnectorInsertTableHandle, ConnectorTableLayoutHandle + implements ConnectorOutputTableHandle, ConnectorInsertTableHandle, ConnectorTableLayoutHandle, ConnectorDistributedProcedureHandle { INSTANCE } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java index 4da22baca8df7..0421d2c316f58 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingHandleResolver.java @@ -14,9 +14,9 @@ package com.facebook.presto.testing; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; import com.facebook.presto.spi.ConnectorOutputTableHandle; import com.facebook.presto.spi.ConnectorSplit; import com.facebook.presto.spi.ConnectorTableHandle; @@ -65,14 +65,14 @@ public Class getInsertTableHandleClass() } @Override - public Class getTransactionHandleClass() + public Class getDistributedProcedureHandleClass() { - return TestingTransactionHandle.class; + return TestingHandle.class; } @Override - public Class getMetadataUpdateHandleClass() + public Class getTransactionHandleClass() { - return TestingMetadataUpdateHandle.class; + return TestingTransactionHandle.class; } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadata.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadata.java index 0c84187591563..0cdd1bdbd1641 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadata.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadata.java @@ -146,7 +146,7 @@ public Map> listTableColumns(ConnectorSess ImmutableList.Builder columns = ImmutableList.builder(); for (ColumnMetadata column : tables.get(tableName).getColumns()) { columns.add(ColumnMetadata.builder() - .setName(column.getName()) + .setName(normalizeIdentifier(session, column.getName())) .setType(column.getType()) .build()); } @@ -289,7 +289,7 @@ public void renameColumn(ConnectorSession session, ConnectorTableHandle tableHan ColumnMetadata columnMetadata = getColumnMetadata(session, tableHandle, source); List columns = new ArrayList<>(tableMetadata.getColumns()); columns.set(columns.indexOf(columnMetadata), ColumnMetadata.builder() - .setName(target) + .setName(normalizeIdentifier(session, target)) .setType(columnMetadata.getType()) .setComment(columnMetadata.getComment().orElse(null)) .setHidden(columnMetadata.isHidden()) diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadataUpdateHandle.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadataUpdateHandle.java deleted file mode 100644 index 207d83102b0c2..0000000000000 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingMetadataUpdateHandle.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.testing; - -import com.facebook.drift.annotations.ThriftConstructor; -import com.facebook.drift.annotations.ThriftField; -import com.facebook.drift.annotations.ThriftStruct; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - -@ThriftStruct -public class TestingMetadataUpdateHandle - implements ConnectorMetadataUpdateHandle -{ - private final int value; - - @ThriftConstructor - @JsonCreator - public TestingMetadataUpdateHandle(@JsonProperty("value") int value) - { - this.value = value; - } - - @ThriftField(1) - @JsonProperty - public int getValue() - { - return value; - } -} diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingPrestoServerModule.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingPrestoServerModule.java new file mode 100644 index 0000000000000..a0139a22ac673 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingPrestoServerModule.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.testing; + +import com.facebook.presto.eventlistener.EventListenerConfig; +import com.facebook.presto.eventlistener.EventListenerManager; +import com.facebook.presto.security.AccessControlManager; +import com.facebook.presto.server.GracefulShutdownHandler; +import com.facebook.presto.server.security.PrestoAuthenticatorManager; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.storage.TempStorageManager; +import com.facebook.presto.transaction.TransactionManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.Singleton; + +import static java.util.Objects.requireNonNull; + +public class TestingPrestoServerModule + implements Module +{ + private final boolean loadDefaultSystemAccessControl; + + public TestingPrestoServerModule(boolean loadDefaultSystemAccessControl) + { + this.loadDefaultSystemAccessControl = loadDefaultSystemAccessControl; + } + + @Override + public void configure(Binder binder) + { + binder.bind(PrestoAuthenticatorManager.class).in(Scopes.SINGLETON); + binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON); + binder.bind(TestingTempStorageManager.class).in(Scopes.SINGLETON); + binder.bind(EventListenerManager.class).to(TestingEventListenerManager.class).in(Scopes.SINGLETON); + binder.bind(EventListenerConfig.class).in(Scopes.SINGLETON); + binder.bind(TempStorageManager.class).to(TestingTempStorageManager.class).in(Scopes.SINGLETON); + binder.bind(AccessControl.class).to(AccessControlManager.class).in(Scopes.SINGLETON); + binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); + binder.bind(ProcedureTester.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + public AccessControlManager createAccessControlManager(TransactionManager transactionManager) + { + requireNonNull(transactionManager, "transactionManager is null"); + return new TestingAccessControlManager(transactionManager, loadDefaultSystemAccessControl); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingTaskContext.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingTaskContext.java index edbfc8486d93f..9fb8f52d37c83 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingTaskContext.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingTaskContext.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.GcMonitor; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskStateMachine; @@ -26,15 +27,14 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spiller.SpillSpaceTracker; -import io.airlift.units.DataSize; import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public final class TestingTaskContext { diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingWarningCollector.java b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingWarningCollector.java index 3c265596ab512..729fdb60c21f8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/TestingWarningCollector.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/TestingWarningCollector.java @@ -21,9 +21,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Multimap; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; diff --git a/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteCodecProvider.java b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteCodecProvider.java new file mode 100644 index 0000000000000..fa68ecc23d993 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteCodecProvider.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.thrift; + +import com.facebook.drift.codec.ThriftCodecManager; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.inject.Provider; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class RemoteCodecProvider + implements ConnectorCodecProvider +{ + private final Provider thriftCodecManagerProvider; + + public RemoteCodecProvider(Provider thriftCodecManagerProvider) + { + this.thriftCodecManagerProvider = requireNonNull(thriftCodecManagerProvider, "thriftCodecManagerProvider is null"); + } + + @Override + public Optional> getConnectorSplitCodec() + { + return Optional.of(new RemoteSplitCodec(thriftCodecManagerProvider)); + } + + @Override + public Optional> getConnectorTransactionHandleCodec() + { + return Optional.of(new RemoteTransactionHandleCodec(thriftCodecManagerProvider)); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteSplitCodec.java b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteSplitCodec.java new file mode 100644 index 0000000000000..f8d1c90f1c35f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteSplitCodec.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.thrift; + +import com.facebook.drift.codec.ThriftCodecManager; +import com.facebook.drift.protocol.TProtocolException; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.split.RemoteSplit; +import com.google.inject.Provider; + +import static com.facebook.presto.server.thrift.ThriftCodecUtils.fromThrift; +import static com.facebook.presto.server.thrift.ThriftCodecUtils.toThrift; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static java.util.Objects.requireNonNull; + +public class RemoteSplitCodec + implements ConnectorCodec +{ + private final Provider thriftCodecManagerProvider; + + public RemoteSplitCodec(Provider thriftCodecManagerProvider) + { + this.thriftCodecManagerProvider = requireNonNull(thriftCodecManagerProvider, "thriftCodecManagerProvider is null"); + } + + @Override + public byte[] serialize(ConnectorSplit split) + { + try { + return toThrift((RemoteSplit) split, thriftCodecManagerProvider.get().getCodec(RemoteSplit.class)); + } + catch (TProtocolException e) { + throw new PrestoException(INVALID_ARGUMENTS, "Can not serialize remote split", e); + } + } + + @Override + public ConnectorSplit deserialize(byte[] bytes) + { + try { + return fromThrift(bytes, thriftCodecManagerProvider.get().getCodec(RemoteSplit.class)); + } + catch (TProtocolException e) { + throw new PrestoException(INVALID_ARGUMENTS, "Can not deserialize remote split", e); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteTransactionHandleCodec.java b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteTransactionHandleCodec.java new file mode 100644 index 0000000000000..b3a664042764a --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/thrift/RemoteTransactionHandleCodec.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.thrift; + +import com.facebook.drift.codec.ThriftCodecManager; +import com.facebook.drift.protocol.TProtocolException; +import com.facebook.presto.metadata.RemoteTransactionHandle; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.inject.Provider; + +import static com.facebook.presto.server.thrift.ThriftCodecUtils.fromThrift; +import static com.facebook.presto.server.thrift.ThriftCodecUtils.toThrift; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static java.util.Objects.requireNonNull; + +public class RemoteTransactionHandleCodec + implements ConnectorCodec +{ + private final Provider thriftCodecManagerProvider; + + public RemoteTransactionHandleCodec(Provider thriftCodecManagerProvider) + { + this.thriftCodecManagerProvider = requireNonNull(thriftCodecManagerProvider, "thriftCodecManagerProvider is null"); + } + + @Override + public byte[] serialize(ConnectorTransactionHandle handle) + { + try { + return toThrift((RemoteTransactionHandle) handle, thriftCodecManagerProvider.get().getCodec(RemoteTransactionHandle.class)); + } + catch (TProtocolException e) { + throw new PrestoException(INVALID_ARGUMENTS, "Can not serialize remote transaction handle", e); + } + } + + @Override + public ConnectorTransactionHandle deserialize(byte[] bytes) + { + try { + return fromThrift(bytes, thriftCodecManagerProvider.get().getCodec(RemoteTransactionHandle.class)); + } + catch (TProtocolException e) { + throw new PrestoException(INVALID_ARGUMENTS, "Can not deserialize remote transaction handle", e); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/tracing/TracerProviderManager.java b/presto-main-base/src/main/java/com/facebook/presto/tracing/TracerProviderManager.java index 6c1f35cd061fa..c63e95ffd3f9d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/tracing/TracerProviderManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/tracing/TracerProviderManager.java @@ -15,8 +15,7 @@ import com.facebook.presto.spi.tracing.TracerProvider; import com.google.inject.Inject; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/transaction/ForTransactionManager.java b/presto-main-base/src/main/java/com/facebook/presto/transaction/ForTransactionManager.java index 7c1006f383081..bbf98f8cfa083 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/transaction/ForTransactionManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/transaction/ForTransactionManager.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.transaction; -import javax.inject.Qualifier; +import jakarta.inject.Qualifier; import java.lang.annotation.Retention; import java.lang.annotation.Target; diff --git a/presto-main-base/src/main/java/com/facebook/presto/transaction/InMemoryTransactionManager.java b/presto-main-base/src/main/java/com/facebook/presto/transaction/InMemoryTransactionManager.java index b10b995c28d0c..bd3a4dbd4f1fd 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/transaction/InMemoryTransactionManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/transaction/InMemoryTransactionManager.java @@ -16,8 +16,10 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.concurrent.ExecutorServiceAdapter; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.metadata.Catalog; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.CatalogMetadata; import com.facebook.presto.spi.ConnectorId; @@ -34,10 +36,8 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.util.HashMap; import java.util.Iterator; @@ -215,6 +215,12 @@ public Map getCatalogNames(TransactionId transactionId) return getTransactionMetadata(transactionId).getCatalogNames(); } + @Override + public Map getCatalogNamesWithConnectorContext(TransactionId transactionId) + { + return getTransactionMetadata(transactionId).getCatalogNamesWithConnectorContext(); + } + @Override public Optional getOptionalCatalogMetadata(TransactionId transactionId, String catalogName) { @@ -463,6 +469,20 @@ private synchronized Map getCatalogNames() return ImmutableMap.copyOf(catalogNames); } + private synchronized Map getCatalogNamesWithConnectorContext() + { + Map catalogNamesWithConnectorContext = new HashMap<>(); + catalogByName.values().stream() + .filter(Optional::isPresent) + .map(Optional::get) + .forEach(catalog -> catalogNamesWithConnectorContext.put(catalog.getCatalogName(), catalog.getCatalogContext())); + + catalogManager.getCatalogs().stream() + .forEach(catalog -> catalogNamesWithConnectorContext.putIfAbsent(catalog.getCatalogName(), catalog.getCatalogContext())); + + return ImmutableMap.copyOf(catalogNamesWithConnectorContext); + } + private synchronized Optional getConnectorId(String catalogName) { Optional catalog = catalogByName.get(catalogName); diff --git a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionInfo.java b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionInfo.java index 1a3999a4977cc..8eeb196a5b709 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionInfo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionInfo.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.transaction; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -20,7 +21,6 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.transaction.IsolationLevel; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import java.util.List; import java.util.Optional; diff --git a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManager.java b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManager.java index b7d3e0ccdca4a..ecacaf515e0ca 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.metadata.Catalog.CatalogContext; import com.facebook.presto.metadata.CatalogMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; @@ -22,6 +23,7 @@ import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.transaction.IsolationLevel; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import java.util.List; @@ -47,6 +49,11 @@ public interface TransactionManager Map getCatalogNames(TransactionId transactionId); + default Map getCatalogNamesWithConnectorContext(TransactionId transactionId) + { + return ImmutableMap.of(); + } + Optional getOptionalCatalogMetadata(TransactionId transactionId, String catalogName); void enableRollback(TransactionId transactionId); diff --git a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManagerConfig.java index 2b38c16ea4d1b..2b94e5ee43037 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/transaction/TransactionManagerConfig.java @@ -15,13 +15,12 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.Map; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/ConfidenceBasedNodeTtlFetcherManager.java b/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/ConfidenceBasedNodeTtlFetcherManager.java index a716fc45095ee..40f69ea28e380 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/ConfidenceBasedNodeTtlFetcherManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/ConfidenceBasedNodeTtlFetcherManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.metadata.InternalNode; @@ -28,13 +29,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import io.airlift.units.Duration; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.io.File; import java.time.Instant; import java.time.temporal.ChronoUnit; diff --git a/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/NodeTtlFetcherManagerConfig.java b/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/NodeTtlFetcherManagerConfig.java index 5d9fd626b55f4..7103ef976adf2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/NodeTtlFetcherManagerConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/ttl/nodettlfetchermanagers/NodeTtlFetcherManagerConfig.java @@ -14,7 +14,7 @@ package com.facebook.presto.ttl.nodettlfetchermanagers; import com.facebook.airlift.configuration.Config; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; public class NodeTtlFetcherManagerConfig { diff --git a/presto-main-base/src/main/java/com/facebook/presto/type/BigintOperators.java b/presto-main-base/src/main/java/com/facebook/presto/type/BigintOperators.java index 92dc55dd997db..4689354736795 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/type/BigintOperators.java +++ b/presto-main-base/src/main/java/com/facebook/presto/type/BigintOperators.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.function.BlockIndex; import com.facebook.presto.spi.function.BlockPosition; import com.facebook.presto.spi.function.IsNull; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarOperator; import com.facebook.presto.spi.function.SqlNullable; diff --git a/presto-main-base/src/main/java/com/facebook/presto/type/LikeFunctions.java b/presto-main-base/src/main/java/com/facebook/presto/type/LikeFunctions.java index 7b89abc3bf48d..672d8b3542ff8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/type/LikeFunctions.java +++ b/presto-main-base/src/main/java/com/facebook/presto/type/LikeFunctions.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.ScalarOperator; diff --git a/presto-main-base/src/main/java/com/facebook/presto/type/TypeDeserializer.java b/presto-main-base/src/main/java/com/facebook/presto/type/TypeDeserializer.java index 5cf48f762ca70..af0701477dc2b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/type/TypeDeserializer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/type/TypeDeserializer.java @@ -17,8 +17,7 @@ import com.facebook.presto.common.type.TypeManager; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java b/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java index f43ce90779cb7..8c795658e8360 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/AnalyzerUtil.java @@ -16,7 +16,6 @@ import com.facebook.presto.Session; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.transaction.TransactionId; -import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; @@ -26,7 +25,7 @@ import com.facebook.presto.spi.analyzer.AnalyzerOptions; import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.QueryAnalyzer; -import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AccessControlContext; @@ -80,6 +79,8 @@ public static AnalyzerOptions createAnalyzerOptions(Session session, WarningColl .setLogFormattedQueryEnabled(isLogFormattedQueryEnabled(session)) .setWarningHandlingLevel(getWarningHandlingLevel(session)) .setWarningCollector(warningCollector) + .setSessionCatalogName(session.getCatalog()) + .setSessionSchemaName(session.getSchema()) .build(); } @@ -99,29 +100,19 @@ public static AnalyzerContext getAnalyzerContext( return new AnalyzerContext(metadataResolver, idAllocator, variableAllocator, query); } - public static void checkAccessPermissions(AccessControlReferences accessControlReferences, String query) + public static void checkAccessPermissions(AccessControlReferences accessControlReferences, ViewDefinitionReferences viewDefinitionReferences, String query, Map preparedStatements, Identity identity, AccessControl accessControl, AccessControlContext accessControlContext) { // Query check - checkAccessPermissionsForQuery(accessControlReferences, query); - // Table checks - checkAccessPermissionsForTable(accessControlReferences); - // Table Column checks - checkAccessPermissionsForColumns(accessControlReferences); + checkQueryIntegrity(identity, accessControl, accessControlContext, query, preparedStatements, viewDefinitionReferences); + + //Table and column checks + checkAccessPermissionsForTablesAndColumns(accessControlReferences); } - private static void checkAccessPermissionsForQuery(AccessControlReferences accessControlReferences, String query) + public static void checkAccessPermissionsForTablesAndColumns(AccessControlReferences accessControlReferences) { - AccessControlInfo queryAccessControlInfo = accessControlReferences.getQueryAccessControlInfo(); - // Only check access if query gets analyzed - if (queryAccessControlInfo != null) { - AccessControl queryAccessControl = queryAccessControlInfo.getAccessControl(); - Identity identity = queryAccessControlInfo.getIdentity(); - AccessControlContext queryAccessControlContext = queryAccessControlInfo.getAccessControlContext(); - Map viewDefinitionMap = accessControlReferences.getViewDefinitions(); - Map materializedViewDefinitionMap = accessControlReferences.getMaterializedViewDefinitions(); - - queryAccessControl.checkQueryIntegrity(identity, queryAccessControlContext, query, viewDefinitionMap, materializedViewDefinitionMap); - } + checkAccessPermissionsForTable(accessControlReferences); + checkAccessPermissionsForColumns(accessControlReferences); } private static void checkAccessPermissionsForColumns(AccessControlReferences accessControlReferences) @@ -166,4 +157,9 @@ private static void checkAccessPermissionsForTable(AccessControlReferences acces } })); } + + private static void checkQueryIntegrity(Identity identity, AccessControl accessControl, AccessControlContext accessControlContext, String query, Map preparedStatements, ViewDefinitionReferences viewDefinitionReferences) + { + accessControl.checkQueryIntegrity(identity, accessControlContext, query, preparedStatements, viewDefinitionReferences.getViewDefinitions(), viewDefinitionReferences.getMaterializedViewDefinitions()); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/DateTimeUtils.java b/presto-main-base/src/main/java/com/facebook/presto/util/DateTimeUtils.java index a6b63679c2699..de8c0ff93da1c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/DateTimeUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/DateTimeUtils.java @@ -18,6 +18,7 @@ import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.sql.tree.IntervalLiteral.IntervalField; +import jakarta.annotation.Nullable; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.joda.time.DurationFieldType; @@ -36,8 +37,6 @@ import org.joda.time.format.PeriodFormatterBuilder; import org.joda.time.format.PeriodParser; -import javax.annotation.Nullable; - import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.reflect.Method; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/Failures.java b/presto-main-base/src/main/java/com/facebook/presto/util/Failures.java index 0d8ece16f113e..a6f768085c961 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/Failures.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/Failures.java @@ -31,8 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import io.airlift.slice.SliceTooLargeException; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.util.Arrays; import java.util.Collection; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/FinalizerService.java b/presto-main-base/src/main/java/com/facebook/presto/util/FinalizerService.java index bfa8fd3b52a9f..9f93a18c44c02 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/FinalizerService.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/FinalizerService.java @@ -15,11 +15,10 @@ import com.facebook.airlift.log.Logger; import com.google.common.collect.Sets; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; import java.lang.ref.PhantomReference; import java.lang.ref.ReferenceQueue; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GcStatusMonitor.java b/presto-main-base/src/main/java/com/facebook/presto/util/GcStatusMonitor.java index 15d5083945470..ad88d3df9a7ee 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GcStatusMonitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GcStatusMonitor.java @@ -24,10 +24,10 @@ import com.facebook.presto.spi.QueryId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ListMultimap; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; import javax.management.JMException; import javax.management.Notification; import javax.management.NotificationListener; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 2c725b65abbd4..5368f400f4646 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -26,11 +26,13 @@ import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.MetadataDeleteNode; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.PlanFragmentId; import com.facebook.presto.spi.plan.PlanNode; @@ -41,8 +43,10 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.CallDistributedProcedureTarget; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; @@ -52,22 +56,24 @@ import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; import com.facebook.presto.sql.planner.plan.LateralJoinNode; -import com.facebook.presto.sql.planner.plan.MetadataDeleteNode; +import com.facebook.presto.sql.planner.plan.MergeProcessorNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.planner.plan.UpdateNode; import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter; import com.facebook.presto.sql.tree.ComparisonExpression; @@ -131,6 +137,9 @@ private enum NodeType ANALYZE_FINISH, EXPLAIN_ANALYZE, UPDATE, + MERGE, + TABLE_FUNCTION, + TABLE_FUNCTION_PROCESSOR } private static final Map NODE_COLORS = immutableEnumMap(ImmutableMap.builder() @@ -162,6 +171,9 @@ private enum NodeType .put(NodeType.ANALYZE_FINISH, "plum") .put(NodeType.EXPLAIN_ANALYZE, "cadetblue1") .put(NodeType.UPDATE, "blue") + .put(NodeType.MERGE, "lightblue") + .put(NodeType.TABLE_FUNCTION, "mediumorchid3") + .put(NodeType.TABLE_FUNCTION_PROCESSOR, "steelblue3") .build()); static { @@ -296,6 +308,13 @@ public Void visitSequence(SequenceNode node, Void context) return null; } + @Override + public Void visitCallDistributedProcedure(CallDistributedProcedureNode node, Void context) + { + printNode(node, format("CallDistributedProcedure[%s]", node.getTarget().map(CallDistributedProcedureTarget::getProcedureName).orElse(null)), NODE_COLORS.get(NodeType.TABLE_WRITER)); + return node.getSource().accept(this, context); + } + @Override public Void visitTableWriter(TableWriterNode node, Void context) { @@ -321,6 +340,20 @@ public Void visitUpdate(UpdateNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Void visitMergeWriter(MergeWriterNode node, Void context) + { + printNode(node, format("MergeWriterNode[%s]", Joiner.on(", ").join(node.getOutputVariables())), NODE_COLORS.get(NodeType.MERGE)); + return node.getSource().accept(this, context); + } + + @Override + public Void visitMergeProcessor(MergeProcessorNode node, Void context) + { + printNode(node, format("MergeProcessorNode[%s]", Joiner.on(", ").join(node.getOutputVariables())), NODE_COLORS.get(NodeType.MERGE)); + return node.getSource().accept(this, context); + } + @Override public Void visitStatisticsWriterNode(StatisticsWriterNode node, Void context) { @@ -382,6 +415,24 @@ public Void visitWindow(WindowNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + printNode(node, "Table Function Processor", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + if (node.getSource().isPresent()) { + node.getSource().get().accept(this, context); + } + return null; + } + + @Override + public Void visitTableFunction(TableFunctionNode node, Void context) + { + printNode(node, "Table Function Node", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + node.getSources().stream().map(source -> source.accept(this, context)); + return null; + } + @Override public Void visitRowNumber(RowNumberNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/JsonUtil.java b/presto-main-base/src/main/java/com/facebook/presto/util/JsonUtil.java index 4b48dcbd8fff8..e8e7cb4802de2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/JsonUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/JsonUtil.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.util; +import com.facebook.airlift.json.JsonObjectMapperProvider; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.SingleRowBlockWriter; @@ -41,6 +42,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.primitives.Shorts; import com.google.common.primitives.SignedBytes; +import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; @@ -86,6 +88,7 @@ import static com.fasterxml.jackson.core.JsonToken.FIELD_NAME; import static com.fasterxml.jackson.core.JsonToken.START_ARRAY; import static com.fasterxml.jackson.core.JsonToken.START_OBJECT; +import static com.fasterxml.jackson.databind.SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static it.unimi.dsi.fastutil.HashCommon.arraySize; @@ -106,6 +109,7 @@ public final class JsonUtil // `OBJECT_MAPPER.writeValueAsString(parser.readValueAsTree());` preserves input order. // Be aware. Using it arbitrarily can produce invalid json (ordered by key is required in Presto). private static final ObjectMapper OBJECT_MAPPED_UNORDERED = new ObjectMapper(JSON_FACTORY); + private static final ObjectMapper OBJECT_MAPPED_SORTED = new JsonObjectMapperProvider().get().configure(ORDER_MAP_ENTRIES_BY_KEYS, true); private static final int MAX_JSON_LENGTH_IN_ERROR_MESSAGE = 10_000; @@ -956,8 +960,18 @@ static BlockBuilderAppender createBlockBuilderAppender(Type type) return new VarcharBlockBuilderAppender(type); case StandardTypes.JSON: return (parser, blockBuilder, sqlFunctionProperties) -> { - String json = OBJECT_MAPPED_UNORDERED.writeValueAsString(parser.readValueAsTree()); - JSON.writeSlice(blockBuilder, Slices.utf8Slice(json)); + Slice slice = Slices.utf8Slice(OBJECT_MAPPED_UNORDERED.writeValueAsString(parser.readValueAsTree())); + try (JsonParser jsonParser = createJsonParser(JSON_FACTORY, slice)) { + SliceOutput dynamicSliceOutput = new DynamicSliceOutput(slice.length()); + OBJECT_MAPPED_SORTED.writeValue((OutputStream) dynamicSliceOutput, OBJECT_MAPPED_SORTED.readValue(jsonParser, Object.class)); + // nextToken() returns null if the input is parsed correctly, + // but will throw an exception if there are trailing characters. + jsonParser.nextToken(); + JSON.writeSlice(blockBuilder, dynamicSliceOutput.slice()); + } + catch (Exception e) { + throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Cannot convert '%s' to JSON", slice.toStringUtf8())); + } }; case StandardTypes.ARRAY: return new ArrayBlockBuilderAppender(createBlockBuilderAppender(((ArrayType) type).getElementType())); diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwo.java b/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwo.java index 7053ef6079df2..c739b6d6c0e77 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwo.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwo.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.util; -import javax.validation.Constraint; -import javax.validation.Payload; +import jakarta.validation.Constraint; +import jakarta.validation.Payload; import java.lang.annotation.Documented; import java.lang.annotation.Retention; diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwoValidator.java b/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwoValidator.java index cdf6ad3f25a7b..cf35d387128f4 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwoValidator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/PowerOfTwoValidator.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.util; -import javax.validation.ConstraintValidator; -import javax.validation.ConstraintValidatorContext; +import jakarta.validation.ConstraintValidator; +import jakarta.validation.ConstraintValidatorContext; public class PowerOfTwoValidator implements ConstraintValidator diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/PrestoDataDefBindingHelper.java b/presto-main-base/src/main/java/com/facebook/presto/util/PrestoDataDefBindingHelper.java index 3170b0f62d792..5091a7e3b285a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/PrestoDataDefBindingHelper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/PrestoDataDefBindingHelper.java @@ -28,6 +28,7 @@ import com.facebook.presto.execution.CreateViewTask; import com.facebook.presto.execution.DataDefinitionTask; import com.facebook.presto.execution.DeallocateTask; +import com.facebook.presto.execution.DropBranchTask; import com.facebook.presto.execution.DropColumnTask; import com.facebook.presto.execution.DropConstraintTask; import com.facebook.presto.execution.DropFunctionTask; @@ -35,6 +36,7 @@ import com.facebook.presto.execution.DropRoleTask; import com.facebook.presto.execution.DropSchemaTask; import com.facebook.presto.execution.DropTableTask; +import com.facebook.presto.execution.DropTagTask; import com.facebook.presto.execution.DropViewTask; import com.facebook.presto.execution.GrantRolesTask; import com.facebook.presto.execution.GrantTask; @@ -67,6 +69,7 @@ import com.facebook.presto.sql.tree.CreateType; import com.facebook.presto.sql.tree.CreateView; import com.facebook.presto.sql.tree.Deallocate; +import com.facebook.presto.sql.tree.DropBranch; import com.facebook.presto.sql.tree.DropColumn; import com.facebook.presto.sql.tree.DropConstraint; import com.facebook.presto.sql.tree.DropFunction; @@ -74,6 +77,7 @@ import com.facebook.presto.sql.tree.DropRole; import com.facebook.presto.sql.tree.DropSchema; import com.facebook.presto.sql.tree.DropTable; +import com.facebook.presto.sql.tree.DropTag; import com.facebook.presto.sql.tree.DropView; import com.facebook.presto.sql.tree.Grant; import com.facebook.presto.sql.tree.GrantRoles; @@ -124,6 +128,8 @@ private PrestoDataDefBindingHelper() {} dataDefBuilder.put(CreateTable.class, CreateTableTask.class); dataDefBuilder.put(RenameTable.class, RenameTableTask.class); dataDefBuilder.put(RenameColumn.class, RenameColumnTask.class); + dataDefBuilder.put(DropBranch.class, DropBranchTask.class); + dataDefBuilder.put(DropTag.class, DropTagTask.class); dataDefBuilder.put(DropColumn.class, DropColumnTask.class); dataDefBuilder.put(DropConstraint.class, DropConstraintTask.class); dataDefBuilder.put(AddConstraint.class, AddConstraintTask.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/QueryInfoUtils.java b/presto-main-base/src/main/java/com/facebook/presto/util/QueryInfoUtils.java index af834dedbb54e..61483a4c11091 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/QueryInfoUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/QueryInfoUtils.java @@ -13,8 +13,23 @@ */ package com.facebook.presto.util; +import com.facebook.presto.client.StageStats; +import com.facebook.presto.client.StatementStats; +import com.facebook.presto.execution.QueryInfo; +import com.facebook.presto.execution.QueryState; +import com.facebook.presto.execution.QueryStats; +import com.facebook.presto.execution.StageExecutionInfo; +import com.facebook.presto.execution.StageExecutionStats; +import com.facebook.presto.execution.StageInfo; +import com.facebook.presto.execution.TaskInfo; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; import io.airlift.slice.XxHash64; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + import static java.lang.Long.toHexString; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; @@ -35,4 +50,89 @@ public static String computeQueryHash(String query) long queryHash = new XxHash64().update(queryBytes).hash(); return toHexString(queryHash); } + + public static StatementStats toStatementStats(QueryInfo queryInfo) + { + QueryStats queryStats = queryInfo.getQueryStats(); + StageInfo outputStage = queryInfo.getOutputStage().orElse(null); + + Set globalUniqueNodes = new HashSet<>(); + StageStats rootStageStats = toStageStats(outputStage, globalUniqueNodes); + + return StatementStats.builder() + .setState(queryInfo.getState().toString()) + .setWaitingForPrerequisites(queryInfo.getState() == QueryState.WAITING_FOR_PREREQUISITES) + .setQueued(queryInfo.getState() == QueryState.QUEUED) + .setScheduled(queryInfo.isScheduled()) + .setNodes(globalUniqueNodes.size()) + .setTotalSplits(queryStats.getTotalDrivers()) + .setQueuedSplits(queryStats.getQueuedDrivers()) + .setRunningSplits(queryStats.getRunningDrivers() + queryStats.getBlockedDrivers()) + .setCompletedSplits(queryStats.getCompletedDrivers()) + .setCpuTimeMillis(queryStats.getTotalCpuTime().toMillis()) + .setWallTimeMillis(queryStats.getTotalScheduledTime().toMillis()) + .setWaitingForPrerequisitesTimeMillis(queryStats.getWaitingForPrerequisitesTime().toMillis()) + .setQueuedTimeMillis(queryStats.getQueuedTime().toMillis()) + .setElapsedTimeMillis(queryStats.getElapsedTime().toMillis()) + .setProcessedRows(queryStats.getRawInputPositions()) + .setProcessedBytes(queryStats.getRawInputDataSize().toBytes()) + .setPeakMemoryBytes(queryStats.getPeakUserMemoryReservation().toBytes()) + .setPeakTotalMemoryBytes(queryStats.getPeakTotalMemoryReservation().toBytes()) + .setPeakTaskTotalMemoryBytes(queryStats.getPeakTaskTotalMemory().toBytes()) + .setSpilledBytes(queryStats.getSpilledDataSize().toBytes()) + .setRootStage(rootStageStats) + .setRuntimeStats(queryStats.getRuntimeStats()) + .build(); + } + + private static StageStats toStageStats(StageInfo stageInfo, Set globalUniqueNodeIds) + { + if (stageInfo == null) { + return null; + } + + StageExecutionInfo currentStageExecutionInfo = stageInfo.getLatestAttemptExecutionInfo(); + StageExecutionStats stageExecutionStats = currentStageExecutionInfo.getStats(); + + // Store current stage details into a builder + StageStats.Builder builder = StageStats.builder() + .setStageId(String.valueOf(stageInfo.getStageId().getId())) + .setState(currentStageExecutionInfo.getState().toString()) + .setDone(currentStageExecutionInfo.getState().isDone()) + .setTotalSplits(stageExecutionStats.getTotalDrivers()) + .setQueuedSplits(stageExecutionStats.getQueuedDrivers()) + .setRunningSplits(stageExecutionStats.getRunningDrivers() + stageExecutionStats.getBlockedDrivers()) + .setCompletedSplits(stageExecutionStats.getCompletedDrivers()) + .setCpuTimeMillis(stageExecutionStats.getTotalCpuTime().toMillis()) + .setWallTimeMillis(stageExecutionStats.getTotalScheduledTime().toMillis()) + .setProcessedRows(stageExecutionStats.getRawInputPositions()) + .setProcessedBytes(stageExecutionStats.getRawInputDataSizeInBytes()) + .setNodes(countStageAndAddGlobalUniqueNodes(currentStageExecutionInfo.getTasks(), globalUniqueNodeIds)); + + // Recurse into child stages to create their StageStats + List subStages = stageInfo.getSubStages(); + if (subStages.isEmpty()) { + builder.setSubStages(ImmutableList.of()); + } + else { + ImmutableList.Builder subStagesBuilder = ImmutableList.builderWithExpectedSize(subStages.size()); + for (StageInfo subStage : subStages) { + subStagesBuilder.add(toStageStats(subStage, globalUniqueNodeIds)); + } + builder.setSubStages(subStagesBuilder.build()); + } + + return builder.build(); + } + + private static int countStageAndAddGlobalUniqueNodes(List tasks, Set globalUniqueNodes) + { + Set stageUniqueNodes = Sets.newHashSetWithExpectedSize(tasks.size()); + for (TaskInfo task : tasks) { + String nodeId = task.getNodeId(); + stageUniqueNodes.add(nodeId); + globalUniqueNodes.add(nodeId); + } + return stageUniqueNodes.size(); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/TaskUtils.java b/presto-main-base/src/main/java/com/facebook/presto/util/TaskUtils.java index cef3997f6e653..7ea843b7d94f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/TaskUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/TaskUtils.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.util; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import java.util.concurrent.ThreadLocalRandom; diff --git a/presto-main-base/src/main/java/com/facebook/presto/version/EmbedVersion.java b/presto-main-base/src/main/java/com/facebook/presto/version/EmbedVersion.java index 1cac4b65944f9..1da6886720210 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/version/EmbedVersion.java +++ b/presto-main-base/src/main/java/com/facebook/presto/version/EmbedVersion.java @@ -19,8 +19,7 @@ import com.facebook.presto.bytecode.Parameter; import com.facebook.presto.server.ServerConfig; import com.google.common.collect.ImmutableMap; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.lang.invoke.MethodHandle; import java.util.Objects; diff --git a/presto-main-base/src/test/java/com/facebook/presto/TestHiddenColumns.java b/presto-main-base/src/test/java/com/facebook/presto/TestHiddenColumns.java index c084b44dc0625..589310d28d4c6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/TestHiddenColumns.java +++ b/presto-main-base/src/test/java/com/facebook/presto/TestHiddenColumns.java @@ -22,6 +22,7 @@ import org.testng.annotations.Test; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.testing.assertions.Assert.assertEquals; @@ -48,10 +49,10 @@ public void destroy() @Test public void testDescribeTable() { - MaterializedResult expected = MaterializedResult.resultBuilder(TEST_SESSION, VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("regionkey", "bigint", "", "") - .row("name", "varchar(25)", "", "") - .row("comment", "varchar(152)", "", "") + MaterializedResult expected = MaterializedResult.resultBuilder(TEST_SESSION, VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("regionkey", "bigint", "", "", 19L, null, null) + .row("name", "varchar(25)", "", "", null, null, 25L) + .row("comment", "varchar(152)", "", "", null, null, 152L) .build(); assertEquals(runner.execute("DESC REGION"), expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/block/BenchmarkThriftUdfPageSerDe.java b/presto-main-base/src/test/java/com/facebook/presto/block/BenchmarkThriftUdfPageSerDe.java index 4c121eebca2c6..0a5d228e7542b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/block/BenchmarkThriftUdfPageSerDe.java +++ b/presto-main-base/src/test/java/com/facebook/presto/block/BenchmarkThriftUdfPageSerDe.java @@ -69,6 +69,7 @@ import static com.facebook.presto.thrift.api.udf.ThriftUdfPage.prestoPage; import static com.facebook.presto.thrift.api.udf.ThriftUdfPage.thriftPage; import static com.google.inject.Scopes.SINGLETON; +import static io.netty.buffer.ByteBufAllocator.DEFAULT; import static org.testng.Assert.assertTrue; @State(Scope.Thread) @@ -123,7 +124,7 @@ public void setup() server = injector.getInstance(DriftServer.class); ThriftCodecManager codecManager = new ThriftCodecManager(); Closer closer = Closer.create(); - MethodInvokerFactory methodInvokerFactory = closer.register(createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig())); + MethodInvokerFactory methodInvokerFactory = closer.register(createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig(), DEFAULT)); DriftClientFactory clientFactory = new DriftClientFactory( codecManager, methodInvokerFactory, diff --git a/presto-main-base/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java b/presto-main-base/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java index 8854816cba296..a5a0dabc5a363 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java +++ b/presto-main-base/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java @@ -18,6 +18,8 @@ import com.facebook.presto.common.block.DictionaryBlock; import com.facebook.presto.common.block.DictionaryId; import com.facebook.presto.common.block.IntArrayBlock; +import com.facebook.presto.common.block.LazyBlock; +import com.facebook.presto.common.block.RunLengthEncodedBlock; import com.facebook.presto.common.block.VariableWidthBlockBuilder; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; @@ -474,4 +476,78 @@ private static void assertDictionaryIds(DictionaryBlock dictionaryBlock, int... assertEquals(dictionaryBlock.getId(position), expected[position]); } } + + @Test + public void testCreateProjectionWithDictionaryBlock() + { + // Create a dictionary block + Slice[] expectedValues = createExpectedValues(10); + DictionaryBlock dictionaryBlock = createDictionaryBlock(expectedValues, 100); + + // Create a new dictionary block to project + Block newDictionary = createSlicesBlock(createExpectedValues(10)); + Block projectedBlock = dictionaryBlock.createProjection(newDictionary); + + // Assert that the projected block is a DictionaryBlock + assertTrue(projectedBlock instanceof DictionaryBlock); + DictionaryBlock projectedDictionaryBlock = (DictionaryBlock) projectedBlock; + + // Verify that the new dictionary in the projected block matches the input + assertEquals(projectedDictionaryBlock.getDictionary(), newDictionary); + + // Verify that the IDs in the projected block remain the same + int[] originalIds = dictionaryBlock.getRawIds(); + int[] projectedIds = projectedDictionaryBlock.getRawIds(); + assertEquals(originalIds.length, projectedIds.length); + for (int i = 0; i < originalIds.length; i++) { + assertEquals(originalIds[i], projectedIds[i]); + } + + // Verify the position count of the projected block + assertEquals(projectedDictionaryBlock.getPositionCount(), dictionaryBlock.getPositionCount()); + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "newDictionary must have the same position count") + public void testCreateProjectionWithDifferentDictionaries() + { + Slice[] expectedValues = createExpectedValues(10); + DictionaryBlock dictionaryBlock = createDictionaryBlock(expectedValues, 100); + + // New dictionary with mismatched position count + Block mismatchedDictionary = createSlicesBlock(createExpectedValues(5)); + dictionaryBlock.createProjection(mismatchedDictionary); + } + + @Test + public void testCreateProjectionWithLazyBlock() + { + Slice[] expectedValues = createExpectedValues(10); + DictionaryBlock dictionaryBlock = createDictionaryBlock(expectedValues, 100); + + LazyBlock lazyBlock = new LazyBlock(10, block -> { + throw new AssertionError("Lazy block should not be loaded"); + }); + Block projectedBlock = dictionaryBlock.createProjection(lazyBlock); + + // Validate that projected block is still a LazyBlock + assertTrue(projectedBlock instanceof LazyBlock); + assertEquals(((LazyBlock) projectedBlock).getPositionCount(), dictionaryBlock.getPositionCount()); + } + + @Test + public void testCreateProjectionWithRunLengthEncodedBlock() + { + Slice[] expectedValues = createExpectedValues(10); + DictionaryBlock dictionaryBlock = createDictionaryBlock(expectedValues, 100); + + Block singleValueBlock = dictionaryBlock.getDictionary().getSingleValueBlock(0); + Block rleBlock = new RunLengthEncodedBlock(singleValueBlock, 10); + Block projectedBlock = dictionaryBlock.createProjection(rleBlock); + + // Validate that projected block is a RunLengthEncodedBlock + assertTrue(projectedBlock instanceof RunLengthEncodedBlock); + assertEquals(((RunLengthEncodedBlock) projectedBlock).getPositionCount(), dictionaryBlock.getPositionCount()); + assertEquals(((RunLengthEncodedBlock) projectedBlock).getValue(), singleValueBlock); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/block/TestRowBlock.java b/presto-main-base/src/test/java/com/facebook/presto/block/TestRowBlock.java index da60f3567a17b..445997369597f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/block/TestRowBlock.java +++ b/presto-main-base/src/test/java/com/facebook/presto/block/TestRowBlock.java @@ -17,7 +17,10 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.ByteArrayBlock; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.block.RowBlock; import com.facebook.presto.common.block.RowBlockBuilder; +import com.facebook.presto.common.block.RunLengthEncodedBlock; import com.facebook.presto.common.block.SingleRowBlock; import com.facebook.presto.common.type.Type; import com.google.common.collect.ImmutableList; @@ -29,6 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.stream.IntStream; import static com.facebook.presto.block.BlockAssertions.createLongDictionaryBlock; import static com.facebook.presto.block.BlockAssertions.createRLEBlock; @@ -323,4 +327,65 @@ private IntArrayList generatePositionList(int numRows, int numPositions) Collections.sort(positions); return positions; } + + @Test + public void testGetRowFieldsFromBlockWithRowBlock() + { + Block[] fieldBlocks = new Block[] { + createRandomDictionaryBlock(createRandomLongsBlock(10, 0), 10, false), + createRandomDictionaryBlock(createRandomLongsBlock(10, 0), 10, false) + }; + Block rowBlock = RowBlock.fromFieldBlocks(10, Optional.empty(), fieldBlocks); + + List rowFields = RowBlock.getRowFieldsFromBlock(rowBlock); + + assertEquals(rowFields.size(), fieldBlocks.length); + assertEquals(rowFields, Arrays.asList(fieldBlocks)); + } + + @Test + public void testGetRowFieldsFromBlockWithRunLengthEncodedBlock() + { + Block[] fieldBlocks = new Block[] { + createRandomDictionaryBlock(createRandomLongsBlock(1, 0), 1, false), + createRandomDictionaryBlock(createRandomLongsBlock(1, 0), 1, false) + }; + Block rowBlock = RowBlock.fromFieldBlocks(1, Optional.empty(), fieldBlocks); + RunLengthEncodedBlock rleBlock = new RunLengthEncodedBlock(rowBlock, 10); + + List rowFields = RowBlock.getRowFieldsFromBlock(rleBlock); + + assertEquals(rowFields.size(), fieldBlocks.length); + for (int i = 0; i < rowFields.size(); i++) { + assertTrue(rowFields.get(i) instanceof RunLengthEncodedBlock); + assertEquals(((RunLengthEncodedBlock) rowFields.get(i)).getValue(), fieldBlocks[i]); + } + } + + @Test + public void testGetRowFieldsFromBlockWithDictionaryBlock() + { + Block[] fieldBlocks = new Block[] { + createRandomDictionaryBlock(createRandomLongsBlock(10, 0), 10, false), + createRandomDictionaryBlock(createRandomLongsBlock(10, 0), 10, false) + }; + Block rowBlock = RowBlock.fromFieldBlocks(10, Optional.empty(), fieldBlocks); + DictionaryBlock dictionaryBlock = new DictionaryBlock(rowBlock, IntStream.range(0, 10).toArray()); + + List rowFields = RowBlock.getRowFieldsFromBlock(dictionaryBlock); + + assertEquals(rowFields.size(), fieldBlocks.length); + for (int i = 0; i < rowFields.size(); i++) { + assertTrue(rowFields.get(i) instanceof DictionaryBlock); + assertEquals(((DictionaryBlock) rowFields.get(i)).getDictionary(), fieldBlocks[i]); + } + } + + @Test(expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "Unexpected block type: LongArrayBlock") + public void testGetRowFieldsFromBlockWithUnexpectedBlockType() + { + Block unexpectedBlock = createRandomLongsBlock(10, 0); + RowBlock.getRowFieldsFromBlock(unexpectedBlock); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java b/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java index f677e5b210bba..a54d9af284356 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java +++ b/presto-main-base/src/test/java/com/facebook/presto/catalogserver/TestCatalogServerResponse.java @@ -14,8 +14,10 @@ package com.facebook.presto.catalogserver; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.transaction.TransactionId; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.connector.informationSchema.InformationSchemaTableHandle; import com.facebook.presto.connector.informationSchema.InformationSchemaTransactionHandle; import com.facebook.presto.metadata.HandleJsonModule; @@ -30,6 +32,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.inject.Guice; import com.google.inject.Injector; +import com.google.inject.Scopes; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; @@ -52,7 +55,11 @@ public class TestCatalogServerResponse public void setup() { this.testingCatalogServerClient = new TestingCatalogServerClient(); - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector(new JsonModule(), binder -> { + binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + }); this.objectMapper = injector.getInstance(ObjectMapper.class); } @@ -212,6 +219,7 @@ public void testGetMaterializedView() table, baseTables, owner, + Optional.empty(), columnMappings, baseTablesOnOuterJoinSide, validRefreshColumns); diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorColumnHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorColumnHandle.java new file mode 100644 index 0000000000000..063d4e37a839c --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorColumnHandle.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +public class TestTVFConnectorColumnHandle + implements ColumnHandle +{ + private final String name; + private final Type type; + + @JsonCreator + public TestTVFConnectorColumnHandle( + @JsonProperty("name") String name, + @JsonProperty("type") Type type) + { + this.name = requireNonNull(name, "name is null"); + this.type = requireNonNull(type, "type is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public Type getType() + { + return type; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("name", name) + .add("type", type) + .toString(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if ((o == null) || (getClass() != o.getClass())) { + return false; + } + TestTVFConnectorColumnHandle other = (TestTVFConnectorColumnHandle) o; + return Objects.equals(name, other.name) && + Objects.equals(type, other.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java new file mode 100644 index 0000000000000..f64d51a0c10f3 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorFactory.java @@ -0,0 +1,563 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.ConnectorViewDefinition; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.InMemoryRecordSet; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.RecordPageSource; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.TableFunctionHandleResolver; +import com.facebook.presto.spi.function.TableFunctionSplitResolver; +import com.facebook.presto.spi.function.table.ConnectorTableFunction; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.facebook.presto.spi.statistics.TableStatistics; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchRecordSetProvider; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.connector.tvf.TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.TEST_TVF_CONNECTOR_SPLIT; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class TestTVFConnectorFactory + implements ConnectorFactory +{ + private final Function> listSchemaNames; + private final BiFunction> listTables; + private final BiFunction> getViews; + private final BiFunction> getColumnHandles; + private final Supplier getTableStatistics; + private final ApplyTableFunction applyTableFunction; + private final Set tableFunctions; + private final Function tableFunctionProcessorProvider; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver; + private final TestTvfTableFunctionSplitResolver tableFunctionSplitResolver; + private final Function tableFunctionSplitsSources; + + private TestTVFConnectorFactory( + Function> listSchemaNames, + BiFunction> listTables, + BiFunction> getViews, + BiFunction> getColumnHandles, + Supplier getTableStatistics, + ApplyTableFunction applyTableFunction, + Set tableFunctions, + Function getTableFunctionProcessorProvider, + TestTvfTableFunctionHandleResolver tableFunctionHandleResolver, + TestTvfTableFunctionSplitResolver tableFunctionSplitResolver, + Function tableFunctionSplitsSources) + { + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + this.listTables = requireNonNull(listTables, "listTables is null"); + this.getViews = requireNonNull(getViews, "getViews is null"); + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); + this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionHandleResolver = requireNonNull(tableFunctionHandleResolver, "tableFunctionHandleResolver is null"); + this.tableFunctionSplitResolver = requireNonNull(tableFunctionSplitResolver, "tableFunctionSplitResolver is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); + } + + @Override + public String getName() + { + return "testTVF"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new TestTVFHandleResolver(); + } + + @Override + public Optional getTableFunctionHandleResolver() + { + return Optional.of(tableFunctionHandleResolver); + } + + @Override + public Optional getTableFunctionSplitResolver() + { + return Optional.of(tableFunctionSplitResolver); + } + + @Override + public Connector create(String catalogName, Map config, ConnectorContext context) + { + return new TestTVFConnector(context, listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionSplitsSources); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Function> defaultGetColumns() + { + return table -> IntStream.range(0, 100) + .boxed() + .map(i -> ColumnMetadata.builder().setName("column_" + i).setType(createUnboundedVarcharType()).build()) + .collect(toImmutableList()); + } + + @FunctionalInterface + public interface ApplyTableFunction + { + Optional> apply(ConnectorSession session, ConnectorTableFunctionHandle handle); + } + + public static class TestTVFConnector + implements Connector + { + private static final String DELETE_ROW_ID = "delete_row_id"; + private static final String UPDATE_ROW_ID = "update_row_id"; + private static final String MERGE_ROW_ID = "merge_row_id"; + + private final ConnectorContext context; + private final Function> listSchemaNames; + private final BiFunction> listTables; + private final BiFunction> getViews; + private final BiFunction> getColumnHandles; + private final Supplier getTableStatistics; + private final ApplyTableFunction applyTableFunction; + private final Function tableFunctionProcessorProvider; + private final Set tableFunctions; + private final Function tableFunctionSplitsSources; + + public TestTVFConnector( + ConnectorContext context, + Function> listSchemaNames, + BiFunction> listTables, + BiFunction> getViews, + BiFunction> getColumnHandles, + Supplier getTableStatistics, + ApplyTableFunction applyTableFunction, + Set tableFunctions, + Function getTableFunctionProcessorProvider, + Function tableFunctionSplitsSources) + { + this.context = requireNonNull(context, "context is null"); + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + this.listTables = requireNonNull(listTables, "listTables is null"); + this.getViews = requireNonNull(getViews, "getViews is null"); + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + this.applyTableFunction = requireNonNull(applyTableFunction, "applyTableFunction is null"); + this.tableFunctions = requireNonNull(tableFunctions, "tableFunctions is null"); + this.tableFunctionProcessorProvider = requireNonNull(getTableFunctionProcessorProvider, "tableFunctionProcessorProvider is null"); + this.tableFunctionSplitsSources = requireNonNull(tableFunctionSplitsSources, "tableFunctionSplitsSources is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return TestTVFConnectorTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) + { + return new TestTVFConnectorMetadata(); + } + + public enum TestTVFConnectorSplit + implements ConnectorSplit + { + TEST_TVF_CONNECTOR_SPLIT; + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return null; + } + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new ConnectorSplitManager() + { + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + return new FixedSplitSource(Collections.singleton(TEST_TVF_CONNECTOR_SPLIT)); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transaction, ConnectorSession session, ConnectorTableFunctionHandle functionHandle) + { + ConnectorSplitSource splits = tableFunctionSplitsSources.apply(functionHandle); + return requireNonNull(splits, "missing ConnectorSplitSource for table function handle " + + functionHandle.getClass().getSimpleName()); + } + }; + } + + @Override + public ConnectorRecordSetProvider getRecordSetProvider() + { + return new TpchRecordSetProvider(); + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new TestTVFConnectorPageSourceProvider(); + } + + @Override + public Set getTableFunctions() + { + return tableFunctions; + } + + @Override + public Function getTableFunctionProcessorProvider() + { + return tableFunctionProcessorProvider; + } + + private class TestTVFConnectorMetadata + implements ConnectorMetadata + { + @Override + public List listSchemaNames(ConnectorSession session) + { + return listSchemaNames.apply(session); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + return new ConnectorTableHandle() {}; + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) + { + TestTVFConnectorTableHandle table = (TestTVFConnectorTableHandle) tableHandle; + return new ConnectorTableMetadata( + table.getTableName(), + defaultGetColumns().apply(table.getTableName()), + ImmutableMap.of()); + } + + @Override + public List listTables(ConnectorSession session, String schemaNameOrNull) + { + return listTables.apply(session, schemaNameOrNull); + } + + public void setTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle, Map properties) + { + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return (Map) (Map) getColumnHandles.apply(session, tableHandle); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + if (columnHandle instanceof TestTVFConnectorColumnHandle) { + TestTVFConnectorColumnHandle testTVFColumnHandle = (TestTVFConnectorColumnHandle) columnHandle; + return ColumnMetadata.builder().setName(testTVFColumnHandle.getName()).setType(testTVFColumnHandle.getType()).build(); + } + else { + TpchColumnHandle tpchColumnHandle = (TpchColumnHandle) columnHandle; + return ColumnMetadata.builder().setName(tpchColumnHandle.getColumnName()).setType(tpchColumnHandle.getType()).build(); + } + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + return listTables(session, prefix.getSchemaName()).stream() + .collect(toImmutableMap(table -> table, table -> IntStream.range(0, 100) + .boxed() + .map(i -> ColumnMetadata.builder().setName("column_" + i).setType(createUnboundedVarcharType()).build()) + .collect(toImmutableList()))); + } + + @Override + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) + { + // TODO: Currently not supporting constraints + TestTVFTableLayoutHandle tvfLayout = new TestTVFTableLayoutHandle((TestTVFConnectorTableHandle) table, TupleDomain.none()); + return new ConnectorTableLayoutResult(new ConnectorTableLayout(tvfLayout, + Optional.empty(), + tvfLayout.getPredicate(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Collections.emptyList(), + Optional.empty()), TupleDomain.none()); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + TestTVFTableLayoutHandle tvfTableLayout = (TestTVFTableLayoutHandle) handle; + return new ConnectorTableLayout(tvfTableLayout); + } + + @Override + public Map getViews(ConnectorSession session, SchemaTablePrefix prefix) + { + return getViews.apply(session, prefix); + } + + @Override + public TableStatistics getTableStatistics(ConnectorSession session, ConnectorTableHandle tableHandle, Optional tableLayoutHandle, List columnHandles, Constraint constraint) + { + return getTableStatistics.get(); + } + + @Override + public Optional> applyTableFunction(ConnectorSession session, ConnectorTableFunctionHandle handle) + { + return applyTableFunction.apply(session, handle); + } + } + + private class TestTVFConnectorPageSourceProvider + implements ConnectorPageSourceProvider + { + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, ConnectorTableLayoutHandle layout, List columns, SplitContext splitContext, RuntimeStats runtimeStats) + { + TestTVFConnectorTableHandle handle = ((TestTVFTableLayoutHandle) layout).getTable(); + SchemaTableName tableName = handle.getTableName(); + List projection = columns.stream() + .map(TestTVFConnectorColumnHandle.class::cast) + .collect(toImmutableList()); + List types = columns.stream() + .map(TestTVFConnectorColumnHandle.class::cast) + .map(TestTVFConnectorColumnHandle::getType) + .collect(toImmutableList()); + return new TestTVFConnectorPageSource(new RecordPageSource(new InMemoryRecordSet(types, ImmutableList.of()))); + } + + private Map getColumnIndexes(SchemaTableName tableName) + { + ImmutableMap.Builder columnIndexes = ImmutableMap.builder(); + List columnMetadata = defaultGetColumns().apply(tableName); + for (int index = 0; index < columnMetadata.size(); index++) { + columnIndexes.put(columnMetadata.get(index).getName(), index); + } + return columnIndexes.buildOrThrow(); + } + } + } + + public static class TestTvfTableFunctionHandleResolver + implements TableFunctionHandleResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionHandleClasses() + { + return handles; + } + + public void addTableFunctionHandle(Class tableFunctionHandleClass) + { + handles.add(tableFunctionHandleClass); + } + } + + public static class TestTvfTableFunctionSplitResolver + implements TableFunctionSplitResolver + { + Set> handles = Sets.newHashSet(); + + @Override + public Set> getTableFunctionSplitClasses() + { + return handles; + } + + public void addSplitClass(Class splitClass) + { + handles.add(splitClass); + } + } + + public static final class Builder + { + private Function> listSchemaNames = (session) -> ImmutableList.of(); + private BiFunction> listTables = (session, schemaName) -> ImmutableList.of(); + private BiFunction> getViews = (session, schemaTablePrefix) -> ImmutableMap.of(); + private BiFunction> getColumnHandles = (session, tableHandle) -> { + TestTVFConnectorTableHandle table = (TestTVFConnectorTableHandle) tableHandle; + return defaultGetColumns().apply(table.getTableName()).stream() + .collect(toImmutableMap(ColumnMetadata::getName, column -> + new TestTVFConnectorColumnHandle(column.getName(), column.getType()))); + }; + private Supplier getTableStatistics = TableStatistics::empty; + private ApplyTableFunction applyTableFunction = (session, handle) -> Optional.empty(); + private Set tableFunctions = ImmutableSet.of(); + private Function tableFunctionProcessorProvider = handle -> null; + private final TestTvfTableFunctionHandleResolver tableFunctionHandleResolver = new TestTvfTableFunctionHandleResolver(); + private TestTvfTableFunctionSplitResolver tableFunctionSplitResolver = new TestTvfTableFunctionSplitResolver(); + private Function tableFunctionSplitsSources = handle -> null; + + public Builder withListSchemaNames(Function> listSchemaNames) + { + this.listSchemaNames = requireNonNull(listSchemaNames, "listSchemaNames is null"); + return this; + } + + public Builder withListTables(BiFunction> listTables) + { + this.listTables = requireNonNull(listTables, "listTables is null"); + return this; + } + + public Builder withGetViews(BiFunction> getViews) + { + this.getViews = requireNonNull(getViews, "getViews is null"); + return this; + } + + public Builder withGetColumnHandles(BiFunction> getColumnHandles) + { + this.getColumnHandles = requireNonNull(getColumnHandles, "getColumnHandles is null"); + return this; + } + + public Builder withGetTableStatistics(Supplier getTableStatistics) + { + this.getTableStatistics = requireNonNull(getTableStatistics, "getTableStatistics is null"); + return this; + } + + public Builder withApplyTableFunction(ApplyTableFunction applyTableFunction) + { + this.applyTableFunction = applyTableFunction; + return this; + } + + public Builder withTableFunctions(Iterable tableFunctions) + { + this.tableFunctions = ImmutableSet.copyOf(tableFunctions); + return this; + } + + public Builder withTableFunctionProcessorProvider(Function tableFunctionProcessorProvider) + { + this.tableFunctionProcessorProvider = tableFunctionProcessorProvider; + return this; + } + + public Builder withTableFunctionResolver(Class tableFunctionHandleclass) + { + this.tableFunctionHandleResolver.addTableFunctionHandle(tableFunctionHandleclass); + return this; + } + + public Builder withTableFunctionSplitResolver(Class splitClass) + { + this.tableFunctionSplitResolver.addSplitClass(splitClass); + return this; + } + + public TestTVFConnectorFactory build() + { + return new TestTVFConnectorFactory(listSchemaNames, listTables, getViews, getColumnHandles, getTableStatistics, applyTableFunction, tableFunctions, tableFunctionProcessorProvider, tableFunctionHandleResolver, tableFunctionSplitResolver, tableFunctionSplitsSources); + } + + private static T notSupported() + { + throw new UnsupportedOperationException(); + } + + public Builder withTableFunctionSplitSource(Function sourceProvider) + { + tableFunctionSplitsSources = requireNonNull(sourceProvider, "sourceProvider is null"); + return this; + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPageSource.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPageSource.java new file mode 100644 index 0000000000000..126b23a48e796 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPageSource.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.Page; +import com.facebook.presto.spi.ConnectorPageSource; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +import static java.util.Objects.requireNonNull; + +public class TestTVFConnectorPageSource + implements ConnectorPageSource +{ + private final ConnectorPageSource delegate; + + public TestTVFConnectorPageSource(ConnectorPageSource delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public long getCompletedBytes() + { + return delegate.getCompletedBytes(); + } + + @Override + public long getCompletedPositions() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return delegate.getReadTimeNanos(); + } + + @Override + public boolean isFinished() + { + return delegate.isFinished(); + } + + @Override + public Page getNextPage() + { + return delegate.getNextPage(); + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public void close() + throws IOException + { + delegate.close(); + } + + @Override + public CompletableFuture isBlocked() + { + return delegate.isBlocked(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPlugin.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPlugin.java new file mode 100644 index 0000000000000..a26ef9c34466a --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorPlugin.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; + +import static java.util.Objects.requireNonNull; + +public class TestTVFConnectorPlugin + implements Plugin +{ + private final ConnectorFactory connectorFactory; + + public TestTVFConnectorPlugin(ConnectorFactory connectorFactory) + { + this.connectorFactory = requireNonNull(connectorFactory, "connectorFactory is null"); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(connectorFactory); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTableHandle.java new file mode 100644 index 0000000000000..9c90975cbc82b --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTableHandle.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class TestTVFConnectorTableHandle + implements ConnectorTableHandle +{ + // These are example fields for a Connector's Table Handle. + // For other examples, see TpchTableHandle or any other implementations + // of ConnectorTableHandle. + private final SchemaTableName tableName; + private final Optional> columns; + private final TupleDomain constraint; + + public TestTVFConnectorTableHandle(SchemaTableName tableName) + { + this(tableName, Optional.empty(), TupleDomain.all()); + } + + @JsonCreator + public TestTVFConnectorTableHandle( + @JsonProperty SchemaTableName tableName, + @JsonProperty("columns") Optional> columns, + @JsonProperty("constraint") TupleDomain constraint) + { + this.tableName = requireNonNull(tableName, "tableName is null"); + requireNonNull(columns, "columns is null"); + this.columns = columns.map(ImmutableList::copyOf); + this.constraint = requireNonNull(constraint, "constraint is null"); + } + + @JsonProperty + public SchemaTableName getTableName() + { + return tableName; + } + + @JsonProperty + public Optional> getColumns() + { + return columns; + } + + @JsonProperty + public TupleDomain getConstraint() + { + return constraint; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestTVFConnectorTableHandle other = (TestTVFConnectorTableHandle) o; + return Objects.equals(tableName, other.tableName) && + Objects.equals(constraint, other.constraint) && + Objects.equals(columns, other.columns); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName, constraint, columns); + } + + @Override + public String toString() + { + return tableName.toString(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTransactionHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTransactionHandle.java new file mode 100644 index 0000000000000..668d3fb6575b2 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFConnectorTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum TestTVFConnectorTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFHandleResolver.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFHandleResolver.java new file mode 100644 index 0000000000000..1d6771ef0bfb7 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFHandleResolver.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class TestTVFHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return TestTVFConnectorTableHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return TestTVFConnectorColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return TestTVFConnectorFactory.TestTVFConnector.TestTVFConnectorSplit.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return TestTVFTableLayoutHandle.class; + } + + @Override + public Class getTransactionHandleClass() + { + return TestTVFConnectorTransactionHandle.class; + } + + @Override + public Class getPartitioningHandleClass() + { + return TestTVFPartitioningHandle.class; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFPartitioningHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFPartitioningHandle.java new file mode 100644 index 0000000000000..cca6fc0a48bb4 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFPartitioningHandle.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class TestTVFPartitioningHandle + implements ConnectorPartitioningHandle +{ + private final String table; + private final long totalRows; + + @JsonCreator + public TestTVFPartitioningHandle(@JsonProperty("table") String table, @JsonProperty("totalRows") long totalRows) + { + this.table = requireNonNull(table, "table is null"); + + checkArgument(totalRows > 0, "totalRows must be at least 1"); + this.totalRows = totalRows; + } + + @JsonProperty + public String getTable() + { + return table; + } + + @JsonProperty + public long getTotalRows() + { + return totalRows; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestTVFPartitioningHandle that = (TestTVFPartitioningHandle) o; + return Objects.equals(table, that.table) && + totalRows == that.totalRows; + } + + @Override + public int hashCode() + { + return Objects.hash(table, totalRows); + } + + @Override + public String toString() + { + return table + ":" + totalRows; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFTableLayoutHandle.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFTableLayoutHandle.java new file mode 100644 index 0000000000000..c9f61203cf1de --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestTVFTableLayoutHandle.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.plan.PlanCanonicalizationStrategy; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; + +public class TestTVFTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final TestTVFConnectorTableHandle table; + private final TupleDomain predicate; + + @JsonCreator + public TestTVFTableLayoutHandle(@JsonProperty("table") TestTVFConnectorTableHandle table, @JsonProperty("predicate") TupleDomain predicate) + { + this.table = table; + this.predicate = predicate; + } + + @JsonProperty + public TestTVFConnectorTableHandle getTable() + { + return table; + } + + @JsonProperty + public TupleDomain getPredicate() + { + return predicate; + } + + @Override + public String toString() + { + return table.toString(); + } + + @Override + public Object getIdentifier(Optional split, PlanCanonicalizationStrategy strategy) + { + return ImmutableMap.builder() + .put("table", table) + .put("predicate", predicate.canonicalize(ignored -> false)) + .build(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java new file mode 100644 index 0000000000000..21e739cb21feb --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -0,0 +1,1478 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.connector.tvf; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.function.SchemaFunctionName; +import com.facebook.presto.spi.function.table.AbstractConnectorTableFunction; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgumentSpecification; +import com.facebook.presto.spi.function.table.ReturnTypeSpecification; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.ScalarArgumentSpecification; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.function.table.TableArgumentSpecification; +import com.facebook.presto.spi.function.table.TableFunctionAnalysis; +import com.facebook.presto.spi.function.table.TableFunctionDataProcessor; +import com.facebook.presto.spi.function.table.TableFunctionProcessorProvider; +import com.facebook.presto.spi.function.table.TableFunctionProcessorState; +import com.facebook.presto.spi.function.table.TableFunctionSplitProcessor; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.openjdk.jol.info.ClassLayout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.stream.IntStream; + +import static com.facebook.presto.common.Utils.checkArgument; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.DescribedTable; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.GenericTable.GENERIC_TABLE; +import static com.facebook.presto.spi.function.table.ReturnTypeSpecification.OnlyPassThrough.ONLY_PASS_THROUGH; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Finished.FINISHED; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.produced; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInput; +import static com.facebook.presto.spi.function.table.TableFunctionProcessorState.Processed.usedInputAndProduced; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.NO_PREFERENCE; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class TestingTableFunctions +{ + private static final String SCHEMA_NAME = "system"; + private static final String TABLE_NAME = "table"; + private static final String COLUMN_NAME = "column"; + private static final ConnectorTableFunctionHandle HANDLE = new TestingTableFunctionPushdownHandle(); + private static final TableFunctionAnalysis ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .build(); + + private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() + .handle(HANDLE) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + + public static class TestConnectorTableFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "test_function"; + public TestConnectorTableFunction() + { + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("c1", Optional.of(BOOLEAN))))) + .build(); + } + } + + public static class TestConnectorTableFunction2 + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "test_function2"; + public TestConnectorTableFunction2() + { + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return null; + } + } + + public static class NullArgumentsTableFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "null_arguments_function"; + public NullArgumentsTableFunction() + { + super(SCHEMA_NAME, FUNCTION_NAME, null, ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return null; + } + } + + public static class DuplicateArgumentsTableFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "duplicate_arguments_function"; + public DuplicateArgumentsTableFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder().name("a").type(INTEGER).build(), + ScalarArgumentSpecification.builder().name("a").type(INTEGER).build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return null; + } + } + + public static class MultipleRSTableFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "multiple_sources_function"; + public MultipleRSTableFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder().name("t").rowSemantics().build(), + TableArgumentSpecification.builder().name("t2").rowSemantics().build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return null; + } + } + + /** + * A table function returning a table with single empty column of type BOOLEAN. + * The argument `COLUMN` is the column name. + * The argument `IGNORED` is ignored. + * Both arguments are optional. + */ + public static class SimpleTableFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "simple_table_function"; + private static final String TABLE_NAME = "simple_table"; + public SimpleTableFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + Arrays.asList( + ScalarArgumentSpecification.builder() + .name("COLUMN") + .type(VARCHAR) + .defaultValue(utf8Slice("col")) + .build(), + ScalarArgumentSpecification.builder() + .name("IGNORED") + .type(BIGINT) + .defaultValue(0L) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument argument = (ScalarArgument) arguments.get("COLUMN"); + String columnName = ((Slice) argument.getValue()).toStringUtf8(); + + return TableFunctionAnalysis.builder() + .handle(new SimpleTableFunctionHandle(getSchema(), TABLE_NAME, columnName)) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(columnName, Optional.of(BOOLEAN))))) + .build(); + } + + public static class SimpleTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final TestTVFConnectorTableHandle tableHandle; + + public SimpleTableFunctionHandle(String schema, String table, String column) + { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(schema, table), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(column, BOOLEAN))), + TupleDomain.all()); + } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + } + + public static class TwoScalarArgumentsFunction + extends AbstractConnectorTableFunction + { + private static final String FUNCTION_NAME = "two_scalar_arguments_function"; + public TwoScalarArgumentsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("TEXT") + .type(VARCHAR) + .build(), + ScalarArgumentSpecification.builder() + .name("NUMBER") + .type(BIGINT) + .defaultValue(null) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + + public static class TableArgumentFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "table_argument_function"; + public TableArgumentFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + } + } + + public static class DescriptorArgumentFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "descriptor_argument_function"; + public DescriptorArgumentFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + DescriptorArgumentSpecification.builder() + .name("SCHEMA") + .defaultValue(null) + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return ANALYSIS; + } + } + + public static class TestingTableFunctionPushdownHandle + implements ConnectorTableFunctionHandle + { + private final TestTVFConnectorTableHandle tableHandle; + + public TestingTableFunctionPushdownHandle() + { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); + } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + + @JsonInclude(JsonInclude.Include.ALWAYS) + public static class TestingTableFunctionHandle + implements ConnectorTableFunctionHandle + { + private final TestTVFConnectorTableHandle tableHandle; + private final SchemaFunctionName schemaFunctionName; + + @JsonCreator + public TestingTableFunctionHandle(@JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName) + { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); + this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); + } + + @JsonProperty + public SchemaFunctionName getSchemaFunctionName() + { + return schemaFunctionName; + } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } + } + + public static class TableArgumentRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "table_argument_row_semantics_function"; + public TableArgumentRowSemanticsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .rowSemantics() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); + } + } + + public static class TwoTableArgumentsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "two_table_arguments_function"; + public TwoTableArgumentsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT1") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT2") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT1", ImmutableList.of(0)) + .requiredColumns("INPUT2", ImmutableList.of(0)) + .build(); + } + } + + public static class OnlyPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "only_pass_through_function"; + public OnlyPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class MonomorphicStaticReturnTypeFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "monomorphic_static_return_type_function"; + public MonomorphicStaticReturnTypeFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("a", "b"), + ImmutableList.of(BOOLEAN, INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .build(); + } + } + + public static class PolymorphicStaticReturnTypeFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "polymorphic_static_return_type_function"; + public PolymorphicStaticReturnTypeFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("a", "b"), + ImmutableList.of(BOOLEAN, INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class PassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "pass_through_function"; + public PassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("x"), + ImmutableList.of(BOOLEAN)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return NO_DESCRIPTOR_ANALYSIS; + } + } + + public static class RequiredColumnsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "required_columns_function"; + public RequiredColumnsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0, 1)) + .build(); + } + } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "different_arguments_function"; + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + // for testing execution by operator + + public static class IdentityFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_function"; + public IdentityFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + List inputColumns = ((TableArgument) arguments.get("INPUT")).getRowType().getFields(); + Descriptor returnedType = new Descriptor(inputColumns.stream() + .map(field -> new Descriptor.Field(field.getName().orElse("anonymous_column"), Optional.of(field.getType()))) + .collect(toImmutableList())); + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(returnedType) + .requiredColumns("INPUT", IntStream.range(0, inputColumns.size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class IdentityFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return input -> { + if (input == null) { + return FINISHED; + } + Optional inputPage = getOnlyElement(input); + return inputPage.map(TableFunctionProcessorState.Processed::usedInputAndProduced).orElseThrow(NoSuchElementException::new); + }; + } + } + } + + public static class IdentityPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "identity_pass_through_function"; + public IdentityPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class IdentityPassThroughFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new IdentityPassThroughFunctionDataProcessor(); + } + } + + public static class IdentityPassThroughFunctionDataProcessor + implements TableFunctionDataProcessor + { + private long processedPositions; // stateful + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + return usedInputAndProduced(new Page(builder.build())); + } + } + } + + public static class RepeatFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "repeat"; + public RepeatFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .passThroughColumns() + .keepWhenEmpty() + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(2L) + .build()), + ONLY_PASS_THROUGH); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new RepeatFunctionHandle((long) count.getValue())) + .requiredColumns("INPUT", ImmutableList.of(0)) // per spec, function must require at least one column + .build(); + } + + public static class RepeatFunctionHandle + implements ConnectorTableFunctionHandle + { + private final long count; + + @JsonCreator + public RepeatFunctionHandle(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class RepeatFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new RepeatFunctionDataProcessor(((RepeatFunctionHandle) handle).getCount()); + } + } + + public static class RepeatFunctionDataProcessor + implements TableFunctionDataProcessor + { + private final long count; + + // stateful + private long processedPositions; + private long processedRounds; + private Block indexes; + boolean usedData; + + public RepeatFunctionDataProcessor(long count) + { + this.count = count; + } + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + if (processedRounds < count && indexes != null) { + processedRounds++; + return produced(new Page(indexes)); + } + return FINISHED; + } + + Page page = getOnlyElement(input).orElseThrow(NoSuchElementException::new); + if (processedRounds == 0) { + BlockBuilder builder = BIGINT.createBlockBuilder(null, page.getPositionCount()); + for (long index = processedPositions; index < processedPositions + page.getPositionCount(); index++) { + // TODO check for long overflow + builder.writeLong(index); + } + processedPositions = processedPositions + page.getPositionCount(); + indexes = builder.build(); + usedData = true; + } + else { + usedData = false; + } + processedRounds++; + + Page result = new Page(indexes); + + if (processedRounds == count) { + processedRounds = 0; + indexes = null; + } + + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + } + } + + public static class EmptyOutputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output"; + public EmptyOutputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputDataProcessor(); + } + } + + // returns an empty Page (one column, zero rows) for each Page of input + private static class EmptyOutputDataProcessor + implements TableFunctionDataProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class EmptyOutputWithPassThroughFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_output_with_pass_through"; + public EmptyOutputWithPassThroughFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .passThroughColumns() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class EmptyOutputWithPassThroughProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptyOutputWithPassThroughDataProcessor(); + } + } + + // returns an empty Page (one proper column and pass-through, zero rows) for each Page of input + private static class EmptyOutputWithPassThroughDataProcessor + implements TableFunctionDataProcessor + { + // one proper channel, and one pass-through index channel + private static final Page EMPTY_PAGE = new Page( + BOOLEAN.createBlockBuilder(null, 0).build(), + BIGINT.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(List> input) + { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(EMPTY_PAGE); + } + } + } + + public static class TestInputsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_inputs_function"; + public TestInputsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT_1") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_4") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", IntStream.range(0, ((TableArgument) arguments.get("INPUT_1")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_2", IntStream.range(0, ((TableArgument) arguments.get("INPUT_2")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_3", IntStream.range(0, ((TableArgument) arguments.get("INPUT_3")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .requiredColumns("INPUT_4", IntStream.range(0, ((TableArgument) arguments.get("INPUT_4")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputsFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder resultBuilder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(resultBuilder, true); + + Page result = new Page(resultBuilder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class PassThroughInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "pass_through"; + public PassThroughInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .passThroughColumns() + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of( + new Descriptor.Field("input_1_present", Optional.of(BOOLEAN)), + new Descriptor.Field("input_2_present", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .build(); + } + + public static class PassThroughInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new PassThroughInputDataProcessor(); + } + } + + private static class PassThroughInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean input1Present; + private boolean input2Present; + private int input1EndIndex; + private int input2EndIndex; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + + // proper column input_1_present + BlockBuilder input1Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input1Builder, input1Present); + + // proper column input_2_present + BlockBuilder input2Builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(input2Builder, input2Present); + + // pass-through index for input_1 + BlockBuilder input1PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input1Present) { + input1PassThroughBuilder.writeLong(input1EndIndex - 1); + } + else { + input1PassThroughBuilder.appendNull(); + } + + // pass-through index for input_2 + BlockBuilder input2PassThroughBuilder = BIGINT.createBlockBuilder(null, 1); + if (input2Present) { + input2PassThroughBuilder.writeLong(input2EndIndex - 1); + } + else { + input2PassThroughBuilder.appendNull(); + } + + return produced(new Page(input1Builder.build(), input2Builder.build(), input1PassThroughBuilder.build(), input2PassThroughBuilder.build())); + } + input.get(0).ifPresent(page -> { + input1Present = true; + input1EndIndex += page.getPositionCount(); + }); + input.get(1).ifPresent(page -> { + input2Present = true; + input2EndIndex += page.getPositionCount(); + }); + return usedInput(); + } + } + } + + public static class TestInputFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_input"; + public TestInputFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .name("INPUT") + .keepWhenEmpty() + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("got_input", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestInputProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + return new TestInputDataProcessor(); + } + } + + private static class TestInputDataProcessor + implements TableFunctionDataProcessor + { + private boolean processorGotInput; + private boolean finished; + + @Override + public TableFunctionProcessorState process(List> input) + { + if (finished) { + return FINISHED; + } + if (input == null) { + finished = true; + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, processorGotInput); + return produced(new Page(builder.build())); + } + processorGotInput = true; + return usedInput(); + } + } + } + + public static class TestSingleInputRowSemanticsFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "test_single_input_function"; + public TestSingleInputRowSemanticsFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(TableArgumentSpecification.builder() + .rowSemantics() + .name("INPUT") + .build()), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("boolean_result", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .requiredColumns("INPUT", IntStream.range(0, ((TableArgument) arguments.get("INPUT")).getRowType().getFields().size()).boxed().collect(toImmutableList())) + .build(); + } + + public static class TestSingleInputFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionDataProcessor getDataProcessor(ConnectorTableFunctionHandle handle) + { + BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1); + BOOLEAN.writeBoolean(builder, true); + Page result = new Page(builder.build()); + + return input -> { + if (input == null) { + return FINISHED; + } + return usedInputAndProduced(result); + }; + } + } + } + + public static class ConstantFunction + extends AbstractConnectorTableFunction + { + static final String FUNCTION_NAME = "constant"; + public ConstantFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + ScalarArgumentSpecification.builder() + .name("VALUE") + .type(INTEGER) + .build(), + ScalarArgumentSpecification.builder() + .name("N") + .type(INTEGER) + .defaultValue(1L) + .build()), + new DescribedTable(Descriptor.descriptor( + ImmutableList.of("constant_column"), + ImmutableList.of(INTEGER)))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + ScalarArgument count = (ScalarArgument) arguments.get("N"); + requireNonNull(count.getValue(), "count value for function repeat() is null"); + checkArgument((long) count.getValue() > 0, "count value for function repeat() must be positive"); + + return TableFunctionAnalysis.builder() + .handle(new ConstantFunctionHandle((Long) ((ScalarArgument) arguments.get("VALUE")).getValue(), (long) count.getValue())) + .build(); + } + + public static class ConstantFunctionHandle + implements ConnectorTableFunctionHandle + { + private final Long value; + private final long count; + + @JsonCreator + public ConstantFunctionHandle(@JsonProperty("value") Long value, @JsonProperty("count") long count) + { + this.value = value; + this.count = count; + } + + @JsonProperty + public Long getValue() + { + return value; + } + + @JsonProperty + public long getCount() + { + return count; + } + } + + public static class ConstantFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new ConstantFunctionProcessor(((ConstantFunctionHandle) handle).getValue()); + } + } + + public static class ConstantFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final int PAGE_SIZE = 1000; + + private final Long value; + + private long fullPagesCount; + private long processedPages; + private int reminder; + private Block block; + + public ConstantFunctionProcessor(Long value) + { + this.value = value; + } + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + boolean usedData = false; + + if (split != null) { + long count = ((ConstantFunctionSplit) split).getCount(); + this.fullPagesCount = count / PAGE_SIZE; + this.reminder = toIntExact(count % PAGE_SIZE); + if (fullPagesCount > 0) { + BlockBuilder builder = INTEGER.createBlockBuilder(null, PAGE_SIZE); + if (value == null) { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < PAGE_SIZE; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + else { + BlockBuilder builder = INTEGER.createBlockBuilder(null, reminder); + if (value == null) { + for (int i = 0; i < reminder; i++) { + builder.appendNull(); + } + } + else { + for (int i = 0; i < reminder; i++) { + builder.writeInt(toIntExact(value)); + } + } + this.block = builder.build(); + } + usedData = true; + } + + if (processedPages < fullPagesCount) { + processedPages++; + Page result = new Page(block); + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + if (reminder > 0) { + Page result = new Page(block.getRegion(0, toIntExact(reminder))); + reminder = 0; + if (usedData) { + return usedInputAndProduced(result); + } + return produced(result); + } + + return FINISHED; + } + } + + public static ConnectorSplitSource getConstantFunctionSplitSource(ConstantFunctionHandle handle) + { + long splitSize = ConstantFunctionSplit.DEFAULT_SPLIT_SIZE; + ImmutableList.Builder splits = ImmutableList.builder(); + for (long i = 0; i < handle.getCount() / splitSize; i++) { + splits.add(new ConstantFunctionSplit(splitSize)); + } + long remainingSize = handle.getCount() % splitSize; + if (remainingSize > 0) { + splits.add(new ConstantFunctionSplit(remainingSize)); + } + return new FixedSplitSource(splits.build()); + } + + public static final class ConstantFunctionSplit + implements ConnectorSplit + { + private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(ConstantFunctionSplit.class).instanceSize()); + public static final int DEFAULT_SPLIT_SIZE = 5500; + + private final long count; + + @JsonCreator + public ConstantFunctionSplit(@JsonProperty("count") long count) + { + this.count = count; + } + + @JsonProperty + public long getCount() + { + return count; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return count; + } + } + } + + public static class EmptySourceFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "empty_source"; + public EmptySourceFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of(), + new DescribedTable(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN)))))); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .build(); + } + + public static class EmptySourceFunctionProcessorProvider + implements TableFunctionProcessorProvider + { + @Override + public TableFunctionSplitProcessor getSplitProcessor(ConnectorTableFunctionHandle handle) + { + return new EmptySourceFunctionProcessor(); + } + } + + public static class EmptySourceFunctionProcessor + implements TableFunctionSplitProcessor + { + private static final Page EMPTY_PAGE = new Page(BOOLEAN.createBlockBuilder(null, 0).build()); + + @Override + public TableFunctionProcessorState process(ConnectorSplit split) + { + if (split == null) { + return FINISHED; + } + + return usedInputAndProduced(EMPTY_PAGE); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java b/presto-main-base/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java index 5d757b68d1e56..8de52e2c67132 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/cost/TestJoinStatsRule.java @@ -128,8 +128,9 @@ public void testStatsForInnerJoin() @Test public void testStatsForInnerJoinWithRepeatedClause() { - double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS // driver join clause - * UNKNOWN_FILTER_COEFFICIENT; // auxiliary join clause + // When duplicate join clauses are passed, JoinNode deduplicates them, + // so we only have a single join clause and no auxiliary filter coefficient is applied. + double innerJoinRowCount = LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_NDV * LEFT_JOIN_COLUMN_NON_NULLS * RIGHT_JOIN_COLUMN_NON_NULLS; PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, variableStatistics(LEFT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), variableStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), diff --git a/presto-main-base/src/test/java/com/facebook/presto/event/TestQueryMonitorConfig.java b/presto-main-base/src/test/java/com/facebook/presto/event/TestQueryMonitorConfig.java index 5c1df0ce98c64..3f44858b4bb01 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/event/TestQueryMonitorConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/event/TestQueryMonitorConfig.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.event; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -24,7 +24,7 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit; +import static com.facebook.airlift.units.DataSize.Unit; public class TestQueryMonitorConfig { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/MockManagedQueryExecution.java b/presto-main-base/src/test/java/com/facebook/presto/execution/MockManagedQueryExecution.java index 3e8ca37c5795c..c2213c099df6c 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/MockManagedQueryExecution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/MockManagedQueryExecution.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.execution.StateMachine.StateChangeListener; @@ -24,8 +26,6 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.net.URI; import java.util.ArrayList; @@ -33,6 +33,7 @@ import java.util.Optional; import java.util.OptionalDouble; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.SystemSessionProperties.QUERY_PRIORITY; import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.FINISHED; @@ -40,7 +41,6 @@ import static com.facebook.presto.execution.QueryState.RUNNING; import static com.facebook.presto.execution.QueryState.WAITING_FOR_PREREQUISITES; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -136,6 +136,14 @@ public BasicQueryInfo getBasicQueryInfo() 7, 8, 9, + 6, + 7, + 8, + 9, + 6, + 7, + 8, + 9, new DataSize(14, BYTE), 15, 16.0, diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/MockQueryExecution.java b/presto-main-base/src/test/java/com/facebook/presto/execution/MockQueryExecution.java index f3da03ad28744..890df43e12345 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/MockQueryExecution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/MockQueryExecution.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.memory.VersionedMemoryPoolId; import com.facebook.presto.server.BasicQueryInfo; @@ -20,12 +21,14 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; import com.facebook.presto.sql.planner.Plan; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import static com.facebook.airlift.units.Duration.succinctDuration; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + public class MockQueryExecution implements QueryExecution { @@ -179,6 +182,12 @@ public long getCreateTimeInMillis() return 0L; } + @Override + public Duration getQueuedTime() + { + return succinctDuration(0, MILLISECONDS); + } + @Override public long getExecutionStartTimeInMillis() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java b/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java index cc810353617d3..ecbe7c01c33a1 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/MockRemoteTaskFactory.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.cost.StatsAndCosts; @@ -58,9 +59,7 @@ import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.GuardedBy; +import com.google.errorprone.annotations.concurrent.GuardedBy; import java.net.URI; import java.util.ArrayList; @@ -79,20 +78,19 @@ import java.util.stream.Stream; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.execution.StateMachine.StateChangeListener; import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static com.facebook.presto.metadata.MetadataUpdates.DEFAULT_METADATA_UPDATES; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static com.facebook.presto.util.Failures.toFailures; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.Math.addExact; import static java.lang.System.currentTimeMillis; import static java.util.Objects.requireNonNull; @@ -128,6 +126,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L SOURCE_DISTRIBUTION, ImmutableList.of(sourceId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), @@ -302,7 +301,6 @@ public TaskInfo getTaskInfo() ImmutableSet.of(), taskContext.getTaskStats(), true, - DEFAULT_METADATA_UPDATES, nodeId); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java index e6282d6a05a25..a13f796547479 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TaskTestUtils.java @@ -31,7 +31,6 @@ import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelectorConfig; import com.facebook.presto.index.IndexManager; import com.facebook.presto.memory.MemoryManagerConfig; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.metadata.Split; @@ -127,6 +126,7 @@ public static PlanFragment createPlanFragment() ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(VARIABLE)) .withBucketToPartition(Optional.of(new int[1])), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), @@ -163,7 +163,6 @@ public static LocalExecutionPlanner createTestingPlanner() partitioningProviderManager, nodePartitioningManager, new PageSinkManager(), - new ConnectorMetadataUpdaterManager(), new ExpressionCompiler(metadata, pageFunctionCompiler), pageFunctionCompiler, new JoinFilterFunctionCompiler(metadata), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskWithConnectorType.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TaskWithConnectorType.java deleted file mode 100644 index 0a8aa7d43c7f7..0000000000000 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TaskWithConnectorType.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.execution; - -import com.facebook.drift.annotations.ThriftConstructor; -import com.facebook.drift.annotations.ThriftField; -import com.facebook.drift.annotations.ThriftStruct; -import com.facebook.presto.server.thrift.Any; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - -@ThriftStruct -public class TaskWithConnectorType -{ - private final int value; - //Connector specific Type - private ConnectorMetadataUpdateHandle connectorMetadataUpdateHandle; - //Connector specific Type serialized as Any - private Any connectorMetadataUpdateHandleAny; - - @JsonCreator - public TaskWithConnectorType( - int value, - ConnectorMetadataUpdateHandle connectorMetadataUpdateHandle) - { - this.value = value; - this.connectorMetadataUpdateHandle = connectorMetadataUpdateHandle; - } - - @ThriftConstructor - public TaskWithConnectorType( - int value, - Any connectorMetadataUpdateHandleAny) - { - this.value = value; - this.connectorMetadataUpdateHandleAny = connectorMetadataUpdateHandleAny; - } - - @ThriftField(1) - @JsonProperty - public int getValue() - { - return value; - } - - @ThriftField(2) - public Any getConnectorMetadataUpdateHandleAny() - { - return connectorMetadataUpdateHandleAny; - } - - @JsonProperty - public ConnectorMetadataUpdateHandle getConnectorMetadataUpdateHandle() - { - return connectorMetadataUpdateHandle; - } -} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterOverloadConfig.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterOverloadConfig.java new file mode 100644 index 0000000000000..b28aee410fb20 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterOverloadConfig.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.presto.execution.ClusterOverloadConfig.OVERLOAD_POLICY_CNT_BASED; + +public class TestClusterOverloadConfig +{ + @Test + public void testDefaults() + { + ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(ClusterOverloadConfig.class) + .setClusterOverloadThrottlingEnabled(false) + .setAllowedOverloadWorkersPct(0.01) + .setAllowedOverloadWorkersCnt(0) + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED) + .setOverloadCheckCacheTtlInSecs(5)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("cluster-overload.enable-throttling", "true") + .put("cluster-overload.allowed-overload-workers-pct", "0.05") + .put("cluster-overload.allowed-overload-workers-cnt", "5") + .put("cluster-overload.overload-policy-type", "overload_worker_pct_based_throttling") + .put("cluster.overload-check-cache-ttl-secs", "10") + .build(); + + ClusterOverloadConfig expected = new ClusterOverloadConfig() + .setClusterOverloadThrottlingEnabled(true) + .setAllowedOverloadWorkersPct(0.05) + .setAllowedOverloadWorkersCnt(5) + .setOverloadPolicyType("overload_worker_pct_based_throttling") + .setOverloadCheckCacheTtlInSecs(10); + + ConfigAssertions.assertFullMapping(properties, expected); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java index 9ef10e188e0d7..89ac8c3caa263 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestClusterSizeMonitor.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.spi.ConnectorId; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java index 95dd8cdfbd3df..144555f3ecb7f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestCreateMaterializedViewTask.java @@ -20,11 +20,18 @@ import com.facebook.presto.common.block.TestingBlockEncodingSerde; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.cost.HistoryBasedOptimizationConfig; +import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; +import com.facebook.presto.execution.warnings.WarningCollectorConfig; +import com.facebook.presto.memory.MemoryManagerConfig; +import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.metadata.AbstractMockMetadata; import com.facebook.presto.metadata.Catalog; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.MaterializedViewPropertyManager; +import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; @@ -40,12 +47,21 @@ import com.facebook.presto.spi.analyzer.MetadataResolver; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AllowAllAccessControl; +import com.facebook.presto.spiller.NodeSpillConfig; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.analyzer.JavaFeaturesConfig; +import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.ParsingOptions; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.CompilerConfig; import com.facebook.presto.sql.tree.CreateMaterializedView; +import com.facebook.presto.testing.TestProcedureRegistry; +import com.facebook.presto.tracing.TracingConfig; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableList; import org.testng.annotations.BeforeMethod; @@ -62,16 +78,20 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; +import static com.facebook.presto.spi.session.PropertyMetadata.durationProperty; import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; import static com.facebook.presto.testing.TestingSession.createBogusTestingCatalog; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static com.google.common.base.Throwables.getRootCause; +import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; import static org.testng.Assert.fail; @Test(singleThreaded = true) @@ -85,6 +105,7 @@ public class TestCreateMaterializedViewTask private TransactionManager transactionManager; private Session testSession; + private SessionPropertyManager sessionPropertyManager; private AccessControl accessControl; @@ -106,10 +127,20 @@ public void setUp() ColumnPropertyManager columnPropertyManager = new ColumnPropertyManager(); columnPropertyManager.addProperties(testCatalog.getConnectorId(), ImmutableList.of()); + MaterializedViewPropertyManager materializedViewPropertyManager = new MaterializedViewPropertyManager(); + materializedViewPropertyManager.addProperties(testCatalog.getConnectorId(), ImmutableList.of( + stringProperty("storage_schema", "Schema for the materialized view storage table", null, false), + stringProperty("storage_table", "Custom name for the materialized view storage table", null, false), + stringProperty("stale_read_behavior", "Behavior when reading from a stale materialized view", null, false), + durationProperty("staleness_window", "Staleness window for materialized view", null, false), + stringProperty("refresh_type", "Refresh type for materialized view", null, false))); + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + sessionPropertyManager = createSessionPropertyManager(); + transactionManager = createTestTransactionManager(catalogManager); - testSession = testSessionBuilder() + testSession = testSessionBuilder(sessionPropertyManager) .setTransactionId(transactionManager.beginTransaction(false)) .build(); @@ -119,8 +150,10 @@ public void setUp() metadata = new MockMetadata( functionAndTypeManager, + new TestProcedureRegistry(), tablePropertyManager, columnPropertyManager, + materializedViewPropertyManager, testCatalog.getConnectorId()); } @@ -128,23 +161,10 @@ public void setUp() public void testCreateMaterializedViewNotExistsTrue() { SqlParser parser = new SqlParser(); - String sql = String.format("CREATE MATERIALIZED VIEW IF NOT EXISTS %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + String sql = format("CREATE MATERIALIZED VIEW IF NOT EXISTS %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); - QueryStateMachine stateMachine = QueryStateMachine.begin( - sql, - Optional.empty(), - testSession, - URI.create("fake://uri"), - new ResourceGroupId("test"), - Optional.empty(), - false, - transactionManager, - accessControl, - executorService, - metadata, - WarningCollector.NOOP); - WarningCollector warningCollector = stateMachine.getWarningCollector(); + WarningCollector warningCollector = createWarningCollector(sql, testSession); CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, testSession, emptyList(), warningCollector, sql)); @@ -155,23 +175,10 @@ public void testCreateMaterializedViewNotExistsTrue() public void testCreateMaterializedViewExistsFalse() { SqlParser parser = new SqlParser(); - String sql = String.format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_B, TABLE_A); + String sql = format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_B, TABLE_A); CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); - QueryStateMachine stateMachine = QueryStateMachine.begin( - sql, - Optional.empty(), - testSession, - URI.create("fake://uri"), - new ResourceGroupId("test"), - Optional.empty(), - false, - transactionManager, - accessControl, - executorService, - metadata, - WarningCollector.NOOP); - WarningCollector warningCollector = stateMachine.getWarningCollector(); + WarningCollector warningCollector = createWarningCollector(sql, testSession); try { getFutureValue(new CreateMaterializedViewTask(parser).execute(statement, transactionManager, metadata, accessControl, testSession, emptyList(), warningCollector, sql)); fail("expected exception"); @@ -186,35 +193,277 @@ public void testCreateMaterializedViewExistsFalse() assertEquals(metadata.getCreateMaterializedViewCallCount(), 0); } + @Test + public void testCreateMaterializedViewWithDefinerSecurity() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s SECURITY DEFINER AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithNonLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithNonLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithNonLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + MaterializedViewDefinition createdView = metadata.getLastCreatedMaterializedViewDefinition(); + assertTrue(createdView.getOwner().isPresent(), "DEFINER security should have owner set"); + assertEquals(createdView.getOwner().get(), sessionWithNonLegacyMV.getUser()); + } + + @Test + public void testCreateMaterializedViewWithInvokerSecurity() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s SECURITY INVOKER AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithNonLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithNonLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithNonLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + MaterializedViewDefinition createdView = metadata.getLastCreatedMaterializedViewDefinition(); + assertTrue(createdView.getOwner().isPresent(), "INVOKER security should have owner set"); + } + + @Test + public void testCreateMaterializedViewWithDefaultDefinerSecurity() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithNonLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .setSystemProperty("default_view_security_mode", "DEFINER") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithNonLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithNonLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + MaterializedViewDefinition createdView = metadata.getLastCreatedMaterializedViewDefinition(); + assertTrue(createdView.getOwner().isPresent(), "Default DEFINER security should have owner set"); + assertEquals(createdView.getOwner().get(), sessionWithNonLegacyMV.getUser()); + } + + @Test + public void testCreateMaterializedViewWithDefaultInvokerSecurity() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithNonLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .setSystemProperty("default_view_security_mode", "INVOKER") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithNonLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithNonLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + MaterializedViewDefinition createdView = metadata.getLastCreatedMaterializedViewDefinition(); + assertTrue(createdView.getOwner().isPresent(), "Default INVOKER security should have owner set"); + } + + @Test + public void testCreateMaterializedViewWithSecurityInLegacyMode() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s SECURITY INVOKER AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "true") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + + SemanticException exception = expectThrows(SemanticException.class, () -> + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithLegacyMV, emptyList(), warningCollector, sql))); + assertTrue(exception.getMessage().contains("SECURITY clause is not supported when legacy_materialized_views is enabled")); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 0); + } + + @Test + public void testCreateMaterializedViewInLegacyModeAlwaysHasOwner() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "true") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + MaterializedViewDefinition createdView = metadata.getLastCreatedMaterializedViewDefinition(); + assertTrue(createdView.getOwner().isPresent(), "Legacy mode should always have owner set"); + assertEquals(createdView.getOwner().get(), sessionWithLegacyMV.getUser()); + } + + @Test + public void testCreateMaterializedViewWithMaterializedViewProperties() + { + SqlParser parser = new SqlParser(); + String sql = format( + "CREATE MATERIALIZED VIEW %s " + + "WITH (stale_read_behavior = 'FAIL', " + + "staleness_window = '1h', " + + "refresh_type = 'FULL') " + + "AS SELECT 2021 AS col_0 FROM %s", + MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithNonLegacyMV = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithNonLegacyMV); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithNonLegacyMV, emptyList(), warningCollector, sql)); + + assertEquals(metadata.getCreateMaterializedViewCallCount(), 1); + + ConnectorTableMetadata viewMetadata = metadata.getLastCreatedViewMetadata(); + Map properties = viewMetadata.getProperties(); + + assertEquals(properties.get("stale_read_behavior"), "FAIL"); + assertEquals(properties.get("staleness_window").toString(), "1.00h"); + assertEquals(properties.get("refresh_type"), "FULL"); + } + + @Test + public void testCreateMaterializedViewWithInvalidDefaultViewSecurityMode() + { + SqlParser parser = new SqlParser(); + String sql = format("CREATE MATERIALIZED VIEW %s AS SELECT 2021 AS col_0 FROM %s", MATERIALIZED_VIEW_A, TABLE_A); + CreateMaterializedView statement = (CreateMaterializedView) parser.createStatement(sql, ParsingOptions.builder().build()); + + Session sessionWithInvalidSecurityMode = testSessionBuilder(sessionPropertyManager) + .setTransactionId(transactionManager.beginTransaction(false)) + .setSystemProperty("legacy_materialized_views", "false") + .setSystemProperty("default_view_security_mode", "INVALID") + .build(); + + WarningCollector warningCollector = createWarningCollector(sql, sessionWithInvalidSecurityMode); + CreateMaterializedViewTask createMaterializedViewTask = new CreateMaterializedViewTask(parser); + + Exception exception = expectThrows(Exception.class, () -> + getFutureValue(createMaterializedViewTask.execute(statement, transactionManager, metadata, accessControl, sessionWithInvalidSecurityMode, emptyList(), warningCollector, sql))); + + Throwable rootCause = getRootCause(exception); + assertTrue(rootCause instanceof IllegalArgumentException, + "Expected IllegalArgumentException but got: " + rootCause.getClass().getName()); + assertTrue(rootCause.getMessage().contains("INVALID") || rootCause.getMessage().contains("ViewSecurity"), + "Exception message should mention INVALID or ViewSecurity but was: " + rootCause.getMessage()); + assertEquals(metadata.getCreateMaterializedViewCallCount(), 0); + } + + private static SessionPropertyManager createSessionPropertyManager() + { + FeaturesConfig featuresConfig = new FeaturesConfig() + .setAllowLegacyMaterializedViewsToggle(true); + + return SessionPropertyManager.createTestingSessionPropertyManager( + new com.facebook.presto.SystemSessionProperties( + new QueryManagerConfig(), + new TaskManagerConfig(), + new MemoryManagerConfig(), + featuresConfig, + new FunctionsConfig(), + new NodeMemoryConfig(), + new WarningCollectorConfig(), + new NodeSchedulerConfig(), + new NodeSpillConfig(), + new TracingConfig(), + new CompilerConfig(), + new HistoryBasedOptimizationConfig()).getSessionProperties(), + featuresConfig, + new JavaFeaturesConfig(), + new NodeSpillConfig()); + } + + private WarningCollector createWarningCollector(String sql, Session session) + { + QueryStateMachine stateMachine = QueryStateMachine.begin( + sql, + Optional.empty(), + session, + URI.create("fake://uri"), + new ResourceGroupId("test"), + Optional.empty(), + false, + transactionManager, + accessControl, + executorService, + metadata, + WarningCollector.NOOP); + return stateMachine.getWarningCollector(); + } + private static class MockMetadata extends AbstractMockMetadata { private final FunctionAndTypeManager functionAndTypeManager; + private final ProcedureRegistry procedureRegistry; private final TablePropertyManager tablePropertyManager; private final ColumnPropertyManager columnPropertyManager; + private final MaterializedViewPropertyManager materializedViewPropertyManager; private final ConnectorId catalogHandle; private final List materializedViews = new CopyOnWriteArrayList<>(); + private MaterializedViewDefinition lastCreatedMaterializedViewDefinition; public MockMetadata( FunctionAndTypeManager functionAndTypeManager, + ProcedureRegistry procedureRegistry, TablePropertyManager tablePropertyManager, ColumnPropertyManager columnPropertyManager, + MaterializedViewPropertyManager materializedViewPropertyManager, ConnectorId catalogHandle) { this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); this.tablePropertyManager = requireNonNull(tablePropertyManager, "tablePropertyManager is null"); this.columnPropertyManager = requireNonNull(columnPropertyManager, "columnPropertyManager is null"); + this.materializedViewPropertyManager = requireNonNull(materializedViewPropertyManager, "materializedViewPropertyManager is null"); this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); } @Override public void createMaterializedView(Session session, String catalogName, ConnectorTableMetadata viewMetadata, MaterializedViewDefinition viewDefinition, boolean ignoreExisting) { - if (!ignoreExisting) { + // Check if materialized view already exists (MATERIALIZED_VIEW_B always exists for testing) + if (viewMetadata.getTable().getTableName().equals(MATERIALIZED_VIEW_B) && !ignoreExisting) { throw new PrestoException(ALREADY_EXISTS, "Materialized view already exists"); } this.materializedViews.add(viewMetadata); + this.lastCreatedMaterializedViewDefinition = viewDefinition; } public int getCreateMaterializedViewCallCount() @@ -222,6 +471,16 @@ public int getCreateMaterializedViewCallCount() return materializedViews.size(); } + public MaterializedViewDefinition getLastCreatedMaterializedViewDefinition() + { + return lastCreatedMaterializedViewDefinition; + } + + public ConnectorTableMetadata getLastCreatedViewMetadata() + { + return materializedViews.isEmpty() ? null : materializedViews.get(materializedViews.size() - 1); + } + @Override public TablePropertyManager getTablePropertyManager() { @@ -234,12 +493,24 @@ public ColumnPropertyManager getColumnPropertyManager() return columnPropertyManager; } + @Override + public MaterializedViewPropertyManager getMaterializedViewPropertyManager() + { + return materializedViewPropertyManager; + } + @Override public FunctionAndTypeManager getFunctionAndTypeManager() { return functionAndTypeManager; } + @Override + public ProcedureRegistry getProcedureRegistry() + { + return procedureRegistry; + } + @Override public Type getType(TypeSignature signature) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java new file mode 100644 index 0000000000000..df517a5e16027 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestExecuteProcedureHandle.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.ExecuteProcedureHandle; +import com.facebook.presto.metadata.DistributedProcedureHandle; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.server.SliceDeserializer; +import com.facebook.presto.server.SliceSerializer; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.Serialization; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.testing.TestingHandle; +import com.facebook.presto.testing.TestingHandleResolver; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.type.TypeDeserializer; +import com.google.common.collect.ImmutableList; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Scopes; +import io.airlift.slice.Slice; +import org.testng.annotations.Test; + +import java.util.UUID; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static org.testng.Assert.assertEquals; + +public class TestExecuteProcedureHandle +{ + @Test + public void testExecuteProcedureHandleRoundTrip() + { + String catalogName = "test_catalog"; + JsonCodec codec = createJsonCodec(catalogName); + UUID uuid = UUID.randomUUID(); + ExecuteProcedureHandle expected = createExecuteProcedureHandle(catalogName, uuid); + ExecuteProcedureHandle actual = codec.fromJson(codec.toJson(expected)); + + assertEquals(actual.getProcedureName(), expected.getProcedureName()); + assertEquals(actual.getSchemaTableName(), expected.getSchemaTableName()); + assertEquals(actual.getHandle().getClass(), expected.getHandle().getClass()); + assertEquals(actual.getHandle().getConnectorId(), expected.getHandle().getConnectorId()); + assertEquals(actual.getHandle().getTransactionHandle(), expected.getHandle().getTransactionHandle()); + assertEquals(actual.getHandle().getConnectorHandle(), expected.getHandle().getConnectorHandle()); + } + + private static JsonCodec createJsonCodec(String catalogName) + { + Module module = binder -> { + SqlParser sqlParser = new SqlParser(); + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + binder.install(new JsonModule()); + binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + binder.bind(SqlParser.class).toInstance(sqlParser); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + configBinder(binder).bindConfig(FeaturesConfig.class); + newSetBinder(binder, Type.class); + jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); + jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + jsonBinder(binder).addSerializerBinding(Expression.class).to(Serialization.ExpressionSerializer.class); + jsonBinder(binder).addDeserializerBinding(Expression.class).to(Serialization.ExpressionDeserializer.class); + jsonBinder(binder).addDeserializerBinding(FunctionCall.class).to(Serialization.FunctionCallDeserializer.class); + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); + jsonCodecBinder(binder).bindJsonCodec(ExecuteProcedureHandle.class); + }; + Bootstrap app = new Bootstrap(ImmutableList.of(module)); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + injector.getInstance(HandleResolver.class) + .addConnectorName(catalogName, new TestingHandleResolver()); + return injector.getInstance(new Key>() {}); + } + + private static ExecuteProcedureHandle createExecuteProcedureHandle(String catalogName, UUID uuid) + { + DistributedProcedureHandle distributedProcedureHandle = new DistributedProcedureHandle( + new ConnectorId(catalogName), + new TestingTransactionHandle(uuid), + TestingHandle.INSTANCE); + return new ExecuteProcedureHandle(distributedProcedureHandle, + new SchemaTableName("schema1", "table1"), + QualifiedObjectName.valueOf(catalogName, "schema1", "table1")); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestInput.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestInput.java index adaa1f9bccf95..3442b6f0bcfb9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestInput.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestInput.java @@ -34,7 +34,7 @@ public void testRoundTrip() new Column("column2", "string"), new Column("column3", "string")), Optional.empty(), - ""); + Optional.empty()); String json = codec.toJson(expected); Input actual = codec.fromJson(json); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java index 93ef33c6e797b..281997a3246b9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeScheduler.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.dispatcher.NoOpQueryManager; @@ -52,6 +54,7 @@ import com.facebook.presto.ttl.nodettlfetchermanagers.ThrowingNodeTtlFetcherManager; import com.facebook.presto.util.FinalizerService; import com.google.common.base.Splitter; +import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMultimap; @@ -59,8 +62,6 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Multimap; import com.google.common.collect.Sets; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -190,6 +191,14 @@ private BasicQueryStats getBasicQueryStats(Duration executionTime) 0, 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, DataSize.valueOf("1MB"), 0, 0, @@ -1128,7 +1137,7 @@ public void testMemoryUsage() } @Test - public void testMaxTasksPerStageWittLimit() + public void testMaxTasksPerStageWithLimit() { NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); TestingTransactionHandle transactionHandle = TestingTransactionHandle.create(); @@ -1233,6 +1242,119 @@ private static Session sessionWithTtlAwareSchedulingStrategyAndEstimatedExecutio .build(); } + private static Session sessionWithScheduleSplitsBasedOnTaskLoad(boolean scheduleSplitsBasedOnTaskLoad) + { + return TestingSession.testSessionBuilder() + .setSystemProperty("schedule_splits_based_on_task_load", String.valueOf(scheduleSplitsBasedOnTaskLoad)) + .build(); + } + + @Test + public void testScheduleSplitsBasedOnTaskLoad() + { + List existingTasks = new ArrayList<>(); + try { + // Test with scheduleSplitsBasedOnTaskLoad enabled + Session taskLoadSession = sessionWithScheduleSplitsBasedOnTaskLoad(true); + NodeSelector taskLoadNodeSelector = nodeScheduler.createNodeSelector(taskLoadSession, CONNECTOR_ID); + + TestingTransactionHandle transactionHandle = TestingTransactionHandle.create(); + MockRemoteTaskFactory remoteTaskFactory = new MockRemoteTaskFactory(remoteTaskExecutor, remoteTaskScheduledExecutor); + + // Create existing tasks with different split weights to test task load selection. + // We will have two queries 'test1' and 'test2'. + // 'test1' would have more load on node 1 and less on node 2 and 3. + // 'test2' would have more load on nodes 2 and 3 and very little on node 1. + // Thus, we will have more total load on nodes 2 and 3, but less for 'test1' on them. + Set nodes = nodeManager.getActiveConnectorNodes(CONNECTOR_ID); + Map nodeToTaskMap = new HashMap<>(); + for (InternalNode node : nodes) { + int nodeIndex = Integer.parseInt(node.getNodeIdentifier().substring("other".length())); + + // Create tasks for query 'test1' with different loads: task 1 (for node 1) has more load. + int initialSplitsCount = (nodeIndex == 1) ? 5 : (nodeIndex == 2) ? 3 : 2; // First task more loaded + List initialSplits = new ArrayList<>(); + for (int j = 0; j < initialSplitsCount; j++) { + initialSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote())); + } + + TaskId taskId = new TaskId("test1", 1, 0, nodeIndex, 0); + MockRemoteTaskFactory.MockRemoteTask remoteTask = remoteTaskFactory.createTableScanTask( + taskId, node, initialSplits, nodeTaskMap.createTaskStatsTracker(node, taskId)); + remoteTask.startSplits(1); + + nodeTaskMap.addTask(node, remoteTask); + nodeToTaskMap.put(node, remoteTask); + existingTasks.add(remoteTask); + + // Create tasks for query 'test2' with different loads: tasks 2 and 3 (for nodes 2 and 3) have more load. + initialSplitsCount = (nodeIndex == 1) ? 1 : 7; // First task less loaded + initialSplits = new ArrayList<>(); + for (int j = 0; j < initialSplitsCount; j++) { + initialSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote())); + } + + taskId = new TaskId("test2", 1, 0, nodeIndex, 0); + remoteTask = remoteTaskFactory.createTableScanTask( + taskId, node, initialSplits, nodeTaskMap.createTaskStatsTracker(node, taskId)); + remoteTask.startSplits(1); + + nodeTaskMap.addTask(node, remoteTask); + } + + // Split situation is now the following (initial + second query + assigned): + // other1: 5 + 1 = 6 + // other2: 3 + 7 = 10 + // other3: 2 + 7 = 9 + // The task-based assignment would pick nodes where the tasks of the 1st query have fewer splits, + // namely nodes 2 and 3, even though they have more splits in total. + + // Create new splits to assign + Set newSplits = new HashSet<>(); + int numNewSplits = 4; + for (int i = 0; i < numNewSplits; i++) { + newSplits.add(new Split(CONNECTOR_ID, transactionHandle, new TestSplitRemote())); + } + + // Verify that splits were assigned only to nodes 2 and 3 + SplitPlacementResult result = taskLoadNodeSelector.computeAssignments(newSplits, existingTasks); + Multimap assignments = result.getAssignments(); + assertEquals(assignments.size(), numNewSplits); + for (InternalNode node : assignments.keySet()) { + assertTrue(node.getNodeIdentifier().equals("other2") || node.getNodeIdentifier().equals("other3")); + } + + PlanNodeId planNodeId = new PlanNodeId("sourceId"); + for (InternalNode node : assignments.keySet()) { + Multimap splits = ArrayListMultimap.create(); + for (Split split : assignments.get(node)) { + splits.put(planNodeId, split); + } + nodeToTaskMap.get(node).addSplits(splits); + } + // Split situation is now the following (initial + second query + assigned) = task/node: + // other1: 5 + 1 + 0 = 5/6 + // other2: 3 + 7 + 2 = 5/12 + // other3: 2 + 7 + 2 = 4/12 + // The task-based assignment would pick nodes where the tasks of the 1st query have fewer splits, + // this time all nodes would be included as the low loaded ones catch up with the high loaded. + + // Verify that splits were assigned to all nodes. + result = taskLoadNodeSelector.computeAssignments(newSplits, existingTasks); + assignments = result.getAssignments(); + assertEquals(assignments.size(), numNewSplits); + for (InternalNode node : assignments.keySet()) { + assertTrue(node.getNodeIdentifier().equals("other1") || node.getNodeIdentifier().equals("other2") || node.getNodeIdentifier().equals("other3")); + } + } + finally { + // Cleanup + for (RemoteTask task : existingTasks) { + task.abort(); + } + } + } + private static PartitionedSplitsInfo standardWeightSplitsInfo(int splitCount) { return PartitionedSplitsInfo.forSplitCountAndWeightSum(splitCount, SplitWeight.rawValueForStandardSplitCount(splitCount)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeSchedulerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeSchedulerConfig.java index dc41f7026bdc9..45e3b64b52e57 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeSchedulerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestNodeSchedulerConfig.java @@ -35,6 +35,8 @@ public void testDefaults() .setNetworkTopology(LEGACY) .setMinCandidates(10) .setMaxSplitsPerNode(100) + .setMaxSplitsPerTask(10) + .setScheduleSplitsBasedOnTaskLoad(false) .setMaxPendingSplitsPerTask(10) .setMaxUnacknowledgedSplitsPerTask(500) .setIncludeCoordinator(true) @@ -54,6 +56,8 @@ public void testExplicitPropertyMappings() .put("node-scheduler.max-pending-splits-per-task", "11") .put("node-scheduler.max-unacknowledged-splits-per-task", "501") .put("node-scheduler.max-splits-per-node", "101") + .put("node-scheduler.max-splits-per-task", "17") + .put("node-scheduler.schedule-splits-based-on-task-load", "true") .put("node-scheduler.node-selection-hash-strategy", "CONSISTENT_HASHING") .put("node-scheduler.consistent-hashing-min-virtual-node-count", "2000") .put("experimental.resource-aware-scheduling-strategy", "TTL") @@ -64,6 +68,8 @@ public void testExplicitPropertyMappings() .setNetworkTopology("flat") .setIncludeCoordinator(false) .setMaxSplitsPerNode(101) + .setMaxSplitsPerTask(17) + .setScheduleSplitsBasedOnTaskLoad(true) .setMaxPendingSplitsPerTask(11) .setMaxUnacknowledgedSplitsPerTask(501) .setMinCandidates(11) diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestOutput.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestOutput.java index a427cf4f4878b..454b29bfcfde0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestOutput.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestOutput.java @@ -14,13 +14,16 @@ package com.facebook.presto.execution; import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.SourceColumn; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.util.Optional; -import static com.facebook.presto.spi.connector.ConnectorCommitHandle.EMPTY_COMMIT_OUTPUT; import static org.testng.Assert.assertEquals; public class TestOutput @@ -34,10 +37,13 @@ public void testRoundTrip() new ConnectorId("connectorId"), "schema", "table", - EMPTY_COMMIT_OUTPUT, Optional.of( ImmutableList.of( - new Column("column", "type")))); + new OutputColumnMetadata( + "column", "type", + ImmutableSet.of( + new SourceColumn(QualifiedObjectName.valueOf("catalog.schema.table"), "column"))))), + Optional.empty()); String json = codec.toJson(expected); Output actual = codec.fromJson(json); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java index 5bc55f203713b..3428fa883e50b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryInfo.java @@ -16,11 +16,13 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.common.plan.PlanCanonicalizationStrategy; import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsAndCosts; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; @@ -30,6 +32,7 @@ import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.WarningCode; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.ValuesNode; @@ -101,7 +104,7 @@ public void testQueryInfoRoundTrip() assertEquals(actual.getStartedTransactionId(), expected.getStartedTransactionId()); assertEquals(actual.isClearTransactionId(), expected.isClearTransactionId()); - assertEquals(actual.getUpdateType(), expected.getUpdateType()); + assertEquals(actual.getUpdateInfo(), expected.getUpdateInfo()); assertEquals(actual.getOutputStage(), expected.getOutputStage()); assertEquals(actual.getFailureInfo(), expected.getFailureInfo()); @@ -132,10 +135,12 @@ private static JsonCodec createJsonCodec() SqlParser sqlParser = new SqlParser(); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); + binder.install(new ThriftCodecModule()); binder.install(new HandleJsonModule()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); configBinder(binder).bindConfig(FeaturesConfig.class); + binder.bind(ConnectorManager.class).toProvider(() -> null); newSetBinder(binder, Type.class); jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); @@ -178,12 +183,12 @@ private static QueryInfo createQueryInfo() ImmutableSet.of("deallocated_prepared_statement", "statement"), Optional.of(TransactionId.create()), true, - "update_type", + new UpdateInfo("update_type", ""), Optional.empty(), null, null, ImmutableList.of(new PrestoWarning(new WarningCode(1, "name"), "message")), - ImmutableSet.of(new Input(new ConnectorId("connector"), "schema", "table", Optional.empty(), ImmutableList.of(new Column("name", "type")), Optional.empty(), "")), + ImmutableSet.of(new Input(new ConnectorId("connector"), "schema", "table", Optional.empty(), ImmutableList.of(new Column("name", "type")), Optional.empty(), Optional.empty())), Optional.empty(), true, Optional.empty(), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryLimit.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryLimit.java index 28534f907d749..6a9778f3f2c67 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryLimit.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryLimit.java @@ -13,10 +13,12 @@ */ package com.facebook.presto.execution; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import org.testng.annotations.Test; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.execution.QueryLimit.Source.QUERY; import static com.facebook.presto.execution.QueryLimit.Source.RESOURCE_GROUP; import static com.facebook.presto.execution.QueryLimit.Source.SYSTEM; @@ -24,8 +26,6 @@ import static com.facebook.presto.execution.QueryLimit.createDurationLimit; import static com.facebook.presto.execution.QueryLimit.getMinimum; import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MINUTES; import static org.testng.Assert.assertThrows; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManager.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManager.java index c852d70c9cb8e..dbebdae2a81ba 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManager.java @@ -53,6 +53,12 @@ public QueryInfo getFullQueryInfo(QueryId queryId) return null; } + @Override + public long getDurationUntilExpirationInMillis(QueryId queryId) + { + return 0; + } + public Session getQuerySession(QueryId queryId) { return null; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManagerConfig.java index 4d06388452f29..109a1be212e3c 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryManagerConfig.java @@ -14,18 +14,18 @@ package com.facebook.presto.execution; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.QueryManagerConfig.ExchangeMaterializationStrategy; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; import java.util.concurrent.TimeUnit; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.Unit.PETABYTE; -import static io.airlift.units.DataSize.Unit.TERABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.PETABYTE; +import static com.facebook.airlift.units.DataSize.Unit.TERABYTE; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MINUTES; @@ -62,6 +62,7 @@ public void testDefaults() .setRemoteTaskMaxCallbackThreads(Runtime.getRuntime().availableProcessors()) .setQueryExecutionPolicy("all-at-once") .setQueryMaxRunTime(new Duration(100, TimeUnit.DAYS)) + .setQueryMaxQueuedTime(new Duration(100, TimeUnit.DAYS)) .setQueryMaxExecutionTime(new Duration(100, TimeUnit.DAYS)) .setQueryMaxCpuTime(new Duration(1_000_000_000, TimeUnit.DAYS)) .setQueryMaxScanRawInputBytes(new DataSize(1000, PETABYTE)) @@ -84,7 +85,9 @@ public void testDefaults() .setRateLimiterCacheLimit(1000) .setRateLimiterCacheWindowMinutes(5) .setEnableWorkerIsolation(false) - .setMinColumnarEncodingChannelsToPreferRowWiseEncoding(1000)); + .setMinColumnarEncodingChannelsToPreferRowWiseEncoding(1000) + .setMaxQueryAdmissionsPerSecond(Integer.MAX_VALUE) + .setMinRunningQueriesForPacing(30)); } @Test @@ -115,6 +118,7 @@ public void testExplicitPropertyMappings() .put("query.remote-task.max-callback-threads", "11") .put("query.execution-policy", "phased") .put("query.max-run-time", "2h") + .put("query.max-queued-time", "1h") .put("query.max-execution-time", "3h") .put("query.max-cpu-time", "2d") .put("query.max-scan-raw-input-bytes", "1MB") @@ -139,6 +143,8 @@ public void testExplicitPropertyMappings() .put("query.cte-partitioning-provider-catalog", "hive") .put("query-manager.enable-worker-isolation", "true") .put("min-columnar-encoding-channels-to-prefer-row-wise-encoding", "123") + .put("query-manager.query-pacing.max-queries-per-second", "10") + .put("query-manager.query-pacing.min-running-queries", "5") .build(); QueryManagerConfig expected = new QueryManagerConfig() @@ -167,6 +173,7 @@ public void testExplicitPropertyMappings() .setRemoteTaskMaxCallbackThreads(11) .setQueryExecutionPolicy("phased") .setQueryMaxRunTime(new Duration(2, TimeUnit.HOURS)) + .setQueryMaxQueuedTime(new Duration(1, TimeUnit.HOURS)) .setQueryMaxExecutionTime(new Duration(3, TimeUnit.HOURS)) .setQueryMaxCpuTime(new Duration(2, TimeUnit.DAYS)) .setQueryMaxScanRawInputBytes(new DataSize(1, MEGABYTE)) @@ -190,7 +197,9 @@ public void testExplicitPropertyMappings() .setRateLimiterCacheWindowMinutes(60) .setCtePartitioningProviderCatalog("hive") .setEnableWorkerIsolation(true) - .setMinColumnarEncodingChannelsToPreferRowWiseEncoding(123); + .setMinColumnarEncodingChannelsToPreferRowWiseEncoding(123) + .setMaxQueryAdmissionsPerSecond(10) + .setMinRunningQueriesForPacing(5); ConfigAssertions.assertFullMapping(properties, expected); } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStateMachine.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStateMachine.java index 8172ce296a0c9..e5fbe269ba9db 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStateMachine.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStateMachine.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.client.FailureInfo; import com.facebook.presto.common.resourceGroups.QueryType; @@ -26,6 +27,7 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.facebook.presto.spi.security.AccessControl; @@ -36,7 +38,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -82,7 +83,7 @@ public class TestQueryStateMachine private static final String QUERY = "sql"; private static final URI LOCATION = URI.create("fake://fake-query"); private static final SQLException FAILED_CAUSE = new SQLException("FAILED"); - private static final List INPUTS = ImmutableList.of(new Input(new ConnectorId("connector"), "schema", "table", Optional.empty(), ImmutableList.of(new Column("a", "varchar")), Optional.empty(), "")); + private static final List INPUTS = ImmutableList.of(new Input(new ConnectorId("connector"), "schema", "table", Optional.empty(), ImmutableList.of(new Column("a", "varchar")), Optional.empty(), Optional.empty())); private static final Optional OUTPUT = Optional.empty(); private static final List OUTPUT_FIELD_NAMES = ImmutableList.of("a", "b", "c"); private static final List OUTPUT_FIELD_TYPES = ImmutableList.of(BIGINT, BIGINT, BIGINT); @@ -558,7 +559,7 @@ private static void assertState(QueryStateMachine stateMachine, QueryState expec assertEquals(queryInfo.getInputs(), INPUTS); assertEquals(queryInfo.getOutput(), OUTPUT); assertEquals(queryInfo.getFieldNames(), OUTPUT_FIELD_NAMES); - assertEquals(queryInfo.getUpdateType(), UPDATE_TYPE); + assertEquals(queryInfo.getUpdateInfo(), new UpdateInfo(UPDATE_TYPE, "")); assertEquals(queryInfo.getMemoryPool(), MEMORY_POOL.getId()); assertEquals(queryInfo.getQueryType(), QUERY_TYPE); @@ -644,7 +645,7 @@ private QueryStateMachine createQueryStateMachineWithTicker(Ticker ticker, Trans stateMachine.setInputs(INPUTS); stateMachine.setOutput(OUTPUT); stateMachine.setColumns(OUTPUT_FIELD_NAMES, OUTPUT_FIELD_TYPES); - stateMachine.setUpdateType(UPDATE_TYPE); + stateMachine.setUpdateInfo(new UpdateInfo(UPDATE_TYPE, "")); stateMachine.setMemoryPool(MEMORY_POOL); for (Entry entry : SET_SESSION_PROPERTIES.entrySet()) { stateMachine.addSetSessionProperties(entry.getKey(), entry.getValue()); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStats.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStats.java index 340f8737fe78d..1df4875ac6492 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryStats.java @@ -16,6 +16,8 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.stats.Distribution; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeMetric; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.operator.DynamicFilterStats; @@ -36,8 +38,6 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.Test; @@ -45,8 +45,8 @@ import java.util.List; import java.util.Optional; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.RuntimeUnit.NONE; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.joda.time.DateTimeZone.UTC; import static org.testng.Assert.assertEquals; @@ -236,6 +236,16 @@ public class TestQueryStats 30, 16, + 12, + 13, + 15, + 16, + + 12, + 13, + 15, + 16, + 17.0, 43.0, new DataSize(18, BYTE), @@ -421,6 +431,192 @@ public void testInputAndOutputStatsCalculation() assertEquals(queryStats.getOutputPositions(), 100); } + @Test + public void testPrestoSparkRemoteSourceOperatorShuffleTracking() + { + // Test that PrestoSparkRemoteSourceOperator is correctly tracked as shuffle operator + // This test simulates a Sapphire Java query with PrestoSparkRemoteSourceOperator + PlanFragment testPlanFragment = TaskTestUtils.createPlanFragment(); + + // build stage with PrestoSparkRemoteSourceOperator + int stageId = 0; + int stageExecutionId = 1; + List operatorStats = ImmutableList.of( + createOperatorStatsWithType(stageId, stageExecutionId, 0, 0, new PlanNodeId("1"), + "PrestoSparkRemoteSourceOperator", + 1000L, 50L, + 1000L, 50L, + 1000L, 50L), + createOperatorStatsWithType(stageId, stageExecutionId, 0, 1, new PlanNodeId("2"), + TaskOutputOperator.class.getSimpleName(), + 0L, 0L, + 1000L, 50L, + 1000L, 50L)); + + StageExecutionStats stageExecutionStats = createStageStats(stageId, stageExecutionId, + 0L, 0L, + 1000L, 50L, + 1000L, 50L, + operatorStats); + + StageExecutionInfo stageExecutionInfo = new StageExecutionInfo( + StageExecutionState.FINISHED, + stageExecutionStats, + ImmutableList.of(), + Optional.empty()); + + StageInfo stageInfo = new StageInfo(StageId.valueOf("0.0"), URI.create("127.0.0.1"), + Optional.of(testPlanFragment), + stageExecutionInfo, ImmutableList.of(), ImmutableList.of(), false); + + // calculate query stats + Optional rootStage = Optional.of(stageInfo); + List allStages = StageInfo.getAllStages(rootStage); + QueryStats queryStats = QueryStats.create(new QueryStateTimer(new TestingTicker()), rootStage, allStages, 0, + 0L, 0L, 0L, 0L, 0L, + new RuntimeStats()); + + // verify that PrestoSparkRemoteSourceOperator data is counted as shuffled data + assertEquals(queryStats.getShuffledDataSize().toBytes(), 1000); + assertEquals(queryStats.getShuffledPositions(), 50); + // verify that raw input (table scan) is not counted since we have no table scan operators + assertEquals(queryStats.getRawInputDataSize().toBytes(), 0); + assertEquals(queryStats.getRawInputPositions(), 0); + } + + @Test + public void testShuffleReadOperatorShuffleTracking() + { + // Test that ShuffleRead is correctly tracked as shuffle operator + // This test simulates a Sapphire Velox query with ShuffleRead + PlanFragment testPlanFragment = TaskTestUtils.createPlanFragment(); + + // build stage with ShuffleRead + int stageId = 0; + int stageExecutionId = 1; + List operatorStats = ImmutableList.of( + createOperatorStatsWithType(stageId, stageExecutionId, 0, 0, new PlanNodeId("1"), + "ShuffleRead", + 2000L, 100L, + 2000L, 100L, + 2000L, 100L), + createOperatorStatsWithType(stageId, stageExecutionId, 0, 1, new PlanNodeId("2"), + TaskOutputOperator.class.getSimpleName(), + 0L, 0L, + 2000L, 100L, + 2000L, 100L)); + + StageExecutionStats stageExecutionStats = createStageStats(stageId, stageExecutionId, + 0L, 0L, + 2000L, 100L, + 2000L, 100L, + operatorStats); + + StageExecutionInfo stageExecutionInfo = new StageExecutionInfo( + StageExecutionState.FINISHED, + stageExecutionStats, + ImmutableList.of(), + Optional.empty()); + + StageInfo stageInfo = new StageInfo(StageId.valueOf("0.0"), URI.create("127.0.0.1"), + Optional.of(testPlanFragment), + stageExecutionInfo, ImmutableList.of(), ImmutableList.of(), false); + + // calculate query stats + Optional rootStage = Optional.of(stageInfo); + List allStages = StageInfo.getAllStages(rootStage); + QueryStats queryStats = QueryStats.create(new QueryStateTimer(new TestingTicker()), rootStage, allStages, 0, + 0L, 0L, 0L, 0L, 0L, + new RuntimeStats()); + + // verify that ShuffleRead data is counted as shuffled data + assertEquals(queryStats.getShuffledDataSize().toBytes(), 2000); + assertEquals(queryStats.getShuffledPositions(), 100); + // verify that raw input (table scan) is not counted since we have no table scan operators + assertEquals(queryStats.getRawInputDataSize().toBytes(), 0); + assertEquals(queryStats.getRawInputPositions(), 0); + } + + @Test + public void testMixedShuffleOperators() + { + // Test that ExchangeOperator, PrestoSparkRemoteSourceOperator and CoscoShuffleRead are all counted as shuffle + PlanFragment testPlanFragment = TaskTestUtils.createPlanFragment(); + + // Stage 0: ExchangeOperator + int stageId0 = 0; + int stageExecutionId0 = 1; + List operatorStats0 = ImmutableList.of( + createOperatorStats(stageId0, stageExecutionId0, 0, 0, new PlanNodeId("1"), + ExchangeOperator.class, + 500L, 25L, + 500L, 25L, + 500L, 25L), + createOperatorStats(stageId0, stageExecutionId0, 0, 1, new PlanNodeId("2"), + TaskOutputOperator.class, + 0L, 0L, + 500L, 25L, + 500L, 25L)); + + StageExecutionStats stageExecutionStats0 = createStageStats(stageId0, stageExecutionId0, + 0L, 0L, + 500L, 25L, + 500L, 25L, + operatorStats0); + + StageExecutionInfo stageExecutionInfo0 = new StageExecutionInfo( + StageExecutionState.FINISHED, + stageExecutionStats0, + ImmutableList.of(), + Optional.empty()); + + // Stage 1: PrestoSparkRemoteSourceOperator + int stageId1 = 1; + int stageExecutionId1 = 11; + List operatorStats1 = ImmutableList.of( + createOperatorStatsWithType(stageId1, stageExecutionId1, 0, 0, new PlanNodeId("101"), + "PrestoSparkRemoteSourceOperator", + 1500L, 75L, + 1500L, 75L, + 1500L, 75L), + createOperatorStats(stageId1, stageExecutionId1, 0, 1, new PlanNodeId("102"), + TaskOutputOperator.class, + 0L, 0L, + 1500L, 75L, + 1500L, 75L)); + + StageExecutionStats stageExecutionStats1 = createStageStats(stageId1, stageExecutionId1, + 0L, 0L, + 1500L, 75L, + 1500L, 75L, + operatorStats1); + + StageExecutionInfo stageExecutionInfo1 = new StageExecutionInfo( + StageExecutionState.FINISHED, + stageExecutionStats1, + ImmutableList.of(), + Optional.empty()); + + // Build stage hierarchy + StageInfo stageInfo1 = new StageInfo(StageId.valueOf("0.1"), URI.create("127.0.0.1"), + Optional.of(testPlanFragment), + stageExecutionInfo1, ImmutableList.of(), ImmutableList.of(), false); + StageInfo stageInfo0 = new StageInfo(StageId.valueOf("0.0"), URI.create("127.0.0.1"), + Optional.of(testPlanFragment), + stageExecutionInfo0, ImmutableList.of(), ImmutableList.of(stageInfo1), false); + + // calculate query stats + Optional rootStage = Optional.of(stageInfo0); + List allStages = StageInfo.getAllStages(rootStage); + QueryStats queryStats = QueryStats.create(new QueryStateTimer(new TestingTicker()), rootStage, allStages, 0, + 0L, 0L, 0L, 0L, 0L, + new RuntimeStats()); + + // verify that both operators' data is counted as shuffled + assertEquals(queryStats.getShuffledDataSize().toBytes(), 2000); // 500 + 1500 + assertEquals(queryStats.getShuffledPositions(), 100); // 25 + 75 + } + @Test public void testJson() { @@ -525,6 +721,19 @@ private static OperatorStats createOperatorStats(int stageId, int stageExecution long rawInputDataSize, long rawInputPositions, long inputDataSize, long inputPositions, long outputDataSize, long outputPositions) + { + return createOperatorStatsWithType(stageId, stageExecutionId, pipelineId, operatorId, planNodeId, + operatorCls.getSimpleName(), + rawInputDataSize, rawInputPositions, + inputDataSize, inputPositions, + outputDataSize, outputPositions); + } + + private static OperatorStats createOperatorStatsWithType(int stageId, int stageExecutionId, int pipelineId, + int operatorId, PlanNodeId planNodeId, String operatorType, + long rawInputDataSize, long rawInputPositions, + long inputDataSize, long inputPositions, + long outputDataSize, long outputPositions) { return new OperatorStats( stageId, @@ -532,7 +741,7 @@ private static OperatorStats createOperatorStats(int stageId, int stageExecution pipelineId, operatorId, planNodeId, - operatorCls.getSimpleName(), + operatorType, 0L, 0L, new Duration(0, NANOSECONDS), @@ -594,6 +803,14 @@ private static StageExecutionStats createStageStats(int stageId, int stageExecut 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0, 0, 0, diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryTrackerQueuedTime.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryTrackerQueuedTime.java new file mode 100644 index 0000000000000..22ebe5256436b --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestQueryTrackerQueuedTime.java @@ -0,0 +1,353 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.airlift.units.Duration; +import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.execution.QueryTracker.TrackedQuery; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.resourceGroups.ResourceGroupQueryLimits; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.airlift.units.Duration.succinctDuration; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TIME_LIMIT; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestQueryTrackerQueuedTime +{ + private ScheduledExecutorService executor; + private QueryTracker queryTracker; + + @BeforeMethod + public void setUp() + { + executor = newSingleThreadScheduledExecutor(); + QueryManagerConfig config = new QueryManagerConfig(); + queryTracker = new QueryTracker<>(config, executor, Optional.empty()); + queryTracker.start(); + } + + @AfterMethod + public void tearDown() + { + if (queryTracker != null) { + queryTracker.stop(); + } + if (executor != null) { + executor.shutdownNow(); + } + } + + @Test + public void testQueryExceedsQueuedTimeLimit() + throws Exception + { + // Create a session with 1 second queued time limit + Session session = Session.builder(TEST_SESSION) + .setSystemProperty(SystemSessionProperties.QUERY_MAX_QUEUED_TIME, "1s") + .build(); + + AtomicReference failureException = new AtomicReference<>(); + AtomicBoolean queryFailed = new AtomicBoolean(false); + + // Create a mock query that has been queued for 2 seconds (exceeds limit) + long currentTime = System.currentTimeMillis(); + MockTrackedQuery query = new MockTrackedQuery( + new QueryId("test_query_1"), + session, + currentTime - 2000, // Created 2 seconds ago + 0, // Not started execution yet + currentTime, + failureException, + queryFailed); + + queryTracker.addQuery(query); + + // Manually trigger time limit enforcement + queryTracker.enforceTimeLimits(); + + // Verify the query was failed due to exceeding queued time limit + assertTrue(queryFailed.get(), "Query should have been failed"); + assertNotNull(failureException.get(), "Failure exception should be set"); + assertEquals(failureException.get().getErrorCode(), EXCEEDED_TIME_LIMIT.toErrorCode()); + assertTrue(failureException.get().getMessage().contains("Query exceeded maximum queued time limit")); + } + + @Test + public void testQueryWithinQueuedTimeLimit() + throws Exception + { + // Create a session with 5 second queued time limit + Session session = Session.builder(TEST_SESSION) + .setSystemProperty(SystemSessionProperties.QUERY_MAX_QUEUED_TIME, "5s") + .build(); + + AtomicReference failureException = new AtomicReference<>(); + AtomicBoolean queryFailed = new AtomicBoolean(false); + + // Create a mock query that has been queued for 1 second (within limit) + long currentTime = System.currentTimeMillis(); + MockTrackedQuery query = new MockTrackedQuery( + new QueryId("test_query_2"), + session, + currentTime - 1000, // Created 1 second ago + 0, // Not started execution yet + currentTime, + failureException, + queryFailed); + + queryTracker.addQuery(query); + + // Manually trigger time limit enforcement + queryTracker.enforceTimeLimits(); + + // Verify the query was not failed + assertFalse(queryFailed.get(), "Query should not have been failed"); + } + + @Test + public void testQueryStartedExecutionQueuedTimeCalculation() + throws Exception + { + // Create a session with 1 second queued time limit + Session session = Session.builder(TEST_SESSION) + .setSystemProperty(SystemSessionProperties.QUERY_MAX_QUEUED_TIME, "1s") + .build(); + + AtomicReference failureException = new AtomicReference<>(); + AtomicBoolean queryFailed = new AtomicBoolean(false); + + // Create a mock query that was queued for 2 seconds but started execution + long currentTime = System.currentTimeMillis(); + MockTrackedQuery query = new MockTrackedQuery( + new QueryId("test_query_3"), + session, + currentTime - 3000, // Created 3 seconds ago + currentTime - 1000, // Started execution 1 second ago (queued for 2 seconds) + currentTime, + failureException, + queryFailed); + + queryTracker.addQuery(query); + + // Manually trigger time limit enforcement + queryTracker.enforceTimeLimits(); + + // Verify the query was failed because it was queued for 2 seconds (exceeds 1s limit) + assertTrue(queryFailed.get(), "Query should have been failed"); + assertNotNull(failureException.get(), "Failure exception should be set"); + assertEquals(failureException.get().getErrorCode(), EXCEEDED_TIME_LIMIT.toErrorCode()); + assertTrue(failureException.get().getMessage().contains("Query exceeded maximum queued time limit")); + } + + @Test + public void testQueryStartedExecutionWithinQueuedTimeLimit() + throws Exception + { + // Create a session with 5 second queued time limit + Session session = Session.builder(TEST_SESSION) + .setSystemProperty(SystemSessionProperties.QUERY_MAX_QUEUED_TIME, "5s") + .build(); + + AtomicReference failureException = new AtomicReference<>(); + AtomicBoolean queryFailed = new AtomicBoolean(false); + + // Create a mock query that was queued for 1 second and started execution + long currentTime = System.currentTimeMillis(); + MockTrackedQuery query = new MockTrackedQuery( + new QueryId("test_query_4"), + session, + currentTime - 2000, // Created 2 seconds ago + currentTime - 1000, // Started execution 1 second ago (queued for 1 second) + currentTime, + failureException, + queryFailed); + + queryTracker.addQuery(query); + + // Manually trigger time limit enforcement + queryTracker.enforceTimeLimits(); + + // Verify the query was not failed + assertFalse(queryFailed.get(), "Query should not have been failed"); + } + + @Test + public void testCompletedQueryNotChecked() + throws Exception + { + // Create a session with 1 second queued time limit + Session session = Session.builder(TEST_SESSION) + .setSystemProperty(SystemSessionProperties.QUERY_MAX_QUEUED_TIME, "1s") + .build(); + + AtomicReference failureException = new AtomicReference<>(); + AtomicBoolean queryFailed = new AtomicBoolean(false); + + // Create a mock query that is already completed + long currentTime = System.currentTimeMillis(); + MockTrackedQuery query = new MockTrackedQuery( + new QueryId("test_query_5"), + session, + currentTime - 5000, // Created 5 seconds ago + 0, // Not started execution yet + currentTime, + failureException, + queryFailed); + query.setDone(true); // Mark as completed + + queryTracker.addQuery(query); + + // Manually trigger time limit enforcement + queryTracker.enforceTimeLimits(); + + // Verify the completed query was not failed + assertFalse(queryFailed.get(), "Completed query should not be checked for time limits"); + } + + private static class MockTrackedQuery + implements TrackedQuery + { + private final QueryId queryId; + private final Session session; + private final long createTimeInMillis; + private final long executionStartTimeInMillis; + private final long lastHeartbeatInMillis; + private final AtomicReference failureException; + private final AtomicBoolean queryFailed; + private boolean done; + + public MockTrackedQuery( + QueryId queryId, + Session session, + long createTimeInMillis, + long executionStartTimeInMillis, + long lastHeartbeatInMillis, + AtomicReference failureException, + AtomicBoolean queryFailed) + { + this.queryId = queryId; + this.session = session; + this.createTimeInMillis = createTimeInMillis; + this.executionStartTimeInMillis = executionStartTimeInMillis; + this.lastHeartbeatInMillis = lastHeartbeatInMillis; + this.failureException = failureException; + this.queryFailed = queryFailed; + } + + public void setDone(boolean done) + { + this.done = done; + } + + @Override + public QueryId getQueryId() + { + return queryId; + } + + @Override + public boolean isDone() + { + return done; + } + + @Override + public Session getSession() + { + return session; + } + + @Override + public long getCreateTimeInMillis() + { + return createTimeInMillis; + } + + @Override + public Duration getQueuedTime() + { + long queuedTimeInMillis; + if (executionStartTimeInMillis > 0) { + queuedTimeInMillis = executionStartTimeInMillis - createTimeInMillis; + } + else { + queuedTimeInMillis = System.currentTimeMillis() - createTimeInMillis; + } + return succinctDuration(queuedTimeInMillis, MILLISECONDS); + } + + @Override + public long getExecutionStartTimeInMillis() + { + return executionStartTimeInMillis; + } + + @Override + public long getLastHeartbeatInMillis() + { + return lastHeartbeatInMillis; + } + + @Override + public long getEndTimeInMillis() + { + return done ? System.currentTimeMillis() : 0; + } + + @Override + public Optional getResourceGroupQueryLimits() + { + return Optional.empty(); + } + + @Override + public void fail(Throwable cause) + { + if (cause instanceof PrestoException) { + failureException.set((PrestoException) cause); + } + queryFailed.set(true); + done = true; + } + + @Override + public void pruneExpiredQueryInfo() + { + // No-op for test + } + + @Override + public void pruneFinishedQueryInfo() + { + // No-op for test + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSplitConcurrencyController.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSplitConcurrencyController.java index 1b22cc5069f03..a9707353c9714 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSplitConcurrencyController.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSplitConcurrencyController.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.execution; -import io.airlift.units.Duration; +import com.facebook.airlift.units.Duration; import org.testng.annotations.Test; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java index bce1acccd12f3..cac4f07885d74 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlStageExecution.java @@ -174,6 +174,7 @@ private static PlanFragment createExchangePlanFragment() SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskExecution.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskExecution.java index 2758637601d44..23cf1359a8cbc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskExecution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskExecution.java @@ -14,6 +14,8 @@ package com.facebook.presto.execution; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.CompressionCodec; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockEncodingManager; @@ -64,8 +66,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -88,6 +88,8 @@ import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createStringSequenceBlock; import static com.facebook.presto.block.BlockAssertions.createStringsBlock; @@ -108,8 +110,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.HOURS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskManager.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskManager.java index 311c7081a76ad..8787dcddd4986 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestSqlTaskManager.java @@ -15,6 +15,9 @@ import com.facebook.airlift.node.NodeInfo; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.execution.buffer.BufferInfo; import com.facebook.presto.execution.buffer.BufferResult; @@ -43,9 +46,6 @@ import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestStageExecutionStats.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestStageExecutionStats.java index fd477c0a3ea9d..da8bbc14ed5ba 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestStageExecutionStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestStageExecutionStats.java @@ -16,11 +16,11 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.stats.Distribution; import com.facebook.airlift.stats.Distribution.DistributionSnapshot; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.spi.eventlistener.StageGcStatistics; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.Test; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -45,6 +45,14 @@ public class TestStageExecutionStats 10, 26, 11, + 7, + 8, + 10, + 11, + 7, + 8, + 10, + 11, 12.0, 27.0, diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestStartTransactionTask.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestStartTransactionTask.java index 1acc47d1b8a9f..c68b2719c872d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestStartTransactionTask.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestStartTransactionTask.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.Session.SessionBuilder; import com.facebook.presto.common.transaction.TransactionId; @@ -30,7 +31,6 @@ import com.facebook.presto.transaction.TransactionManager; import com.facebook.presto.transaction.TransactionManagerConfig; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java index 0fce5aab8aa1a..45b64ef7dc345 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestTaskManagerConfig.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.memory.HighMemoryTaskKillerStrategy; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.math.BigDecimal; @@ -26,9 +26,9 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static com.facebook.airlift.units.DataSize.Unit; import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.QUERY_FAIR; import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.TASK_FAIR; -import static io.airlift.units.DataSize.Unit; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -52,8 +52,8 @@ public void testDefaults() .setMinDriversPerTask(3) .setMaxDriversPerTask(Integer.MAX_VALUE) .setMaxTasksPerStage(Integer.MAX_VALUE) - .setInfoMaxAge(new Duration(15, TimeUnit.MINUTES)) - .setClientTimeout(new Duration(2, TimeUnit.MINUTES)) + .setInfoMaxAge(new Duration(15, MINUTES)) + .setClientTimeout(new Duration(2, MINUTES)) .setMaxIndexMemoryUsage(new DataSize(64, Unit.MEGABYTE)) .setShareIndexLoading(false) .setMaxPartialAggregationMemoryUsage(new DataSize(16, Unit.MEGABYTE)) @@ -80,8 +80,7 @@ public void testDefaults() .setHighMemoryTaskKillerFrequentFullGCDurationThreshold(new Duration(1, SECONDS)) .setHighMemoryTaskKillerHeapMemoryThreshold(0.9) .setTaskUpdateSizeTrackingEnabled(true) - .setSlowMethodThresholdOnEventLoop(new Duration(0, SECONDS)) - .setEventLoopEnabled(false)); + .setSlowMethodThresholdOnEventLoop(new Duration(0, SECONDS))); } @Test @@ -130,7 +129,6 @@ public void testExplicitPropertyMappings() .put("experimental.task.high-memory-task-killer-frequent-full-gc-duration-threshold", "2s") .put("experimental.task.high-memory-task-killer-heap-memory-threshold", "0.8") .put("task.update-size-tracking-enabled", "false") - .put("task.enable-event-loop", "true") .put("task.event-loop-slow-method-threshold", "10m") .build(); @@ -153,7 +151,7 @@ public void testExplicitPropertyMappings() .setMinDriversPerTask(5) .setMaxDriversPerTask(13) .setMaxTasksPerStage(999) - .setInfoMaxAge(new Duration(22, TimeUnit.MINUTES)) + .setInfoMaxAge(new Duration(22, MINUTES)) .setClientTimeout(new Duration(10, SECONDS)) .setSinkMaxBufferSize(new DataSize(42, Unit.MEGABYTE)) .setMaxPagePartitioningBufferSize(new DataSize(40, Unit.MEGABYTE)) @@ -177,7 +175,6 @@ public void testExplicitPropertyMappings() .setHighMemoryTaskKillerFrequentFullGCDurationThreshold(new Duration(2, SECONDS)) .setHighMemoryTaskKillerHeapMemoryThreshold(0.8) .setTaskUpdateSizeTrackingEnabled(false) - .setEventLoopEnabled(true) .setSlowMethodThresholdOnEventLoop(new Duration(10, MINUTES)); assertFullMapping(properties, expected); diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/TestThriftResourceGroupInfo.java b/presto-main-base/src/test/java/com/facebook/presto/execution/TestThriftResourceGroupInfo.java index ee18213a7501e..43e8610923c5a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/TestThriftResourceGroupInfo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/TestThriftResourceGroupInfo.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.codec.ThriftCodecManager; import com.facebook.drift.codec.internal.compiler.CompilerThriftCodecFactory; @@ -37,7 +38,6 @@ import com.facebook.presto.spi.resourceGroups.SchedulingPolicy; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; import org.joda.time.DateTime; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -343,6 +343,12 @@ private void setUpQueryProgressStats() FAKE_PROGRESS_PERCENTAGE_1, FAKE_QUEUED_DRIVERS, FAKE_RUNNING_DRIVERS, + FAKE_COMPLETED_DRIVERS, + FAKE_QUEUED_DRIVERS, + FAKE_RUNNING_DRIVERS, + FAKE_COMPLETED_DRIVERS, + FAKE_QUEUED_DRIVERS, + FAKE_RUNNING_DRIVERS, FAKE_COMPLETED_DRIVERS)); queryProgressStats.add(new QueryProgressStats( FAKE_ELAPSED_TIME_MILLIS, @@ -363,6 +369,12 @@ private void setUpQueryProgressStats() FAKE_PROGRESS_PERCENTAGE_2, FAKE_QUEUED_DRIVERS, FAKE_RUNNING_DRIVERS, + FAKE_COMPLETED_DRIVERS, + FAKE_QUEUED_DRIVERS, + FAKE_RUNNING_DRIVERS, + FAKE_COMPLETED_DRIVERS, + FAKE_QUEUED_DRIVERS, + FAKE_RUNNING_DRIVERS, FAKE_COMPLETED_DRIVERS)); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/BufferTestUtils.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/BufferTestUtils.java index 03801a8f21167..3ba233e902182 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/BufferTestUtils.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/BufferTestUtils.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; @@ -23,8 +25,6 @@ import com.facebook.presto.spi.page.SerializedPage; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import java.util.List; import java.util.Optional; @@ -32,10 +32,10 @@ import java.util.stream.Collectors; import static com.facebook.airlift.concurrent.MoreFutures.tryGetFutureValue; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestArbitraryOutputBuffer.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestArbitraryOutputBuffer.java index ee0c641cf2b5a..cde1c631c69a9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestArbitraryOutputBuffer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestArbitraryOutputBuffer.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.execution.Lifespan; @@ -21,8 +23,6 @@ import com.facebook.presto.memory.context.SimpleLocalMemoryContext; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -34,6 +34,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.execution.buffer.BufferResult.emptyResults; import static com.facebook.presto.execution.buffer.BufferState.OPEN; @@ -53,7 +54,6 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestBroadcastOutputBuffer.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestBroadcastOutputBuffer.java index 3fcd01d68907f..aaf64f5cd1dab 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestBroadcastOutputBuffer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestBroadcastOutputBuffer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.execution.StateMachine; @@ -24,7 +25,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -35,6 +35,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.execution.buffer.BufferResult.emptyResults; import static com.facebook.presto.execution.buffer.BufferState.OPEN; @@ -61,7 +62,6 @@ import static com.facebook.presto.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestClientBuffer.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestClientBuffer.java index 393d81be898e5..99aea6a2ab2f3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestClientBuffer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestClientBuffer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.execution.Lifespan; @@ -21,12 +22,10 @@ import com.facebook.presto.execution.buffer.SerializedPageReference.PagesReleasedListener; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.testng.annotations.Test; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; - import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestPartitionedOutputBuffer.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestPartitionedOutputBuffer.java index e15bc4a96b483..85f21b214b1e3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestPartitionedOutputBuffer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestPartitionedOutputBuffer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.execution.StateMachine; @@ -20,7 +21,6 @@ import com.facebook.presto.memory.context.SimpleLocalMemoryContext; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -30,6 +30,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.execution.buffer.BufferResult.emptyResults; import static com.facebook.presto.execution.buffer.BufferState.OPEN; @@ -53,7 +54,6 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestSpoolingOutputBuffer.java b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestSpoolingOutputBuffer.java index e12a48e9bda17..50ab99ba79adb 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestSpoolingOutputBuffer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/buffer/TestSpoolingOutputBuffer.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.execution.buffer; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.execution.QueryIdGenerator; @@ -23,7 +24,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -35,6 +35,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.execution.buffer.BufferResult.emptyResults; import static com.facebook.presto.execution.buffer.BufferState.OPEN; @@ -52,7 +53,6 @@ import static com.facebook.presto.execution.buffer.BufferTestUtils.sizeOfPages; import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.SPOOLING; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java b/presto-main-base/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java index 0a0e86da43d70..d5bbf86036e33 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/executor/TestTaskExecutor.java @@ -14,6 +14,7 @@ package com.facebook.presto.execution.executor; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.SplitRunner; import com.facebook.presto.execution.TaskId; import com.facebook.presto.server.ServerConfig; @@ -23,7 +24,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Arrays; @@ -262,7 +262,8 @@ public void testLevelMultipliers() int phasesForNextLevel = LEVEL_THRESHOLD_SECONDS[i + 1] - LEVEL_THRESHOLD_SECONDS[i]; TestingJob[] drivers = new TestingJob[6]; for (int j = 0; j < 6; j++) { - drivers[j] = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), phasesForNextLevel, 1000); + // shouldn't deregister the global phaser upon the completion of process + drivers[j] = new TestingJob(ticker, globalPhaser, new Phaser(), new Phaser(), phasesForNextLevel, 1000, false); } taskExecutor.enqueueSplits(taskHandles[0], true, ImmutableList.of(drivers[0], drivers[1])); @@ -525,6 +526,7 @@ private static class TestingJob private final Phaser endQuantaPhaser; private final int requiredPhases; private final int quantaTimeMillis; + private final boolean deregisterGlobalPhaser; private final AtomicInteger completedPhases = new AtomicInteger(); private final AtomicInteger firstPhase = new AtomicInteger(-1); @@ -534,6 +536,11 @@ private static class TestingJob private final SettableFuture completed = SettableFuture.create(); public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaPhaser, Phaser endQuantaPhaser, int requiredPhases, int quantaTimeMillis) + { + this(ticker, globalPhaser, beginQuantaPhaser, endQuantaPhaser, requiredPhases, quantaTimeMillis, true); + } + + public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaPhaser, Phaser endQuantaPhaser, int requiredPhases, int quantaTimeMillis, boolean deregisterGlobalPhaser) { this.ticker = ticker; this.globalPhaser = globalPhaser; @@ -541,6 +548,7 @@ public TestingJob(TestingTicker ticker, Phaser globalPhaser, Phaser beginQuantaP this.endQuantaPhaser = endQuantaPhaser; this.requiredPhases = requiredPhases; this.quantaTimeMillis = quantaTimeMillis; + this.deregisterGlobalPhaser = deregisterGlobalPhaser; beginQuantaPhaser.register(); endQuantaPhaser.register(); @@ -578,7 +586,9 @@ public ListenableFuture processFor(Duration duration) if (completedPhases.incrementAndGet() >= requiredPhases) { endQuantaPhaser.arriveAndDeregister(); beginQuantaPhaser.arriveAndDeregister(); - globalPhaser.arriveAndDeregister(); + if (deregisterGlobalPhaser) { + globalPhaser.arriveAndDeregister(); + } completed.set(null); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/BenchmarkResourceGroup.java b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/BenchmarkResourceGroup.java index ac60da6966f7c..7fba90f5fdc63 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/BenchmarkResourceGroup.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/BenchmarkResourceGroup.java @@ -13,10 +13,14 @@ */ package com.facebook.presto.execution.resourceGroups; +import com.facebook.airlift.units.DataSize; +import com.facebook.presto.execution.ClusterOverloadConfig; import com.facebook.presto.execution.MockManagedQueryExecution; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterOverloadPolicy; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.metadata.InMemoryNodeManager; -import io.airlift.units.DataSize; +import com.facebook.presto.metadata.InternalNodeManager; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -39,7 +43,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; @SuppressWarnings("MethodMayBeStatic") @State(Scope.Thread) @@ -73,7 +77,7 @@ public static class BenchmarkData @Setup public void setup() { - root = new RootInternalResourceGroup("root", (group, export) -> {}, executor, ignored -> Optional.empty(), rg -> false, new InMemoryNodeManager()); + root = new RootInternalResourceGroup("root", (group, export) -> {}, executor, ignored -> Optional.empty(), rg -> false, new InMemoryNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(queries); root.setHardConcurrencyLimit(queries); @@ -89,6 +93,31 @@ public void setup() } } + private ClusterResourceChecker createClusterResourceChecker() + { + // Create a mock cluster overload policy that never reports overload + ClusterOverloadPolicy mockPolicy = new ClusterOverloadPolicy() + { + @Override + public boolean isClusterOverloaded(InternalNodeManager nodeManager) + { + return false; // Never overloaded for benchmarks + } + + @Override + public String getName() + { + return "benchmark-policy"; + } + }; + + // Create a config with throttling disabled for benchmarks + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setClusterOverloadThrottlingEnabled(false); + + return new ClusterResourceChecker(mockPolicy, config, new InMemoryNodeManager()); + } + @TearDown public void tearDown() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestInternalResourceGroupManager.java b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestInternalResourceGroupManager.java index bc4cf99103531..b54e1984fcec2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestInternalResourceGroupManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestInternalResourceGroupManager.java @@ -15,8 +15,11 @@ package com.facebook.presto.execution.resourceGroups; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.execution.ClusterOverloadConfig; import com.facebook.presto.execution.MockManagedQueryExecution; import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; +import com.facebook.presto.execution.scheduler.clusterOverload.CpuMemoryOverloadPolicy; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.server.ServerConfig; import com.facebook.presto.spi.PrestoException; @@ -27,13 +30,17 @@ import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.testing.TestingMBeanServer; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + public class TestInternalResourceGroupManager { @Test(expectedExceptions = PrestoException.class, expectedExceptionsMessageRegExp = ".*Presto server is still initializing.*") public void testQueryFailsWithInitializingConfigurationManager() { - InternalResourceGroupManager> internalResourceGroupManager = new InternalResourceGroupManager<>((poolId, listener) -> {}, - new QueryManagerConfig(), new NodeInfo("test"), new MBeanExporter(new TestingMBeanServer()), () -> null, new ServerConfig(), new InMemoryNodeManager()); + InternalResourceGroupManager> internalResourceGroupManager = new InternalResourceGroupManager<>((poolId, listener) -> {}, new QueryManagerConfig(), new NodeInfo("test"), new MBeanExporter(new TestingMBeanServer()), () -> null, new ServerConfig(), new InMemoryNodeManager(), new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); internalResourceGroupManager.submit(new MockManagedQueryExecution(0), new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), command -> {}); } @@ -42,8 +49,222 @@ public void testQuerySucceedsWhenConfigurationManagerLoaded() throws Exception { InternalResourceGroupManager> internalResourceGroupManager = new InternalResourceGroupManager<>((poolId, listener) -> {}, - new QueryManagerConfig(), new NodeInfo("test"), new MBeanExporter(new TestingMBeanServer()), () -> null, new ServerConfig(), new InMemoryNodeManager()); + new QueryManagerConfig(), new NodeInfo("test"), new MBeanExporter(new TestingMBeanServer()), () -> null, new ServerConfig(), new InMemoryNodeManager(), new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); internalResourceGroupManager.loadConfigurationManager(); internalResourceGroupManager.submit(new MockManagedQueryExecution(0), new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), command -> {}); } + + // Tests that admission always succeeds when pacing is disabled (default config) + @Test + public void testAdmissionPacingUnlimited() + { + // When maxQueryAdmissionsPerSecond is Integer.MAX_VALUE (default), admission should always succeed + QueryManagerConfig config = new QueryManagerConfig(); + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + // Multiple consecutive calls should all succeed + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + } + + // Tests that admission respects 1 query/second rate limit + @Test + public void testAdmissionPacingOnePerSecond() + throws InterruptedException + { + // When maxQueryAdmissionsPerSecond is 1, verify admission succeeds after waiting + QueryManagerConfig config = new QueryManagerConfig().setMaxQueryAdmissionsPerSecond(1); + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + // First admission should succeed + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Wait for 1 second (required interval) and verify next admission succeeds + Thread.sleep(1100); + assertTrue(manager.tryAcquireAdmissionSlot()); + } + + // Tests that admission respects 10 queries/second rate limit + @Test + public void testAdmissionPacingMultiplePerSecond() + throws InterruptedException + { + // When maxQueryAdmissionsPerSecond is 10, verify admission succeeds after waiting appropriate interval + QueryManagerConfig config = new QueryManagerConfig().setMaxQueryAdmissionsPerSecond(10); + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + // First admission should succeed + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Wait for 150ms (more than the 100ms interval required for 10 queries/sec) and verify next admission succeeds + Thread.sleep(150); + assertTrue(manager.tryAcquireAdmissionSlot()); + } + + // Tests that pacing is bypassed when running queries are below threshold + @Test + public void testAdmissionPacingBypassedBelowRunningQueryThreshold() + throws Exception + { + // Configure pacing with a threshold of 5 running queries + // When running queries are below threshold, pacing should be bypassed + QueryManagerConfig config = new QueryManagerConfig() + .setMaxQueryAdmissionsPerSecond(1) // Very slow pacing: 1 per second + .setMinRunningQueriesForPacing(5); // Threshold of 5 running queries + + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + manager.loadConfigurationManager(); + + // Create a resource group with some running queries (but below threshold) + MockManagedQueryExecution query1 = new MockManagedQueryExecution(0); + MockManagedQueryExecution query2 = new MockManagedQueryExecution(0); + manager.submit(query1, new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), directExecutor()); + manager.submit(query2, new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), directExecutor()); + + // With only 2 running queries (below threshold of 5), pacing should be bypassed + // Multiple rapid admissions should all succeed without waiting + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Verify metrics are NOT tracked when pacing is bypassed + assertEquals(manager.getTotalAdmissionAttempts(), 0); + assertEquals(manager.getTotalAdmissionsGranted(), 0); + assertEquals(manager.getTotalAdmissionsDenied(), 0); + } + + // Tests that pacing is enforced when running queries exceed threshold + @Test + public void testAdmissionPacingAppliedAboveRunningQueryThreshold() + throws Exception + { + // Configure pacing with a threshold of 2 running queries + QueryManagerConfig config = new QueryManagerConfig() + .setMaxQueryAdmissionsPerSecond(1) // 1 per second + .setMinRunningQueriesForPacing(2); // Threshold of 2 running queries + + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + manager.loadConfigurationManager(); + + // Create resource groups with enough running queries to exceed threshold + MockManagedQueryExecution query1 = new MockManagedQueryExecution(0); + MockManagedQueryExecution query2 = new MockManagedQueryExecution(0); + MockManagedQueryExecution query3 = new MockManagedQueryExecution(0); + manager.submit(query1, new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), directExecutor()); + manager.submit(query2, new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), directExecutor()); + manager.submit(query3, new SelectionContext<>(new ResourceGroupId("global"), ImmutableMap.of()), directExecutor()); + + // Wait for rate limit window to expire after query submissions (which internally call tryAcquireAdmissionSlot) + Thread.sleep(1100); + + // With 3 running queries (above threshold of 2), pacing should be applied + // First admission should succeed + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Immediate second attempt should be denied (need to wait 1 second) + assertFalse(manager.tryAcquireAdmissionSlot()); + + // Verify metrics ARE tracked when pacing is applied + // Note: Query 3's submission also triggered pacing (running queries = 2 at submission time), + // so we have 3 total attempts: 1 from query3 submission + 2 from explicit calls + assertEquals(manager.getTotalAdmissionAttempts(), 3); + assertEquals(manager.getTotalAdmissionsGranted(), 2); + assertEquals(manager.getTotalAdmissionsDenied(), 1); + } + + // Tests that pacing turns off when running queries drop below the threshold + @Test + public void testAdmissionPacingTurnsOffWhenRunningQueriesDropBelowThreshold() + throws Exception + { + // Configure pacing with a threshold of 2 running queries and a slow rate + QueryManagerConfig config = new QueryManagerConfig() + .setMaxQueryAdmissionsPerSecond(1) // 1 per second, so pacing should be visible + .setMinRunningQueriesForPacing(2); // Threshold of 2 running queries + + InternalResourceGroupManager> manager = new InternalResourceGroupManager<>( + (poolId, listener) -> {}, + config, + new NodeInfo("test"), + new MBeanExporter(new TestingMBeanServer()), + () -> null, + new ServerConfig(), + new InMemoryNodeManager(), + new ClusterResourceChecker(new CpuMemoryOverloadPolicy(new ClusterOverloadConfig()), new ClusterOverloadConfig(), new InMemoryNodeManager())); + + // Simulate being above the threshold by incrementing running queries counter + manager.incrementRunningQueries(); + manager.incrementRunningQueries(); + + // With 2 running queries (at threshold), pacing should be applied + // First admission should succeed and set the lastAdmittedQueryNanos timestamp + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Immediate second attempt should be denied (need to wait 1 second) + assertFalse(manager.tryAcquireAdmissionSlot()); + + // Verify metrics are tracked when pacing is applied + assertEquals(manager.getTotalAdmissionAttempts(), 2); + assertEquals(manager.getTotalAdmissionsGranted(), 1); + assertEquals(manager.getTotalAdmissionsDenied(), 1); + + // Now simulate queries finishing so that we drop below the threshold + manager.decrementRunningQueries(); + manager.decrementRunningQueries(); + + // With 0 running queries (below threshold of 2), pacing should be bypassed + // Multiple rapid admissions should all succeed without waiting + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + assertTrue(manager.tryAcquireAdmissionSlot()); + + // Verify metrics did NOT increase when pacing was bypassed + // (should still be the same as before the decrement) + assertEquals(manager.getTotalAdmissionAttempts(), 2); + assertEquals(manager.getTotalAdmissionsGranted(), 1); + assertEquals(manager.getTotalAdmissionsDenied(), 1); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java index 1503f97b74723..43e9ba467e92b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/resourceGroups/TestResourceGroups.java @@ -13,8 +13,13 @@ */ package com.facebook.presto.execution.resourceGroups; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.execution.ClusterOverloadConfig; import com.facebook.presto.execution.MockManagedQueryExecution; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup.RootInternalResourceGroup; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterOverloadPolicy; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; @@ -22,8 +27,6 @@ import com.facebook.presto.server.ResourceGroupInfo; import com.facebook.presto.spi.ConnectorId; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.apache.commons.math3.distribution.BinomialDistribution; import org.testng.annotations.Test; @@ -43,6 +46,9 @@ import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.QUEUED; import static com.facebook.presto.execution.QueryState.RUNNING; @@ -53,9 +59,6 @@ import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED_FAIR; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Collections.reverse; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; @@ -69,7 +72,7 @@ public class TestResourceGroups @Test(timeOut = 10_000) public void testQueueFull() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(1); root.setHardConcurrencyLimit(1); @@ -91,7 +94,7 @@ public void testQueueFull() @Test(timeOut = 10_000) public void testFairEligibility() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(4); root.setHardConcurrencyLimit(1); @@ -151,7 +154,7 @@ public void testFairEligibility() @Test public void testSetSchedulingPolicy() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(4); root.setHardConcurrencyLimit(1); @@ -197,7 +200,7 @@ public void testSetSchedulingPolicy() @Test(timeOut = 10_000) public void testFairQueuing() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(4); root.setHardConcurrencyLimit(1); @@ -243,7 +246,7 @@ public void testFairQueuing() @Test(timeOut = 10_000) public void testMemoryLimit() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, BYTE)); root.setMaxQueuedQueries(4); root.setHardConcurrencyLimit(3); @@ -271,7 +274,7 @@ public void testMemoryLimit() @Test public void testSubgroupMemoryLimit() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(10, BYTE)); root.setMaxQueuedQueries(4); root.setHardConcurrencyLimit(3); @@ -304,7 +307,7 @@ public void testSubgroupMemoryLimit() @Test(timeOut = 10_000) public void testSoftCpuLimit() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, BYTE)); root.setSoftCpuLimit(new Duration(1, SECONDS)); root.setHardCpuLimit(new Duration(2, SECONDS)); @@ -341,7 +344,7 @@ public void testSoftCpuLimit() @Test(timeOut = 10_000) public void testPerWorkerQueryLimit() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setWorkersPerQueryLimit(5); root.setMaxQueuedQueries(2); root.setHardConcurrencyLimit(2); @@ -374,7 +377,7 @@ public void testPerWorkerQueryLimit() @Test(timeOut = 10_000) public void testPerWorkerQueryLimitMultipleGroups() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setWorkersPerQueryLimit(5); root.setMaxQueuedQueries(5); root.setHardConcurrencyLimit(2); @@ -417,7 +420,7 @@ public void testPerWorkerQueryLimitMultipleGroups() @Test(timeOut = 10_000) public void testHardCpuLimit() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, BYTE)); root.setHardCpuLimit(new Duration(1, SECONDS)); root.setCpuQuotaGenerationMillisPerSecond(2000); @@ -444,7 +447,7 @@ public void testHardCpuLimit() @Test(timeOut = 10_000) public void testPriorityScheduling() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(100); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -494,7 +497,7 @@ public void testPriorityScheduling() @Test(timeOut = 20_000) public void testWeightedScheduling() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(4); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -543,7 +546,7 @@ public void testWeightedScheduling() @Test(timeOut = 30_000) public void testWeightedFairScheduling() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(50); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -586,7 +589,7 @@ public void testWeightedFairScheduling() @Test(timeOut = 10_000) public void testWeightedFairSchedulingEqualWeights() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(50); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -645,7 +648,7 @@ public void testWeightedFairSchedulingEqualWeights() @Test(timeOut = 20_000) public void testWeightedFairSchedulingNoStarvation() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(50); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -686,7 +689,7 @@ public void testWeightedFairSchedulingNoStarvation() @Test public void testGetInfo() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(40); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -776,7 +779,7 @@ public void testGetInfo() @Test public void testGetResourceGroupStateInfo() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, GIGABYTE)); root.setMaxQueuedQueries(40); root.setHardConcurrencyLimit(10); @@ -844,7 +847,7 @@ public void testGetResourceGroupStateInfo() @Test public void testGetStaticResourceGroupInfo() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, GIGABYTE)); root.setMaxQueuedQueries(100); root.setHardConcurrencyLimit(10); @@ -921,7 +924,7 @@ private Optional getResourceGroupInfoForId(InternalResourceGr @Test public void testGetBlockedQueuedQueries() { - RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager()); + RootInternalResourceGroup root = new RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, createNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(40); // Start with zero capacity, so that nothing starts running until we've added all the queries @@ -1072,4 +1075,27 @@ private InternalNodeManager createNodeManager() false)); return internalNodeManager; } + + private ClusterResourceChecker createClusterResourceChecker() + { + ClusterOverloadPolicy mockPolicy = new ClusterOverloadPolicy() + { + @Override + public boolean isClusterOverloaded(InternalNodeManager nodeManager) + { + return false; + } + + @Override + public String getName() + { + return "test-policy"; + } + }; + + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setClusterOverloadThrottlingEnabled(false); + + return new ClusterResourceChecker(mockPolicy, config, createNodeManager()); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java index 8f50e98cd8d02..3a9ca81197856 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestAdaptivePhasedExecutionPolicy.java @@ -159,6 +159,7 @@ private static PlanFragment createPlanFragment(PlanFragmentId fragmentId, PlanNo SOURCE_DISTRIBUTION, ImmutableList.of(remoteSourcePlanNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), remoteSourcePlanNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java index 8c932566c4453..a42ed6eb014f8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestPhasedExecutionSchedule.java @@ -280,6 +280,7 @@ private static PlanFragment createFragment(PlanNode planNode) SOURCE_DISTRIBUTION, ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputVariables()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java index 5ea371568b9c3..4fdb5325e2011 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/TestSourcePartitionedScheduler.java @@ -336,7 +336,8 @@ public void testNoNodes() TABLE_SCAN_NODE_ID, new ConnectorAwareSplitSource(CONNECTOR_ID, TestingTransactionHandle.create(), createFixedSplitSource(20, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(TestingSession.testSessionBuilder().build(), CONNECTOR_ID), stage::getAllTasks), - 2); + 2, + new CTEMaterializationTracker()); scheduler.schedule(); fail("expected PrestoException"); @@ -474,7 +475,7 @@ private static StageScheduler getSourcePartitionedScheduler( new SimpleTtlNodeSelectorConfig()); SplitSource splitSource = new ConnectorAwareSplitSource(CONNECTOR_ID, TestingTransactionHandle.create(), connectorSplitSource); SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(TestingSession.testSessionBuilder().build(), splitSource.getConnectorId()), stage::getAllTasks); - return newSourcePartitionedSchedulerAsStageScheduler(stage, TABLE_SCAN_NODE_ID, splitSource, placementPolicy, splitBatchSize); + return newSourcePartitionedSchedulerAsStageScheduler(stage, TABLE_SCAN_NODE_ID, splitSource, placementPolicy, splitBatchSize, new CTEMaterializationTracker()); } private static SubPlan createPlan() @@ -515,6 +516,7 @@ private static SubPlan createPlan() SOURCE_DISTRIBUTION, ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestClusterResourceChecker.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestClusterResourceChecker.java new file mode 100644 index 0000000000000..1a5d9c4af8371 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestClusterResourceChecker.java @@ -0,0 +1,305 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.presto.execution.ClusterOverloadConfig; +import com.facebook.presto.metadata.AllNodes; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeLoadMetrics; +import com.facebook.presto.spi.NodeState; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestClusterResourceChecker +{ + private static final int CACHE_TTL_SECS = 1; + private static final int SLEEP_BUFFER_MILLIS = 200; + + private ClusterOverloadConfig config; + private TestingClusterOverloadPolicy clusterOverloadPolicy; + private ClusterResourceChecker clusterResourceChecker; + private TestingInternalNodeManager nodeManager; + + @BeforeMethod + public void setUp() + { + config = new ClusterOverloadConfig() + .setOverloadCheckCacheTtlInSecs(CACHE_TTL_SECS) + .setClusterOverloadThrottlingEnabled(true); + + clusterOverloadPolicy = new TestingClusterOverloadPolicy(); + nodeManager = new TestingInternalNodeManager(); + + clusterResourceChecker = new ClusterResourceChecker(clusterOverloadPolicy, config, nodeManager); + } + + public void testInitialState() + { + assertFalse(clusterResourceChecker.isClusterOverloaded()); + assertEquals(clusterResourceChecker.getOverloadDetectionCount().getTotalCount(), 0); + assertEquals(clusterResourceChecker.getOverloadDurationMillis(), 0); + assertTrue(clusterResourceChecker.isClusterOverloadThrottlingEnabled()); + } + + @Test + public void testIsClusterCurrentlyOverloaded() + { + // Start the periodic task + clusterResourceChecker.start(); + + // Initially not overloaded + clusterOverloadPolicy.setOverloaded(false); + assertFalse(clusterResourceChecker.isClusterCurrentlyOverloaded()); + + // Wait for periodic check to update state + clusterOverloadPolicy.setOverloaded(true); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertTrue(clusterResourceChecker.isClusterCurrentlyOverloaded()); + assertTrue(clusterResourceChecker.isClusterOverloaded()); + assertEquals(clusterResourceChecker.getOverloadDetectionCount().getTotalCount(), 1); + + // Set back to not overloaded + clusterOverloadPolicy.setOverloaded(false); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertFalse(clusterResourceChecker.isClusterCurrentlyOverloaded()); + assertFalse(clusterResourceChecker.isClusterOverloaded()); + + // Stop the periodic task + clusterResourceChecker.stop(); + } + + @Test + public void testOverloadDurationMetric() + { + // Start the periodic task + clusterResourceChecker.start(); + + // Set to overloaded and wait for periodic check + clusterOverloadPolicy.setOverloaded(true); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertTrue(clusterResourceChecker.isClusterCurrentlyOverloaded()); + assertTrue(clusterResourceChecker.isClusterOverloaded()); + + // Duration should be greater than 0 + sleep(100); + assertTrue(clusterResourceChecker.getOverloadDurationMillis() > 0); + + // Set back to not overloaded + clusterOverloadPolicy.setOverloaded(false); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertFalse(clusterResourceChecker.isClusterCurrentlyOverloaded()); + + // Duration should be 0 again + assertEquals(clusterResourceChecker.getOverloadDurationMillis(), 0); + + // Stop the periodic task + clusterResourceChecker.stop(); + } + + @Test + public void testMultipleOverloadTransitions() + { + // Start the periodic task + clusterResourceChecker.start(); + + // First transition to overloaded + clusterOverloadPolicy.setOverloaded(true); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertTrue(clusterResourceChecker.isClusterCurrentlyOverloaded()); + assertEquals(clusterResourceChecker.getOverloadDetectionCount().getTotalCount(), 1); + + // Wait and transition back to not overloaded + clusterOverloadPolicy.setOverloaded(false); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertFalse(clusterResourceChecker.isClusterCurrentlyOverloaded()); + + // Second transition to overloaded + clusterOverloadPolicy.setOverloaded(true); + sleep((CACHE_TTL_SECS * 1000) + SLEEP_BUFFER_MILLIS); + assertTrue(clusterResourceChecker.isClusterCurrentlyOverloaded()); + assertEquals(clusterResourceChecker.getOverloadDetectionCount().getTotalCount(), 2); + + // Stop the periodic task + clusterResourceChecker.stop(); + } + + @Test + public void testClusterOverloadThrottlingEnabled() + { + // Default is enabled (set in setUp) + assertTrue(clusterResourceChecker.isClusterOverloadThrottlingEnabled()); + + // Create a new config with throttling disabled + ClusterOverloadConfig disabledConfig = new ClusterOverloadConfig() + .setOverloadCheckCacheTtlInSecs(CACHE_TTL_SECS) + .setClusterOverloadThrottlingEnabled(false); + + // Create a new checker with throttling disabled + ClusterResourceChecker disabledChecker = new ClusterResourceChecker(clusterOverloadPolicy, disabledConfig, nodeManager); + assertFalse(disabledChecker.isClusterOverloadThrottlingEnabled()); + + // Even when cluster is overloaded, isClusterCurrentlyOverloaded should return false if throttling is disabled + clusterOverloadPolicy.setOverloaded(true); + assertFalse(disabledChecker.isClusterCurrentlyOverloaded()); + } + + private void sleep(long millis) + { + try { + Thread.sleep(millis); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + private static class TestingClusterOverloadPolicy + implements ClusterOverloadPolicy + { + private boolean overloaded; + private final AtomicInteger checkCount = new AtomicInteger(); + + @Override + public boolean isClusterOverloaded(InternalNodeManager nodeManager) + { + checkCount.incrementAndGet(); + return overloaded; + } + + @Override + public String getName() + { + return "test-policy"; + } + + public void setOverloaded(boolean overloaded) + { + this.overloaded = overloaded; + } + + public int getCheckCount() + { + return checkCount.get(); + } + } + + private static class TestingInternalNodeManager + implements InternalNodeManager + { + @Override + public Set getNodes(NodeState state) + { + return Collections.emptySet(); + } + + @Override + public Set getActiveConnectorNodes(ConnectorId connectorId) + { + return Collections.emptySet(); + } + + @Override + public Set getAllConnectorNodes(ConnectorId connectorId) + { + return Collections.emptySet(); + } + + @Override + public InternalNode getCurrentNode() + { + return null; + } + + @Override + public Set getCoordinators() + { + return Collections.emptySet(); + } + + @Override + public Set getShuttingDownCoordinator() + { + return Collections.emptySet(); + } + + @Override + public Set getResourceManagers() + { + return Collections.emptySet(); + } + + @Override + public Set getCatalogServers() + { + return Collections.emptySet(); + } + + @Override + public Set getCoordinatorSidecars() + { + return Collections.emptySet(); + } + + @Override + public AllNodes getAllNodes() + { + return new AllNodes( + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet()); + } + + @Override + public void refreshNodes() + { + // No-op for testing + } + + @Override + public void addNodeChangeListener(Consumer listener) + { + // No-op for testing + } + + @Override + public void removeNodeChangeListener(Consumer listener) + { + // No-op for testing + } + + @Override + public Optional getNodeLoadMetrics(String nodeIdentifier) + { + return Optional.empty(); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestCpuMemoryOverloadPolicy.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestCpuMemoryOverloadPolicy.java new file mode 100644 index 0000000000000..f3efbad3f549f --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/clusterOverload/TestCpuMemoryOverloadPolicy.java @@ -0,0 +1,254 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution.scheduler.clusterOverload; + +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.execution.ClusterOverloadConfig; +import com.facebook.presto.metadata.AllNodes; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeLoadMetrics; +import com.facebook.presto.spi.NodeState; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.function.Consumer; + +import static com.facebook.presto.execution.ClusterOverloadConfig.OVERLOAD_POLICY_CNT_BASED; +import static com.facebook.presto.execution.ClusterOverloadConfig.OVERLOAD_POLICY_PCT_BASED; +import static com.facebook.presto.metadata.InternalNode.NodeStatus.ALIVE; +import static com.facebook.presto.spi.NodePoolType.DEFAULT; +import static com.facebook.presto.spi.NodeState.ACTIVE; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestCpuMemoryOverloadPolicy +{ + private static final NodeVersion TEST_VERSION = new NodeVersion("test"); + private static final URI TEST_URI = URI.create("http://test.example.com"); + + @Test + public void testIsClusterOverloadedCountBasedNoOverload() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setAllowedOverloadWorkersCnt(1) + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // One node is overloaded, but allowed count is 1, so not overloaded + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of(createNode("node1", true, false), createNode("node2", false, false), createNode("node3", false, false))); + assertFalse(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testIsClusterOverloadedCountBasedOverload() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setAllowedOverloadWorkersCnt(1) + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // Two nodes are overloaded, but allowed count is 1, so overloaded + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of(createNode("node1", true, false), createNode("node2", false, true), createNode("node3", false, false))); + assertTrue(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testIsClusterOverloadedPctBasedNoOverload() + { + ClusterOverloadConfig config = new ClusterOverloadConfig().setAllowedOverloadWorkersPct(0.4).setOverloadPolicyType(OVERLOAD_POLICY_PCT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // 1 out of 3 nodes (33%) are overloaded, allowed is 40%, so not overloaded + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of(createNode("node1", true, false), createNode("node2", false, false), createNode("node3", false, false))); + assertFalse(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testIsClusterOverloadedPctBasedOverload() + { + ClusterOverloadConfig config = new ClusterOverloadConfig().setAllowedOverloadWorkersPct(0.3).setOverloadPolicyType(OVERLOAD_POLICY_PCT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // 2 out of 5 nodes (40%) are overloaded, allowed is 30%, so overloaded + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of(createNode("node1", true, false), createNode("node2", false, true), createNode("node3", false, false), createNode("node4", false, false), createNode("node5", false, false))); + assertTrue(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testIsClusterOverloadedBothMetricsOverloaded() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setAllowedOverloadWorkersCnt(0) + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // Node has both CPU and memory overloaded, should only count as one overloaded node + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of(createNode("node1", true, true))); + assertTrue(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testIsClusterOverloadedNoNodes() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setAllowedOverloadWorkersCnt(0) + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + // No nodes, should not be overloaded + InternalNodeManager nodeManager = createNodeManager(ImmutableSet.of()); + assertFalse(policy.isClusterOverloaded(nodeManager)); + } + + @Test + public void testGetNameCountBased() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setOverloadPolicyType(OVERLOAD_POLICY_CNT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + + assertEquals(policy.getName(), "cpu-memory-overload-cnt"); + } + + @Test + public void testGetNamePctBased() + { + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setOverloadPolicyType(OVERLOAD_POLICY_PCT_BASED); + CpuMemoryOverloadPolicy policy = new CpuMemoryOverloadPolicy(config); + assertEquals(policy.getName(), "cpu-memory-overload-pct"); + } + + // Store metrics separately since they're no longer part of InternalNode + private static final Map NODE_METRICS = new HashMap<>(); + + private static InternalNode createNode(String nodeId, boolean cpuOverload, boolean memoryOverload) + { + // Store metrics in the map for later retrieval + NODE_METRICS.put(nodeId, new NodeLoadMetrics(0.0, 0.0, 0, cpuOverload, memoryOverload)); + return new InternalNode( + nodeId, + TEST_URI, + OptionalInt.empty(), + TEST_VERSION, + false, + false, + false, + false, + ALIVE, + OptionalInt.empty(), + DEFAULT); + } + + private static InternalNodeManager createNodeManager(Set nodes) + { + return new InternalNodeManager() + { + @Override + public Set getNodes(NodeState state) + { + if (state == ACTIVE) { + return nodes; + } + return ImmutableSet.of(); + } + + @Override + public Set getActiveConnectorNodes(ConnectorId connectorId) + { + return ImmutableSet.of(); + } + + @Override + public Set getAllConnectorNodes(ConnectorId connectorId) + { + return Collections.emptySet(); + } + + @Override + public InternalNode getCurrentNode() + { + return null; + } + + @Override + public Set getCoordinators() + { + return ImmutableSet.of(); + } + + @Override + public Set getShuttingDownCoordinator() + { + return ImmutableSet.of(); + } + + @Override + public Set getResourceManagers() + { + return ImmutableSet.of(); + } + + @Override + public Set getCatalogServers() + { + return ImmutableSet.of(); + } + + @Override + public Set getCoordinatorSidecars() + { + return ImmutableSet.of(); + } + + @Override + public AllNodes getAllNodes() + { + return new AllNodes(nodes, ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of(), ImmutableSet.of()); + } + + @Override + public void refreshNodes() + { + } + + @Override + public void addNodeChangeListener(Consumer listener) + { + } + + @Override + public void removeNodeChangeListener(Consumer listener) + { + } + + @Override + public Optional getNodeLoadMetrics(String nodeIdentifier) + { + NodeLoadMetrics metrics = NODE_METRICS.get(nodeIdentifier); + return metrics != null ? Optional.of(metrics) : Optional.empty(); + } + }; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelector.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelector.java index eaf27ad3c87c3..6f0cded44be3e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelector.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelector.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.execution.scheduler.nodeselection; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelector; import com.facebook.presto.spi.ttl.ConfidenceBasedTtlInfo; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.concurrent.TimeUnit; diff --git a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelectorConfig.java b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelectorConfig.java index 43a558799528f..1f861adad98a2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelectorConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/execution/scheduler/nodeselection/TestSimpleTtlNodeSelectorConfig.java @@ -14,9 +14,9 @@ package com.facebook.presto.execution.scheduler.nodeselection; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelectorConfig; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestGeoFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestGeoFunctions.java index b651f9188161b..28a20d1be8a63 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestGeoFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestGeoFunctions.java @@ -457,12 +457,12 @@ public void testGeometryInvalidReason() assertInvalidReason("POLYGON ((0 0, 1 1, 0 1, 1 0, 0 0))", "Error constructing Polygon: shell is empty but holes are not"); assertInvalidReason("POLYGON ((0 0, 0 1, 0 1, 1 1, 1 0, 0 0), (2 2, 2 3, 3 3, 3 2, 2 2))", "Hole lies outside shell"); assertInvalidReason("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0), (2 2, 2 3, 3 3, 3 2, 2 2))", "Hole lies outside shell"); - assertInvalidReason("POLYGON ((0 0, 0 1, 2 1, 1 1, 1 0, 0 0))", "Ring Self-intersection"); + assertInvalidReason("POLYGON ((0 0, 0 1, 2 1, 1 1, 1 0, 0 0))", "Self-intersection"); assertInvalidReason("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0), (0 1, 1 1, 0.5 0.5, 0 1))", "Self-intersection"); assertInvalidReason("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0), (0 0, 0.5 0.7, 1 1, 0.5 0.4, 0 0))", "Interior is disconnected"); assertInvalidReason("POLYGON ((0 0, -1 0.5, 0 1, 1 1, 1 0, 0 1, 0 0))", "Ring Self-intersection"); assertInvalidReason("MULTIPOLYGON (((0 0, 0 1, 1 1, 1 0, 0 0)), ((0.5 0.5, 0.5 2, 2 2, 2 0.5, 0.5 0.5)))", "Self-intersection"); - assertInvalidReason("GEOMETRYCOLLECTION (POINT (1 2), POLYGON ((0 0, 0 1, 2 1, 1 1, 1 0, 0 0)))", "Ring Self-intersection"); + assertInvalidReason("GEOMETRYCOLLECTION (POINT (1 2), POLYGON ((0 0, 0 1, 2 1, 1 1, 1 0, 0 0)))", "Self-intersection"); // non-simple geometries assertInvalidReason("MULTIPOINT (1 2, 2 4, 3 6, 1 2)", "[MultiPoint] Repeated point: (1.0 2.0)"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestSpatialJoinOperator.java b/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestSpatialJoinOperator.java index bdde2214538af..c29e9dd104c69 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestSpatialJoinOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/geospatial/TestSpatialJoinOperator.java @@ -32,7 +32,7 @@ import com.facebook.presto.operator.ValuesOperator; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.plan.SpatialJoinNode.Type; +import com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.TestingTaskContext; @@ -64,8 +64,8 @@ import static com.facebook.presto.geospatial.GeoFunctions.stPoint; import static com.facebook.presto.geospatial.type.GeometryType.GEOMETRY; import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEqualsIgnoreOrder; -import static com.facebook.presto.spi.plan.SpatialJoinNode.Type.INNER; -import static com.facebook.presto.spi.plan.SpatialJoinNode.Type.LEFT; +import static com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType.INNER; +import static com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType.LEFT; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.SECONDS; @@ -227,7 +227,7 @@ public void testSpatialLeftJoin() assertSpatialJoin(taskContext, LEFT, buildPages, probePages, expected); } - private void assertSpatialJoin(TaskContext taskContext, Type joinType, RowPagesBuilder buildPages, RowPagesBuilder probePages, MaterializedResult expected) + private void assertSpatialJoin(TaskContext taskContext, SpatialJoinType joinType, RowPagesBuilder buildPages, RowPagesBuilder probePages, MaterializedResult expected) { DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); PagesSpatialIndexFactory pagesSpatialIndexFactory = buildIndex(driverContext, (build, probe, r) -> build.intersects(probe), Optional.empty(), Optional.empty(), buildPages); diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/LowMemoryKillerTestingUtils.java b/presto-main-base/src/test/java/com/facebook/presto/memory/LowMemoryKillerTestingUtils.java index e3f3650e89669..34fc05e034fa7 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/LowMemoryKillerTestingUtils.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/LowMemoryKillerTestingUtils.java @@ -14,6 +14,7 @@ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.spi.QueryId; @@ -21,16 +22,15 @@ import com.facebook.presto.spi.memory.MemoryPoolInfo; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import java.net.URI; import java.util.HashMap; import java.util.List; import java.util.Map; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.memory.LocalMemoryManager.GENERAL_POOL; import static com.facebook.presto.memory.LocalMemoryManager.RESERVED_POOL; -import static io.airlift.units.DataSize.Unit.BYTE; public class LowMemoryKillerTestingUtils { diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestLocalMemoryManager.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestLocalMemoryManager.java index 05ec026e2db5d..414ed2756c3b0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestLocalMemoryManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestLocalMemoryManager.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.memory; -import io.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize; import org.testng.annotations.Test; -import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryManagerConfig.java index 43313580b64b5..1738e38d58bdc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryManagerConfig.java @@ -14,18 +14,18 @@ package com.facebook.presto.memory; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.memory.MemoryManagerConfig.LowMemoryKillerPolicy.NONE; import static com.facebook.presto.memory.MemoryManagerConfig.LowMemoryKillerPolicy.TOTAL_RESERVATION_ON_BLOCKED_NODES; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryPools.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryPools.java index 71ff26aa1a7e2..bfada8044be1a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryPools.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryPools.java @@ -14,6 +14,7 @@ package com.facebook.presto.memory; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; import com.facebook.presto.Session; import com.facebook.presto.common.Page; @@ -40,7 +41,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.Test; @@ -52,14 +52,14 @@ import java.util.function.Function; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN; import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithInitialTransaction; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryTracking.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryTracking.java index e635923781150..5f713c317ae25 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryTracking.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestMemoryTracking.java @@ -14,6 +14,7 @@ package com.facebook.presto.memory; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskStateMachine; @@ -32,7 +33,6 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spiller.SpillSpaceTracker; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; @@ -46,9 +46,9 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.execution.TaskTestUtils.PLAN_FRAGMENT; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestNodeMemoryConfig.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestNodeMemoryConfig.java index 2b8b134dd8c31..b3057d46c51ca 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestNodeMemoryConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestNodeMemoryConfig.java @@ -14,17 +14,17 @@ package com.facebook.presto.memory; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.memory.LocalMemoryManager.validateHeapHeadroom; import static com.facebook.presto.memory.NodeMemoryConfig.AVAILABLE_HEAP_MEMORY; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; public class TestNodeMemoryConfig { diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestQueryContext.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestQueryContext.java index e389403847c00..cdd588119c127 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestQueryContext.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestQueryContext.java @@ -14,6 +14,7 @@ package com.facebook.presto.memory; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskStateMachine; @@ -27,7 +28,6 @@ import com.facebook.presto.spiller.SpillSpaceTracker; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -38,12 +38,12 @@ import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.TaskTestUtils.PLAN_FRAGMENT; import static com.facebook.presto.memory.LocalMemoryManager.GENERAL_POOL; import static com.facebook.presto.memory.LocalMemoryManager.RESERVED_POOL; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestReservedSystemMemoryConfig.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestReservedSystemMemoryConfig.java index ef9a0ed40bb86..0e52c87971539 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestReservedSystemMemoryConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestReservedSystemMemoryConfig.java @@ -13,15 +13,15 @@ */ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static org.testng.Assert.fail; public class TestReservedSystemMemoryConfig diff --git a/presto-main-base/src/test/java/com/facebook/presto/memory/TestSystemMemoryBlocking.java b/presto-main-base/src/test/java/com/facebook/presto/memory/TestSystemMemoryBlocking.java index f49952457825f..8d7ef6964aef3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/memory/TestSystemMemoryBlocking.java +++ b/presto-main-base/src/test/java/com/facebook/presto/memory/TestSystemMemoryBlocking.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.memory; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.Type; import com.facebook.presto.execution.ScheduledSplit; import com.facebook.presto.execution.TaskSource; @@ -39,8 +41,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java index e29d5a4777415..7ee833b6c9529 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/AbstractMockMetadata.java @@ -20,15 +20,15 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; -import com.facebook.presto.execution.QueryManager; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewStatus; +import com.facebook.presto.spi.MergeHandle; import com.facebook.presto.spi.NewTableLayout; -import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.SystemTable; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.TableMetadata; @@ -37,9 +37,12 @@ import com.facebook.presto.spi.connector.ConnectorCapabilities; import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.connector.ConnectorTableVersion; +import com.facebook.presto.spi.connector.RowChangeParadigm; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.plan.PartitioningHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.security.GrantInfo; import com.facebook.presto.spi.security.PrestoPrincipal; import com.facebook.presto.spi.security.Privilege; @@ -59,6 +62,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; +import static java.util.Locale.ENGLISH; public abstract class AbstractMockMetadata implements Metadata @@ -86,10 +90,17 @@ public void registerBuiltInFunctions(List functions) throw new UnsupportedOperationException(); } + @Override + public void registerConnectorFunctions(String catalogName, List functionInfos) + { + throw new UnsupportedOperationException(); + } + @Override public MetadataResolver getMetadataResolver(Session session) { - return new MetadataResolver() { + return new MetadataResolver() + { @Override public boolean catalogExists(String catalogName) { @@ -393,13 +404,13 @@ public Optional finishInsert(Session session, InsertTab } @Override - public ColumnHandle getDeleteRowIdColumnHandle(Session session, TableHandle tableHandle) + public Optional getDeleteRowIdColumn(Session session, TableHandle tableHandle) { throw new UnsupportedOperationException(); } @Override - public ColumnHandle getUpdateRowIdColumnHandle(Session session, TableHandle tableHandle, List updatedColumns) + public Optional getUpdateRowIdColumn(Session session, TableHandle tableHandle, List updatedColumns) { throw new UnsupportedOperationException(); } @@ -423,7 +434,20 @@ public DeleteTableHandle beginDelete(Session session, TableHandle tableHandle) } @Override - public void finishDelete(Session session, DeleteTableHandle tableHandle, Collection fragments) + public Optional finishDeleteWithOutput(Session session, DeleteTableHandle tableHandle, Collection fragments) + { + throw new UnsupportedOperationException(); + } + + @Override + public DistributedProcedureHandle beginCallDistributedProcedure(Session session, QualifiedObjectName procedureName, + TableHandle tableHandle, Object[] arguments, boolean sourceTableEliminated) + { + throw new UnsupportedOperationException(); + } + + @Override + public void finishCallDistributedProcedure(Session session, DistributedProcedureHandle procedureHandle, QualifiedObjectName procedureName, Collection fragments) { throw new UnsupportedOperationException(); } @@ -440,6 +464,49 @@ public void finishUpdate(Session session, TableHandle tableHandle, Collection getMergeUpdateLayout(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + /** + * Begin merge query + */ + public MergeHandle beginMerge(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + + /** + * Finish merge query + */ + public void finishMerge(Session session, MergeHandle tableHandle, Collection fragments, Collection computedStatistics) + { + throw new UnsupportedOperationException(); + } + @Override public Optional getCatalogHandle(Session session, String catalogName) { @@ -464,6 +531,24 @@ public Map getViews(Session session, Qualif throw new UnsupportedOperationException(); } + @Override + public List listMaterializedViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getMaterializedViews(Session session, QualifiedTablePrefix prefix) + { + throw new UnsupportedOperationException(); + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus(Session session, QualifiedObjectName viewName, TupleDomain baseQueryDomain) + { + throw new UnsupportedOperationException(); + } + @Override public void createView(Session session, String catalogName, ConnectorTableMetadata viewMetadata, String viewData, boolean replace) { @@ -602,12 +687,6 @@ public ListenableFuture commitPageSinkAsync(Session session, DeleteTableHa throw new UnsupportedOperationException(); } - @Override - public MetadataUpdates getMetadataUpdateResults(Session session, QueryManager queryManager, MetadataUpdates metadataUpdateRequests, QueryId queryId) - { - throw new UnsupportedOperationException(); - } - @Override public FunctionAndTypeManager getFunctionAndTypeManager() { @@ -644,6 +723,12 @@ public TablePropertyManager getTablePropertyManager() throw new UnsupportedOperationException(); } + @Override + public MaterializedViewPropertyManager getMaterializedViewPropertyManager() + { + throw new UnsupportedOperationException(); + } + @Override public ColumnPropertyManager getColumnPropertyManager() { @@ -668,6 +753,18 @@ public Set getConnectorCapabilities(Session session, Conn throw new UnsupportedOperationException(); } + @Override + public void dropBranch(Session session, TableHandle tableHandle, String branchName, boolean branchExists) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dropTag(Session session, TableHandle tableHandle, String tagName, boolean tagExists) + { + throw new UnsupportedOperationException(); + } + @Override public void dropConstraint(Session session, TableHandle tableHandle, Optional constraintName, Optional columnName) { @@ -679,4 +776,16 @@ public void addConstraint(Session session, TableHandle tableHandle, TableConstra { throw new UnsupportedOperationException(); } + + @Override + public String normalizeIdentifier(Session session, String catalogName, String identifier) + { + return identifier.toLowerCase(ENGLISH); + } + + @Override + public Optional> applyTableFunction(Session session, TableFunctionHandle handle) + { + return Optional.empty(); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java new file mode 100644 index 0000000000000..9016d6a21fac0 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestAbstractTypedJacksonModule.java @@ -0,0 +1,607 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Module; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Base64; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +@Test(singleThreaded = true) +public class TestAbstractTypedJacksonModule +{ + private ObjectMapper objectMapper; + @BeforeMethod + public void setup() + { + // Default setup with binary serialization disabled + setupInjector(false, null); + } + + private void setupInjector(boolean binarySerializationEnabled, ConnectorCodecProvider codecProvider) + { + Module testModule = binder -> { + binder.install(new JsonModule()); + + // Configure FeaturesConfig + FeaturesConfig featuresConfig = new FeaturesConfig(); + featuresConfig.setUseConnectorProvidedSerializationCodecs(binarySerializationEnabled); + binder.bind(FeaturesConfig.class).toInstance(featuresConfig); + + // Bind HandleResolver + binder.bind(HandleResolver.class).toInstance(new TestHandleResolver()); + + // Bind TestConnectorManager as a singleton + TestConnectorManager testConnectorManager = new TestConnectorManager(codecProvider); + binder.bind(TestConnectorManager.class).toInstance(testConnectorManager); + + // Register the test Jackson module + jsonBinder(binder).addModuleBinding().to(TestHandleJacksonModule.class); + }; + + Injector injector = Guice.createInjector(testModule); + objectMapper = injector.getInstance(ObjectMapper.class); + } + + @Test + public void testLegacyJsonSerializationWithoutCodec() + throws Exception + { + // Setup with binary serialization disabled + setupInjector(false, null); + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should have @type field but no binary data + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "id", "value1"); + assertJsonContains(json, "count", "42"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testBinarySerializationWithCodec() + throws Exception + { + // Create a simple codec that serializes to a custom format + ConnectorCodec codec = new SimpleCodec(); + + // Setup with binary serialization enabled and codec provider + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should have @type and binary data fields + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "customSerializedValue"); + assertJsonNotContains(json, "id", "value1"); // Should not have regular fields + + // Test deserialization + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getId(), original.getId()); + assertEquals(deserialized.getCount(), original.getCount()); + } + + @Test + public void testBinarySerializationDisabled() + throws Exception + { + // This test verifies that when binary serialization is disabled via the feature flag, + // the module falls back to legacy JSON serialization even if codecs are available + + // Setup with binary serialization disabled even though codec is available + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(false, codecProvider); // false = binary serialization disabled + + TestHandle original = new TestHandle("connector1", "value1", 42); + String json = objectMapper.writeValueAsString(original); + + // Should use legacy JSON serialization even though codec is available + assertJsonContains(json, "@type", "connector1"); + assertJsonContains(json, "id", "value1"); + assertJsonContains(json, "count", "42"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testFallbackToJsonWhenNoCodec() + throws Exception + { + // Setup with binary serialization enabled but no codec available + setupInjector(true, null); + + // Test with connector2 (no codec available) + TestHandle original = new TestHandle("connector2", "value2", 84); + String json = objectMapper.writeValueAsString(original); + + // Should fall back to JSON serialization + assertJsonContains(json, "@type", "connector2"); + assertJsonContains(json, "id", "value2"); + assertJsonContains(json, "count", "84"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testInternalHandlesAlwaysUseJson() + throws Exception + { + // Setup with codec that would handle all connectors + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + // Test with internal handle (starts with $) + TestHandle original = new TestHandle("$remote", "internal", 99); + String json = objectMapper.writeValueAsString(original); + + // Should use JSON serialization for internal handles + assertJsonContains(json, "@type", "$remote"); + assertJsonContains(json, "id", "internal"); + assertJsonContains(json, "count", "99"); + assertJsonNotContains(json, "customSerializedValue"); + } + + @Test + public void testNullValueSerialization() + throws Exception + { + setupInjector(false, null); + + String json = objectMapper.writeValueAsString(null); + assertEquals(json, "null"); + + TestHandle deserialized = objectMapper.readValue("null", TestHandle.class); + assertNull(deserialized); + } + + @Test + public void testRoundTripWithMixedHandles() + throws Exception + { + // Create a TestConnectorManager that only provides codec for "binary-connector" + setupInjector(true, new SelectiveCodecProvider("binary-connector")); + + // Test multiple handles with different serialization methods + TestHandle[] handles = new TestHandle[] { + new TestHandle("binary-connector", "binary1", 1), + new TestHandle("json-connector", "json1", 2), + new TestHandle("$internal", "internal1", 3), + new TestHandle("binary-connector", "binary2", 4), + }; + + for (TestHandle original : handles) { + String json = objectMapper.writeValueAsString(original); + + // Verify serialization format based on handle type + if (original.getConnectorId().equals("binary-connector")) { + // Should use binary serialization + assertJsonContains(json, "customSerializedValue"); + assertJsonNotContains(json, "\"id\":"); + + // Test deserialization for binary-serialized handles + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getId(), original.getId()); + assertEquals(deserialized.getCount(), original.getCount()); + } + else { + // Should use JSON serialization + assertJsonNotContains(json, "customSerializedValue"); + assertJsonContains(json, "id", original.getId()); + } + } + } + + @Test + public void testDirectBinaryDataDeserialization() + throws Exception + { + // Test deserialization of manually crafted binary data JSON + ConnectorCodec codec = new SimpleCodec(); + ConnectorCodecProvider codecProvider = new ConnectorCodecProvider() + { + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + }; + setupInjector(true, codecProvider); + + // Manually create JSON with binary data + String encodedData = Base64.getEncoder().encodeToString("connector1|testValue|999".getBytes(UTF_8)); + String json = String.format("{\"@type\":\"connector1\",\"customSerializedValue\":\"%s\"}", encodedData); + + // Deserialize + TestHandle deserialized = objectMapper.readValue(json, TestHandle.class); + assertEquals(deserialized.getConnectorId(), "connector1"); + assertEquals(deserialized.getId(), "testValue"); + assertEquals(deserialized.getCount(), 999); + } + + @Test + public void testMixedSerializationRoundTrip() + throws Exception + { + // Test that we can serialize and deserialize a mix of binary and JSON in sequence + setupInjector(true, new SelectiveCodecProvider("binary-connector")); + + // Create handles with different serialization methods + TestHandle binaryHandle = new TestHandle("binary-connector", "binary-data", 100); + TestHandle jsonHandle = new TestHandle("json-connector", "json-data", 200); + + // Serialize both + String binaryJson = objectMapper.writeValueAsString(binaryHandle); + String jsonJson = objectMapper.writeValueAsString(jsonHandle); + + // Deserialize both + TestHandle deserializedBinary = objectMapper.readValue(binaryJson, TestHandle.class); + // For JSON deserialization, we skip due to complex type handling in isolated tests + + // Verify binary deserialization worked + assertEquals(deserializedBinary.getId(), binaryHandle.getId()); + assertEquals(deserializedBinary.getCount(), binaryHandle.getCount()); + + // Verify JSON format is correct (even if we can't deserialize in this test) + assertJsonContains(jsonJson, "\"id\":\"json-data\""); + assertJsonContains(jsonJson, "\"count\":200"); + } + + private void assertJsonContains(String json, String... values) + { + for (String value : values) { + if (!json.contains(value)) { + throw new AssertionError("JSON does not contain: " + value + "\nJSON: " + json); + } + } + } + + private void assertJsonNotContains(String json, String... values) + { + for (String value : values) { + if (json.contains(value)) { + throw new AssertionError("JSON should not contain: " + value + "\nJSON: " + json); + } + } + } + + // Simple codec implementation for testing + private static class SimpleCodec + implements ConnectorCodec + { + @Override + public byte[] serialize(TestHandle value) + { + return String.format("%s|%s|%d", value.getConnectorId(), value.getId(), value.getCount()).getBytes(UTF_8); + } + + @Override + public TestHandle deserialize(byte[] data) + { + String[] parts = new String(data, UTF_8).split("\\|"); + return new TestHandle(parts[0], parts[1], Integer.parseInt(parts[2])); + } + } + + // Codec provider that only provides codec for specific connectors + private static class SelectiveCodecProvider + implements ConnectorCodecProvider + { + private final String connectorIdWithCodec; + private final ConnectorCodec codec = new SimpleCodec(); + + public SelectiveCodecProvider(String connectorIdWithCodec) + { + this.connectorIdWithCodec = connectorIdWithCodec; + } + + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of((ConnectorCodec) codec); + } + } + + // Test handle that implements multiple connector interfaces for testing + public static class TestHandle + implements com.facebook.presto.spi.ConnectorTableHandle, + com.facebook.presto.spi.ConnectorSplit, + com.facebook.presto.spi.ColumnHandle, + com.facebook.presto.spi.ConnectorTableLayoutHandle, + com.facebook.presto.spi.ConnectorOutputTableHandle, + com.facebook.presto.spi.ConnectorInsertTableHandle, + com.facebook.presto.spi.ConnectorDeleteTableHandle, + com.facebook.presto.spi.ConnectorIndexHandle, + com.facebook.presto.spi.connector.ConnectorPartitioningHandle, + com.facebook.presto.spi.connector.ConnectorTransactionHandle, + com.facebook.presto.spi.ConnectorDistributedProcedureHandle + { + private final String connectorId; + private final String id; + private final int count; + + // Constructor for programmatic creation + public TestHandle(String connectorId, String id, int count) + { + this.connectorId = connectorId; + this.id = id; + this.count = count; + } + + // Constructor for Jackson deserialization + @JsonCreator + public TestHandle( + @JsonProperty("id") String id, + @JsonProperty("count") int count) + { + // When deserializing, the connector ID is determined by the @type field + // For simplicity in tests, we use a fixed value + this("deserialized", id, count); + } + + // This field is excluded from JSON serialization but used internally for type resolution + @JsonIgnore + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + public String getId() + { + return id; + } + + @JsonProperty + public int getCount() + { + return count; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestHandle that = (TestHandle) o; + return count == that.count && + Objects.equals(connectorId, that.connectorId) && + Objects.equals(id, that.id); + } + + @Override + public int hashCode() + { + return Objects.hash(connectorId, id, count); + } + + @Override + public String toString() + { + return "TestHandle{" + + "connectorId='" + connectorId + '\'' + + ", id='" + id + '\'' + + ", count=" + count + + '}'; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return null; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return ImmutableList.of(); + } + + @Override + public Object getInfo() + { + return null; + } + } + + // Test ConnectorHandleResolver implementation + private static class TestConnectorHandleResolver + implements com.facebook.presto.spi.ConnectorHandleResolver + { + @Override + public Class getTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getSplitClass() + { + return TestHandle.class; + } + + @Override + public Class getIndexHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getOutputTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getInsertTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getDeleteTableHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getPartitioningHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getTransactionHandleClass() + { + return TestHandle.class; + } + + @Override + public Class getDistributedProcedureHandleClass() + { + return TestHandle.class; + } + } + + // Test HandleResolver implementation + private static class TestHandleResolver + extends HandleResolver + { + public TestHandleResolver() + { + super(); + // Register the test handle resolver for all test connectors + TestConnectorHandleResolver resolver = new TestConnectorHandleResolver(); + addConnectorName("connector1", resolver); + addConnectorName("connector2", resolver); + addConnectorName("binary-connector", resolver); + addConnectorName("json-connector", resolver); + addConnectorName("$internal", resolver); + addConnectorName("deserialized", resolver); + } + } + + // Mock ConnectorManager implementation + private static class TestConnectorManager + { + private final ConnectorCodecProvider codecProvider; + + public TestConnectorManager(ConnectorCodecProvider codecProvider) + { + this.codecProvider = codecProvider; + } + + public Optional getConnectorCodecProvider(ConnectorId connectorId) + { + // Only return codec provider for specific connectors if it's a SelectiveCodecProvider + if (codecProvider instanceof SelectiveCodecProvider) { + SelectiveCodecProvider selective = (SelectiveCodecProvider) codecProvider; + if (connectorId.getCatalogName().equals(selective.connectorIdWithCodec)) { + return Optional.of(codecProvider); + } + return Optional.empty(); + } + return Optional.ofNullable(codecProvider); + } + } + + // Test Jackson module that uses TestHandle + public static class TestHandleJacksonModule + extends AbstractTypedJacksonModule + { + @jakarta.inject.Inject + public TestHandleJacksonModule( + HandleResolver handleResolver, + TestConnectorManager testConnectorManager, + FeaturesConfig featuresConfig) + { + super(TestHandle.class, + TestHandle::getConnectorId, + id -> TestHandle.class, + featuresConfig.isUseConnectorProvidedSerializationCodecs(), + connectorId -> testConnectorManager + .getConnectorCodecProvider(connectorId) + .flatMap(provider -> { + Optional> codec = + provider.getConnectorTableHandleCodec(); + // Cast is safe because TestHandle implements ConnectorTableHandle + return (Optional>) (Optional) codec; + })); + } + } +} diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestConvertApplicableTypeToVariable.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestConvertApplicableTypeToVariable.java similarity index 78% rename from presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestConvertApplicableTypeToVariable.java rename to presto-main-base/src/test/java/com/facebook/presto/metadata/TestConvertApplicableTypeToVariable.java index 93f9ee75f2b82..33757fb9d57d8 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestConvertApplicableTypeToVariable.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestConvertApplicableTypeToVariable.java @@ -11,9 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.sidecar; +package com.facebook.presto.metadata; import com.facebook.presto.common.type.NamedTypeSignature; +import com.facebook.presto.common.type.RowFieldName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.common.type.TypeSignatureParameter; import org.testng.annotations.Test; @@ -21,8 +22,8 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.builtin.tools.WorkerFunctionUtil.convertApplicableTypeToVariable; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.sidecar.functionNamespace.NativeFunctionNamespaceManager.convertApplicableTypeToVariable; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -265,4 +266,71 @@ public void testConvertApplicableTypeToVariableGenericArrayParamType() convertApplicableTypeToVariable(actualTypeSignature.getTypeOrNamedTypeParametersAsTypeSignatures()); assertEquals(expectedTypeSignature.getTypeOrNamedTypeParametersAsTypeSignatures(), resolvedTypeSignaturesList); } + + @Test + public void testConvertApplicableTypeToVariableNamedRowFieldsWithVarchar() + { + TypeSignature actualTypeSignature = parseTypeSignature("row(format_type varchar, num_vectors bigint)"); + TypeSignature expectedTypeSignature = new TypeSignature( + "row", + TypeSignatureParameter.of( + new NamedTypeSignature( + Optional.of(new RowFieldName("format_type", false)), + parseTypeSignature("varchar"))), + TypeSignatureParameter.of( + new NamedTypeSignature( + Optional.of(new RowFieldName("num_vectors", false)), + parseTypeSignature("bigint")))); + TypeSignature resolvedTypeSignature = convertApplicableTypeToVariable(actualTypeSignature); + assertEquals(expectedTypeSignature, resolvedTypeSignature); + } + + @Test + public void testConvertApplicableTypeToVariableNamedRowFieldsWithMap() + { + TypeSignature actualTypeSignature = parseTypeSignature("row(metadata map(varchar, varchar), count bigint)"); + TypeSignature expectedTypeSignature = new TypeSignature( + "row", + TypeSignatureParameter.of( + new NamedTypeSignature( + Optional.of(new RowFieldName("metadata", false)), + new TypeSignature( + "map", + TypeSignatureParameter.of(parseTypeSignature("varchar")), + TypeSignatureParameter.of(parseTypeSignature("varchar"))))), + TypeSignatureParameter.of( + new NamedTypeSignature( + Optional.of(new RowFieldName("count", false)), + parseTypeSignature("bigint")))); + TypeSignature resolvedTypeSignature = convertApplicableTypeToVariable(actualTypeSignature); + assertEquals(expectedTypeSignature, resolvedTypeSignature); + } + + @Test + public void testConvertApplicableTypeToVariableSignatureWithVarcharMapBigint() + { + TypeSignature actualTypeSignature = parseTypeSignature( + "row(format_type varchar, num_vectors bigint, dimension integer, " + + "index_type varchar, distance_metric varchar, id_type varchar, " + + "metadata map(varchar, varchar))"); + TypeSignature resolvedTypeSignature = convertApplicableTypeToVariable(actualTypeSignature); + + List params = resolvedTypeSignature.getParameters(); + assertEquals(params.size(), 7); + + assertNamedField(params.get(0), "format_type", "varchar"); + assertNamedField(params.get(1), "num_vectors", "bigint"); + assertNamedField(params.get(2), "dimension", "integer"); + assertNamedField(params.get(3), "index_type", "varchar"); + assertNamedField(params.get(4), "distance_metric", "varchar"); + assertNamedField(params.get(5), "id_type", "varchar"); + assertNamedField(params.get(6), "metadata", "map"); + } + + private static void assertNamedField(TypeSignatureParameter typeSignatureParameter, String expectedFieldName, String expectedTypeBase) + { + assertTrue(typeSignatureParameter.isNamedTypeSignature()); + assertEquals(typeSignatureParameter.getNamedTypeSignature().getFieldName(), Optional.of(new RowFieldName(expectedFieldName, false))); + assertEquals(typeSignatureParameter.getNamedTypeSignature().getTypeSignature().getBase(), expectedTypeBase); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestFunctionAndTypeManager.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestFunctionAndTypeManager.java index ece53b2e9d8d7..52c1d128b3358 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestFunctionAndTypeManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestFunctionAndTypeManager.java @@ -18,9 +18,14 @@ import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation; import com.facebook.presto.operator.scalar.CustomFunctions; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionImplementationType; import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.ScalarFunction; @@ -191,6 +196,149 @@ public void testListingVisibilityBetaFunctionsEnabled() assertFalse(names.contains("max_data_size_for_stats"), "Expected function names " + names + " not to contain 'max_data_size_for_stats'"); } + @Test + public void testListFunctionsWithNonBuiltInFunctionNamespacesFilter() + { + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + + functionAndTypeManager.addFunctionNamespace( + "catalog_a", + new InMemoryFunctionNamespaceManager( + "catalog_a", + new SqlFunctionExecutors( + ImmutableMap.of( + RoutineCharacteristics.Language.SQL, FunctionImplementationType.SQL, + new RoutineCharacteristics.Language("java"), FunctionImplementationType.THRIFT), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql,java"))); + + functionAndTypeManager.addFunctionNamespace( + "catalog_b", + new InMemoryFunctionNamespaceManager( + "catalog_b", + new SqlFunctionExecutors( + ImmutableMap.of( + RoutineCharacteristics.Language.SQL, FunctionImplementationType.SQL, + new RoutineCharacteristics.Language("java"), FunctionImplementationType.THRIFT), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql,java"))); + + SqlInvokedFunction funcA = new SqlInvokedFunction( + QualifiedObjectName.valueOf("catalog_a", "schema", "func_a"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "func_a(x)", + RoutineCharacteristics.builder().setLanguage(RoutineCharacteristics.Language.SQL).build(), + "", + notVersioned()); + functionAndTypeManager.createFunction(funcA, true); + + SqlInvokedFunction funcB = new SqlInvokedFunction( + QualifiedObjectName.valueOf("catalog_b", "schema", "func_b"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "func_b(x)", + RoutineCharacteristics.builder().setLanguage(RoutineCharacteristics.Language.SQL).build(), + "", + notVersioned()); + functionAndTypeManager.createFunction(funcB, true); + + // Test with list_built_in_functions_only = false and no namespace filter (empty string) + // Should list functions from all namespaces + Session sessionNoFilter = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty("list_built_in_functions_only", "false") + .setSystemProperty("non_built_in_function_namespaces_to_list_functions", "") + .build(); + + List functionsNoFilter = functionAndTypeManager.listFunctions(sessionNoFilter, Optional.empty(), Optional.empty()); + List namesNoFilter = transform(functionsNoFilter, input -> input.getSignature().getNameSuffix()); + + assertTrue(namesNoFilter.contains("func_a"), "Expected function names to contain 'func_a' when no namespace filter is set"); + assertTrue(namesNoFilter.contains("func_b"), "Expected function names to contain 'func_b' when no namespace filter is set"); + + Session sessionFilterA = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty("list_built_in_functions_only", "false") + .setSystemProperty("non_built_in_function_namespaces_to_list_functions", "catalog_a") + .build(); + + List functionsFilterA = functionAndTypeManager.listFunctions(sessionFilterA, Optional.empty(), Optional.empty()); + List namesFilterA = transform(functionsFilterA, input -> input.getSignature().getNameSuffix()); + + assertTrue(namesFilterA.contains("func_a"), "Expected function names to contain 'func_a' when filtering for catalog_a"); + assertFalse(namesFilterA.contains("func_b"), "Expected function names NOT to contain 'func_b' when filtering for catalog_a"); + + Session sessionFilterB = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty("list_built_in_functions_only", "false") + .setSystemProperty("non_built_in_function_namespaces_to_list_functions", "catalog_b") + .build(); + + List functionsFilterB = functionAndTypeManager.listFunctions(sessionFilterB, Optional.empty(), Optional.empty()); + List namesFilterB = transform(functionsFilterB, input -> input.getSignature().getNameSuffix()); + + assertFalse(namesFilterB.contains("func_a"), "Expected function names NOT to contain 'func_a' when filtering for catalog_b"); + assertTrue(namesFilterB.contains("func_b"), "Expected function names to contain 'func_b' when filtering for catalog_b"); + + Session sessionFilterBoth = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty("list_built_in_functions_only", "false") + .setSystemProperty("non_built_in_function_namespaces_to_list_functions", "catalog_a,catalog_b") + .build(); + + List functionsFilterBoth = functionAndTypeManager.listFunctions(sessionFilterBoth, Optional.empty(), Optional.empty()); + List namesFilterBoth = transform(functionsFilterBoth, input -> input.getSignature().getNameSuffix()); + + assertTrue(namesFilterBoth.contains("func_a"), "Expected function names to contain 'func_a' when filtering for both catalogs"); + assertTrue(namesFilterBoth.contains("func_b"), "Expected function names to contain 'func_b' when filtering for both catalogs"); + } + + @Test + public void testListFunctionsWithBuiltInFunctionsOnlyTrue() + { + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + + functionAndTypeManager.addFunctionNamespace( + "custom_catalog", + new InMemoryFunctionNamespaceManager( + "custom_catalog", + new SqlFunctionExecutors( + ImmutableMap.of( + RoutineCharacteristics.Language.SQL, FunctionImplementationType.SQL), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql"))); + + SqlInvokedFunction customFunc = new SqlInvokedFunction( + QualifiedObjectName.valueOf("custom_catalog", "schema", "custom_func"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "custom_func(x)", + RoutineCharacteristics.builder().setLanguage(RoutineCharacteristics.Language.SQL).build(), + "", + notVersioned()); + functionAndTypeManager.createFunction(customFunc, true); + + // Test with list_built_in_functions_only = true (default) + // The non_built_in_function_namespaces_to_list_functions should be ignored + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema(TINY_SCHEMA_NAME) + .setSystemProperty("list_built_in_functions_only", "true") + .setSystemProperty("non_built_in_function_namespaces_to_list_functions", "custom_catalog") + .build(); + + List functions = functionAndTypeManager.listFunctions(session, Optional.empty(), Optional.empty()); + List names = transform(functions, input -> input.getSignature().getNameSuffix()); + + assertTrue(names.contains("length"), "Expected built-in function 'length' to be present"); + assertFalse(names.contains("custom_func"), "Expected custom function NOT to be present when list_built_in_functions_only is true"); + } + @Test public void testOperatorTypes() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaMetadata.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaMetadata.java index ded11a9fe6652..a8eb7ba6399af 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaMetadata.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaMetadata.java @@ -44,6 +44,7 @@ import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.MAX_PREFIXES_COUNT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; @@ -59,6 +60,7 @@ public class TestInformationSchemaMetadata private final TransactionManager transactionManager; private final Metadata metadata; + private static final String TEST_CATALOG = "test_catalog"; public TestInformationSchemaMetadata() { @@ -152,6 +154,27 @@ public void testInformationSchemaPredicatePushdownWithConstraintPredicate() assertEquals(tableHandle.getPrefixes(), ImmutableSet.of(new QualifiedTablePrefix("test_catalog", "test_schema", "test_view"))); } + @Test + public void testInformationSchemaMaxPrefixesCount() + { + InformationSchemaMetadata info = new InformationSchemaMetadata(TEST_CATALOG, metadata); + + TransactionId transactionId = transactionManager.beginTransaction(false); + ConnectorSession session = createNewSessionWithMaxPrefixes(transactionId, 10); + Constraint constraint = new Constraint<>(TupleDomain.all()); + + ConnectorTableLayoutResult result = info.getTableLayoutForConstraint( + session, + new InformationSchemaTableHandle(TEST_CATALOG, "information_schema", "tables"), + constraint, + Optional.empty()); + + InformationSchemaTableLayoutHandle handle = + (InformationSchemaTableLayoutHandle) result.getTableLayout().getHandle(); + + assertTrue(handle.getPrefixes().size() <= 10); + } + private ConnectorSession createNewSession(TransactionId transactionId) { return testSessionBuilder() @@ -161,4 +184,15 @@ private ConnectorSession createNewSession(TransactionId transactionId) .build() .toConnectorSession(); } + + private ConnectorSession createNewSessionWithMaxPrefixes(TransactionId transactionId, int maxPrefixesCount) + { + return testSessionBuilder() + .setCatalog(TEST_CATALOG) + .setSchema("information_schema") + .setTransactionId(transactionId) + .setSystemProperty(MAX_PREFIXES_COUNT, String.valueOf(maxPrefixesCount)) + .build() + .toConnectorSession(); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java index 8b072926db5c1..61e1c0bca3f65 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestInformationSchemaTableHandle.java @@ -14,6 +14,8 @@ package com.facebook.presto.metadata; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.connector.informationSchema.InformationSchemaTableHandle; import com.facebook.presto.spi.ConnectorTableHandle; import com.fasterxml.jackson.core.type.TypeReference; @@ -21,6 +23,7 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Guice; import com.google.inject.Injector; +import com.google.inject.Scopes; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -44,7 +47,13 @@ public class TestInformationSchemaTableHandle @BeforeMethod public void startUp() { - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector( + new JsonModule(), + binder -> { + binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + }); objectMapper = injector.getInstance(ObjectMapper.class); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSessionPropertyProviderConfig.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSessionPropertyProviderConfig.java new file mode 100644 index 0000000000000..cfa205bb637e3 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSessionPropertyProviderConfig.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestSessionPropertyProviderConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(SessionPropertyProviderConfig.class) + .setSessionPropertyProvidersConfigurationDir(new File("etc/session-property-providers"))); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("session-property-provider.config-dir", "/foo") + .build(); + + SessionPropertyProviderConfig expected = new SessionPropertyProviderConfig() + .setSessionPropertyProvidersConfigurationDir(new File("/foo")); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestStaticTypeManagerStoreConfig.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestStaticTypeManagerStoreConfig.java new file mode 100644 index 0000000000000..424078a6b4b29 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestStaticTypeManagerStoreConfig.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.io.File; +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestStaticTypeManagerStoreConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(StaticTypeManagerStoreConfig.class) + .setTypeManagerConfigurationDir(new File("etc/type-managers"))); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("type-manager.config-dir", "/foo") + .build(); + + StaticTypeManagerStoreConfig expected = new StaticTypeManagerStoreConfig() + .setTypeManagerConfigurationDir(new File("/foo")); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java index 0e291c6a8fb59..5083171e881b1 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestSystemTableHandle.java @@ -14,6 +14,8 @@ package com.facebook.presto.metadata; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.connector.system.SystemTableHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableHandle; @@ -23,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Guice; import com.google.inject.Injector; +import com.google.inject.Scopes; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -47,7 +50,13 @@ public class TestSystemTableHandle @BeforeMethod public void startUp() { - Injector injector = Guice.createInjector(new JsonModule(), new HandleJsonModule()); + Injector injector = Guice.createInjector( + new JsonModule(), + binder -> { + binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + }); objectMapper = injector.getInstance(ObjectMapper.class); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java new file mode 100644 index 0000000000000..e7670a5af9b5d --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/metadata/TestTableFunctionRegistry.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.connector.tvf.TestingTableFunctions.DuplicateArgumentsTableFunction; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.MultipleRSTableFunction; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.NullArgumentsTableFunction; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.TestConnectorTableFunction; +import static com.facebook.presto.connector.tvf.TestingTableFunctions.TestConnectorTableFunction2; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; + +public class TestTableFunctionRegistry +{ + private static final String CATALOG = "test_catalog"; + private static final String USER = "user"; + private static final String SCHEMA = "system"; + private static final Session SESSION = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema(SCHEMA) + .setIdentity(new Identity(USER, Optional.empty())).build(); + private static final Session MISMATCH_SESSION = testSessionBuilder() + .setCatalog(CATALOG) + .setSchema("other") + .setIdentity(new Identity(USER, Optional.empty())).build(); + private static final String TEST_FUNCTION = "test_function"; + private static final String TEST_FUNCTION_2 = "test_function2"; + + @Test + public void testTableFunctionRegistry() + { + TableFunctionRegistry testFunctionRegistry = new TableFunctionRegistry(); + ConnectorId id = new ConnectorId(CATALOG); + + // Verify registration with multiple table functions. + testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new TestConnectorTableFunction(), new TestConnectorTableFunction2())); + + // Verify that table functions cannot be overridden for the same catalog. + RuntimeException ex = expectThrows(IllegalStateException.class, () -> testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new TestConnectorTableFunction()))); + assertTrue(ex.getMessage().contains("Table functions already registered for catalog: test"), ex.getMessage()); + + // Verify table function resolution. + assertTrue(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).isPresent()); + assertTrue(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION_2)).isPresent()); + assertFalse(testFunctionRegistry.resolve(SESSION, QualifiedName.of("none")).isPresent()); + assertFalse(testFunctionRegistry.resolve(MISMATCH_SESSION, QualifiedName.of("none")).isPresent()); + + // Verify metadata. + TableFunctionMetadata data = testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).get(); + assertEquals(data.getConnectorId(), id); + assertTrue(data.getFunction() instanceof TestConnectorTableFunction); + + // Verify the removal of table functions. + testFunctionRegistry.removeTableFunctions(id); + assertFalse(testFunctionRegistry.resolve(SESSION, QualifiedName.of(TEST_FUNCTION)).isPresent()); + + // Verify that null arguments table functions cannot be added. + ex = expectThrows(NullPointerException.class, () -> testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new NullArgumentsTableFunction()))); + assertTrue(ex.getMessage().contains("arguments is null"), ex.getMessage()); + + // Verify that duplicate arguments table functions cannot be added. + ex = expectThrows(IllegalArgumentException.class, () -> testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new DuplicateArgumentsTableFunction()))); + assertTrue(ex.getMessage().contains("duplicate argument name: a"), ex.getMessage()); + + // Verify that two row semantic table function arguments functions cannot be added. + ex = expectThrows(IllegalArgumentException.class, () -> testFunctionRegistry.addTableFunctions(id, ImmutableList.of(new MultipleRSTableFunction()))); + assertTrue(ex.getMessage().contains("more than one table argument with row semantics"), ex.getMessage()); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkDynamicFilterSourceOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkDynamicFilterSourceOperator.java index 0b3dbf5e75526..2556805f5e5df 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkDynamicFilterSourceOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkDynamicFilterSourceOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.PageBuilder; import com.facebook.presto.spi.plan.PlanNodeId; @@ -20,7 +21,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.tpch.LineItem; import io.airlift.tpch.LineItemGenerator; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Measurement; @@ -45,12 +45,12 @@ import java.util.concurrent.TimeUnit; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverRowCount; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringMaxPerDriverSize; import static com.facebook.presto.SystemSessionProperties.getDynamicFilteringRangeRowLimitPerDriver; import static com.facebook.presto.common.type.BigintType.BIGINT; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java index 139ce5f2b2b37..bc94cbce32811 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndSegmentedAggregationOperators.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockBuilder; @@ -26,7 +27,6 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -50,6 +50,9 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; @@ -59,9 +62,6 @@ import static com.facebook.presto.operator.BenchmarkHashAndSegmentedAggregationOperators.Context.TOTAL_PAGES; import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java index 6a097684675b7..33bf6651d4b95 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashAndStreamingAggregationOperators.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockBuilder; @@ -27,7 +28,6 @@ import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -51,6 +51,9 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -59,9 +62,6 @@ import static com.facebook.presto.operator.BenchmarkHashAndStreamingAggregationOperators.Context.TOTAL_PAGES; import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory; import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java index 030993b5fa56b..c2491ebbfbc3f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkHashBuildAndJoinOperators.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; @@ -25,7 +26,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -51,13 +51,13 @@ import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spiller.PartitioningSpillerFactory.unsupportedPartitioningSpillerFactory; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkUnnestOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkUnnestOperator.java index d240874106aea..8bbca76adc103 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkUnnestOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkUnnestOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; @@ -21,7 +22,6 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -50,10 +50,10 @@ import java.util.concurrent.TimeUnit; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.operator.PageAssertions.createPageWithRandomData; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkWindowOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkWindowOperator.java index 14a8e388b7e12..223a79b99fe7a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkWindowOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/BenchmarkWindowOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockBuilder; @@ -21,7 +22,6 @@ import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -45,6 +45,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -52,7 +53,6 @@ import static com.facebook.presto.operator.BenchmarkWindowOperator.Context.TOTAL_PAGES; import static com.facebook.presto.operator.TestWindowOperator.ROW_NUMBER; import static com.facebook.presto.operator.TestWindowOperator.createFactoryUnbounded; -import static io.airlift.units.DataSize.Unit.GIGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java b/presto-main-base/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java index 296b00c46049e..2892d64ac5619 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/GroupByHashYieldAssertion.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; @@ -23,7 +24,6 @@ import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spiller.SpillSpaceTracker; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import java.util.LinkedList; import java.util.List; @@ -36,14 +36,14 @@ import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.operator.OperatorAssertion.finishOperator; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/OperatorAssertion.java b/presto-main-base/src/test/java/com/facebook/presto/operator/OperatorAssertion.java index 2175292a8bfd0..3b7b9d0a8a3cf 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/OperatorAssertion.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/OperatorAssertion.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; @@ -24,7 +25,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java index 5927f1fc65ee6..220b91ba548d8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForAggregates.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.InputFunction; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.LongVariableConstraint; import com.facebook.presto.spi.function.OperatorDependency; @@ -53,7 +54,6 @@ import com.facebook.presto.spi.function.TypeParameterSpecialization; import com.facebook.presto.spi.function.aggregation.AggregationMetadata; import com.facebook.presto.type.Constraint; -import com.facebook.presto.type.LiteralParameter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java index 80ab4502de2e6..787a2a5f8da27 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForScalars.java @@ -30,13 +30,13 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.FunctionKind; import com.facebook.presto.spi.function.IsNull; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; -import com.facebook.presto.type.LiteralParameter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java index 053635444a321..6ed8a5dacb20b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java @@ -51,7 +51,7 @@ public void testParseFunctionDefinition() new ArrayType(BIGINT).getTypeSignature(), ImmutableList.of(INTEGER.getTypeSignature())); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class, JAVA_BUILTIN_NAMESPACE); assertEquals(functions.size(), 1); SqlInvokedFunction f = functions.get(0); @@ -75,7 +75,7 @@ public void testParseFunctionDefinitionWithTypeParameter() ImmutableList.of(new TypeSignature("T")), false); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class, JAVA_BUILTIN_NAMESPACE); assertEquals(functions.size(), 1); SqlInvokedFunction f = functions.get(0); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriver.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriver.java index f1b1256596292..b7ec5bafd5e73 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriver.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriver.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; @@ -48,7 +49,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriverStats.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriverStats.java index 6750dc4d30170..e1ea49bbb6301 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriverStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDriverStats.java @@ -14,10 +14,10 @@ package com.facebook.presto.operator; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.Lifespan; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDynamicFilterSourceOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDynamicFilterSourceOperator.java index e727c35ae4bbf..73230f934e9f9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestDynamicFilterSourceOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestDynamicFilterSourceOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; @@ -24,7 +25,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -37,6 +37,7 @@ import java.util.stream.IntStream; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBlockOfReals; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; @@ -67,7 +68,6 @@ import static com.google.common.base.Strings.repeat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.lang.Float.floatToRawIntBits; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestExchangeClientConfig.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestExchangeClientConfig.java index d56eafb5f84f3..bf5c5a3867e34 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestExchangeClientConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestExchangeClientConfig.java @@ -13,9 +13,9 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -24,9 +24,9 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctDataSize; +import static com.facebook.airlift.units.DataSize.Unit; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctDataSize; public class TestExchangeClientConfig { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheConfig.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheConfig.java index d93de79164d2c..8b422ea071779 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheConfig.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.CompressionCodec; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.net.URI; @@ -25,8 +25,8 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.DAYS; public class TestFileFragmentResultCacheConfig diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheManager.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheManager.java index 8719124d0d6d4..6367ce3586782 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFileFragmentResultCacheManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.TestingBlockEncodingSerde; import com.facebook.presto.metadata.Split; @@ -23,7 +24,6 @@ import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.schedule.NodeSelectionStrategy; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFilterAndProjectOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFilterAndProjectOperator.java index 94ac9812c1bf8..b1819fa39fba9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestFilterAndProjectOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestFilterAndProjectOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.MetadataManager; @@ -23,7 +24,6 @@ import com.facebook.presto.sql.gen.PageFunctionCompiler; import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -35,6 +35,8 @@ import java.util.function.Supplier; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.function.OperatorType.ADD; @@ -50,8 +52,6 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.sql.relational.Expressions.field; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java index 6265b47e6e9a3..ad2d90829f0b9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperator.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; @@ -41,8 +43,6 @@ import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -59,6 +59,10 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; +import static com.facebook.airlift.units.DataSize.succinctDataSize; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongRepeatBlock; @@ -85,10 +89,6 @@ import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE; import static io.airlift.slice.SizeOf.SIZE_OF_LONG; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; -import static io.airlift.units.DataSize.succinctDataSize; import static java.lang.String.format; import static java.util.Collections.emptyIterator; import static java.util.concurrent.Executors.newCachedThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java index 3bc91c526211d..60cb1d17734f2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashAggregationOperatorInSegmentedAggregationMode.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.LongArrayBlock; @@ -26,7 +27,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -37,14 +37,14 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.operator.OperatorAssertion.assertPagesEqualIgnoreOrder; import static com.facebook.presto.operator.OperatorAssertion.toPages; import static com.facebook.presto.operator.aggregation.GenericAccumulatorFactory.generateAccumulatorFactory; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java index 7da0680696695..01ce32dc6c459 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashJoinOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.Session; @@ -47,7 +48,6 @@ import com.google.common.collect.Iterators; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -73,6 +73,8 @@ import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertEqualsIgnoreOrder; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY_PER_NODE; @@ -93,8 +95,6 @@ import static com.google.common.collect.Iterators.unmodifiableIterator; import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static com.google.common.util.concurrent.Futures.immediateFuture; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.nCopies; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java index a2a754ea3ae1b..39618409990bf 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestHashSemiJoinOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; @@ -24,7 +25,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -38,6 +38,7 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; import static com.facebook.airlift.testing.Assertions.assertGreaterThanOrEqual; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -49,7 +50,6 @@ import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.google.common.collect.Iterables.concat; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java index 97e0b10a6ffcc..07e72e31c416b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorAssertion.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.testing.assertions.Assert; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorStats.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorStats.java index b2a4f01dc5a55..20f09bc357b4a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOperatorStats.java @@ -14,13 +14,13 @@ package com.facebook.presto.operator; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeMetric; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.operator.repartition.PartitionedOutputInfo; import com.facebook.presto.spi.plan.PlanNodeId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Arrays; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOrderByOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOrderByOperator.java index 74cd44555d1c6..de055caf8c547 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestOrderByOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestOrderByOperator.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.common.Page; import com.facebook.presto.operator.OrderByOperator.OrderByOperatorFactory; @@ -21,8 +23,6 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -35,6 +35,7 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; @@ -47,7 +48,6 @@ import static com.facebook.presto.operator.OperatorAssertion.toPages; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestPartitionedOutputOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestPartitionedOutputOperator.java index 8110b4139cb61..089b204f1b55d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestPartitionedOutputOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestPartitionedOutputOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; @@ -30,7 +31,6 @@ import com.facebook.presto.sql.planner.OutputPartitioning; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.List; @@ -42,6 +42,9 @@ import java.util.function.Function; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createLongDictionaryBlock; import static com.facebook.presto.block.BlockAssertions.createLongSequenceBlock; @@ -52,9 +55,6 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestScanFilterAndProjectOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestScanFilterAndProjectOperator.java index cbce577e9e919..56221b726fed9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestScanFilterAndProjectOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestScanFilterAndProjectOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.SequencePageBuilder; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.common.Page; @@ -48,7 +49,6 @@ import com.facebook.presto.testing.TestingTransactionHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Iterator; @@ -59,6 +59,8 @@ import java.util.function.Supplier; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.toValues; @@ -77,8 +79,6 @@ import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.google.common.base.Preconditions.checkState; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java index 780ba4f6d9d15..828862139e086 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTableWriterOperator.java @@ -21,10 +21,8 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskMetadataContext; import com.facebook.presto.execution.scheduler.ExecutionWriterTarget.CreateHandle; import com.facebook.presto.memory.context.MemoryTrackingContext; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.OutputTableHandle; import com.facebook.presto.operator.AggregationOperator.AggregationOperatorFactory; @@ -214,7 +212,6 @@ public void testStatisticsAggregation() .build(); TaskContext taskContext = createTaskContext(executor, scheduledExecutor, session); DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); - TaskMetadataContext taskMetadataContext = taskContext.getTaskMetadataContext(); FunctionAndTypeManager functionAndTypeManager = createTestMetadataManager().getFunctionAndTypeManager(); JavaAggregationFunctionImplementation longMaxFunction = functionAndTypeManager.getJavaAggregateFunctionImplementation( functionAndTypeManager.lookupFunction("max", fromTypes(BIGINT))); @@ -228,7 +225,6 @@ public void testStatisticsAggregation() true), outputTypes, session, - taskMetadataContext, driverContext); operator.addInput(rowPagesBuilder(BIGINT).row(42).build().get(0)); @@ -309,8 +305,7 @@ private Operator createTableWriterOperator(PageSinkManager pageSinkManager, Oper { TaskContext taskContext = createTaskContext(executor, scheduledExecutor, session); DriverContext driverContext = taskContext.addPipelineContext(0, true, true, false).addDriverContext(); - TaskMetadataContext taskMetadataContext = taskContext.getTaskMetadataContext(); - return createTableWriterOperator(pageSinkManager, statisticsAggregation, outputTypes, session, taskMetadataContext, driverContext); + return createTableWriterOperator(pageSinkManager, statisticsAggregation, outputTypes, session, driverContext); } private Operator createTableWriterOperator( @@ -318,7 +313,6 @@ private Operator createTableWriterOperator( OperatorFactory statisticsAggregation, List outputTypes, Session session, - TaskMetadataContext taskMetadataContext, DriverContext driverContext) { List notNullColumnNames = new ArrayList<>(1); @@ -327,8 +321,6 @@ private Operator createTableWriterOperator( 0, new PlanNodeId("test"), pageSinkManager, - new ConnectorMetadataUpdaterManager(), - taskMetadataContext, new CreateHandle(new OutputTableHandle( CONNECTOR_ID, new ConnectorTransactionHandle() {}, diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskContextRuntimeStats.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskContextRuntimeStats.java new file mode 100644 index 0000000000000..14bb790060592 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskContextRuntimeStats.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator; + +import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.presto.Session; +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.execution.StageExecutionId; +import com.facebook.presto.execution.StageId; +import com.facebook.presto.execution.TaskId; +import com.facebook.presto.execution.TaskStateMachine; +import com.facebook.presto.memory.QueryContext; +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.memory.MemoryPoolId; +import com.facebook.presto.spiller.SpillSpaceTracker; +import com.google.common.util.concurrent.MoreExecutors; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.ScheduledExecutorService; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.succinctBytes; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.Executors.newScheduledThreadPool; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestTaskContextRuntimeStats +{ + private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed("test-%s")); + + @AfterClass(alwaysRun = true) + public void tearDown() + { + scheduledExecutor.shutdownNow(); + } + + @Test + public void testTaskStatsIncludesCreateTimeAndEndTime() + { + Session session = testSessionBuilder().build(); + QueryContext queryContext = createQueryContext(session); + + TaskStateMachine taskStateMachine = new TaskStateMachine( + new TaskId(new StageExecutionId(new StageId(new QueryId("test_query"), 0), 0), 0, 0), + MoreExecutors.directExecutor()); + + TaskContext taskContext = queryContext.addTaskContext( + taskStateMachine, + session, + Optional.empty(), + false, + false, + false, + false, + false); + + long createTimeBeforeStats = taskStateMachine.getCreatedTimeInMillis(); + + // Get task stats + TaskStats taskStats = taskContext.getTaskStats(); + + // Verify RuntimeStats contains createTime + RuntimeStats runtimeStats = taskStats.getRuntimeStats(); + assertNotNull(runtimeStats, "RuntimeStats should not be null"); + assertTrue(runtimeStats.getMetrics().containsKey("createTime"), "RuntimeStats should contain createTime metric"); + + // Verify createTime value is reasonable + long createTimeFromStats = (long) runtimeStats.getMetric("createTime").getSum(); + assertEquals(createTimeFromStats, createTimeBeforeStats, "createTime should match task creation time"); + + // Mark task as finished to trigger endTime + taskStateMachine.finished(); + TaskStats finalTaskStats = taskContext.getTaskStats(); + RuntimeStats finalRuntimeStats = finalTaskStats.getRuntimeStats(); + + // Verify endTime is now present + assertTrue(finalRuntimeStats.getMetrics().containsKey("endTime"), "RuntimeStats should contain endTime metric after task finishes"); + long endTimeFromStats = (long) finalRuntimeStats.getMetric("endTime").getSum(); + assertTrue(endTimeFromStats > 0, "endTime should be greater than 0"); + assertTrue(endTimeFromStats >= createTimeFromStats, "endTime should be >= createTime"); + } + + @Test + public void testTaskStatsRuntimeStatsNotNullBeforeTaskFinish() + { + Session session = testSessionBuilder().build(); + QueryContext queryContext = createQueryContext(session); + + TaskStateMachine taskStateMachine = new TaskStateMachine( + new TaskId(new StageExecutionId(new StageId(new QueryId("test_query_2"), 0), 0), 0, 0), + MoreExecutors.directExecutor()); + + TaskContext taskContext = queryContext.addTaskContext( + taskStateMachine, + session, + Optional.empty(), + false, + false, + false, + false, + false); + + // Get stats before task finishes + TaskStats taskStats = taskContext.getTaskStats(); + RuntimeStats runtimeStats = taskStats.getRuntimeStats(); + + // Verify RuntimeStats is not null and contains createTime even before task finishes + assertNotNull(runtimeStats, "RuntimeStats should not be null"); + assertTrue(runtimeStats.getMetrics().containsKey("createTime"), "RuntimeStats should contain createTime even before task finishes"); + + // endTime should not be present yet (or be 0) + if (runtimeStats.getMetrics().containsKey("endTime")) { + long endTime = (long) runtimeStats.getMetric("endTime").getSum(); + assertEquals(endTime, 0L, "endTime should be 0 before task finishes"); + } + } + + private QueryContext createQueryContext(Session session) + { + return new QueryContext( + session.getQueryId(), + succinctBytes(1 * 1024 * 1024), + succinctBytes(1 * 1024 * 1024 * 1024), + succinctBytes(1 * 1024 * 1024 * 1024), + succinctBytes(1 * 1024 * 1024 * 1024), + new TestingMemoryPool(succinctBytes(1 * 1024 * 1024 * 1024)), + new TestingGcMonitor(), + MoreExecutors.directExecutor(), + scheduledExecutor, + succinctBytes(1 * 1024 * 1024 * 1024), + new SpillSpaceTracker(succinctBytes(1 * 1024 * 1024 * 1024)), + listJsonCodec(TaskMemoryReservationSummary.class)); + } + + private static class TestingMemoryPool + extends com.facebook.presto.memory.MemoryPool + { + public TestingMemoryPool(com.facebook.airlift.units.DataSize maxMemory) + { + super(new MemoryPoolId("test"), maxMemory); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskStats.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskStats.java index 775ca3bd9c91c..dfa4328ce3b58 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTaskStats.java @@ -44,6 +44,14 @@ public class TestTaskStats 29L, 24, 10, + 6, + 7, + 8, + 10, + 6, + 7, + 8, + 10, 11.0, 43.0, diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNOperator.java index eff0c91c22515..0eeca35157918 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNOperator.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.common.Page; import com.facebook.presto.operator.TopNOperator.TopNOperatorFactory; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.testing.MaterializedResult; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; @@ -29,6 +29,7 @@ import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; @@ -39,7 +40,6 @@ import static com.facebook.presto.operator.OperatorAssertion.assertOperatorEquals; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertFalse; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java index 73c0e593a5dc8..ee3ce16a66a0e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestTopNRowNumberOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.RowPagesBuilder; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; @@ -24,7 +25,6 @@ import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestWindowOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestWindowOperator.java index 1133a0e21a144..b5f9ebbe5c8da 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestWindowOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestWindowOperator.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; import com.facebook.presto.ExceededMemoryLimitException; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.SortOrder; @@ -33,8 +35,6 @@ import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; import com.google.common.primitives.Ints; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.DataProvider; @@ -47,6 +47,7 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertGreaterThan; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.RowPagesBuilder.rowPagesBuilder; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -63,7 +64,6 @@ import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestNumericHistogram.java b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestNumericHistogram.java index 781870da0e81f..c3cbb18c0737e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestNumericHistogram.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/TestNumericHistogram.java @@ -133,4 +133,21 @@ public void testMergeDifferent() histogram1.mergeWith(histogram2); assertEquals(histogram1.getBuckets(), expected.getBuckets()); } + + @Test + public void testNaN() + { + NumericHistogram histogram = new NumericHistogram(2, 100); + + histogram.add(Double.NaN, 1); + histogram.add(2, 1); + histogram.add(Double.NaN, 1); + + Map expected = ImmutableMap.builder() + .put(Double.NaN, 2.0) + .put(2.0, 1.0) + .build(); + + assertEquals(histogram.getBuckets(), expected); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongAggregation.java b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongAggregation.java index 5f55aa0418802..b62ec0f423657 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongAggregation.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongAggregation.java @@ -24,6 +24,8 @@ import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; +import com.facebook.presto.util.RetryAnalyzer; +import com.facebook.presto.util.RetryCount; import org.testng.annotations.Test; import java.util.Arrays; @@ -422,7 +424,8 @@ public void testNoisySumGaussianLongClippingSomeNoiseScale() expected); } - @Test + @Test(retryAnalyzer = RetryAnalyzer.class) + @RetryCount(100) public void testNoisySumGaussianLongClippingSomeNoiseScaleWithinSomeStd() { JavaAggregationFunctionImplementation noisySumGaussian = getFunction(BIGINT, DOUBLE, DOUBLE, DOUBLE); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongDecimalAggregation.java b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongDecimalAggregation.java index 6e631753b08fa..0a9fa210f2ef6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongDecimalAggregation.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/aggregation/noisyaggregation/TestNoisySumGaussianLongDecimalAggregation.java @@ -25,6 +25,8 @@ import com.facebook.presto.testing.LocalQueryRunner; import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; +import com.facebook.presto.util.RetryAnalyzer; +import com.facebook.presto.util.RetryCount; import org.testng.annotations.Test; import java.util.Arrays; @@ -281,7 +283,8 @@ public void testNoisySumGaussianLongDecimalClippingSomeNoiseScale() expected); } - @Test + @Test(retryAnalyzer = RetryAnalyzer.class) + @RetryCount(100) public void testNoisySumGaussianLongDecimalClippingSomeNoiseScaleWithinSomeStd() { JavaAggregationFunctionImplementation noisySumGaussian = getFunction(LONG_DECIMAL_TYPE, DOUBLE, DOUBLE, DOUBLE); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java b/presto-main-base/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java index b363940ed0b33..465a374eca960 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.exchange; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.SequencePageBuilder; import com.facebook.presto.Session; import com.facebook.presto.common.Page; @@ -38,7 +39,6 @@ import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; @@ -51,6 +51,7 @@ import java.util.stream.Stream; import static com.facebook.airlift.testing.Assertions.assertContains; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.operator.PipelineExecutionStrategy.GROUPED_EXECUTION; import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; @@ -64,7 +65,6 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.airlift.units.DataSize.Unit.BYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/project/TestPageProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/operator/project/TestPageProcessor.java index b4bf55690bcd8..f90e7d9d92671 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/project/TestPageProcessor.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/project/TestPageProcessor.java @@ -14,6 +14,7 @@ package com.facebook.presto.operator.project; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.Duration; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; @@ -34,7 +35,6 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.Duration; import org.openjdk.jol.info.ClassLayout; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/BenchmarkPartitionedOutputOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/BenchmarkPartitionedOutputOperator.java index f7a268f4a8170..62780860fc640 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/BenchmarkPartitionedOutputOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/BenchmarkPartitionedOutputOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.repartition; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; import com.facebook.presto.Session; import com.facebook.presto.common.Page; @@ -40,7 +41,6 @@ import com.facebook.presto.sql.planner.OutputPartitioning; import com.facebook.presto.testing.TestingTaskContext; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; @@ -70,6 +70,9 @@ import java.util.stream.IntStream; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.block.BlockAssertions.createMapType; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; @@ -90,9 +93,6 @@ import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SystemPartitionFunction.HASH; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Collections.nCopies; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/TestOptimizedPartitionedOutputOperator.java b/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/TestOptimizedPartitionedOutputOperator.java index d2e576d71bdcf..4ed538593ba58 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/TestOptimizedPartitionedOutputOperator.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/repartition/TestOptimizedPartitionedOutputOperator.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.repartition; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; import com.facebook.presto.Session; import com.facebook.presto.common.Page; @@ -48,7 +49,6 @@ import com.google.common.collect.Maps; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.ArrayList; @@ -67,6 +67,10 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertBetweenInclusive; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.block.BlockAssertions.Encoding.DICTIONARY; import static com.facebook.presto.block.BlockAssertions.Encoding.RUN_LENGTH; import static com.facebook.presto.block.BlockAssertions.createLongDictionaryBlock; @@ -95,10 +99,6 @@ import static com.facebook.presto.operator.PageAssertions.updateBlockTypesWithHashBlockAndNullBlock; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java index d31ae0b09eb02..e9b3e8641c885 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/AbstractTestFunctions.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.facebook.presto.sql.analyzer.SemanticErrorCode; +import com.facebook.presto.tests.operator.scalar.TestFunctions; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import org.intellij.lang.annotations.Language; @@ -52,6 +53,7 @@ import static org.testng.Assert.fail; public abstract class AbstractTestFunctions + implements TestFunctions { private static final double DELTA = 1e-5; @@ -59,6 +61,7 @@ public abstract class AbstractTestFunctions private final FeaturesConfig featuresConfig; private final FunctionsConfig functionsConfig; protected FunctionAssertions functionAssertions; + private final boolean loadInlinedSqlInvokedFunctionsPlugin; protected AbstractTestFunctions() { @@ -81,18 +84,23 @@ protected AbstractTestFunctions(FunctionsConfig functionsConfig) } protected AbstractTestFunctions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig) + { + this(session, featuresConfig, functionsConfig, true); + } + protected AbstractTestFunctions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, boolean loadInlinedSqlInvokedFunctionsPlugin) { this.session = requireNonNull(session, "session is null"); this.featuresConfig = requireNonNull(featuresConfig, "featuresConfig is null"); this.functionsConfig = requireNonNull(functionsConfig, "config is null") .setLegacyLogFunction(true) .setUseNewNanDefinition(true); + this.loadInlinedSqlInvokedFunctionsPlugin = loadInlinedSqlInvokedFunctionsPlugin; } @BeforeClass public final void initTestFunctions() { - functionAssertions = new FunctionAssertions(session, featuresConfig, functionsConfig, false); + functionAssertions = new FunctionAssertions(session, featuresConfig, functionsConfig, false, loadInlinedSqlInvokedFunctionsPlugin); } @AfterClass(alwaysRun = true) @@ -107,7 +115,8 @@ public FunctionAndTypeManager getFunctionAndTypeManager() return functionAssertions.getFunctionAndTypeManager(); } - protected void assertFunction(String projection, Type expectedType, Object expected) + @Override + public void assertFunction(String projection, Type expectedType, Object expected) { functionAssertions.assertFunction(projection, expectedType, expected); } @@ -209,7 +218,8 @@ public void assertCachedInstanceHasBoundedRetainedSize(String projection) functionAssertions.assertCachedInstanceHasBoundedRetainedSize(projection); } - protected void assertNotSupported(String projection, String message) + @Override + public void assertNotSupported(String projection, String message) { try { functionAssertions.executeProjectionWithFullEngine(projection); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/BenchmarkDateTimeFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/BenchmarkDateTimeFunctions.java new file mode 100644 index 0000000000000..1b84f57715c23 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/BenchmarkDateTimeFunctions.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.airlift.units.Duration; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(value = 1, jvmArgs = {"-Xms2G", "-Xmx2G"}) +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 10, time = 1) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkDateTimeFunctions +{ + @Param({"ns", "us", "ms", "s", "m", "h", "d"}) + private String unit = "ns"; + + private Random random = new Random(); + + @Setup + public void setup() + { + random = new Random(42); // Fixed seed for reproducibility + } + + @Benchmark + public void testBaseline(Blackhole bh) + { + int v1 = random.nextInt(10000); + int v2 = random.nextInt(10000); + Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit); + bh.consume(value.toStringUtf8()); + } + + @Benchmark + public void testUseBigDecimal(Blackhole bh) + { + int v1 = random.nextInt(10000); + int v2 = random.nextInt(10000); + Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit); + bh.consume(DateTimeFunctions.parseDuration(value)); + } + + @Benchmark + public void testUseDouble(Blackhole bh) + { + int v1 = random.nextInt(10000); + int v2 = random.nextInt(10000); + Slice value = Slices.utf8Slice(v1 + "." + v2 + " " + unit); + bh.consume(Duration.valueOf(value.toStringUtf8()).toMillis()); + } + + public static void main(String[] args) + throws RunnerException + { + Options opt = new OptionsBuilder() + .include(BenchmarkDateTimeFunctions.class.getSimpleName()) + .build(); + + new Runner(opt).run(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java index c12cf7e62fa88..d46bb9a0d0bd2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/FunctionAssertions.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.common.InvalidTypeDefinitionException; import com.facebook.presto.common.Page; @@ -40,6 +41,7 @@ import com.facebook.presto.operator.project.CursorProcessor; import com.facebook.presto.operator.project.PageProcessor; import com.facebook.presto.operator.project.PageProjectionWithOutputs; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorPageSource; @@ -91,7 +93,6 @@ import com.google.common.util.concurrent.UncheckedExecutionException; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.DataSize; import org.intellij.lang.annotations.Language; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -115,6 +116,7 @@ import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.block.BlockAssertions.createBooleansBlock; import static com.facebook.presto.block.BlockAssertions.createDoublesBlock; @@ -149,7 +151,6 @@ import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; import static com.facebook.presto.util.Failures.toFailure; import static io.airlift.slice.SizeOf.sizeOf; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.String.format; import static java.util.Collections.emptyMap; import static java.util.Objects.requireNonNull; @@ -221,20 +222,25 @@ public FunctionAssertions() public FunctionAssertions(Session session) { - this(session, new FeaturesConfig(), new FunctionsConfig(), false); + this(session, new FeaturesConfig(), new FunctionsConfig(), false, true); } public FunctionAssertions(Session session, FeaturesConfig featuresConfig) { - this(session, featuresConfig, new FunctionsConfig(), false); + this(session, featuresConfig, new FunctionsConfig(), false, true); } public FunctionAssertions(Session session, FunctionsConfig functionsConfig) { - this(session, new FeaturesConfig(), functionsConfig, false); + this(session, new FeaturesConfig(), functionsConfig, false, true); } public FunctionAssertions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, boolean refreshSession) + { + this(session, featuresConfig, functionsConfig, refreshSession, true); + } + + public FunctionAssertions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, boolean refreshSession, boolean loadInlinedSqlInvokedFunctionsPlugin) { requireNonNull(session, "session is null"); runner = new LocalQueryRunner(session, featuresConfig, functionsConfig); @@ -244,6 +250,9 @@ public FunctionAssertions(Session session, FeaturesConfig featuresConfig, Functi else { this.session = session; } + if (loadInlinedSqlInvokedFunctionsPlugin) { + runner.installPlugin(new SqlInvokedFunctionsPlugin()); + } metadata = runner.getMetadata(); compiler = runner.getExpressionCompiler(); } @@ -269,6 +278,12 @@ public FunctionAssertions addScalarFunctions(Class clazz) return this; } + public FunctionAssertions addConnectorFunctions(List functionInfos, String namespace) + { + metadata.registerConnectorFunctions(namespace, functionInfos); + return this; + } + public void assertFunction(String projection, Type expectedType, Object expected) { if (expected instanceof Slice) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java index 96c2db4dabc30..61ab2058d9011 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayExceptFunction.java @@ -14,35 +14,21 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.RowType; +import com.facebook.presto.tests.operator.scalar.AbstractTestArrayExcept; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; -import static java.util.Arrays.asList; import static java.util.Collections.singletonList; public class TestArrayExceptFunction extends AbstractTestFunctions + implements AbstractTestArrayExcept { @Test - public void testBasic() - { - assertFunction("array_except(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(1L, 3L)); - assertFunction("array_except(ARRAY[CAST('x' as VARCHAR), 'y', 'z'], ARRAY['x'])", new ArrayType(VARCHAR), ImmutableList.of("y", "z")); - assertFunction("array_except(ARRAY[true, false, null], ARRAY[true])", new ArrayType(BOOLEAN), asList(false, null)); - assertFunction("array_except(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(1.1, 3.9)); - } - - @Test - public void testEmpty() + void testEmpty() { assertFunction("array_except(ARRAY[], ARRAY[])", new ArrayType(UNKNOWN), ImmutableList.of()); assertFunction("array_except(ARRAY[], ARRAY[1, 3])", new ArrayType(INTEGER), ImmutableList.of()); @@ -59,40 +45,4 @@ public void testNull() assertFunction("array_except(ARRAY[], ARRAY[NULL])", new ArrayType(UNKNOWN), ImmutableList.of()); assertFunction("array_except(ARRAY[NULL], ARRAY[])", new ArrayType(UNKNOWN), singletonList(null)); } - - @Test - public void testDuplicates() - { - assertFunction("array_except(ARRAY[1, 5, 3, 5, 1], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); - assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 5, 3, 3, 3, 1], ARRAY[3, 5])", new ArrayType(BIGINT), ImmutableList.of(1L)); - assertFunction("array_except(ARRAY[CAST('x' as VARCHAR), 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); - assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); - } - - @Test - public void testIndeterminateRows() - { - // test unsupported - assertFunction( - "array_except(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of()); - assertFunction( - "array_except(ARRAY[(NULL, 'abc'), (123, null), (123, 'abc')], ARRAY[(456, 'def'),(NULL, 'abc')])", - new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), - ImmutableList.of(asList(123, null), asList(123, "abc"))); - } - - @Test - public void testIndeterminateArrays() - { - assertFunction( - "array_except(ARRAY[ARRAY[123, 456], ARRAY[123, NULL]], ARRAY[ARRAY[123, 456], ARRAY[123, NULL]])", - new ArrayType(new ArrayType(INTEGER)), - ImmutableList.of()); - assertFunction( - "array_except(ARRAY[ARRAY[NULL, 456], ARRAY[123, null], ARRAY[123, 456]], ARRAY[ARRAY[456, 456],ARRAY[NULL, 456]])", - new ArrayType(new ArrayType(INTEGER)), - ImmutableList.of(asList(123, null), asList(123, 456))); - } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java index 4192b410c0639..cc773a09a351b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArrayNormalizeFunction.java @@ -22,6 +22,7 @@ import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; public class TestArrayNormalizeFunction @@ -93,7 +94,7 @@ public void testUnsupportedType() "Unsupported array element type for array_normalize function: integer"); assertInvalidFunction( "array_normalize(ARRAY['a', 'b', 'c'], 'd')", - FUNCTION_IMPLEMENTATION_MISSING, + NOT_SUPPORTED, "Unsupported type parameters.*"); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortByKeyFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortByKeyFunctions.java new file mode 100644 index 0000000000000..dee4f22e42a96 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortByKeyFunctions.java @@ -0,0 +1,309 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.UnknownType.UNKNOWN; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public class TestArraySortByKeyFunctions + extends AbstractTestFunctions +{ + @Test + public void testBasic() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("pear", "kiwi", "apple", "banana")); + + assertFunction( + "array_sort(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> substr(x, length(x), 1))", + new ArrayType(createVarcharType(6)), + asList("banana", "apple", "kiwi", "pear")); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("banana", "apple", "pear", "kiwi")); + + assertFunction( + "array_sort_desc(ARRAY['pear', 'apple', 'banana', 'kiwi'], x -> substr(x, length(x), 1))", + new ArrayType(createVarcharType(6)), + asList("pear", "kiwi", "apple", "banana")); + } + + @Test + public void testNulls() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY['apple', NULL, 'banana', NULL], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("apple", "banana", null, null)); + + assertFunction( + "array_sort(ARRAY['apple', 'banana', 'pear'], x -> IF(x = 'banana', NULL, length(x)))", + new ArrayType(createVarcharType(6)), + asList("pear", "apple", "banana")); + + assertFunction( + "array_sort(ARRAY['apple', NULL, 'banana', 'pear', NULL], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("pear", "apple", "banana", null, null)); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY['apple', NULL, 'banana', NULL], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("banana", "apple", null, null)); + + assertFunction( + "array_sort_desc(ARRAY['apple', 'banana', 'pear'], x -> IF(x = 'banana', NULL, length(x)))", + new ArrayType(createVarcharType(6)), + asList("apple", "pear", "banana")); + + assertFunction( + "array_sort_desc(ARRAY['apple', NULL, 'banana', 'pear', NULL], x -> length(x))", + new ArrayType(createVarcharType(6)), + asList("banana", "apple", "pear", null, null)); + } + + @Test + public void testSpecialDoubleValues() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[CAST(0.0 AS DOUBLE), CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE)], x -> x)", + new ArrayType(DOUBLE), + asList(Double.NEGATIVE_INFINITY, 0.0, Double.POSITIVE_INFINITY, Double.NaN)); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[CAST(0.0 AS DOUBLE), CAST('NaN' AS DOUBLE), CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE)], x -> x)", + new ArrayType(DOUBLE), + asList(Double.NaN, Double.POSITIVE_INFINITY, 0.0, Double.NEGATIVE_INFINITY)); + } + + @Test + public void testNumericKeys() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[5, 20, 3, 9, 100], x -> x)", + new ArrayType(INTEGER), + asList(3, 5, 9, 20, 100)); + + assertFunction( + "array_sort(ARRAY[CAST(5000000000 AS BIGINT), CAST(20000000000 AS BIGINT), CAST(3000000000 AS BIGINT), CAST(9000000000 AS BIGINT), CAST(100000000000 AS BIGINT)], x -> x)", + new ArrayType(BIGINT), + asList(3000000000L, 5000000000L, 9000000000L, 20000000000L, 100000000000L)); + + assertFunction( + "array_sort(ARRAY[CAST(5.5 AS DOUBLE), CAST(20.1 AS DOUBLE), CAST(3.9 AS DOUBLE), CAST(9.0 AS DOUBLE), CAST(100.0 AS DOUBLE)], x -> x)", + new ArrayType(DOUBLE), + asList(3.9, 5.5, 9.0, 20.1, 100.0)); + + assertFunction( + "array_sort(ARRAY[5, 20, 3, 9, 100], x -> x % 10)", + new ArrayType(INTEGER), + asList(20, 100, 3, 5, 9)); + + assertFunction( + "array_sort(ARRAY[CAST(5000000000 AS BIGINT), CAST(20000000000 AS BIGINT), CAST(3000000000 AS BIGINT), CAST(9000000000 AS BIGINT), CAST(100000000000 AS BIGINT)], x -> x % CAST(10000000000 AS BIGINT))", + new ArrayType(BIGINT), + asList(20000000000L, 100000000000L, 3000000000L, 5000000000L, 9000000000L)); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[5, 20, 3, 9, 100], x -> x)", + new ArrayType(INTEGER), + asList(100, 20, 9, 5, 3)); + + assertFunction( + "array_sort_desc(ARRAY[CAST(5000000000 AS BIGINT), CAST(20000000000 AS BIGINT), CAST(3000000000 AS BIGINT), CAST(9000000000 AS BIGINT), CAST(100000000000 AS BIGINT)], x -> x)", + new ArrayType(BIGINT), + asList(100000000000L, 20000000000L, 9000000000L, 5000000000L, 3000000000L)); + + assertFunction( + "array_sort_desc(ARRAY[CAST(5.5 AS DOUBLE), CAST(20.1 AS DOUBLE), CAST(3.9 AS DOUBLE), CAST(9.0 AS DOUBLE), CAST(100.0 AS DOUBLE)], x -> x)", + new ArrayType(DOUBLE), + asList(100.0, 20.1, 9.0, 5.5, 3.9)); + + assertFunction( + "array_sort_desc(ARRAY[5, 20, 3, 9, 100], x -> x % 10)", + new ArrayType(INTEGER), + asList(9, 5, 3, 20, 100)); + + assertFunction( + "array_sort_desc(ARRAY[CAST(5000000000 AS BIGINT), CAST(20000000000 AS BIGINT), CAST(3000000000 AS BIGINT), CAST(9000000000 AS BIGINT), CAST(100000000000 AS BIGINT)], x -> x % CAST(10000000000 AS BIGINT))", + new ArrayType(BIGINT), + asList(9000000000L, 5000000000L, 3000000000L, 20000000000L, 100000000000L)); + } + + @Test + public void testBooleanKeys() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[true, false, true, false], x -> x)", + new ArrayType(BOOLEAN), + asList(false, false, true, true)); + + assertFunction( + "array_sort(ARRAY[true, false, true, false], x -> NOT x)", + new ArrayType(BOOLEAN), + asList(true, true, false, false)); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[true, false, true, false], x -> x)", + new ArrayType(BOOLEAN), + asList(true, true, false, false)); + + assertFunction( + "array_sort_desc(ARRAY[true, false, true, false], x -> NOT x)", + new ArrayType(BOOLEAN), + asList(false, false, true, true)); + } + + @Test + public void testComplexTypes() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5], ARRAY[6, 7, 8, 9]], x -> cardinality(x))", + new ArrayType(new ArrayType(INTEGER)), + asList(asList(4, 5), asList(1, 2, 3), asList(6, 7, 8, 9))); + + assertFunction( + "array_sort(ARRAY[ROW('a', 3), ROW('b', 1), ROW('c', 2)], x -> x[2])", + new ArrayType(RowType.anonymous(ImmutableList.of(createVarcharType(1), INTEGER))), + asList(asList("b", 1), asList("c", 2), asList("a", 3))); + + assertFunction( + "array_sort(ARRAY[ROW('a', CAST(3000000000 AS BIGINT)), ROW('b', CAST(1000000000 AS BIGINT)), ROW('c', CAST(2000000000 AS BIGINT))], x -> x[2])", + new ArrayType(RowType.anonymous(ImmutableList.of(createVarcharType(1), BIGINT))), + asList(asList("b", 1000000000L), asList("c", 2000000000L), asList("a", 3000000000L))); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5], ARRAY[6, 7, 8, 9]], x -> cardinality(x))", + new ArrayType(new ArrayType(INTEGER)), + asList(asList(6, 7, 8, 9), asList(1, 2, 3), asList(4, 5))); + + assertFunction( + "array_sort_desc(ARRAY[ROW('a', 3), ROW('b', 1), ROW('c', 2)], x -> x[2])", + new ArrayType(RowType.anonymous(ImmutableList.of(createVarcharType(1), INTEGER))), + asList(asList("a", 3), asList("c", 2), asList("b", 1))); + + assertFunction( + "array_sort_desc(ARRAY[ROW('a', CAST(3000000000 AS BIGINT)), ROW('b', CAST(1000000000 AS BIGINT)), ROW('c', CAST(2000000000 AS BIGINT))], x -> x[2])", + new ArrayType(RowType.anonymous(ImmutableList.of(createVarcharType(1), BIGINT))), + asList(asList("a", 3000000000L), asList("c", 2000000000L), asList("b", 1000000000L))); + } + + @Test + public void testEdgeCases() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[], x -> x)", + new ArrayType(UNKNOWN), + ImmutableList.of()); + + assertFunction( + "array_sort(ARRAY[5], x -> x)", + new ArrayType(INTEGER), + asList(5)); + + assertFunction( + "array_sort(ARRAY[NULL, NULL, NULL], x -> x)", + new ArrayType(UNKNOWN), + asList(null, null, null)); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[], x -> x)", + new ArrayType(UNKNOWN), + ImmutableList.of()); + + assertFunction( + "array_sort_desc(ARRAY[5], x -> x)", + new ArrayType(INTEGER), + asList(5)); + + assertFunction( + "array_sort_desc(ARRAY[NULL, NULL, NULL], x -> x)", + new ArrayType(UNKNOWN), + asList(null, null, null)); + } + + @Test + public void testTypeCoercion() + { + // Test array_sort + assertFunction( + "array_sort(ARRAY[5, 20, 3, 9, 100], x -> x + CAST(0.5 AS DOUBLE))", + new ArrayType(INTEGER), + asList(3, 5, 9, 20, 100)); + + assertFunction( + "array_sort(ARRAY[5, 20, 3, 9, 100], x -> x * CAST(1000000000 AS BIGINT))", + new ArrayType(INTEGER), + asList(3, 5, 9, 20, 100)); + + assertFunction( + "array_sort(ARRAY['5', '20', '3', '9', '100'], x -> cast(x as integer))", + new ArrayType(createVarcharType(3)), + asList("3", "5", "9", "20", "100")); + + assertFunction( + "array_sort(ARRAY['5000000000', '20000000000', '3000000000', '9000000000', '100000000000'], x -> cast(x as bigint))", + new ArrayType(createVarcharType(12)), + asList("3000000000", "5000000000", "9000000000", "20000000000", "100000000000")); + + // Test array_sort_desc + assertFunction( + "array_sort_desc(ARRAY[5, 20, 3, 9, 100], x -> x + CAST(0.5 AS DOUBLE))", + new ArrayType(INTEGER), + asList(100, 20, 9, 5, 3)); + + assertFunction( + "array_sort_desc(ARRAY[5, 20, 3, 9, 100], x -> x * CAST(1000000000 AS BIGINT))", + new ArrayType(INTEGER), + asList(100, 20, 9, 5, 3)); + + assertFunction( + "array_sort_desc(ARRAY['5', '20', '3', '9', '100'], x -> cast(x as integer))", + new ArrayType(createVarcharType(3)), + asList("100", "20", "9", "5", "3")); + + assertFunction( + "array_sort_desc(ARRAY['5000000000', '20000000000', '3000000000', '9000000000', '100000000000'], x -> cast(x as bigint))", + new ArrayType(createVarcharType(12)), + asList("100000000000", "20000000000", "9000000000", "5000000000", "3000000000")); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortFunction.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortFunction.java index a0df75f03b29d..cdb888a4d8473 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortFunction.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestArraySortFunction.java @@ -13,30 +13,10 @@ */ package com.facebook.presto.operator.scalar; -import com.facebook.presto.common.type.ArrayType; -import com.google.common.collect.ImmutableList; -import org.testng.annotations.Test; - -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.IntegerType.INTEGER; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; -import static java.util.Arrays.asList; +import com.facebook.presto.tests.operator.scalar.AbstractTestArraySort; public class TestArraySortFunction extends AbstractTestFunctions + implements AbstractTestArraySort { - @Test - public void testArraySort() - { - assertFunction("array_sort(ARRAY [5, 20, null, 5, 3, 50]) ", new ArrayType(INTEGER), - asList(3, 5, 5, 20, 50, null)); - assertFunction("array_sort(array['x', 'a', 'a', 'a', 'a', 'm', 'j', 'p'])", - new ArrayType(createVarcharType(1)), ImmutableList.of("a", "a", "a", "a", "j", "m", "p", "x")); - assertFunction("array_sort(sequence(-4, 3))", new ArrayType(BIGINT), - asList(-4L, -3L, -2L, -1L, 0L, 1L, 2L, 3L)); - assertFunction("array_sort(reverse(sequence(-4, 3)))", new ArrayType(BIGINT), - asList(-4L, -3L, -2L, -1L, 0L, 1L, 2L, 3L)); - assertFunction("repeat(1,4)", new ArrayType(INTEGER), asList(1, 1, 1, 1)); - assertFunction("cast(array[] as array)", new ArrayType(INTEGER), asList()); - } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java index be335dd54c5c9..6d0b4f0f65052 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java @@ -24,6 +24,7 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; public class TestCustomFunctions extends AbstractTestFunctions @@ -41,7 +42,7 @@ protected TestCustomFunctions(FeaturesConfig config) public void setupClass() { registerScalar(CustomFunctions.class); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(CustomFunctions.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(CustomFunctions.class, JAVA_BUILTIN_NAMESPACE); this.functionAssertions.addFunctions(functions); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java index 1a939a01b3461..cf45e5643f9cb 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctions.java @@ -15,12 +15,10 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.Session; -import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; import org.joda.time.DateTime; import org.testng.annotations.Test; -import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.VarcharType.createVarcharType; @@ -46,31 +44,6 @@ public void testFormatDateCannotImplicitlyAddTimeZoneToTimestampLiteral() "format_datetime for TIMESTAMP type, cannot use 'Z' nor 'z' in format, as this type does not contain TZ information"); } - @Test - public void testLocalTime() - { - Session localSession = Session.builder(session) - .setStartTime(new DateTime(2017, 3, 1, 14, 30, 0, 0, DATE_TIME_ZONE).getMillis()) - .build(); - try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { - localAssertion.assertFunctionString("LOCALTIME", TimeType.TIME, "14:30:00.000"); - } - } - - @Test - public void testCurrentTime() - { - Session localSession = Session.builder(session) - // we use Asia/Kathmandu here to test the difference in semantic change of current_time - // between legacy and non-legacy timestamp - .setTimeZoneKey(KATHMANDU_ZONE_KEY) - .setStartTime(new DateTime(2017, 3, 1, 15, 45, 0, 0, KATHMANDU_ZONE).getMillis()) - .build(); - try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { - localAssertion.assertFunctionString("CURRENT_TIME", TIME_WITH_TIME_ZONE, "15:45:00.000 Asia/Kathmandu"); - } - } - @Test public void testLocalTimestamp() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsBase.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsBase.java index bb4066049496d..733108e7d8dcc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsBase.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsBase.java @@ -186,6 +186,30 @@ private static long epochDaysInZone(TimeZoneKey timeZoneKey, long instant) return LocalDate.from(Instant.ofEpochMilli(instant).atZone(ZoneId.of(timeZoneKey.getId()))).toEpochDay(); } + @Test + public void testLocalTime() + { + Session localSession = Session.builder(session) + .setStartTime(new DateTime(2017, 3, 1, 14, 30, 0, 0, DATE_TIME_ZONE).getMillis()) + .build(); + try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { + localAssertion.assertFunctionString("LOCALTIME", TimeType.TIME, "14:30:00.000"); + } + } + + @Test + public void testCurrentTime() + { + Session localSession = Session.builder(session) + // we use Asia/Kathmandu here, as it has different zone offset on 2017-03-01 and on 1970-01-01 + .setTimeZoneKey(KATHMANDU_ZONE_KEY) + .setStartTime(new DateTime(2017, 3, 1, 15, 45, 0, 0, KATHMANDU_ZONE).getMillis()) + .build(); + try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { + localAssertion.assertFunctionString("CURRENT_TIME", TIME_WITH_TIME_ZONE, "15:45:00.000 Asia/Kathmandu"); + } + } + @Test public void testFromUnixTime() { @@ -1226,10 +1250,23 @@ public void testParseDuration() assertFunction("parse_duration('1234.567h')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(51, 10, 34, 1, 200)); assertFunction("parse_duration('1234.567d')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(1234, 13, 36, 28, 800)); + // trailing spaces + assertFunction("parse_duration('1234 ns ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 0)); + assertFunction("parse_duration('1234 us ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 0, 1)); + assertFunction("parse_duration('1234ms ')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(0, 0, 0, 1, 234)); + // invalid function calls assertInvalidFunction("parse_duration('')", "duration is empty"); assertInvalidFunction("parse_duration('1f')", "Unknown time unit: f"); assertInvalidFunction("parse_duration('abc')", "duration is not a valid data duration string: abc"); + + // long milliseconds edge cases + assertFunction("parse_duration('7702741401940153ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153)); + assertFunction("parse_duration('9117756383778565ms')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(105529587, 18, 36, 18, 565)); + + // Test precision for large values with fractional seconds + assertFunction("parse_duration('7702741401940.153s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153)); + assertFunction("parse_duration('7702741401940.153 s')", INTERVAL_DAY_TIME, new SqlIntervalDayTime(89152099, 13, 25, 40, 153)); } @Test diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsLegacy.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsLegacy.java index 97527e6b00d10..2293dd0ceb865 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsLegacy.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestDateTimeFunctionsLegacy.java @@ -15,12 +15,10 @@ package com.facebook.presto.operator.scalar; import com.facebook.presto.Session; -import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; import org.joda.time.DateTime; import org.testng.annotations.Test; -import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; @@ -45,31 +43,6 @@ public void testFormatDateCanImplicitlyAddTimeZoneToTimestampLiteral() assertFunction("format_datetime(" + TIMESTAMP_LITERAL + ", 'YYYY/MM/dd HH:mm ZZZZ')", VARCHAR, "2001/08/22 03:04 " + DATE_TIME_ZONE.getID()); } - @Test - public void testLocalTime() - { - Session localSession = Session.builder(session) - .setStartTime(new DateTime(2017, 3, 1, 14, 30, 0, 0, DATE_TIME_ZONE).getMillis()) - .build(); - try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { - localAssertion.assertFunctionString("LOCALTIME", TimeType.TIME, "13:30:00.000"); - } - } - - @Test - public void testCurrentTime() - { - Session localSession = Session.builder(session) - // we use Asia/Kathmandu here to test the difference in semantic change of current_time - // between legacy and non-legacy timestamp - .setTimeZoneKey(KATHMANDU_ZONE_KEY) - .setStartTime(new DateTime(2017, 3, 1, 15, 45, 0, 0, KATHMANDU_ZONE).getMillis()) - .build(); - try (FunctionAssertions localAssertion = new FunctionAssertions(localSession)) { - localAssertion.assertFunctionString("CURRENT_TIME", TIME_WITH_TIME_ZONE, "15:30:00.000 Asia/Kathmandu"); - } - } - @Test public void testLocalTimestamp() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java index 0fc5d8f1407c1..1e2902fa42607 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java @@ -1392,9 +1392,7 @@ public void testArrayCosineSimilarity() DOUBLE, null); - assertFunction("cosine_similarity(array [1.0E0, null], array [1.0E0, 3.0E0])", - DOUBLE, - null); + assertInvalidFunction("cosine_similarity(array [1.0E0, null], array [1.0E0, 3.0E0])", "Both arrays must not have nulls"); assertInvalidFunction("cosine_similarity(array [], array [1.0E0, 3.0E0])", "Both array arguments need to have identical size"); @@ -1405,6 +1403,172 @@ public void testArrayCosineSimilarity() assertFunction("cosine_similarity(array [], null)", DOUBLE, null); + + assertInvalidFunction( + "cosine_similarity(array[1.0, null, 3.0], array[1.0, 2.0, 3.0])", "Both arrays must not have nulls"); + + assertInvalidFunction( + "cosine_similarity(array[1.0, 2.0, 3.0], array[1.0, null, 3.0])", "Both arrays must not have nulls"); + } + + @Test + public void testArrayL2Squared() + { + assertFunction( + "l2_squared(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, 0.0f); + assertFunction( + "l2_squared(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '4.0', REAL '5.0', REAL '6.0'])", + REAL, 27.0f); + assertFunction( + "l2_squared(array[REAL '-1.0', REAL '-2.0', REAL '-3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, 56.0f); + assertFunction( + "l2_squared(array[REAL '0.0', REAL '0.0', REAL '0.0'], array[REAL '0.0', REAL '0.0', REAL '0.0'])", + REAL, 0.0f); + assertInvalidFunction( + "l2_squared(array[REAL '1.0', REAL '2.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + "Both array arguments need to have identical size"); + assertFunction( + "l2_squared(CAST(null AS array(real)), array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, null); + assertFunction( + "l2_squared(array[REAL '1.0', REAL '2.0', REAL '3.0'], CAST(null AS array(real)))", + REAL, null); + assertFunction( + "l2_squared(CAST(null AS array(real)), CAST(null AS array(real)))", + REAL, null); + assertInvalidFunction( + "l2_squared(array[REAL '1.0', null, REAL '3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", "Both arrays must not have nulls"); + assertInvalidFunction( + "l2_squared(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '1.0', null, REAL '3.0'])", "Both arrays must not have nulls"); + } + + @Test + public void testArrayL2SquaredDouble() + { + assertFunction( + "l2_squared(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, 0.0d); + assertFunction( + "l2_squared(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '4.0', DOUBLE '5.0', DOUBLE '6.0'])", + DOUBLE, 27.0d); + assertFunction( + "l2_squared(array[DOUBLE '-1.0', DOUBLE '-2.0', DOUBLE '-3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, 56.0d); + assertFunction( + "l2_squared(array[DOUBLE '0.0', DOUBLE '0.0', DOUBLE '0.0'], array[DOUBLE '0.0', DOUBLE '0.0', DOUBLE '0.0'])", + DOUBLE, 0.0d); + assertInvalidFunction( + "l2_squared(array[DOUBLE '1.0', DOUBLE '2.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + "Both array arguments need to have identical size"); + assertFunction( + "l2_squared(CAST(null AS array(double)), array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, null); + assertFunction( + "l2_squared(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], CAST(null AS array(double)))", + DOUBLE, null); + assertFunction( + "l2_squared(CAST(null AS array(double)), CAST(null AS array(double)))", + DOUBLE, null); + assertInvalidFunction( + "l2_squared(array[DOUBLE '1.0', null, DOUBLE '3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", "Both arrays must not have nulls"); + assertInvalidFunction( + "l2_squared(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '1.0', null, DOUBLE '3.0'])", "Both arrays must not have nulls"); + } + + @Test + public void testArrayDotProduct() + { + // functionality test + assertFunction( + "dot_product(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, 14.0d); + assertFunction( + "dot_product(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '4.0', DOUBLE '5.0', DOUBLE '6.0'])", + DOUBLE, 32.0d); + assertFunction( + "dot_product(array[DOUBLE '-1.0', DOUBLE '-2.0', DOUBLE '-3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, -14.0d); + assertFunction( + "dot_product(array[DOUBLE '0.0', DOUBLE '0.0', DOUBLE '0.0'], array[DOUBLE '0.0', DOUBLE '0.0', DOUBLE '0.0'])", + DOUBLE, 0.0d); + // identical size test + assertInvalidFunction( + "dot_product(array[DOUBLE '1.0', DOUBLE '2.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + "Both array arguments must have identical sizes"); + // null test + assertFunction( + "dot_product(CAST(null AS array(double)), array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + DOUBLE, null); + assertFunction( + "dot_product(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], CAST(null AS array(double)))", + DOUBLE, null); + assertFunction( + "dot_product(CAST(null AS array(double)), CAST(null AS array(double)))", + DOUBLE, null); + // any null inside the equal sized arrays must throw error + assertInvalidFunction( + "dot_product(array[DOUBLE '1.0', null, DOUBLE '3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", + "Both array arguments must not have nulls"); + assertInvalidFunction( + "dot_product(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '1.0', null, DOUBLE '3.0'])", + "Both array arguments must not have nulls"); + // NaN test + assertFunction("dot_product(array[nan()], array[nan()])", + DOUBLE, Double.NaN); + // infinity test + assertFunction("dot_product(array[infinity()], array[infinity()])", + DOUBLE, Double.POSITIVE_INFINITY); + assertFunction("dot_product(array[infinity()], array[-1.0])", + DOUBLE, Double.NEGATIVE_INFINITY); + } + + @Test + public void testArrayDotProductReal() + { + // functionality test + assertFunction( + "dot_product(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, 14.0f); + assertFunction( + "dot_product(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '4.0', REAL '5.0', REAL '6.0'])", + REAL, 32.0f); + assertFunction( + "dot_product(array[REAL '-1.0', REAL '-2.0', REAL '-3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, -14.0f); + assertFunction( + "dot_product(array[REAL '0.0', REAL '0.0', REAL '0.0'], array[REAL '0.0', REAL '0.0', REAL '0.0'])", + REAL, 0.0f); + // identical size test + assertInvalidFunction( + "dot_product(array[REAL '1.0', REAL '2.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + "Both array arguments must have identical sizes"); + // null test + assertFunction( + "dot_product(CAST(null AS array(real)), array[REAL '1.0', REAL '2.0', REAL '3.0'])", + REAL, null); + assertFunction( + "dot_product(array[REAL '1.0', REAL '2.0', REAL '3.0'], CAST(null AS array(real)))", + REAL, null); + assertFunction( + "dot_product(CAST(null AS array(real)), CAST(null AS array(real)))", + REAL, null); + // any null inside the equal sized arrays must throw error + assertInvalidFunction( + "dot_product(array[REAL '1.0', null, REAL '3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", + "Both array arguments must not have nulls"); + assertInvalidFunction( + "dot_product(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '1.0', null, REAL '3.0'])", + "Both array arguments must not have nulls"); + // NaN test + assertFunction("dot_product(array[CAST(nan() AS REAL)], array[CAST(nan() AS REAL)])", + REAL, Float.NaN); + // infinity test + assertFunction("dot_product(array[CAST(infinity() AS REAL)], array[CAST(infinity() AS REAL)])", + REAL, Float.POSITIVE_INFINITY); + assertFunction("dot_product(array[CAST(infinity() AS REAL)], array[REAL '-1.0'])", + REAL, Float.NEGATIVE_INFINITY); } @Test @@ -1645,6 +1809,31 @@ public void testPoissonCdf() assertInvalidFunction("poisson_cdf(3, -10)", "poissonCdf Function: value must be a non-negative integer"); } + @Test + public void testInverseTCdf() + { + assertFunction("inverse_t_cdf(1000, 0.5)", DOUBLE, 0.0); + assertFunction("inverse_t_cdf(1000, 0.0)", DOUBLE, Double.NEGATIVE_INFINITY); + assertFunction("inverse_t_cdf(1000, 1.0)", DOUBLE, Double.POSITIVE_INFINITY); + + assertInvalidFunction("inverse_t_cdf(0, 0.5)", "df must be greater than 0"); + assertInvalidFunction("inverse_t_cdf(-1, 0.5)", "df must be greater than 0"); + assertInvalidFunction("inverse_t_cdf(3, -0.1)", "p must be in the interval [0, 1]"); + assertInvalidFunction("inverse_t_cdf(3, 1.1)", "p must be in the interval [0, 1]"); + } + + @Test + public void testTCdf() + throws Exception + { + assertFunction("t_cdf(1000, 0.0)", DOUBLE, 0.5); + assertFunction("t_cdf(1000, infinity())", DOUBLE, 1.0); + assertFunction("t_cdf(1000, -infinity())", DOUBLE, 0.0); + + assertInvalidFunction("t_cdf(0, 0.5)", "df must be greater than 0"); + assertInvalidFunction("t_cdf(-1, 0.5)", "df must be greater than 0"); + } + @Test public void testInverseWeibullCdf() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestQuantileDigestFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestQuantileDigestFunctions.java index 8f8c637531f58..801068d7cf1b8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestQuantileDigestFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestQuantileDigestFunctions.java @@ -20,11 +20,13 @@ import com.facebook.presto.operator.aggregation.FloatingPointBitsConverterUtil; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.QuantileDigestParametricType.QDIGEST; +import static com.facebook.presto.common.type.RealType.REAL; import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -60,6 +62,37 @@ public void testGetValueAtQuantileBelowZero() null); } + @DataProvider(name = "nullQuantileScenarios") + public Object[][] nullQuantileScenarios() + { + return new Object[][]{ + {"bigint", BIGINT, "ARRAY[0.25, NULL, 0.75]"}, + {"double", DOUBLE, "ARRAY[0.25, NULL, 0.75]"}, + {"real", REAL, "ARRAY[0.25, NULL, 0.75]"}, + + {"bigint", BIGINT, "ARRAY[NULL, NULL, NULL]"}, + {"double", DOUBLE, "ARRAY[NULL, NULL, NULL]"}, + {"real", REAL, "ARRAY[NULL, NULL, NULL]"}, + + {"bigint", BIGINT, "ARRAY[NULL]"}, + {"double", DOUBLE, "ARRAY[NULL]"}, + {"real", REAL, "ARRAY[NULL]"} + }; + } + @Test(dataProvider = "nullQuantileScenarios", + expectedExceptions = PrestoException.class, + expectedExceptionsMessageRegExp = "All quantiles should be non-null.") + public void testValuesAtQuantilesWithNullsThrowsError(String typeName, Type type, String arrayExpression) + { + QuantileDigest qdigest = new QuantileDigest(1); + + functionAssertions.assertFunction( + format("values_at_quantiles(CAST(X'%s' AS qdigest(%s)), %s)", + toHexString(qdigest), typeName, arrayExpression), + type, + null); + } + @Test public void testValueAtQuantileBigint() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java index a12ef98cec3fb..e92c24af09b52 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestScalarValidation.java @@ -23,10 +23,9 @@ import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; +import jakarta.annotation.Nullable; import org.testng.annotations.Test; -import javax.annotation.Nullable; - @SuppressWarnings("UtilityClassWithoutPrivateConstructor") public class TestScalarValidation { diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java index 9197fae5aa811..0023402dd9e91 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestStringFunctions.java @@ -18,10 +18,10 @@ import com.facebook.presto.common.type.SqlVarbinary; import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.LiteralParameter; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; import com.facebook.presto.spi.function.SqlType; -import com.facebook.presto.type.LiteralParameter; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java index 30b7914933e11..bc50a40fcb368 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestVarbinaryFunctions.java @@ -375,6 +375,8 @@ public void testXxhash64() { assertFunction("xxhash64(CAST('' AS VARBINARY))", VARBINARY, sqlVarbinaryHex("EF46DB3751D8E999")); assertFunction("xxhash64(CAST('hashme' AS VARBINARY))", VARBINARY, sqlVarbinaryHex("F9D96E0E1165E892")); + assertFunction("xxhash64(CAST('' AS VARBINARY), 0)", VARBINARY, sqlVarbinaryHex("EF46DB3751D8E999")); + assertFunction("xxhash64(CAST('hashme' AS VARBINARY), 0)", VARBINARY, sqlVarbinaryHex("F9D96E0E1165E892")); } @Test diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java index 2e0e865f13be4..9456372e9632e 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestArraySqlFunctions.java @@ -401,4 +401,17 @@ public void testArrayTopNEdgeAndErrorCase() assertFunction("ARRAY_TOP_N(ARRAY [], 3)", new ArrayType(UNKNOWN), emptyList()); assertFunction("ARRAY_TOP_N(ARRAY [1, 4], 3)", new ArrayType(INTEGER), ImmutableList.of(4, 1)); } + + @Test + public void testArrayTranspose() + { + assertFunction("ARRAY_TRANSPOSE(ARRAY[ARRAY[1, 2, 3], ARRAY[4, 5, 6]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 4), ImmutableList.of(2, 5), ImmutableList.of(3, 6))); + assertFunction("ARRAY_TRANSPOSE(ARRAY[ARRAY[1]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1))); + assertFunction("ARRAY_TRANSPOSE(ARRAY[])", new ArrayType(new ArrayType(UNKNOWN)), emptyList()); + assertFunction("ARRAY_TRANSPOSE(ARRAY[ARRAY[VARCHAR 'a', VARCHAR 'b'], ARRAY[VARCHAR 'c', VARCHAR 'd'], ARRAY[VARCHAR 'e', VARCHAR 'f']])", new ArrayType(new ArrayType(VARCHAR)), ImmutableList.of(ImmutableList.of("a", "c", "e"), ImmutableList.of("b", "d", "f"))); + assertFunction("array_transpose(array[array[1, null, 3], array[4, 5, null]])", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 4), asList(null, 5), asList(3, null))); + assertFunction("array_transpose(array_transpose(array[array[1, 2, 3], array[4, 5, 6]]))", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(ImmutableList.of(1, 2, 3), ImmutableList.of(4, 5, 6))); + + assertInvalidFunction("array_transpose(array[array[1, 2, 3], array[4, 5]])", StandardErrorCode.GENERIC_USER_ERROR, "All rows must have the same length for matrix transpose"); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapSqlFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapSqlFunctions.java new file mode 100644 index 0000000000000..73996c37ed690 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/sql/TestMapSqlFunctions.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.operator.scalar.sql; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.spi.StandardErrorCode; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.util.StructuralTestUtil.mapType; +import static java.util.Arrays.asList; + +public class TestMapSqlFunctions + extends AbstractTestFunctions +{ + @Test + public void testMapIntKeysToArray() + { + assertFunction("MAP_INT_KEYS_TO_ARRAY(CAST(MAP() AS MAP))", + new ArrayType(INTEGER), + null); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[1, 3], ARRAY['a', 'b']))", + new ArrayType(createVarcharType(1)), + asList("a", null, "b")); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[3, 5, 6, 9], ARRAY['a', 'b', 'c', 'd']))", + new ArrayType(createVarcharType(1)), + asList(null, null, "a", null, "b", "c", null, null, "d")); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[3, 5, 6, 9], ARRAY['a', null, 'c', 'd']))", + new ArrayType(createVarcharType(1)), + asList(null, null, "a", null, null, "c", null, null, "d")); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(CAST(MAP() AS MAP))", + new ArrayType(INTEGER), + null); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[2], ARRAY[MAP(ARRAY[3, 5, 6, 9], ARRAY['a', null, 'c', 'd'])]))", + new ArrayType(mapType(INTEGER, createVarcharType(1))), + asList(null, asMap(asList(3, 5, 6, 9), asList("a", null, "c", "d")))); + + assertFunction("MAP_INT_KEYS_TO_ARRAY(CAST(MAP() AS MAP))", + new ArrayType(INTEGER), + null); + + assertInvalidFunction( + "MAP_INT_KEYS_TO_ARRAY(MAP(CAST(SEQUENCE(1,10000)||ARRAY[10001] AS ARRAY),SEQUENCE(1,10000)||ARRAY[10001]))", + StandardErrorCode.GENERIC_USER_ERROR, + "Max key value must be <= 10k for map_int_keys_to_array function"); + + assertInvalidFunction( + "MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[0],ARRAY[1]))", + StandardErrorCode.GENERIC_USER_ERROR, + "Only positive keys allowed in map_int_keys_to_array function, but got: 0"); + + assertInvalidFunction( + "MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[-1, 2], ARRAY['a', 'b']))", + StandardErrorCode.GENERIC_USER_ERROR, + "Only positive keys allowed in map_int_keys_to_array function, but got: -1"); + + assertInvalidFunction("MAP_INT_KEYS_TO_ARRAY(MAP(ARRAY[0, 2], ARRAY['a', 'b']))", + StandardErrorCode.GENERIC_USER_ERROR, + "Only positive keys allowed in map_int_keys_to_array function, but got: 0"); + } + + @Test + public void testArrayToMapIntKeys() + { + assertFunction("ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, 6, 9] AS ARRAY))", + mapType(INTEGER, INTEGER), + ImmutableMap.of(1, 3, 2, 5, 3, 6, 4, 9)); + + assertFunction("ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, null, 6, 9] AS ARRAY))", + mapType(INTEGER, INTEGER), + asMap(asList(1, 2, 4, 5), asList(3, 5, 6, 9))); + + assertFunction("ARRAY_TO_MAP_INT_KEYS(CAST(ARRAY[3, 5, null, 6, 9, null, null, 1] AS ARRAY))", + mapType(INTEGER, INTEGER), + asMap(asList(1, 2, 4, 5, 8), asList(3, 5, 6, 9, 1))); + + assertFunction("ARRAY_TO_MAP_INT_KEYS(CAST(NULL AS ARRAY))", + mapType(INTEGER, INTEGER), + null); + + assertInvalidFunction( + "ARRAY_TO_MAP_INT_KEYS(SEQUENCE(1,10000)||ARRAY[10001])", + StandardErrorCode.GENERIC_USER_ERROR, + "Max number of elements must be <= 10k for array_to_map_int_keys function"); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStateProvider.java b/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStateProvider.java index 29801427f6982..4afefed16b9a2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStateProvider.java +++ b/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStateProvider.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.execution.QueryState; import com.facebook.presto.execution.resourceGroups.ResourceGroupRuntimeInfo; @@ -32,8 +34,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -45,6 +45,8 @@ import java.util.Optional; import java.util.OptionalDouble; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.succinctBytes; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.QueryState.DISPATCHING; import static com.facebook.presto.execution.QueryState.FAILED; @@ -60,8 +62,6 @@ import static com.facebook.presto.memory.LocalMemoryManager.RESERVED_POOL; import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; import static com.facebook.presto.operator.BlockedReason.WAITING_FOR_MEMORY; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.SECONDS; @@ -760,6 +760,14 @@ private static BasicQueryInfo createQueryInfo(String queryId, QueryState state, 14, 15, 100, + 13, + 14, + 15, + 100, + 13, + 14, + 15, + 100, DataSize.valueOf("21GB"), 22, 23, diff --git a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerConfig.java index 89dd90527e383..48c47a24731b6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerConfig.java @@ -14,8 +14,8 @@ package com.facebook.presto.resourcemanager; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java b/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java index 91b6018814476..480c7fb2f44f0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/security/TestAccessControlManager.java @@ -173,6 +173,7 @@ public void testCheckQueryIntegrity() accessControlManager.addSystemAccessControlFactory(accessControlFactory); accessControlManager.setSystemAccessControl("test", ImmutableMap.of()); String testQuery = "test_query"; + Map preparedStatements = ImmutableMap.of(); Map viewDefinitions = ImmutableMap.of(); Map materializedViewDefinitions = ImmutableMap.of(); @@ -187,6 +188,7 @@ public void testCheckQueryIntegrity() Optional.empty()), context, testQuery, + preparedStatements, viewDefinitions, materializedViewDefinitions); assertEquals(accessControlFactory.getCheckedUserName(), USER_NAME); @@ -206,6 +208,7 @@ public void testCheckQueryIntegrity() Optional.empty()), context, testQuery, + preparedStatements, viewDefinitions, materializedViewDefinitions)); } @@ -306,7 +309,7 @@ public void checkCanSetUser(Identity identity, AccessControlContext context, Opt } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitionMap) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitionMap) { } @@ -420,7 +423,7 @@ public void checkCanSetUser(Identity identity, AccessControlContext context, Opt } @Override - public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map viewDefinitions, Map materializedViewDefinitions) + public void checkQueryIntegrity(Identity identity, AccessControlContext context, String query, Map preparedStatements, Map viewDefinitions, Map materializedViewDefinitions) { if (!query.equals(identity.getExtraCredentials().get(QUERY_TOKEN_FIELD))) { denyQueryIntegrityCheck(); @@ -487,6 +490,12 @@ public void checkCanRenameSchema(ConnectorTransactionHandle transactionHandle, C throw new UnsupportedOperationException(); } + @Override + public void checkCanShowCreateTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + throw new UnsupportedOperationException(); + } + @Override public void checkCanCreateTable(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { @@ -565,6 +574,12 @@ public void checkCanCreateViewWithSelectFromColumns(ConnectorTransactionHandle t throw new UnsupportedOperationException(); } + @Override + public void checkCanCallProcedure(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName procedureName) + { + throw new UnsupportedOperationException(); + } + @Override public void checkCanSetCatalogSessionProperty(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, String propertyName) { @@ -583,6 +598,18 @@ public void checkCanRevokeTablePrivilege(ConnectorTransactionHandle transactionH throw new UnsupportedOperationException(); } + @Override + public void checkCanDropBranch(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + throw new UnsupportedOperationException(); + } + + @Override + public void checkCanDropTag(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) + { + throw new UnsupportedOperationException(); + } + @Override public void checkCanDropConstraint(ConnectorTransactionHandle transactionHandle, ConnectorIdentity identity, AccessControlContext context, SchemaTableName tableName) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/security/TestStatsRecordingSystemAccessControl.java b/presto-main-base/src/test/java/com/facebook/presto/security/TestStatsRecordingSystemAccessControl.java new file mode 100644 index 0000000000000..70b1109c71b3c --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/security/TestStatsRecordingSystemAccessControl.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.security; + +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.SystemAccessControl; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.Optional; + +import static com.facebook.presto.spi.testing.InterfaceTestUtils.assertAllMethodsOverridden; +import static org.testng.Assert.assertEquals; + +public class TestStatsRecordingSystemAccessControl +{ + public static final AccessControlContext CONTEXT = new AccessControlContext(new QueryId("query_id"), Optional.empty(), Collections.emptySet(), Optional.empty(), WarningCollector.NOOP, new RuntimeStats(), Optional.empty(), Optional.empty(), Optional.empty()); + + @Test + public void testEverythingDelegated() + { + assertAllMethodsOverridden(SystemAccessControl.class, StatsRecordingSystemAccessControl.class); + } + + @Test + public void testStatsRecording() + { + SystemAccessControl delegate = new AllowAllSystemAccessControl(); + StatsRecordingSystemAccessControl statsRecordingAccessControl = new StatsRecordingSystemAccessControl(delegate); + + assertEquals(statsRecordingAccessControl.getStats().getCheckCanAccessCatalog().getTime().getAllTime().getCount(), 0.0); + + statsRecordingAccessControl.checkCanAccessCatalog(null, CONTEXT, "test-catalog"); + + assertEquals(statsRecordingAccessControl.getStats().getCheckCanAccessCatalog().getTime().getAllTime().getCount(), 1.0); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java index b30ce4c921f04..288812f173894 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestBasicQueryInfo.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.cost.StatsAndCosts; @@ -21,13 +23,12 @@ import com.facebook.presto.operator.BlockedReason; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.StandardErrorCode; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.eventlistener.StageGcStatistics; import com.facebook.presto.spi.memory.MemoryPoolId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.Test; @@ -81,6 +82,14 @@ public void testConstructor() 18, 34, 19, + 16, + 17, + 18, + 19, + 16, + 17, + 18, + 19, 20.0, 43.0, DataSize.valueOf("21GB"), @@ -130,7 +139,7 @@ public void testConstructor() ImmutableSet.of(), Optional.empty(), false, - "33", + new UpdateInfo("UPDATE TYPE", ""), Optional.empty(), null, StandardErrorCode.ABANDONED_QUERY.toErrorCode(), diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestInternalCommunicationConfig.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestInternalCommunicationConfig.java index b171d5fdda8c3..de88e6b3fc877 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestInternalCommunicationConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestInternalCommunicationConfig.java @@ -13,10 +13,10 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.server.InternalCommunicationConfig.CommunicationProtocol; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; @@ -24,7 +24,7 @@ import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class TestInternalCommunicationConfig { @@ -52,7 +52,9 @@ public void testDefaults() .setSharedSecret(null) .setTaskUpdateRequestThriftSerdeEnabled(false) .setTaskInfoResponseThriftSerdeEnabled(false) - .setInternalJwtEnabled(false)); + .setInternalJwtEnabled(false) + .setNodeStatsRefreshIntervalMillis(1_000) + .setNodeDiscoveryPollingIntervalMillis(5_000)); } @Test @@ -80,6 +82,8 @@ public void testExplicitPropertyMappings() .put("internal-communication.jwt.enabled", "true") .put("experimental.internal-communication.task-update-request-thrift-serde-enabled", "true") .put("experimental.internal-communication.task-info-response-thrift-serde-enabled", "true") + .put("internal-communication.node-stats-refresh-interval-millis", "2000") + .put("internal-communication.node-discovery-polling-interval-millis", "3000") .build(); InternalCommunicationConfig expected = new InternalCommunicationConfig() @@ -103,7 +107,9 @@ public void testExplicitPropertyMappings() .setSharedSecret("secret") .setInternalJwtEnabled(true) .setTaskUpdateRequestThriftSerdeEnabled(true) - .setTaskInfoResponseThriftSerdeEnabled(true); + .setTaskInfoResponseThriftSerdeEnabled(true) + .setNodeStatsRefreshIntervalMillis(2000) + .setNodeDiscoveryPollingIntervalMillis(3000); assertFullMapping(properties, expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryProgressStats.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryProgressStats.java index b0285e09994e3..9975393bdadcd 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryProgressStats.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryProgressStats.java @@ -48,6 +48,12 @@ public void testJson() OptionalDouble.of(33.33), 1200, 1100, + 1000, + 1200, + 1100, + 1000, + 1200, + 1100, 1000); JsonCodec codec = JsonCodec.jsonCodec(QueryProgressStats.class); diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java index 8bdc2b62aec0e..f14a66b5a2629 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestQueryStateInfo.java @@ -13,23 +13,29 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.cost.StatsAndCosts; +import com.facebook.presto.execution.ClusterOverloadConfig; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryState; import com.facebook.presto.execution.QueryStats; import com.facebook.presto.execution.resourceGroups.InternalResourceGroup; +import com.facebook.presto.execution.resourceGroups.QueryPacingContext; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterOverloadPolicy; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.spi.PrestoWarning; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.WarningCode; +import com.facebook.presto.spi.analyzer.UpdateInfo; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.testng.annotations.Test; @@ -38,6 +44,7 @@ import java.util.Optional; import java.util.OptionalInt; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.QueryState.FINISHED; import static com.facebook.presto.execution.QueryState.QUEUED; @@ -47,7 +54,6 @@ import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_GLOBAL_MEMORY_LIMIT; import static com.facebook.presto.spi.resourceGroups.SchedulingPolicy.WEIGHTED; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -57,7 +63,7 @@ public class TestQueryStateInfo @Test public void testQueryStateInfo() { - InternalResourceGroup.RootInternalResourceGroup root = new InternalResourceGroup.RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, new InMemoryNodeManager()); + InternalResourceGroup.RootInternalResourceGroup root = new InternalResourceGroup.RootInternalResourceGroup("root", (group, export) -> {}, directExecutor(), ignored -> Optional.empty(), rg -> false, new InMemoryNodeManager(), createClusterResourceChecker(), QueryPacingContext.NOOP); root.setSoftMemoryLimit(new DataSize(1, MEGABYTE)); root.setMaxQueuedQueries(40); root.setHardConcurrencyLimit(0); @@ -241,6 +247,14 @@ private QueryInfo createQueryInfo(String queryId, ResourceGroupId resourceGroupI 18, 34, 19, + 100, + 17, + 18, + 19, + 100, + 17, + 18, + 19, 20.0, 43.0, DataSize.valueOf("21GB"), @@ -282,7 +296,7 @@ private QueryInfo createQueryInfo(String queryId, ResourceGroupId resourceGroupI ImmutableSet.of(), Optional.empty(), false, - "33", + new UpdateInfo("UPDATE TYPE", ""), Optional.empty(), null, EXCEEDED_GLOBAL_MEMORY_LIMIT.toErrorCode(), @@ -309,4 +323,29 @@ private QueryInfo createQueryInfo(String queryId, ResourceGroupId resourceGroupI ImmutableMap.of(), Optional.empty()); } + + private ClusterResourceChecker createClusterResourceChecker() + { + // Create a mock cluster overload policy that never reports overload + ClusterOverloadPolicy mockPolicy = new ClusterOverloadPolicy() + { + @Override + public boolean isClusterOverloaded(InternalNodeManager nodeManager) + { + return false; // Never overloaded for tests + } + + @Override + public String getName() + { + return "test-policy"; + } + }; + + // Create a config with throttling disabled for tests + ClusterOverloadConfig config = new ClusterOverloadConfig() + .setClusterOverloadThrottlingEnabled(false); + + return new ClusterResourceChecker(mockPolicy, config, new InMemoryNodeManager()); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestRetryConfig.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestRetryConfig.java new file mode 100644 index 0000000000000..cec8cb1ac853e --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestRetryConfig.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; + +public class TestRetryConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(ConfigAssertions.recordDefaults(RetryConfig.class) + .setRetryEnabled(true) + .setRequireHttps(false) + .setAllowedRetryDomains(null) + .setCrossClusterRetryErrorCodes("REMOTE_TASK_ERROR")); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("retry.enabled", "false") + .put("retry.allowed-domains", "*.foo.bar,*.baz.qux") + .put("retry.require-https", "true") + .put("retry.cross-cluster-error-codes", "QUERY_QUEUE_FULL") + .build(); + + RetryConfig expected = new RetryConfig() + .setRetryEnabled(false) + .setRequireHttps(true) + .setAllowedRetryDomains("*.foo.bar,*.baz.qux") + .setCrossClusterRetryErrorCodes("QUERY_QUEUE_FULL"); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestServerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestServerConfig.java index 56f94a1ec216b..b570f157a36f4 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestServerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestServerConfig.java @@ -14,8 +14,8 @@ package com.facebook.presto.server; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; @@ -49,7 +49,8 @@ public void testDefaults() .setPoolType(DEFAULT) .setClusterStatsExpirationDuration(new Duration(0, MILLISECONDS)) .setNestedDataSerializationEnabled(true) - .setClusterResourceGroupStateInfoExpirationDuration(new Duration(0, MILLISECONDS))); + .setClusterResourceGroupStateInfoExpirationDuration(new Duration(0, MILLISECONDS)) + .setClusterTag(null)); } @Test @@ -72,6 +73,7 @@ public void testExplicitPropertyMappings() .put("cluster-stats-expiration-duration", "10s") .put("nested-data-serialization-enabled", "false") .put("cluster-resource-group-state-info-expiration-duration", "10s") + .put("cluster-tag", "test-cluster") .build(); ServerConfig expected = new ServerConfig() @@ -90,7 +92,8 @@ public void testExplicitPropertyMappings() .setPoolType(LEAF) .setClusterStatsExpirationDuration(new Duration(10, SECONDS)) .setNestedDataSerializationEnabled(false) - .setClusterResourceGroupStateInfoExpirationDuration(new Duration(10, SECONDS)); + .setClusterResourceGroupStateInfoExpirationDuration(new Duration(10, SECONDS)) + .setClusterTag("test-cluster"); assertFullMapping(properties, expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestThriftTaskIntegration.java b/presto-main-base/src/test/java/com/facebook/presto/server/TestThriftTaskIntegration.java index cea5e4cea70f6..64009fd931cc6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/TestThriftTaskIntegration.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/TestThriftTaskIntegration.java @@ -40,7 +40,6 @@ import com.facebook.presto.execution.buffer.ThriftBufferResult; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.memory.MemoryPoolAssignmentsRequest; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.server.thrift.ThriftTaskClient; import com.facebook.presto.server.thrift.ThriftTaskService; import com.facebook.presto.sql.planner.PlanFragment; @@ -53,12 +52,11 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; +import jakarta.inject.Singleton; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import javax.inject.Singleton; - import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutorService; @@ -69,6 +67,7 @@ import static com.facebook.drift.server.guice.DriftServerBinder.driftServerBinder; import static com.facebook.drift.transport.netty.client.DriftNettyMethodInvokerFactory.createStaticDriftNettyMethodInvokerFactory; import static com.facebook.presto.execution.buffer.BufferResult.emptyResults; +import static io.netty.buffer.ByteBufAllocator.DEFAULT; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; @@ -110,7 +109,7 @@ public void testServer() AddressSelector addressSelector = new SimpleAddressSelector( ImmutableSet.of(HostAndPort.fromParts("localhost", thriftServerPort)), true); - try (DriftNettyMethodInvokerFactory invokerFactory = createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig())) { + try (DriftNettyMethodInvokerFactory invokerFactory = createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig(), DEFAULT)) { DriftClientFactory clientFactory = new DriftClientFactory(new ThriftCodecManager(), invokerFactory, addressSelector, NORMAL_RESULT); ThriftTaskClient client = clientFactory.createDriftClient(ThriftTaskClient.class).get(); @@ -280,12 +279,6 @@ public void removeRemoteSource(TaskId taskId, TaskId remoteSourceTaskId) { throw new UnsupportedOperationException(); } - - @Override - public void updateMetadataResults(TaskId taskId, MetadataUpdates metadataUpdates) - { - throw new UnsupportedOperationException(); - } }; } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/remotetask/TestBackoff.java b/presto-main-base/src/test/java/com/facebook/presto/server/remotetask/TestBackoff.java index 886bd7a6433c4..589981cada9d6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/remotetask/TestBackoff.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/remotetask/TestBackoff.java @@ -14,8 +14,8 @@ package com.facebook.presto.server.remotetask; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableList; -import io.airlift.units.Duration; import org.testng.annotations.Test; import static java.util.concurrent.TimeUnit.MICROSECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/security/TestSecurityConfig.java b/presto-main-base/src/test/java/com/facebook/presto/server/security/TestSecurityConfig.java index eb27efd54d52b..9211233bd8799 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/security/TestSecurityConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/server/security/TestSecurityConfig.java @@ -31,7 +31,8 @@ public void testDefaults() ConfigAssertions.assertRecordedDefaults(ConfigAssertions.recordDefaults(SecurityConfig.class) .setAuthenticationTypes("") .setAllowForwardedHttps(false) - .setAuthorizedIdentitySelectionEnabled(false)); + .setAuthorizedIdentitySelectionEnabled(false) + .setEnableSqlQueryTextContextField(false)); } @Test @@ -41,12 +42,14 @@ public void testExplicitPropertyMappings() .put("http-server.authentication.type", "KERBEROS,PASSWORD") .put("http-server.authentication.allow-forwarded-https", "true") .put("permissions.authorized-identity-selection-enabled", "true") + .put("permissions.enable-sql-query-text-context-field", "true") .build(); SecurityConfig expected = new SecurityConfig() .setAuthenticationTypes(ImmutableList.of(KERBEROS, PASSWORD)) .setAllowForwardedHttps(true) - .setAuthorizedIdentitySelectionEnabled(true); + .setAuthorizedIdentitySelectionEnabled(true) + .setEnableSqlQueryTextContextField(true); ConfigAssertions.assertFullMapping(properties, expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/spiller/TestNodeSpillConfig.java b/presto-main-base/src/test/java/com/facebook/presto/spiller/TestNodeSpillConfig.java index 54c3749d6cc0e..c4ee47d96b020 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/spiller/TestNodeSpillConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/spiller/TestNodeSpillConfig.java @@ -14,18 +14,18 @@ package com.facebook.presto.spiller; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.CompressionCodec; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; public class TestNodeSpillConfig { diff --git a/presto-main-base/src/test/java/com/facebook/presto/spiller/TestSpillSpaceTracker.java b/presto-main-base/src/test/java/com/facebook/presto/spiller/TestSpillSpaceTracker.java index e7f10941c75a2..b058b44295ffb 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/spiller/TestSpillSpaceTracker.java +++ b/presto-main-base/src/test/java/com/facebook/presto/spiller/TestSpillSpaceTracker.java @@ -13,12 +13,12 @@ */ package com.facebook.presto.spiller; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.ExceededSpillLimitException; -import io.airlift.units.DataSize; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; @Test(singleThreaded = true) diff --git a/presto-main-base/src/test/java/com/facebook/presto/split/MockSplitSource.java b/presto-main-base/src/test/java/com/facebook/presto/split/MockSplitSource.java index 99a20ca970314..934257c7eedd7 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/split/MockSplitSource.java +++ b/presto-main-base/src/test/java/com/facebook/presto/split/MockSplitSource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.split; +import com.facebook.airlift.concurrent.NotThreadSafe; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.metadata.Split; import com.facebook.presto.spi.ConnectorId; @@ -27,8 +28,6 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import javax.annotation.concurrent.NotThreadSafe; - import java.util.Collections; import java.util.List; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java index f07caf36cf5e8..221685a9cf85d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java @@ -13,44 +13,19 @@ */ package com.facebook.presto.sql; -import com.facebook.presto.common.CatalogSchemaName; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockEncodingManager; -import com.facebook.presto.common.block.BlockEncodingSerde; -import com.facebook.presto.common.block.BlockSerdeUtil; -import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.SqlTimestampWithTimeZone; -import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.scalar.FunctionAssertions; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.function.AggregationFunctionMetadata; -import com.facebook.presto.spi.function.FunctionKind; -import com.facebook.presto.spi.function.Parameter; -import com.facebook.presto.spi.function.RoutineCharacteristics; -import com.facebook.presto.spi.function.SqlInvokedFunction; -import com.facebook.presto.spi.relation.CallExpression; -import com.facebook.presto.spi.relation.ConstantExpression; -import com.facebook.presto.spi.relation.InputReferenceExpression; -import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.parser.ParsingOptions; -import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; -import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -63,122 +38,32 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import org.intellij.lang.annotations.Language; -import org.joda.time.DateTime; -import org.joda.time.DateTimeZone; -import org.joda.time.LocalDate; -import org.joda.time.LocalTime; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.math.BigInteger; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.DateType.DATE; -import static com.facebook.presto.common.type.DecimalType.createDecimalType; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.common.type.IntegerType.INTEGER; -import static com.facebook.presto.common.type.TimeType.TIME; -import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; -import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; -import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; -import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; -import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; -import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; -import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; -import static io.airlift.slice.Slices.utf8Slice; import static java.lang.String.format; import static java.util.Collections.emptyMap; -import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertThrows; -import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class TestExpressionInterpreter + extends AbstractTestExpressionInterpreter { - public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( - QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), - parseTypeSignature(StandardTypes.BIGINT), - "Integer square", - RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), - "", - notVersioned()); - - public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( - QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), - parseTypeSignature(StandardTypes.DOUBLE), - "Returns mean of doubles", - RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), - "", - notVersioned(), - FunctionKind.AGGREGATE, - Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); - - private static final int TEST_VARCHAR_TYPE_LENGTH = 17; - private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() - .put("bound_integer", INTEGER) - .put("bound_long", BIGINT) - .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) - .put("bound_varbinary", VarbinaryType.VARBINARY) - .put("bound_double", DOUBLE) - .put("bound_boolean", BOOLEAN) - .put("bound_date", DATE) - .put("bound_time", TIME) - .put("bound_timestamp", TIMESTAMP) - .put("bound_pattern", VARCHAR) - .put("bound_null_string", VARCHAR) - .put("bound_decimal_short", createDecimalType(5, 2)) - .put("bound_decimal_long", createDecimalType(23, 3)) - .put("time", BIGINT) // for testing reserved identifiers - .put("unbound_integer", INTEGER) - .put("unbound_long", BIGINT) - .put("unbound_long2", BIGINT) - .put("unbound_long3", BIGINT) - .put("unbound_string", VARCHAR) - .put("unbound_double", DOUBLE) - .put("unbound_boolean", BOOLEAN) - .put("unbound_date", DATE) - .put("unbound_time", TIME) - .put("unbound_array", new ArrayType(BIGINT)) - .put("unbound_timestamp", TIMESTAMP) - .put("unbound_interval", INTERVAL_DAY_TIME) - .put("unbound_pattern", VARCHAR) - .put("unbound_null_string", VARCHAR) - .build()); - - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); - private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); - @BeforeClass public void setup() { @@ -186,247 +71,6 @@ public void setup() setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); } - @Test - public void testAnd() - { - assertOptimizedEquals("true and false", "false"); - assertOptimizedEquals("false and true", "false"); - assertOptimizedEquals("false and false", "false"); - - assertOptimizedEquals("true and null", "null"); - assertOptimizedEquals("false and null", "false"); - assertOptimizedEquals("null and true", "null"); - assertOptimizedEquals("null and false", "false"); - assertOptimizedEquals("null and null", "null"); - - assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'"); - assertOptimizedEquals("unbound_string='z' and false", "false"); - assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'"); - assertOptimizedEquals("false and unbound_string='z'", "false"); - - assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2"); - assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0"); - } - - @Test - public void testOr() - { - assertOptimizedEquals("true or true", "true"); - assertOptimizedEquals("true or false", "true"); - assertOptimizedEquals("false or true", "true"); - assertOptimizedEquals("false or false", "false"); - - assertOptimizedEquals("true or null", "true"); - assertOptimizedEquals("null or true", "true"); - assertOptimizedEquals("null or null", "null"); - - assertOptimizedEquals("false or null", "null"); - assertOptimizedEquals("null or false", "null"); - - assertOptimizedEquals("bound_string='z' or true", "true"); - assertOptimizedEquals("bound_string='z' or false", "bound_string='z'"); - assertOptimizedEquals("true or bound_string='z'", "true"); - assertOptimizedEquals("false or bound_string='z'", "bound_string='z'"); - - assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2"); - assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0"); - } - - @Test - public void testComparison() - { - assertOptimizedEquals("null = null", "null"); - - assertOptimizedEquals("'a' = 'b'", "false"); - assertOptimizedEquals("'a' = 'a'", "true"); - assertOptimizedEquals("'a' = null", "null"); - assertOptimizedEquals("null = 'a'", "null"); - assertOptimizedEquals("bound_integer = 1234", "true"); - assertOptimizedEquals("bound_integer = 12340000000", "false"); - assertOptimizedEquals("bound_long = BIGINT '1234'", "true"); - assertOptimizedEquals("bound_long = 1234", "true"); - assertOptimizedEquals("bound_double = 12.34", "true"); - assertOptimizedEquals("bound_string = 'hello'", "true"); - assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long"); - - assertOptimizedEquals("10151082135029368 = 10151082135029369", "false"); - - assertOptimizedEquals("bound_varbinary = X'a b'", "true"); - assertOptimizedEquals("bound_varbinary = X'a d'", "false"); - - assertOptimizedEquals("1.1 = 1.1", "true"); - assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true"); - assertOptimizedEquals("bound_decimal_short = 123.45", "true"); - assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true"); - } - - @Test - public void testIsDistinctFrom() - { - assertOptimizedEquals("null is distinct from null", "false"); - - assertOptimizedEquals("3 is distinct from 4", "true"); - assertOptimizedEquals("3 is distinct from BIGINT '4'", "true"); - assertOptimizedEquals("3 is distinct from 4000000000", "true"); - assertOptimizedEquals("3 is distinct from 3", "false"); - assertOptimizedEquals("3 is distinct from null", "true"); - assertOptimizedEquals("null is distinct from 3", "true"); - - assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true"); - - assertOptimizedEquals("1.1 is distinct from 1.1", "false"); - assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true"); - assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true"); - assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false"); - } - - @Test - public void testIsNull() - { - assertOptimizedEquals("null is null", "true"); - assertOptimizedEquals("1 is null", "false"); - assertOptimizedEquals("10000000000 is null", "false"); - assertOptimizedEquals("BIGINT '1' is null", "false"); - assertOptimizedEquals("1.0 is null", "false"); - assertOptimizedEquals("'a' is null", "false"); - assertOptimizedEquals("true is null", "false"); - assertOptimizedEquals("null+1 is null", "true"); - assertOptimizedEquals("unbound_string is null", "unbound_string is null"); - assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null"); - assertOptimizedEquals("1.1 is null", "false"); - assertOptimizedEquals("9876543210.9874561203 is null", "false"); - assertOptimizedEquals("bound_decimal_short is null", "false"); - assertOptimizedEquals("bound_decimal_long is null", "false"); - } - - @Test - public void testIsNotNull() - { - assertOptimizedEquals("null is not null", "false"); - assertOptimizedEquals("1 is not null", "true"); - assertOptimizedEquals("10000000000 is not null", "true"); - assertOptimizedEquals("BIGINT '1' is not null", "true"); - assertOptimizedEquals("1.0 is not null", "true"); - assertOptimizedEquals("'a' is not null", "true"); - assertOptimizedEquals("true is not null", "true"); - assertOptimizedEquals("null+1 is not null", "false"); - assertOptimizedEquals("unbound_string is not null", "unbound_string is not null"); - assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null"); - assertOptimizedEquals("1.1 is not null", "true"); - assertOptimizedEquals("9876543210.9874561203 is not null", "true"); - assertOptimizedEquals("bound_decimal_short is not null", "true"); - assertOptimizedEquals("bound_decimal_long is not null", "true"); - } - - @Test - public void testNullIf() - { - assertOptimizedEquals("nullif(true, true)", "null"); - assertOptimizedEquals("nullif(true, false)", "true"); - assertOptimizedEquals("nullif(null, false)", "null"); - assertOptimizedEquals("nullif(true, null)", "true"); - - assertOptimizedEquals("nullif('a', 'a')", "null"); - assertOptimizedEquals("nullif('a', 'b')", "'a'"); - assertOptimizedEquals("nullif(null, 'b')", "null"); - assertOptimizedEquals("nullif('a', null)", "'a'"); - - assertOptimizedEquals("nullif(1, 1)", "null"); - assertOptimizedEquals("nullif(1, 2)", "1"); - assertOptimizedEquals("nullif(1, BIGINT '2')", "1"); - assertOptimizedEquals("nullif(1, 20000000000)", "1"); - assertOptimizedEquals("nullif(1.0E0, 1)", "null"); - assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null"); - assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0"); - assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null"); - assertOptimizedEquals("nullif(1, 2-1)", "null"); - assertOptimizedEquals("nullif(null, null)", "null"); - assertOptimizedEquals("nullif(1, null)", "1"); - assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)"); - assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)"); - assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)"); - - assertOptimizedEquals("nullif(1.1, 1.2)", "1.1"); - assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null"); - assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null"); - assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null"); - assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true"); - assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); - assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); - } - - @Test - public void testNegative() - { - assertOptimizedEquals("-(1)", "-1"); - assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'"); - assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)"); - assertOptimizedEquals("-(1+1)", "-2"); - assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'"); - assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null"); - assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)"); - assertOptimizedEquals("-(1.1+1.2)", "-2.3"); - assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))"); - assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90"); - assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))"); - } - - @Test - public void testNot() - { - assertOptimizedEquals("not true", "false"); - assertOptimizedEquals("not false", "true"); - assertOptimizedEquals("not null", "null"); - assertOptimizedEquals("not 1=1", "false"); - assertOptimizedEquals("not 1=BIGINT '1'", "false"); - assertOptimizedEquals("not 1!=1", "true"); - assertOptimizedEquals("not unbound_long=1", "not unbound_long=1"); - assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2"); - } - - @Test - public void testFunctionCall() - { - assertOptimizedEquals("abs(-5)", "5"); - assertOptimizedEquals("abs(-10-5)", "15"); - assertOptimizedEquals("abs(-bound_integer + 1)", "1233"); - assertOptimizedEquals("abs(-bound_long + 1)", "1233"); - assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233"); - assertOptimizedEquals("abs(-bound_long)", "1234"); - assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)"); - assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)"); - assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))"); - assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))"); - assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))"); - } - - @Test - public void testNonDeterministicFunctionCall() - { - // optimize should do nothing - assertOptimizedEquals("random()", "random()"); - - // evaluate should execute - Object value = evaluate("random()", false); - assertTrue(value instanceof Double); - double randomValue = (double) value; - assertTrue(0 <= randomValue && randomValue < 1); - } - - @Test - public void testCppFunctionCall() - { - METADATA.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); - assertOptimizedEquals("json.test_schema.square(-5)", "json.test_schema.square(-5)"); - } - - @Test - public void testCppAggregateFunctionCall() - { - METADATA.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); - assertOptimizedEquals("json.test_schema.avg(1.0)", "json.test_schema.avg(1.0)"); - } - // Run this method exactly once. private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) { @@ -439,983 +83,11 @@ private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAn } @Test - public void testBetween() - { - assertOptimizedEquals("3 between 2 and 4", "true"); - assertOptimizedEquals("2 between 3 and 4", "false"); - assertOptimizedEquals("null between 2 and 4", "null"); - assertOptimizedEquals("3 between null and 4", "null"); - assertOptimizedEquals("3 between 2 and null", "null"); - - assertOptimizedEquals("'cc' between 'b' and 'd'", "true"); - assertOptimizedEquals("'b' between 'cc' and 'd'", "false"); - assertOptimizedEquals("null between 'b' and 'd'", "null"); - assertOptimizedEquals("'cc' between null and 'd'", "null"); - assertOptimizedEquals("'cc' between 'b' and null", "null"); - - assertOptimizedEquals("bound_integer between 1000 and 2000", "true"); - assertOptimizedEquals("bound_integer between 3 and 4", "false"); - assertOptimizedEquals("bound_long between 1000 and 2000", "true"); - assertOptimizedEquals("bound_long between 3 and 4", "false"); - assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true"); - assertOptimizedEquals("bound_string between 'e' and 'i'", "true"); - assertOptimizedEquals("bound_string between 'a' and 'b'", "false"); - - assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001"); - assertOptimizedEquals( - "bound_string between unbound_string and 'bar'", - format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH)); - - assertOptimizedEquals("1.15 between 1.1 and 1.2", "true"); - assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true"); - assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true"); - assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false"); - } - - @Test - public void testExtract() - { - DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey())); - double seconds = dateTime.getMillis() / 1000.0; - - assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001"); - assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8"); - assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34"); - assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234"); - assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22"); - assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4"); - assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5"); - assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7"); - assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9"); - - assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001"); - assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3"); - assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8"); - assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34"); - assertOptimizedEquals("extract (DOW from bound_timestamp)", "2"); - assertOptimizedEquals("extract (DOY from bound_timestamp)", "233"); - assertOptimizedEquals("extract (DAY from bound_timestamp)", "21"); - assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16"); - assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4"); - assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5"); - // todo reenable when cast as timestamp with time zone is implemented - // todo add bound timestamp with time zone - //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0"); - //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0"); - - assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)"); - assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8"); - } - - @Test - public void testIn() - { - assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); - assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); - assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); - - assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); - assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); - assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); - - assertOptimizedEquals("null in (2, null, 3, 5)", "null"); - assertOptimizedEquals("3 in (2, null)", "null"); - - assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); - assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); - assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); - assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); - assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); - - assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); - assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); - assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); - assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); - assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); - - assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); - assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); - assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); - assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); - - assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); - assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); - - assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); - assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); - - assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); - assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); - assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); - assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); - assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); - } - - @Test - public void testInComplexTypes() - { - assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null"); - assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null"); - assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null"); - assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null"); - - assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true"); - assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false"); - assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true"); - assertEvaluatedEquals("ROW(1) IN (null)", "null"); - assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null"); - assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null"); - assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null"); - assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null"); - - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - } - - @Test - public void testCurrentTimestamp() - { - double current = TEST_SESSION.getStartTime() / 1000.0; - assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true"); - double future = current + TimeUnit.MINUTES.toSeconds(1); - assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false"); - } - - @Test - public void testCurrentUser() - throws Exception - { - assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'"); - } - - @Test - public void testCastToString() - { - // integer - assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'"); - assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'"); - - // bigint - assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'"); - assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'"); - assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'"); - - // double - assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'"); - assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'"); - assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'"); - assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'"); - - // boolean - assertOptimizedEquals("cast(true as VARCHAR)", "'true'"); - assertOptimizedEquals("cast(false as VARCHAR)", "'false'"); - - // string - assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'"); - assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'"); - - // null - assertOptimizedEquals("cast(null as VARCHAR)", "null"); - - // decimal - assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'"); - // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); - } - - @Test - public void testCastBigintToBoundedVarchar() - { - assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); - assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); - - try { - evaluate("CAST(12300000000 AS varchar(3))", true); - fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); - } - catch (PrestoException e) { - try { - assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); - assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); - } - catch (Throwable failure) { - failure.addSuppressed(e); - throw failure; - } - } - - try { - evaluate("CAST(-12300000000 AS varchar(3))", true); - } - catch (PrestoException e) { - try { - assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); - assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); - } - catch (Throwable failure) { - failure.addSuppressed(e); - throw failure; - } - } - } - - @Test - public void testCastToBoolean() - { - // integer - assertOptimizedEquals("cast(123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(0 as BOOLEAN)", "false"); - - // bigint - assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false"); - - // boolean - assertOptimizedEquals("cast(true as BOOLEAN)", "true"); - assertOptimizedEquals("cast(false as BOOLEAN)", "false"); - - // string - assertOptimizedEquals("cast('true' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('false' as BOOLEAN)", "false"); - assertOptimizedEquals("cast('t' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('f' as BOOLEAN)", "false"); - assertOptimizedEquals("cast('1' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('0' as BOOLEAN)", "false"); - - // null - assertOptimizedEquals("cast(null as BOOLEAN)", "null"); - - // double - assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false"); - - // decimal - assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false"); - assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false"); - } - - @Test - public void testCastToBigint() - { - // integer - assertOptimizedEquals("cast(0 as BIGINT)", "0"); - assertOptimizedEquals("cast(123 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123 as BIGINT)", "-123"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0"); - assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123"); - assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123"); - - // double - assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123"); - assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123"); - - // boolean - assertOptimizedEquals("cast(true as BIGINT)", "1"); - assertOptimizedEquals("cast(false as BIGINT)", "0"); - - // string - assertOptimizedEquals("cast('123' as BIGINT)", "123"); - assertOptimizedEquals("cast('-123' as BIGINT)", "-123"); - - // null - assertOptimizedEquals("cast(null as BIGINT)", "null"); - - // decimal - assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1"); - assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8"); - assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890"); - assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0"); - } - - @Test - public void testCastToInteger() - { - // integer - assertOptimizedEquals("cast(0 as INTEGER)", "0"); - assertOptimizedEquals("cast(123 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123 as INTEGER)", "-123"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0"); - assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123"); - assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123"); - - // double - assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123"); - assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123"); - - // boolean - assertOptimizedEquals("cast(true as INTEGER)", "1"); - assertOptimizedEquals("cast(false as INTEGER)", "0"); - - // string - assertOptimizedEquals("cast('123' as INTEGER)", "123"); - assertOptimizedEquals("cast('-123' as INTEGER)", "-123"); - - // null - assertOptimizedEquals("cast(null as INTEGER)", "null"); - } - - @Test - public void testCastToDouble() - { - // integer - assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0"); - assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0"); - - // double - assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0"); - assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0"); - - // string - assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0"); - assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0"); - - // null - assertOptimizedEquals("cast(null as DOUBLE)", "null"); - - // boolean - assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0"); - assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0"); - - // decimal - assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'"); - assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'"); - assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'"); - assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'"); - } - - @Test - public void testCastToDecimal() - { - // long - assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'"); - assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))"); - - // double - assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - - // string - assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - - // null - assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null"); - assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null"); - - // boolean - assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'"); - assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'"); - assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'"); - assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))"); - - // decimal - assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - } - - @Test - public void testCastOptimization() - { - assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)"); - assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string"); - assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer"); - assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))"); - } - - @Test - public void testTryCast() - { - assertOptimizedEquals("try_cast(null as BIGINT)", "null"); - assertOptimizedEquals("try_cast(123 as BIGINT)", "123"); - assertOptimizedEquals("try_cast(null as INTEGER)", "null"); - assertOptimizedEquals("try_cast(123 as INTEGER)", "123"); - assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'"); - assertOptimizedEquals("try_cast('foo' as BIGINT)", "null"); - assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)"); - assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null"); - } - - @Test - public void testReservedWithDoubleQuotes() - { - assertOptimizedEquals("\"time\"", "\"time\""); - } - - @Test - public void testSearchCase() - { - assertOptimizedEquals("case " + - "when true then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case " + - "when false then 10000000000 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case " + - "when bound_long = 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_long " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_integer = 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_integer " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_long = 1234 then 33 " + - "else unbound_long " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_long " + - "else unbound_long " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then unbound_long " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_integer = 1234 then 33 " + - "else unbound_integer " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_integer " + - "else unbound_integer " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then unbound_integer " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when unbound_long = 1234 then 33 " + - "else 1 " + - "end", - "" + - "case " + - "when unbound_long = 1234 then 33 " + - "else 1 " + - "end"); - - assertOptimizedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)"); - - assertOptimizedEquals("case " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - "CAST(3.3 AS DECIMAL(20,10))"); - - assertOptimizedEquals("case " + - "when false then 1 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); - assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - } - - @Test - public void testSimpleCase() - { - assertOptimizedEquals("case 1 " + - "when 1 then 32 + 1 " + - "when 1 then 34 " + - "end", - "33"); - - assertOptimizedEquals("case null " + - "when true then 33 " + - "end", - "null"); - assertOptimizedEquals("case null " + - "when true then 33 " + - "else 33 " + - "end", - "33"); - assertOptimizedEquals("case 33 " + - "when null then 1 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case null " + - "when true then 3300000000 " + - "end", - "null"); - assertOptimizedEquals("case null " + - "when true then 3300000000 " + - "else 3300000000 " + - "end", - "3300000000"); - assertOptimizedEquals("case 33 " + - "when null then 3300000000 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case true " + - "when true then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else 33 end", - "33"); - - assertOptimizedEquals("case bound_long " + - "when 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case 1234 " + - "when bound_long then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_long " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case bound_integer " + - "when 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case 1234 " + - "when bound_integer then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_integer " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case bound_long " + - "when 1234 then 33 " + - "else unbound_long " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_long " + - "else unbound_long " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then unbound_long " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case unbound_long " + - "when 1234 then 33 " + - "else 1 " + - "end", - "" + - "case unbound_long " + - "when 1234 then 33 " + - "else 1 " + - "end"); - - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 33 then unbound_long " + - "else 1 " + - "end", - "unbound_long"); - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 33 then 1 " + - "when unbound_long then 2 " + - "else 1 " + - "end", - "1"); - assertOptimizedEquals("case 33 " + - "when unbound_long then 0 " + - "when 1 then 1 " + - "when 33 then 2 " + - "else 0 " + - "end", - "case 33 " + - "when unbound_long then 0 " + - "else 2 " + - "end"); - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 1 then 1 " + - "else unbound_long " + - "end", - "unbound_long"); - assertOptimizedEquals("case 33 " + - "when unbound_long then 0 " + - "when 1 then 1 " + - "when unbound_long2 then 2 " + - "else 3 " + - "end", - "case 33 " + - "when unbound_long then 0 " + - "when unbound_long2 then 2 " + - "else 3 " + - "end"); - - assertOptimizedEquals("case true " + - "when unbound_long = 1 then 1 " + - "when 0 / 0 = 0 then 2 " + - "else 33 end", - "" + - "case true " + - "when unbound_long = 1 then 1 " + - "when 0 / 0 = 0 then 2 else 33 " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when 123 * 10 + unbound_long then 1 = 1 " + - "else 1 = 2 " + - "end", - "" + - "case bound_long when 1230 + unbound_long then true " + - "else false " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when unbound_long then 2 + 2 " + - "end", - "" + - "case bound_long " + - "when unbound_long then 4 " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when unbound_long then 2 + 2 " + - "when 1 then null " + - "when 2 then null " + - "end", - "" + - "case bound_long " + - "when unbound_long then 4 " + - "end"); - - assertOptimizedMatches("case 1 " + - "when unbound_long then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case BIGINT '1' " + - "when unbound_long then 1 " + - "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + - "else 1 " + - "end"); - - assertOptimizedMatches("case 1 " + - "when 0 / 0 then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case 1 " + - "when cast(fail(8, 'ignored failure message') as integer) then 1 " + - "when cast(fail(8, 'ignored failure message') as integer) then 2 " + - "else 1 " + - "end"); - - assertOptimizedEquals("case true " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - "2.2"); - - // TODO enabled when DECIMAL is default for literal: -// assertOptimizedEquals("case true " + -// "when false then 1234567890.0987654321 " + -// "when true then 3.3 " + -// "end", -// "CAST(3.3 AS DECIMAL(20,10))"); - - assertOptimizedEquals("case true " + - "when false then 1 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); - assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - } - - @Test - public void testCoalesce() - { - assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)"); - assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)"); - assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)"); - assertOptimizedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", - "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long"); - assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)"); - assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6"); - assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6"); - assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); - assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)"); - assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); - } - - @Test - public void testIf() - { - assertOptimizedEquals("IF(2 = 2, 3, 4)", "3"); - assertOptimizedEquals("IF(1 = 2, 3, 4)", "4"); - assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4"); - assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4"); - - assertOptimizedEquals("IF(true, 3, 4)", "3"); - assertOptimizedEquals("IF(false, 3, 4)", "4"); - assertOptimizedEquals("IF(null, 3, 4)", "4"); - - assertOptimizedEquals("IF(true, 3, null)", "3"); - assertOptimizedEquals("IF(false, 3, null)", "null"); - assertOptimizedEquals("IF(true, null, 4)", "null"); - assertOptimizedEquals("IF(false, null, 4)", "4"); - assertOptimizedEquals("IF(true, null, null)", "null"); - assertOptimizedEquals("IF(false, null, null)", "null"); - - assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0"); - assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0"); - - assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'"); - assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'"); - - assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01"); - assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02"); - assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123"); - assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123"); - } - - @Test - public void testLike() - { - assertOptimizedEquals("'a' LIKE 'a'", "true"); - assertOptimizedEquals("'' LIKE 'a'", "false"); - assertOptimizedEquals("'abc' LIKE 'a'", "false"); - - assertOptimizedEquals("'a' LIKE '_'", "true"); - assertOptimizedEquals("'' LIKE '_'", "false"); - assertOptimizedEquals("'abc' LIKE '_'", "false"); - - assertOptimizedEquals("'a' LIKE '%'", "true"); - assertOptimizedEquals("'' LIKE '%'", "true"); - assertOptimizedEquals("'abc' LIKE '%'", "true"); - - assertOptimizedEquals("'abc' LIKE '___'", "true"); - assertOptimizedEquals("'ab' LIKE '___'", "false"); - assertOptimizedEquals("'abcd' LIKE '___'", "false"); - - assertOptimizedEquals("'abc' LIKE 'abc'", "true"); - assertOptimizedEquals("'xyz' LIKE 'abc'", "false"); - assertOptimizedEquals("'abc0' LIKE 'abc'", "false"); - assertOptimizedEquals("'0abc' LIKE 'abc'", "false"); - - assertOptimizedEquals("'abc' LIKE 'abc%'", "true"); - assertOptimizedEquals("'abc0' LIKE 'abc%'", "true"); - assertOptimizedEquals("'0abc' LIKE 'abc%'", "false"); - - assertOptimizedEquals("'abc' LIKE '%abc'", "true"); - assertOptimizedEquals("'0abc' LIKE '%abc'", "true"); - assertOptimizedEquals("'abc0' LIKE '%abc'", "false"); - - assertOptimizedEquals("'abc' LIKE '%abc%'", "true"); - assertOptimizedEquals("'0abc' LIKE '%abc%'", "true"); - assertOptimizedEquals("'abc0' LIKE '%abc%'", "true"); - assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true"); - assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false"); - - assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true"); - - assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false"); - - // ensure regex chars are escaped - assertOptimizedEquals("'\' LIKE '\'", "true"); - assertOptimizedEquals("'.*' LIKE '.*'", "true"); - assertOptimizedEquals("'[' LIKE '['", "true"); - assertOptimizedEquals("']' LIKE ']'", "true"); - assertOptimizedEquals("'{' LIKE '{'", "true"); - assertOptimizedEquals("'}' LIKE '}'", "true"); - assertOptimizedEquals("'?' LIKE '?'", "true"); - assertOptimizedEquals("'+' LIKE '+'", "true"); - assertOptimizedEquals("'(' LIKE '('", "true"); - assertOptimizedEquals("')' LIKE ')'", "true"); - assertOptimizedEquals("'|' LIKE '|'", "true"); - assertOptimizedEquals("'^' LIKE '^'", "true"); - assertOptimizedEquals("'$' LIKE '$'", "true"); - - assertOptimizedEquals("null LIKE '%'", "null"); - assertOptimizedEquals("'a' LIKE null", "null"); - assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); - assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); - - assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true"); - - assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true"); - assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false"); - } - - @Test - public void testLikeOptimization() - { - assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)"); - - assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); - - assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); - assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); - - assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); - assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); - - assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); - } - - @Test - public void testInvalidLike() + public void testBind() { - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'")); + assertOptimizedEquals("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", "apply(90, \"$internal$bind\"(9, (x, y) -> x + y))"); + evaluate("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", true); + evaluate("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", true); } @Test @@ -1435,42 +107,6 @@ public void testLambda() assertEquals(evaluate("reduce(ARRAY[1, 5], 0, (x, y) -> x + y, x -> x)", true), 6L); } - @Test - public void testBind() - { - assertOptimizedEquals("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", "apply(90, \"$internal$bind\"(9, (x, y) -> x + y))"); - evaluate("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", true); - evaluate("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", true); - } - - @Test - public void testFailedExpressionOptimization() - { - assertOptimizedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", - "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END"); - - assertOptimizedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", - "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END"); - - assertOptimizedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", - "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END"); - - assertOptimizedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", - "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end"); - - assertOptimizedMatches("case when unbound_boolean then 1 else 0 / 0 end", - "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end"); - - assertOptimizedMatches("case when unbound_boolean then 0 / 0 else 1 end", - "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testOptimizeDivideByZero() - { - optimize("0 / 0"); - } - @Test public void testMassiveArray() { @@ -1485,90 +121,6 @@ public void testMassiveArray() optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "ARRAY['" + i + "']").iterator()))); } - @Test - public void testArrayConstructor() - { - optimize("ARRAY []"); - assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", - "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", - "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", - "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); - } - - @Test - public void testRowConstructor() - { - optimize("ROW(NULL)"); - optimize("ROW(1)"); - optimize("ROW(unbound_long + 0)"); - optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)"); - optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)"); - optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]"); - optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]"); - optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))"); - optimize("ROW(unbound_string, bound_string)"); - - optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]"); - optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]"); - - optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]"); - optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]"); - } - - @Test - public void testDereference() - { - optimize("ARRAY []"); - assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", - "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", - "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", - "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); - } - - @Test - public void testRowDereference() - { - optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a"); - } - - @Test - public void testRowSubscript() - { - assertOptimizedEquals("ROW (1, 'a', true)[3]", "true"); - assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testArraySubscriptConstantNegativeIndex() - { - optimize("ARRAY [1, 2, 3][-1]"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testArraySubscriptConstantZeroIndex() - { - optimize("ARRAY [1, 2, 3][0]"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testMapSubscriptMissingKey() - { - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[-1]"); - } - - @Test - public void testMapSubscriptConstantIndexes() - { - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]"); - optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]"); - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]"); - optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]"); - } - @Test(timeOut = 60000) public void testLikeInvalidUtf8() { @@ -1577,21 +129,13 @@ public void testLikeInvalidUtf8() } @Test - public void testLiterals() + public void testLikeSerializable() { - optimize("date '2013-04-03' + unbound_interval"); - optimize("time '03:04:05.321' + unbound_interval"); - optimize("time '03:04:05.321 UTC' + unbound_interval"); - optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval"); - optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval"); - - optimize("interval '3' day * unbound_long"); - optimize("interval '3' year * unbound_long"); - - assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true"); + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false"); } - private static void assertLike(byte[] value, String pattern, boolean expected) + private void assertLike(byte[] value, String pattern, boolean expected) { Expression predicate = new LikePredicate( rawStringLiteral(Slices.wrappedBuffer(value)), @@ -1612,12 +156,23 @@ public Slice getSlice() }; } - private static void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + @Override + public void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel) { - assertEquals(optimize(actual), optimize(expected)); + assertRoundTrip(expression); + Expression translatedExpression = expression(expression); + RowExpression rowExpression = toRowExpression(translatedExpression); + + Object expressionResult = optimize(translatedExpression); + if (expressionResult instanceof Expression) { + expressionResult = toRowExpression((Expression) expressionResult); + } + Object rowExpressionResult = optimize(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); } - private static void assertRowExpressionEquals(Level level, @Language("SQL") String actual, @Language("SQL") String expected) + private void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected) { Object actualResult = optimize(toRowExpression(expression(actual)), level); Object expectedResult = optimize(toRowExpression(expression(expected)), level); @@ -1628,7 +183,14 @@ private static void assertRowExpressionEquals(Level level, @Language("SQL") Stri assertEquals(actualResult, expectedResult); } - private static void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + @Override + public void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertEquals(optimize(actual), optimize(expected)); + } + + @Override + public void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) { // replaces FunctionCalls to FailureFunction by fail() Object actualOptimized = optimize(actual); @@ -1640,7 +202,8 @@ private static void assertOptimizedMatches(@Language("SQL") String actual, @Lang rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); } - private static Object optimize(@Language("SQL") String expression) + @Override + public Object optimize(@Language("SQL") String expression) { assertRoundTrip(expression); @@ -1653,186 +216,40 @@ private static Object optimize(@Language("SQL") String expression) return expressionResult; } - private static Expression expression(String expression) - { - return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - } - - private static RowExpression toRowExpression(Expression expression) + @Override + public void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) { - return TRANSLATOR.translate(expression, SYMBOL_TYPES); + assertEquals(evaluate(actual, true), evaluate(expected, true)); } - private static Object optimize(Expression expression) + private Object optimize(RowExpression expression, ExpressionOptimizer.Level level) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); - return interpreter.optimize(variable -> { + return new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level).optimize(variable -> { Symbol symbol = new Symbol(variable.getName()); Object value = symbolConstant(symbol); if (value == null) { - return symbol.toSymbolReference(); + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); } return value; }); } - private static Object optimize(RowExpression expression, Level level) + private Object optimize(Expression expression) { - return new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level).optimize(variable -> { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); + return interpreter.optimize(variable -> { Symbol symbol = new Symbol(variable.getName()); Object value = symbolConstant(symbol); if (value == null) { - return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + return symbol.toSymbolReference(); } return value; }); } - private static void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) - { - assertRoundTrip(expression); - Expression translatedExpression = expression(expression); - RowExpression rowExpression = toRowExpression(translatedExpression); - - Object expressionResult = optimize(translatedExpression); - if (expressionResult instanceof Expression) { - expressionResult = toRowExpression((Expression) expressionResult); - } - Object rowExpressionResult = optimize(rowExpression, optimizationLevel); - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); - } - - private static Object symbolConstant(Symbol symbol) - { - switch (symbol.getName().toLowerCase(ENGLISH)) { - case "bound_integer": - return 1234L; - case "bound_long": - return 1234L; - case "bound_string": - return utf8Slice("hello"); - case "bound_double": - return 12.34; - case "bound_date": - return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis(); - case "bound_time": - return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis(); - case "bound_timestamp": - return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(); - case "bound_pattern": - return utf8Slice("%el%"); - case "bound_timestamp_with_timezone": - return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z")); - case "bound_varbinary": - return Slices.wrappedBuffer((byte) 0xab); - case "bound_decimal_short": - return 12345L; - case "bound_decimal_long": - return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123")); - } - return null; - } - - private static void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) - { - if (rowExpressionResult instanceof RowExpression) { - // Cannot be completely evaluated into a constant; compare expressions - assertTrue(expressionResult instanceof Expression); - - // It is tricky to check the equivalence of an expression and a row expression. - // We rely on the optimized translator to fill the gap. - RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); - assertRowExpressionEvaluationEquals(translated, rowExpressionResult); - } - else { - // We have constants; directly compare - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - } - } - - /** - * Assert the evaluation result of two row expressions equivalent - * no matter they are constants or remaining row expressions. - */ - private static void assertRowExpressionEvaluationEquals(Object left, Object right) - { - if (right instanceof RowExpression) { - assertTrue(left instanceof RowExpression); - // assertEquals(((RowExpression) left).getType(), ((RowExpression) right).getType()); - if (left instanceof ConstantExpression) { - if (isRemovableCast(right)) { - assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); - return; - } - assertTrue(right instanceof ConstantExpression); - assertRowExpressionEvaluationEquals(((ConstantExpression) left).getValue(), ((ConstantExpression) left).getValue()); - } - else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { - assertEquals(left, right); - } - else if (left instanceof CallExpression) { - assertTrue(right instanceof CallExpression); - assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); - assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); - for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { - assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); - } - } - else if (left instanceof SpecialFormExpression) { - assertTrue(right instanceof SpecialFormExpression); - assertEquals(((SpecialFormExpression) left).getForm(), ((SpecialFormExpression) right).getForm()); - assertEquals(((SpecialFormExpression) left).getArguments().size(), ((SpecialFormExpression) right).getArguments().size()); - for (int i = 0; i < ((SpecialFormExpression) left).getArguments().size(); i++) { - assertRowExpressionEvaluationEquals(((SpecialFormExpression) left).getArguments().get(i), ((SpecialFormExpression) right).getArguments().get(i)); - } - } - else { - assertTrue(left instanceof LambdaDefinitionExpression); - assertTrue(right instanceof LambdaDefinitionExpression); - assertEquals(((LambdaDefinitionExpression) left).getArguments(), ((LambdaDefinitionExpression) right).getArguments()); - assertEquals(((LambdaDefinitionExpression) left).getArgumentTypes(), ((LambdaDefinitionExpression) right).getArgumentTypes()); - assertRowExpressionEvaluationEquals(((LambdaDefinitionExpression) left).getBody(), ((LambdaDefinitionExpression) right).getBody()); - } - } - else { - // We have constants; directly compare - if (left instanceof Block) { - assertTrue(right instanceof Block); - assertEquals(blockToSlice((Block) left), blockToSlice((Block) right)); - } - else { - assertEquals(left, right); - } - } - } - - private static boolean isRemovableCast(Object value) - { - if (value instanceof CallExpression && - new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { - Type targetType = ((CallExpression) value).getType(); - Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); - return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); - } - return false; - } - - private static Slice blockToSlice(Block block) - { - // This function is strictly for testing use only - SliceOutput sliceOutput = new DynamicSliceOutput(1000); - BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); - return sliceOutput.slice(); - } - - private static void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) - { - assertEquals(evaluate(actual, true), evaluate(expected, true)); - } - - private static Object evaluate(String expression, boolean deterministic) + @Override + public Object evaluate(@Language("SQL") String expression, boolean deterministic) { assertRoundTrip(expression); @@ -1841,14 +258,7 @@ private static Object evaluate(String expression, boolean deterministic) return evaluate(parsedExpression, deterministic); } - private static void assertRoundTrip(String expression) - { - ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); - assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), - SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); - } - - private static Object evaluate(Expression expression, boolean deterministic) + private Object evaluate(Expression expression, boolean deterministic) { Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java b/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java index 485dbfbacdfc8..70216b1a2c294 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/TestRowExpressionSerde.java @@ -17,6 +17,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; import com.facebook.airlift.stats.cardinality.HyperLogLog; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.block.BlockJsonSerde; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockEncoding; @@ -29,6 +30,7 @@ import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.Metadata; @@ -244,8 +246,10 @@ private JsonCodec getJsonCodec() { Module module = binder -> { binder.install(new JsonModule()); + binder.install(new ThriftCodecModule()); binder.install(new HandleJsonModule()); configBinder(binder).bindConfig(FeaturesConfig.class); + binder.bind(ConnectorManager.class).toProvider(() -> null); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/TestSqlFormatter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/TestSqlFormatter.java new file mode 100644 index 0000000000000..26026ad290d00 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/TestSqlFormatter.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql; + +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.tree.Statement; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.sql.SqlFormatterUtil.getFormattedSql; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; + +public class TestSqlFormatter +{ + @Test + public void testSimpleExpression() + { + assertQuery("SELECT id\nFROM\n public.orders\n"); + assertQuery("SELECT id\nFROM\n \"public\".\"order\"\n"); + assertQuery("SELECT id\nFROM\n \"public\".\"order\"\"2\"\n"); + } + + @Test + public void testQuotedColumnNames() + { + assertQuery("SELECT \"Id\"\nFROM\n public.orders\n"); + assertQuery("SELECT \"order\".\"Name\"\nFROM\n \"public\".\"orders\"\n"); + assertQuery("SELECT \"a\".\"b\".\"C\"\nFROM\n \"schema\".\"table\"\n"); + assertQuery("ALTER TABLE sales.orders RENAME COLUMN \"OrderId\" TO \"OrderID_New\""); + assertQuery("ALTER TABLE sales.orders ADD COLUMN \"Customer\" VARCHAR"); + assertQuery("ALTER TABLE sales.orders DROP COLUMN \"Customer\""); + } + + private void assertQuery(String query) + { + SqlParser parser = new SqlParser(); + Statement statement = parser.createStatement(query, new ParsingOptions()); + String formattedQuery = getFormattedSql(statement, parser, Optional.empty()); + assertEquals(formattedQuery, query); + assertEquals(formattedQuery, query, format("Formatted SQL did not match original for query: %s", query)); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java index 3702dc249a00a..9ea7a163cec3c 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/AbstractAnalyzerTest.java @@ -22,6 +22,24 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.connector.informationSchema.InformationSchemaConnector; import com.facebook.presto.connector.system.SystemConnector; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.MonomorphicStaticReturnTypeFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.OnlyPassThroughFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.PassThroughFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.PolymorphicStaticReturnTypeFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.RequiredColumnsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.SimpleTableFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TableArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TableArgumentRowSemanticsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.CostCalculatorUsingExchanges; +import com.facebook.presto.cost.CostCalculatorWithEstimatedExchanges; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.TaskCountEstimator; +import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.execution.warnings.WarningCollectorConfig; import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; @@ -32,12 +50,16 @@ import com.facebook.presto.metadata.InMemoryNodeManager; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.MaterializedViewDefinition; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.analyzer.ViewDefinitionReferences; import com.facebook.presto.spi.connector.Connector; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.facebook.presto.spi.connector.ConnectorSplitManager; @@ -46,13 +68,28 @@ import com.facebook.presto.spi.function.Parameter; import com.facebook.presto.spi.function.RoutineCharacteristics; import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.Procedure; +import com.facebook.presto.spi.procedure.Procedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.spi.security.AllowAllAccessControl; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.PartitioningProviderManager; +import com.facebook.presto.sql.planner.PlanFragmenter; +import com.facebook.presto.sql.planner.PlanOptimizers; +import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; +import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.facebook.presto.sql.tree.NodeLocation; import com.facebook.presto.sql.tree.Statement; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestProcedureRegistry; import com.facebook.presto.testing.TestingAccessControlManager; import com.facebook.presto.testing.TestingMetadata; import com.facebook.presto.testing.TestingWarningCollector; @@ -62,11 +99,16 @@ import com.google.common.collect.ImmutableMap; import org.intellij.lang.annotations.Language; import org.testng.annotations.BeforeClass; +import org.weakref.jmx.MBeanExporter; +import org.weakref.jmx.testing.TestingMBeanServer; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Consumer; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SystemSessionProperties.CHECK_ACCESS_CONTROL_ON_UTILIZED_COLUMNS_ONLY; import static com.facebook.presto.SystemSessionProperties.CHECK_ACCESS_CONTROL_WITH_SUBFIELDS; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -81,11 +123,14 @@ import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.SQL; import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.spi.session.PropertyMetadata.integerProperty; import static com.facebook.presto.spi.session.PropertyMetadata.stringProperty; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static java.lang.String.format; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -134,8 +179,7 @@ public void setup() CatalogManager catalogManager = new CatalogManager(); transactionManager = createTestTransactionManager(catalogManager); accessControl = new TestingAccessControlManager(transactionManager); - - metadata = createTestMetadataManager(transactionManager); + metadata = createTestMetadataManager(transactionManager, new FeaturesConfig(), new FunctionsConfig(), new TestProcedureRegistry()); metadata.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); @@ -150,6 +194,36 @@ public void setup() metadata.getFunctionAndTypeManager().createFunction(SQL_FUNCTION_SQUARE, true); + metadata.getFunctionAndTypeManager().getTableFunctionRegistry().addTableFunctions(TPCH_CONNECTOR_ID, + ImmutableList.of( + new SimpleTableFunction(), + new TwoScalarArgumentsFunction(), + new TableArgumentFunction(), + new DescriptorArgumentFunction(), + new TableArgumentRowSemanticsFunction(), + new TwoTableArgumentsFunction(), + new OnlyPassThroughFunction(), + new MonomorphicStaticReturnTypeFunction(), + new PolymorphicStaticReturnTypeFunction(), + new PassThroughFunction(), + new RequiredColumnsFunction())); + + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, StandardTypes.VARCHAR)); + arguments.add(new Argument(TABLE_NAME, StandardTypes.VARCHAR)); + + List distributedArguments = new ArrayList<>(); + distributedArguments.add(new DistributedProcedure.Argument(SCHEMA, StandardTypes.VARCHAR)); + distributedArguments.add(new DistributedProcedure.Argument(TABLE_NAME, StandardTypes.VARCHAR)); + List> procedures = new ArrayList<>(); + procedures.add(new Procedure("system", "procedure", arguments)); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_procedure", + distributedArguments, + (session, transactionContext, procedureHandle, fragments, sortOrderIndex) -> null, + (session, transactionContext, procedureHandle, fragments) -> {}, + ignored -> new TestProcedureRegistry.TestProcedureContext())); + metadata.getProcedureRegistry().addProcedures(SECOND_CONNECTOR_ID, procedures); + Catalog tpchTestCatalog = createTestingCatalog(TPCH_CATALOG, TPCH_CONNECTOR_ID); catalogManager.registerCatalog(tpchTestCatalog); metadata.getAnalyzePropertyManager().addProperties(TPCH_CONNECTOR_ID, tpchTestCatalog.getConnector(TPCH_CONNECTOR_ID).getAnalyzeProperties()); @@ -288,6 +362,33 @@ public void setup() ColumnMetadata.builder().setName("z").setType(BIGINT).build())), false)); + // materialized view referencing table in same schema + List baseTables = new ArrayList<>(Collections.singletonList(table2)); + MaterializedViewDefinition.TableColumn baseTableColumns = new MaterializedViewDefinition.TableColumn(table2, "a", true); + + SchemaTableName materializedTable = new SchemaTableName("s1", "mv1"); + MaterializedViewDefinition.TableColumn materializedViewTableColumn = new MaterializedViewDefinition.TableColumn(materializedTable, "a", true); + + List columnMappings = Collections.singletonList( + new MaterializedViewDefinition.ColumnMapping(materializedViewTableColumn, Collections.singletonList(baseTableColumns))); + + MaterializedViewDefinition materializedViewData1 = new MaterializedViewDefinition( + "select a from t2", + "s1", + "mv1", + baseTables, + Optional.of("user"), + Optional.empty(), + columnMappings, + new ArrayList<>(), + Optional.of(new ArrayList<>(Collections.singletonList("a")))); + + ConnectorTableMetadata materializedViewMetadata1 = new ConnectorTableMetadata( + materializedTable, ImmutableList.of(ColumnMetadata.builder().setName("a").setType(BIGINT).build())); + + inSetupTransaction(session -> + metadata.createMaterializedView(session, TPCH_CATALOG, materializedViewMetadata1, materializedViewData1, false)); + // valid view referencing table in same schema String viewData1 = JsonCodec.jsonCodec(ViewDefinition.class).toJson( new ViewDefinition( @@ -488,9 +589,11 @@ private void analyze(Session clientSession, WarningCollector warningCollector, @ .readUncommitted() .readOnly() .execute(clientSession, session -> { - Analyzer analyzer = AbstractAnalyzerTest.createAnalyzer(session, metadata, warningCollector, query); + Analyzer analyzer = AbstractAnalyzerTest.createAnalyzer(session, metadata, warningCollector, Optional.empty(), query); Statement statement = SQL_PARSER.createStatement(query); - analyzer.analyze(statement); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissions(accessControlReferences, analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); }); } @@ -506,7 +609,7 @@ protected void assertFails(SemanticErrorCode error, int line, int column, @Langu protected void assertFails(SemanticErrorCode error, String message, @Language("SQL") String query) { - assertFails(CLIENT_SESSION, error, message, query); + assertFails(CLIENT_SESSION, error, message, query, false); } protected void assertFails(Session session, SemanticErrorCode error, @Language("SQL") String query) @@ -514,6 +617,11 @@ protected void assertFails(Session session, SemanticErrorCode error, @Language(" assertFails(session, error, Optional.empty(), query); } + protected void assertFails(Session session, SemanticErrorCode error, String message, @Language("SQL") String query) + { + assertFails(session, error, message, query, false); + } + private void assertFails(Session session, SemanticErrorCode error, Optional location, @Language("SQL") String query) { try { @@ -542,7 +650,7 @@ private void assertFails(Session session, SemanticErrorCode error, Optional queryExplainer, String query) { return new Analyzer( session, metadata, SQL_PARSER, new AllowAllAccessControl(), - Optional.empty(), + queryExplainer, emptyList(), emptyMap(), warningCollector, - query); + query, + new ViewDefinitionReferences()); + } + + protected static QueryExplainer createTestingQueryExplainer(Session session, AccessControl accessControl, Metadata metadata) + { + try (LocalQueryRunner localQueryRunner = new LocalQueryRunner(session)) { + SqlParser sqlParser = new SqlParser(); + FeaturesConfig featuresConfig = new FeaturesConfig(); + TaskCountEstimator taskCountEstimator = new TaskCountEstimator(localQueryRunner::getNodeCount); + CostCalculator costCalculator = new CostCalculatorUsingExchanges(taskCountEstimator); + List optimizers = new PlanOptimizers( + metadata, + sqlParser, + localQueryRunner.getNodeCount() == 1, + new MBeanExporter(new TestingMBeanServer()), + localQueryRunner.getSplitManager(), + localQueryRunner.getPlanOptimizerManager(), + localQueryRunner.getPageSourceManager(), + localQueryRunner.getStatsCalculator(), + costCalculator, + new CostCalculatorWithEstimatedExchanges(costCalculator, taskCountEstimator), + new CostComparator(featuresConfig), + taskCountEstimator, + new PartitioningProviderManager(), + featuresConfig, + new ExpressionOptimizerManager( + new PluginNodeManager(new InMemoryNodeManager()), + localQueryRunner.getMetadata().getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))), + new TaskManagerConfig(), + localQueryRunner.getAccessControl()) + .getPlanningTimeOptimizers(); + return new QueryExplainer( + optimizers, + new PlanFragmenter(metadata, localQueryRunner.getNodePartitioningManager(), new QueryManagerConfig(), featuresConfig, localQueryRunner.getPlanCheckerProviderManager()), + metadata, + accessControl, + sqlParser, + localQueryRunner.getStatsCalculator(), + costCalculator, + ImmutableMap.of(), + new PlanChecker(featuresConfig, false, localQueryRunner.getPlanCheckerProviderManager())); + } } private Catalog createTestingCatalog(String catalogName, ConnectorId connectorId) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java index 8ee94a582b353..c3236d3e6627a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestAnalyzer.java @@ -26,12 +26,14 @@ import com.facebook.presto.spi.StandardWarningCode; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spiller.NodeSpillConfig; +import com.facebook.presto.sql.parser.ParsingException; import com.facebook.presto.sql.planner.CompilerConfig; import com.facebook.presto.tracing.TracingConfig; import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; import java.util.List; +import java.util.regex.Pattern; import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; import static com.facebook.presto.spi.StandardWarningCode.PERFORMANCE_WARNING; @@ -44,6 +46,8 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_COLUMN_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_PROPERTY; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.DUPLICATE_RELATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.EXPRESSION_NOT_CONSTANT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_FUNCTION_NAME; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_LITERAL; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.INVALID_OFFSET_ROW_COUNT; @@ -72,11 +76,21 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_AGGREGATE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.ORDER_BY_MUST_BE_IN_SELECT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.PROCEDURE_NOT_FOUND; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_AGGREGATION; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.REFERENCE_TO_OUTPUT_ATTRIBUTE_WITHIN_ORDER_BY_GROUPING; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SAMPLE_PERCENTAGE_OUT_OF_RANGE; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.SCHEMA_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.STANDALONE_LAMBDA; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_COLUMN_NOT_FOUND; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_IMPLEMENTATION_ERROR; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_ARGUMENTS; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_COLUMN_REFERENCE; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_COPARTITIONING; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TABLE_FUNCTION_MISSING_ARGUMENT; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TOO_MANY_GROUPING_SETS; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.TYPE_MISMATCH; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.VIEW_ANALYSIS_ERROR; @@ -87,6 +101,7 @@ import static com.facebook.presto.sql.analyzer.SemanticErrorCode.WINDOW_REQUIRES_OVER; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; @@ -152,6 +167,32 @@ void testNoORWarning() assertNoWarning(analyzeWithWarnings("SELECT * FROM t1 JOIN t2 ON t1.a = t2.a \n" + "AND (t1.b = t2.b OR t1.b > t2.b)")); } + @Test + public void testMapFilterWarnings() + { + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v > 1) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)")); + + assertHasWarning( + analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"), + PERFORMANCE_WARNING, + "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); + + assertHasWarning( + analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k IN (1, 3)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)"), + PERFORMANCE_WARNING, + "Function 'presto.default.map_filter' uses a lambda on large maps which is expensive. Consider using map_subset"); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> v IN (20, 30)) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k + v > 25) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(user_features, (k, v) -> k > 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT transform_values(user_features, (k, v) -> v * 2) FROM (VALUES (map(ARRAY[1,2], ARRAY[2,3]))) AS t(user_features)")); + + assertNoWarning(analyzeWithWarnings("SELECT map_filter(x, (k, v) -> k = 2) FROM (VALUES (map(ARRAY[1,2,3], ARRAY[10,20,30]))) AS t(x)")); + } + @Test public void testIgnoreNullWarning() { @@ -406,6 +447,20 @@ public void testWindowsNotAllowed() assertFails(NESTED_WINDOW, "SELECT 1 FROM (VALUES 1) HAVING count(*) OVER () > 1"); } + @Test + public void testCallProcedure() + { + Session session = testSessionBuilder() + .setCatalog("c2") + .setSchema("t4") + .build(); + assertFails(session, PROCEDURE_NOT_FOUND, "call system.not_exist_procedure('a', 'b')"); + assertFails(session, PROCEDURE_NOT_FOUND, "call system.procedure('a', 'b')"); + assertFails(session, MISSING_SCHEMA, "call system.distributed_procedure('s1', 't4')"); + assertFails(session, MISSING_TABLE, "call system.distributed_procedure('s2', 't9')"); + analyze(session, "call system.distributed_procedure('s2', 't4')"); + } + @Test public void testGrouping() { @@ -1913,4 +1968,403 @@ public void testInvalidTemporaryFunctionName() assertFails(INVALID_FUNCTION_NAME, "CREATE TEMPORARY FUNCTION sum() RETURNS INT RETURN 1"); assertFails(INVALID_FUNCTION_NAME, "CREATE TEMPORARY FUNCTION dev.test.foo() RETURNS INT RETURN 1"); } + + @Test + public void testTableFunctionNotFound() + { + assertFails(FUNCTION_NOT_FOUND, + "line 1:21: Table function non_existent_table_function not registered", + "SELECT * FROM TABLE(non_existent_table_function())"); + } + + @Test + public void testTableFunctionArguments() + { + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, "line 1:58: Too many arguments. Expected at most 2 arguments, got 3 arguments", "SELECT * FROM TABLE(system.two_scalar_arguments_function(1, 2, 3))"); + + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo'))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', number => 1))"); + + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', number => 1))"); + + assertFails(TABLE_FUNCTION_INVALID_ARGUMENTS, + "line 1:58: All arguments must be passed by name or all must be passed positionally", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', 1))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', text => 'bar'))"); + + // argument names are resolved in the canonical form + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:73: Duplicate argument name: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', TeXt => 'bar'))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:73: Unexpected argument name: BAR", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'foo', bar => 'bar'))"); + + assertFails(TABLE_FUNCTION_MISSING_ARGUMENT, + "line 1:58: Missing argument: TEXT", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(number => 1))"); + } + + @Test + public void testScalarArgument() + { + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('foo', 1))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:71: Invalid argument NUMBER. Expected expression, got descriptor", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(x integer, y boolean)))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:71: 'descriptor' function is not allowed as a table function argument", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => DESCRIPTOR(1 + 2)))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:71: Invalid argument NUMBER. Expected expression, got table", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => TABLE(t1)))"); + + assertFails(EXPRESSION_NOT_CONSTANT, + "line 1:81: Constant expression cannot contain a subquery", + "SELECT * FROM TABLE(system.two_scalar_arguments_function(text => 'a', number => (SELECT 1)))"); + } + + @Test + public void testTableArgument() + { + // cannot pass a table function as the argument + assertFails(NOT_SUPPORTED, + "line 1:52: Invalid table argument INPUT. Table functions are not allowed as table function arguments", + "SELECT * FROM TABLE(system.table_argument_function(input => my_schema.my_table_function(1)))"); + + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => my_schema.my_table_function(arg => 1)))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:93: mismatched input '=>'."); + + // cannot pass a table function as the argument, also preceding nested table function with TABLE is incorrect + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(my_schema.my_table_function(1))))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:94: mismatched input '('."); + + // a table passed as the argument must be preceded with TABLE + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(t1)))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => t1))"); + + // a query passed as the argument must be preceded with TABLE + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT * FROM t1)))"); + + assertThatThrownBy(() -> analyze("SELECT * FROM TABLE(system.table_argument_function(input => SELECT * FROM t1))")) + .isInstanceOf(ParsingException.class) + .hasMessageContaining("line 1:61: mismatched input 'SELECT'."); + + // query passed as the argument is correlated + analyze("SELECT * FROM t1 CROSS JOIN LATERAL (SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 WHERE a > 0))))"); + + // wrong argument type + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => 'foo'))"); + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got descriptor", + "SELECT * FROM TABLE(system.table_argument_function(input => DESCRIPTOR(x int, y int)))"); + } + + @Test + public void testTableArgumentProperties() + { + analyze("SELECT * FROM TABLE(system.table_argument_function(input => TABLE(t1) PARTITION BY a KEEP WHEN EMPTY ORDER BY b))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:66: Invalid argument INPUT. Partitioning specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) PARTITION BY a))"); + + assertFails(TABLE_FUNCTION_COLUMN_NOT_FOUND, + "line 1:92: Column b is not present in the input relation", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) PARTITION BY b))"); + + assertFails(TABLE_FUNCTION_INVALID_COLUMN_REFERENCE, + "line 1:88: Expected column reference. Actual: 1", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY 1))"); + + assertFails(TYPE_MISMATCH, + "line 1:104: HyperLogLog is not comparable, and therefore cannot be used in PARTITION BY", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT approx_set(1) a) PARTITION BY a))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, "line 1:66: Invalid argument INPUT. Ordering specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) ORDER BY a))"); + + assertFails(TABLE_FUNCTION_COLUMN_NOT_FOUND, + "line 1:88: Column b is not present in the input relation", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY b))"); + + assertFails(TABLE_FUNCTION_INVALID_COLUMN_REFERENCE, + "line 1:88: Expected column reference. Actual: 1", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT 1 a) ORDER BY 1))"); + + assertFails(TYPE_MISMATCH, + "line 1:100: HyperLogLog is not orderable, and therefore cannot be used in ORDER BY", + "SELECT * FROM TABLE(system.table_argument_function(input => TABLE(SELECT approx_set(1) a) ORDER BY a))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:85: Invalid argument INPUT. Empty behavior specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) PRUNE WHEN EMPTY))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:85: Invalid argument INPUT. Empty behavior specified for table argument with row semantics", + "SELECT * FROM TABLE(system.table_argument_row_semantics_function(input => TABLE(t1) KEEP WHEN EMPTY))"); + } + + @Test + public void testDescriptorArgument() + { + analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x integer, y boolean)))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + Pattern.quote("line 1:57: Invalid descriptor argument SCHEMA. Descriptors should be formatted as 'DESCRIPTOR(name [type], ...)'"), + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(1 + 2)))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got expression", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => 1))"); + + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got table", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => TABLE(t1)))"); + + assertFails(TYPE_MISMATCH, + "line 1:78: Unknown type: verybigint", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => DESCRIPTOR(x verybigint)))"); + } + + @Test + public void testCopartitioning() + { + // TABLE(t1) is matched by fully qualified name: tpch.s1.t1. It matches the second copartition item s1.t1. + // Aliased relation TABLE(SELECT 1, 2) t1(x, y) is matched by unqualified name. It matches the first copartition item t1. + analyze("SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(SELECT 1, 2) t1(x, y) PARTITION BY (x, y)" + + "COPARTITION (t1, s1.t1)))"); + + // Copartition items t1, t2 are first matched to arguments by unqualified names, and when no match is found, by fully qualified names. + // TABLE(tpch.s1.t1) is matched by fully qualified name. It matches the first copartition item t1. + // TABLE(s1.t2) is matched by unqualified name: tpch.s1.t2. It matches the second copartition item t2. + analyze("SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(tpch.s1.t1) PARTITION BY (a, b)," + + "input2 => TABLE(s1.t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:153: No table argument found for name: s1.foo", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, s1.foo)))"); + + // Both table arguments are matched by fully qualified name: tpch.s1.t1 + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, "line 1:149: Ambiguous reference: multiple table arguments found for name: t1", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t1) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + // Both table arguments are matched by unqualified name: t1 + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:185: Ambiguous reference: multiple table arguments found for name: t1", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(SELECT 1, 2) t1(a, b) PARTITION BY (a, b)," + + "input2 => TABLE(SELECT 3, 4) t1(c, d) PARTITION BY (c, d)" + + "COPARTITION (t1, t2)))"); + + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:153: Multiple references to table argument: t1 in COPARTITION clause", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t1)))"); + } + + @Test + public void testCopartitionColumns() + { + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:67: Table tpch.s1.t1 referenced in COPARTITION clause is not partitioned", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1)," + + "input2 => TABLE(t2) PARTITION BY (a, b)" + + "COPARTITION (t1, t2)))"); + + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:67: No partitioning columns specified for table tpch.s1.t1 referenced in COPARTITION clause", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY ()," + + "input2 => TABLE(t2) PARTITION BY ()" + + "COPARTITION (t1, t2)))"); + + assertFails(TABLE_FUNCTION_INVALID_COPARTITIONING, + "line 1:146: Numbers of partitioning columns in copartitioned tables do not match", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(t1) PARTITION BY (a, b)," + + "input2 => TABLE(t2) PARTITION BY (a)" + + "COPARTITION (t1, t2)))"); + + assertFails(TYPE_MISMATCH, + "line 1:169: Partitioning columns in copartitioned tables have incompatible types", + "SELECT * FROM TABLE(system.two_table_arguments_function(" + + "input1 => TABLE(SELECT 1) t1(a) PARTITION BY (a)," + + "input2 => TABLE(SELECT 'x') t2(b) PARTITION BY (b)" + + "COPARTITION (t1, t2)))"); + } + + @Test + public void testNullArguments() + { + // cannot pass null for table argument + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:52: Invalid argument INPUT. Expected table, got expression", + "SELECT * FROM TABLE(system.table_argument_function(input => null))"); + + // the wrong way to pass null for descriptor + assertFails(TABLE_FUNCTION_INVALID_FUNCTION_ARGUMENT, + "line 1:57: Invalid argument SCHEMA. Expected descriptor, got expression", + "SELECT * FROM TABLE(system.descriptor_argument_function(schema => null))"); + + // the right way to pass null for descriptor + analyze("SELECT * FROM TABLE(system.descriptor_argument_function(schema => CAST(null AS DESCRIPTOR)))"); + + // the default value for the argument schema is null + analyze("SELECT * FROM TABLE(system.descriptor_argument_function())"); + + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function(null, null))"); + + // the default value for the second argument is null + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a'))"); + } + + @Test + public void testTableFunctionInvocationContext() + { + // cannot specify relation alias for table function with ONLY PASS THROUGH return type + assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:21: Alias specified for table function with ONLY PASS THROUGH return type", + "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) f(x)"); + + // per SQL standard, relation alias is required for table function with GENERIC TABLE return type. We don't require it. + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x)"); + analyze("SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1))"); + + // per SQL standard, relation alias is required for table function with statically declared return type, only if the function is polymorphic. + // We don't require aliasing polymorphic functions. + analyze("SELECT * FROM TABLE(system.monomorphic_static_return_type_function())"); + analyze("SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) f(x, y)"); + analyze("SELECT * FROM TABLE(system.polymorphic_static_return_type_function(input => TABLE(t1)))"); + analyze("SELECT * FROM TABLE(system.polymorphic_static_return_type_function(input => TABLE(t1))) f(x, y)"); + + // sampled + assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:21: Cannot apply sample to polymorphic table function invocation", + "SELECT * FROM TABLE(system.only_pass_through_function(TABLE(t1))) TABLESAMPLE BERNOULLI (10)"); + + // aliased + sampled + assertFails(TABLE_FUNCTION_INVALID_TABLE_FUNCTION_INVOCATION, + "line 1:15: Cannot apply sample to polymorphic table function invocation", + "SELECT * FROM TABLE(system.two_scalar_arguments_function('a', 1)) f(x) TABLESAMPLE BERNOULLI (10)"); + } + + @Test + public void testTableFunctionAliasing() + { + // case-insensitive name matching + assertFails(TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE, + "line 1:64: Relation alias: T1 is a duplicate of input table name: tpch.s1.t1", + "SELECT * FROM TABLE(system.table_argument_function(TABLE(t1))) T1(x)"); + + assertFails(TABLE_FUNCTION_DUPLICATE_RANGE_VARIABLE, + "line 1:76: Relation alias: t1 is a duplicate of input table name: t1", + "SELECT * FROM TABLE(system.table_argument_function(TABLE(SELECT 1) T1(a))) t1(x)"); + + analyze("SELECT * FROM TABLE(system.table_argument_function(TABLE(t1) t2)) T1(x)"); + + // the original returned relation type is ("column" : BOOLEAN) + analyze("SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias"); + + analyze("SELECT column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); + + analyze("SELECT table_alias.column_alias FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:8: Column 'column' cannot be resolved", + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(column_alias)"); + + assertFails(MISMATCHED_COLUMN_ALIASES, + "line 1:20: Column alias list has 3 entries but table function has 1 proper columns", + "SELECT column FROM TABLE(system.two_scalar_arguments_function('a', 1)) table_alias(col1, col2, col3)"); + + // the original returned relation type is ("a" : BOOLEAN, "b" : INTEGER) + analyze("SELECT column_alias_1, column_alias_2 FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(column_alias_1, column_alias_2)"); + + assertFails(DUPLICATE_COLUMN_NAME, + "line 1:21: Duplicate name of table function proper column: col", + "SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(col, col)"); + + // case-insensitive name matching + assertFails(DUPLICATE_COLUMN_NAME, + "line 1:21: Duplicate name of table function proper column: col", + "SELECT * FROM TABLE(system.monomorphic_static_return_type_function()) table_alias(col, COL)"); + + // pass-through columns of an input table must not be aliased, and must be referenced by the original range variables of their corresponding table arguments + // the function pass_through_function has one proper column ("x" : BOOLEAN), and one table argument with pass-through property + // tha alias applies only to the proper column + analyze("SELECT table_alias.x, t1.a, t1.b, t1.c, t1.d FROM TABLE(system.pass_through_function(TABLE(t1))) table_alias"); + + analyze("SELECT table_alias.x, arg_alias.a, arg_alias.b, arg_alias.c, arg_alias.d FROM TABLE(system.pass_through_function(TABLE(t1) arg_alias)) table_alias"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:23: 't1.a' cannot be resolved", + "SELECT table_alias.x, t1.a FROM TABLE(system.pass_through_function(TABLE(t1) arg_alias)) table_alias"); + + assertFails(MISSING_ATTRIBUTE, + "line 1:23: 'table_alias.a' cannot be resolved", + "SELECT table_alias.x, table_alias.a FROM TABLE(system.pass_through_function(TABLE(t1))) table_alias"); + } + + @Test + public void testTableFunctionRequiredColumns() + { + // the function required_column_function specifies columns 0 and 1 from table argument "INPUT" as required. + analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(t1)))"); + + analyze("SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1, 2, 3)))"); + + assertFails(TABLE_FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(SELECT 1)))"); + + // table s1.t5 has two columns. The second column is hidden. Table function cannot require a hidden column. + assertFails(TABLE_FUNCTION_IMPLEMENTATION_ERROR, + "Invalid index: 1 of required column from table argument INPUT", + "SELECT * FROM TABLE(system.required_columns_function(input => TABLE(s1.t5)))"); + } + + @Test + public void testInvalidMerge() + { + assertFails(MISSING_TABLE, "Table tpch.s1.foo does not exist", + "MERGE INTO foo USING bar ON foo.id = bar.id WHEN MATCHED THEN UPDATE SET id = bar.id + 1"); + + assertFails(NOT_SUPPORTED, "line 1:1: Merging into views is not supported", + "MERGE INTO v1 USING t1 ON v1.a = t1.a WHEN MATCHED THEN UPDATE SET id = bar.id + 1"); + + assertFails(NOT_SUPPORTED, "line 1:1: Merging into materialized views is not supported", + "MERGE INTO mv1 USING t1 ON mv1.a = t1.a WHEN MATCHED THEN UPDATE SET id = bar.id + 1"); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java index df194fc6b4783..e8a8e5a942a00 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestColumnAndSubfieldAnalyzer.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.Subfield; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -24,12 +25,14 @@ import org.testng.annotations.Test; import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.SystemSessionProperties.CHECK_ACCESS_CONTROL_ON_UTILIZED_COLUMNS_ONLY; import static com.facebook.presto.SystemSessionProperties.CHECK_ACCESS_CONTROL_WITH_SUBFIELDS; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static org.testng.Assert.assertEquals; @@ -214,9 +217,12 @@ private void assertTableColumns(@Language("SQL") String query, Map { - Analyzer analyzer = createAnalyzer(s, metadata, WarningCollector.NOOP, query); + Analyzer analyzer = createAnalyzer(s, metadata, WarningCollector.NOOP, Optional.empty(), query); Statement statement = SQL_PARSER.createStatement(query); - Analysis analysis = analyzer.analyze(statement); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissions(accessControlReferences, analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); + assertEquals( analysis.getAccessControlReferences().getTableColumnAndSubfieldReferencesForAccessControl() .values().stream().findFirst().get().entrySet().stream() diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index f2f7bb9aa8843..c1afed9467d17 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -15,7 +15,10 @@ import com.facebook.airlift.configuration.ConfigurationFactory; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.CompressionCodec; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; import com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; @@ -27,14 +30,17 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.RandomizeOuterJoinNullKeyStrategy; import com.facebook.presto.sql.analyzer.FeaturesConfig.SingleStreamSpillerChoice; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; import static com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy.LEGACY; import static com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationPartitioningMergingStrategy.TOP_DOWN; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; @@ -45,11 +51,6 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.SPILL_ENABLED; import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.ORDER_BY_CREATE_TIME; import static com.facebook.presto.sql.analyzer.FeaturesConfig.TaskSpillingStrategy.PER_TASK_MEMORY_THRESHOLD; -import static com.facebook.presto.sql.tree.CreateView.Security.DEFINER; -import static com.facebook.presto.sql.tree.CreateView.Security.INVOKER; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; @@ -59,6 +60,7 @@ public class TestFeaturesConfig public void testDefaults() { assertRecordedDefaults(ConfigAssertions.recordDefaults(FeaturesConfig.class) + .setMaxPrefixesCount(100) .setCpuCostWeight(75) .setMemoryCostWeight(10) .setNetworkCostWeight(15) @@ -85,8 +87,10 @@ public void testDefaults() .setHistoryCanonicalPlanNodeLimit(1000) .setHistoryBasedOptimizerTimeout(new Duration(10, SECONDS)) .setHistoryBasedOptimizerPlanCanonicalizationStrategies("IGNORE_SAFE_CONSTANTS") + .setQueryTypesEnabledForHbo("SELECT,INSERT") .setLogPlansUsedInHistoryBasedOptimizer(false) .setEnforceTimeoutForHBOQueryRegistration(false) + .setHistoryBasedOptimizerEstimateSizeUsingVariables(false) .setRedistributeWrites(false) .setScaleWriters(true) .setWriterMinSize(new DataSize(32, MEGABYTE)) @@ -167,7 +171,7 @@ public void testDefaults() .setSkipRedundantSort(true) .setWarnOnNoTableLayoutFilter("") .setInlineSqlFunctions(true) - .setCheckAccessControlOnUtilizedColumnsOnly(false) + .setCheckAccessControlOnUtilizedColumnsOnly(true) .setCheckAccessControlWithSubfields(false) .setAllowWindowOrderByLiterals(true) .setEnforceFixedDistributionForOutputOperator(false) @@ -185,6 +189,10 @@ public void testDefaults() .setMaterializedViewDataConsistencyEnabled(true) .setMaterializedViewPartitionFilteringEnabled(true) .setQueryOptimizationWithMaterializedViewEnabled(false) + .setLegacyMaterializedViews(true) + .setAllowLegacyMaterializedViewsToggle(false) + .setMaterializedViewAllowFullRefreshEnabled(false) + .setMaterializedViewStaleReadBehavior(MaterializedViewStaleReadBehavior.USE_VIEW_QUERY) .setVerboseRuntimeStatsEnabled(false) .setAggregationIfToFilterRewriteStrategy(AggregationIfToFilterRewriteStrategy.DISABLED) .setAnalyzerType("BUILTIN") @@ -193,12 +201,16 @@ public void testDefaults() .setMaxStageCountForEagerScheduling(25) .setHyperloglogStandardErrorWarningThreshold(0.004) .setPreferMergeJoinForSortedInputs(false) + .setPreferSortMergeJoin(false) + .setSortedExchangeEnabled(false) .setSegmentedAggregationEnabled(false) .setQueryAnalyzerTimeout(new Duration(3, MINUTES)) .setQuickDistinctLimitEnabled(false) .setPushRemoteExchangeThroughGroupId(false) .setOptimizeMultipleApproxPercentileOnSameFieldEnabled(true) + .setOptimizeMultipleApproxDistinctOnSameTypeEnabled(false) .setNativeExecutionEnabled(false) + .setBuiltInSidecarFunctionsEnabled(false) .setDisableTimeStampWithTimeZoneForNative(false) .setDisableIPAddressForNative(false) .setNativeExecutionExecutablePath("./presto_server") @@ -207,6 +219,7 @@ public void testDefaults() .setNativeEnforceJoinBuildInputPartition(true) .setRandomizeOuterJoinNullKeyEnabled(false) .setRandomizeOuterJoinNullKeyStrategy(RandomizeOuterJoinNullKeyStrategy.DISABLED) + .setRandomizeNullSourceKeyInSemiJoinStrategy(FeaturesConfig.RandomizeNullSourceKeyInSemiJoinStrategy.DISABLED) .setShardedJoinStrategy(FeaturesConfig.ShardedJoinStrategy.DISABLED) .setJoinShardCount(100) .setOptimizeConditionalAggregationEnabled(false) @@ -239,6 +252,7 @@ public void testDefaults() .setCteFilterAndProjectionPushdownEnabled(true) .setGenerateDomainFilters(false) .setRewriteExpressionWithConstantVariable(true) + .setOptimizeConditionalApproxDistinct(true) .setDefaultWriterReplicationCoefficient(3.0) .setDefaultViewSecurityMode(DEFINER) .setCteHeuristicReplicationThreshold(4) @@ -252,20 +266,34 @@ public void testDefaults() .setPrestoSparkExecutionEnvironment(false) .setSingleNodeExecutionEnabled(false) .setNativeExecutionScaleWritersThreadsEnabled(false) - .setNativeExecutionTypeRewriteEnabled(false) .setEnhancedCTESchedulingEnabled(true) .setExpressionOptimizerName("default") .setExcludeInvalidWorkerSessionProperties(false) .setAddExchangeBelowPartialAggregationOverGroupId(false) + .setAddDistinctBelowSemiJoinBuild(false) + .setPushdownSubfieldForMapFunctions(true) + .setPushdownSubfieldForCardinality(false) + .setUtilizeUniquePropertyInQueryPlanning(true) + .setExpressionOptimizerUsedInRowExpressionRewrite("") .setInnerJoinPushdownEnabled(false) + .setBroadcastSemiJoinForDelete(true) .setInEqualityJoinPushdownEnabled(false) - .setPrestoSparkExecutionEnvironment(false)); + .setRewriteMinMaxByToTopNEnabled(false) + .setPrestoSparkExecutionEnvironment(false) + .setMaxSerializableObjectSize(1000) + .setTableScanShuffleParallelismThreshold(0.1) + .setTableScanShuffleStrategy(FeaturesConfig.ShuffleForTableScanStrategy.DISABLED) + .setSkipPushdownThroughExchangeForRemoteProjection(false) + .setUseConnectorProvidedSerializationCodecs(false) + .setRemoteFunctionNamesForFixedParallelism("") + .setRemoteFunctionFixedParallelismTaskCount(10)); } @Test public void testExplicitPropertyMappings() { Map properties = new ImmutableMap.Builder() + .put("max-prefixes-count", "1") .put("cpu-cost-weight", "0.4") .put("memory-cost-weight", "0.3") .put("network-cost-weight", "0.2") @@ -304,8 +332,10 @@ public void testExplicitPropertyMappings() .put("optimizer.use-perfectly-consistent-histories", "true") .put("optimizer.history-canonical-plan-node-limit", "2") .put("optimizer.history-based-optimizer-plan-canonicalization-strategies", "IGNORE_SAFE_CONSTANTS,IGNORE_SCAN_CONSTANTS") + .put("optimizer.query-types-enabled-for-hbo", "SELECT,INSERT,DELETE") .put("optimizer.log-plans-used-in-history-based-optimizer", "true") .put("optimizer.enforce-timeout-for-hbo-query-registration", "true") + .put("optimizer.history-based-optimizer-estimate-size-using-variables", "true") .put("optimizer.history-based-optimizer-timeout", "1s") .put("redistribute-writes", "true") .put("scale-writers", "false") @@ -373,7 +403,7 @@ public void testExplicitPropertyMappings() .put("optimizer.joins-not-null-inference-strategy", "USE_FUNCTION_METADATA") .put("warn-on-no-table-layout-filter", "ry@nlikestheyankees,ds") .put("inline-sql-functions", "false") - .put("check-access-control-on-utilized-columns-only", "true") + .put("check-access-control-on-utilized-columns-only", "false") .put("check-access-control-with-subfields", "true") .put("optimizer.skip-redundant-sort", "false") .put("is-allow-window-order-by-literals", "false") @@ -392,6 +422,10 @@ public void testExplicitPropertyMappings() .put("materialized-view-data-consistency-enabled", "false") .put("consider-query-filters-for-materialized-view-partitions", "false") .put("query-optimization-with-materialized-view-enabled", "true") + .put("experimental.legacy-materialized-views", "false") + .put("experimental.allow-legacy-materialized-views-toggle", "true") + .put("materialized-view-allow-full-refresh-enabled", "true") + .put("materialized-view-stale-read-behavior", "FAIL") .put("analyzer-type", "CRUX") .put("pre-process-metadata-calls", "true") .put("verbose-runtime-stats-enabled", "true") @@ -400,12 +434,16 @@ public void testExplicitPropertyMappings() .put("execution-policy.max-stage-count-for-eager-scheduling", "123") .put("hyperloglog-standard-error-warning-threshold", "0.02") .put("optimizer.prefer-merge-join-for-sorted-inputs", "true") + .put("experimental.optimizer.prefer-sort-merge-join", "true") + .put("experimental.optimizer.sorted-exchange-enabled", "true") .put("optimizer.segmented-aggregation-enabled", "true") .put("planner.query-analyzer-timeout", "10s") .put("optimizer.quick-distinct-limit-enabled", "true") .put("optimizer.push-remote-exchange-through-group-id", "true") .put("optimizer.optimize-multiple-approx-percentile-on-same-field", "false") + .put("optimizer.optimize-multiple-approx-distinct-on-same-type", "true") .put("native-execution-enabled", "true") + .put("built-in-sidecar-functions-enabled", "true") .put("disable-timestamp-with-timezone-for-native-execution", "true") .put("disable-ipaddress-for-native-execution", "true") .put("native-execution-executable-path", "/bin/echo") @@ -414,6 +452,7 @@ public void testExplicitPropertyMappings() .put("native-enforce-join-build-input-partition", "false") .put("optimizer.randomize-outer-join-null-key", "true") .put("optimizer.randomize-outer-join-null-key-strategy", "key_from_outer_join") + .put("optimizer.randomize-null-source-key-in-semi-join-strategy", "always") .put("optimizer.sharded-join-strategy", "cost_based") .put("optimizer.join-shard-count", "200") .put("optimizer.optimize-conditional-aggregation-enabled", "true") @@ -446,6 +485,7 @@ public void testExplicitPropertyMappings() .put("optimizer.skip-hash-generation-for-join-with-table-scan-input", "true") .put("optimizer.generate-domain-filters", "true") .put("optimizer.rewrite-expression-with-constant-variable", "false") + .put("optimizer.optimize-constant-approx-distinct", "false") .put("optimizer.default-writer-replication-coefficient", "5.0") .put("default-view-security-mode", INVOKER.name()) .put("cte-heuristic-replication-threshold", "2") @@ -457,18 +497,32 @@ public void testExplicitPropertyMappings() .put("eager-plan-validation-enabled", "true") .put("eager-plan-validation-thread-pool-size", "2") .put("optimizer.inner-join-pushdown-enabled", "true") + .put("optimizer.broadcast-semi-join-for-delete", "false") .put("optimizer.inequality-join-pushdown-enabled", "true") + .put("optimizer.rewrite-minBy-maxBy-to-topN-enabled", "true") .put("presto-spark-execution-environment", "true") .put("single-node-execution-enabled", "true") .put("native-execution-scale-writer-threads-enabled", "true") - .put("native-execution-type-rewrite-enabled", "true") .put("enhanced-cte-scheduling-enabled", "false") .put("expression-optimizer-name", "custom") .put("exclude-invalid-worker-session-properties", "true") + .put("optimizer.add-distinct-below-semi-join-build", "true") + .put("optimizer.pushdown-subfield-for-map-functions", "false") + .put("optimizer.pushdown-subfield-for-cardinality", "true") + .put("optimizer.utilize-unique-property-in-query-planning", "false") + .put("optimizer.expression-optimizer-used-in-expression-rewrite", "custom") .put("optimizer.add-exchange-below-partial-aggregation-over-group-id", "true") + .put("max_serializable_object_size", "50") + .put("optimizer.table-scan-shuffle-parallelism-threshold", "0.3") + .put("optimizer.table-scan-shuffle-strategy", "ALWAYS_ENABLED") + .put("optimizer.skip-pushdown-through-exchange-for-remote-projection", "true") + .put("use-connector-provided-serialization-codecs", "true") + .put("optimizer.remote-function-names-for-fixed-parallelism", "remote_.*") + .put("optimizer.remote-function-fixed-parallelism-task-count", "100") .build(); FeaturesConfig expected = new FeaturesConfig() + .setMaxPrefixesCount(1) .setCpuCostWeight(0.4) .setMemoryCostWeight(0.3) .setNetworkCostWeight(0.2) @@ -507,8 +561,10 @@ public void testExplicitPropertyMappings() .setHistoryCanonicalPlanNodeLimit(2) .setHistoryBasedOptimizerTimeout(new Duration(1, SECONDS)) .setHistoryBasedOptimizerPlanCanonicalizationStrategies("IGNORE_SAFE_CONSTANTS,IGNORE_SCAN_CONSTANTS") + .setQueryTypesEnabledForHbo("SELECT,INSERT,DELETE") .setLogPlansUsedInHistoryBasedOptimizer(true) .setEnforceTimeoutForHBOQueryRegistration(true) + .setHistoryBasedOptimizerEstimateSizeUsingVariables(true) .setRedistributeWrites(true) .setScaleWriters(false) .setWriterMinSize(new DataSize(42, GIGABYTE)) @@ -577,7 +633,7 @@ public void testExplicitPropertyMappings() .setSkipRedundantSort(false) .setWarnOnNoTableLayoutFilter("ry@nlikestheyankees,ds") .setInlineSqlFunctions(false) - .setCheckAccessControlOnUtilizedColumnsOnly(true) + .setCheckAccessControlOnUtilizedColumnsOnly(false) .setCheckAccessControlWithSubfields(true) .setSkipRedundantSort(false) .setAllowWindowOrderByLiterals(false) @@ -596,6 +652,10 @@ public void testExplicitPropertyMappings() .setMaterializedViewDataConsistencyEnabled(false) .setMaterializedViewPartitionFilteringEnabled(false) .setQueryOptimizationWithMaterializedViewEnabled(true) + .setLegacyMaterializedViews(false) + .setAllowLegacyMaterializedViewsToggle(true) + .setMaterializedViewAllowFullRefreshEnabled(true) + .setMaterializedViewStaleReadBehavior(MaterializedViewStaleReadBehavior.FAIL) .setVerboseRuntimeStatsEnabled(true) .setAggregationIfToFilterRewriteStrategy(AggregationIfToFilterRewriteStrategy.FILTER_WITH_IF) .setAnalyzerType("CRUX") @@ -604,12 +664,16 @@ public void testExplicitPropertyMappings() .setMaxStageCountForEagerScheduling(123) .setHyperloglogStandardErrorWarningThreshold(0.02) .setPreferMergeJoinForSortedInputs(true) + .setPreferSortMergeJoin(true) + .setSortedExchangeEnabled(true) .setSegmentedAggregationEnabled(true) .setQueryAnalyzerTimeout(new Duration(10, SECONDS)) .setQuickDistinctLimitEnabled(true) .setPushRemoteExchangeThroughGroupId(true) .setOptimizeMultipleApproxPercentileOnSameFieldEnabled(false) + .setOptimizeMultipleApproxDistinctOnSameTypeEnabled(true) .setNativeExecutionEnabled(true) + .setBuiltInSidecarFunctionsEnabled(true) .setDisableTimeStampWithTimeZoneForNative(true) .setDisableIPAddressForNative(true) .setNativeExecutionExecutablePath("/bin/echo") @@ -618,6 +682,7 @@ public void testExplicitPropertyMappings() .setNativeEnforceJoinBuildInputPartition(false) .setRandomizeOuterJoinNullKeyEnabled(true) .setRandomizeOuterJoinNullKeyStrategy(RandomizeOuterJoinNullKeyStrategy.KEY_FROM_OUTER_JOIN) + .setRandomizeNullSourceKeyInSemiJoinStrategy(FeaturesConfig.RandomizeNullSourceKeyInSemiJoinStrategy.ALWAYS) .setShardedJoinStrategy(FeaturesConfig.ShardedJoinStrategy.COST_BASED) .setJoinShardCount(200) .setOptimizeConditionalAggregationEnabled(true) @@ -650,6 +715,7 @@ public void testExplicitPropertyMappings() .setCteFilterAndProjectionPushdownEnabled(false) .setGenerateDomainFilters(true) .setRewriteExpressionWithConstantVariable(false) + .setOptimizeConditionalApproxDistinct(false) .setDefaultWriterReplicationCoefficient(5.0) .setDefaultViewSecurityMode(INVOKER) .setCteHeuristicReplicationThreshold(2) @@ -663,14 +729,27 @@ public void testExplicitPropertyMappings() .setPrestoSparkExecutionEnvironment(true) .setSingleNodeExecutionEnabled(true) .setNativeExecutionScaleWritersThreadsEnabled(true) - .setNativeExecutionTypeRewriteEnabled(true) .setEnhancedCTESchedulingEnabled(false) .setExpressionOptimizerName("custom") .setExcludeInvalidWorkerSessionProperties(true) .setAddExchangeBelowPartialAggregationOverGroupId(true) + .setAddDistinctBelowSemiJoinBuild(true) + .setPushdownSubfieldForMapFunctions(false) + .setPushdownSubfieldForCardinality(true) + .setUtilizeUniquePropertyInQueryPlanning(false) + .setExpressionOptimizerUsedInRowExpressionRewrite("custom") .setInEqualityJoinPushdownEnabled(true) + .setBroadcastSemiJoinForDelete(false) + .setRewriteMinMaxByToTopNEnabled(true) .setInnerJoinPushdownEnabled(true) - .setPrestoSparkExecutionEnvironment(true); + .setPrestoSparkExecutionEnvironment(true) + .setMaxSerializableObjectSize(50) + .setTableScanShuffleParallelismThreshold(0.3) + .setTableScanShuffleStrategy(FeaturesConfig.ShuffleForTableScanStrategy.ALWAYS_ENABLED) + .setSkipPushdownThroughExchangeForRemoteProjection(true) + .setUseConnectorProvidedSerializationCodecs(true) + .setRemoteFunctionNamesForFixedParallelism("remote_.*") + .setRemoteFunctionFixedParallelismTaskCount(100); assertFullMapping(properties, expected); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestJavaFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestJavaFeaturesConfig.java index f83f6b3eb03ba..91445461149e5 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestJavaFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestJavaFeaturesConfig.java @@ -14,8 +14,8 @@ package com.facebook.presto.sql.analyzer; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.DataSize; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java index d2837732be856..108ebb4da54e3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestMaterializedViewQueryOptimizer.java @@ -1581,6 +1581,7 @@ private MaterializedViewDefinition createStubConnectorMaterializedViewDefinition viewName, baseTables, Optional.empty(), + Optional.empty(), ImmutableList.of(), ImmutableList.of(), Optional.empty()); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestUtilizedColumnsAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestUtilizedColumnsAnalyzer.java index 0bff1b7961cd7..5d64d7eb66d42 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestUtilizedColumnsAnalyzer.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestUtilizedColumnsAnalyzer.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.analyzer.AccessControlInfo; +import com.facebook.presto.spi.analyzer.AccessControlReferences; import com.facebook.presto.sql.tree.Statement; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -23,10 +24,12 @@ import org.testng.annotations.Test; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; import static org.testng.Assert.assertEquals; @Test(singleThreaded = true) @@ -656,9 +659,12 @@ private void assertUtilizedTableColumns(@Language("SQL") String query, Map { - Analyzer analyzer = createAnalyzer(session, metadata, WarningCollector.NOOP, query); + Analyzer analyzer = createAnalyzer(session, metadata, WarningCollector.NOOP, Optional.empty(), query); Statement statement = SQL_PARSER.createStatement(query); - Analysis analysis = analyzer.analyze(statement); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissions(accessControlReferences, analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); + assertEquals(analysis.getUtilizedTableColumnReferences().entrySet().stream().collect(Collectors.toMap(entry -> extractAccessControlInfo(entry.getKey()), Map.Entry::getValue)), expected); }); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestViewDefinitionCollector.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestViewDefinitionCollector.java new file mode 100644 index 0000000000000..a0407e9d61f46 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestViewDefinitionCollector.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.analyzer; + +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.analyzer.AccessControlReferences; +import com.facebook.presto.sql.tree.Statement; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.facebook.presto.transaction.TransactionBuilder.transaction; +import static com.facebook.presto.util.AnalyzerUtil.checkAccessPermissions; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestViewDefinitionCollector + extends AbstractAnalyzerTest +{ + public void testSelectLeftJoinViews() + { + @Language("SQL") String query = "SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testCreateViewWithNestedViews() + { + @Language("SQL") String query = "CREATE VIEW top_level_view1 AS SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testCreateTableAsSelectWithViews() + { + @Language("SQL") String query = "CREATE TABLE top_level_view1 AS SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainWithViews() + { + @Language("SQL") String query = "EXPLAIN SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainTypeIoWithViews() + { + @Language("SQL") String query = "EXPLAIN (TYPE IO) SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainTypeValidateWithViews() + { + @Language("SQL") String query = "EXPLAIN (TYPE VALIDATE) SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainAnalyzeWithViews() + { + @Language("SQL") String query = "EXPLAIN ANALYZE SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainExplainWithViews() + { + @Language("SQL") String query = "EXPLAIN EXPLAIN SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainExplainTypeValidateWithViews() + { + @Language("SQL") String query = "EXPLAIN EXPLAIN (TYPE VALIDATE) SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainTypeValidateExplainWithViews() + { + @Language("SQL") String query = "EXPLAIN (TYPE VALIDATE) EXPLAIN SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainTypeValidateExplainTypeValidateWithViews() + { + @Language("SQL") String query = "EXPLAIN (TYPE VALIDATE) EXPLAIN (TYPE VALIDATE) SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainAnalyzeExplainWithViews() + { + @Language("SQL") String query = "EXPLAIN ANALYZE EXPLAIN SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainAnalyzeExplainAnalyzeWithViews() + { + @Language("SQL") String query = "EXPLAIN ANALYZE EXPLAIN ANALYZE SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainExplainAnalyzeWithViews() + { + @Language("SQL") String query = "EXPLAIN EXPLAIN ANALYZE SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainAnalyzeExplainTypeValidateWithViews() + { + @Language("SQL") String query = "EXPLAIN ANALYZE EXPLAIN (TYPE VALIDATE) SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + public void testExplainTypeValidateExplainAnalyzeWithViews() + { + @Language("SQL") String query = "EXPLAIN (TYPE VALIDATE) EXPLAIN ANALYZE SELECT view_definer1.a, view_definer1.c, view_invoker2.y FROM view_definer1 left join view_invoker2 on view_invoker2.y = view_definer1.c"; + + assertViewDefinitions(query, ImmutableMap.of( + "tpch.s1.view_invoker2", "select x, y, z from t13", + "tpch.s1.view_definer1", "select a,b,c from t1" + ), ImmutableMap.of()); + } + + private void assertViewDefinitions(@Language("SQL") String query, Map expectedViewDefinitions, Map expectedMaterializedViewDefinitions) + { + transaction(transactionManager, accessControl) + .singleStatement() + .readUncommitted() + .readOnly() + .execute(CLIENT_SESSION, session -> { + Analyzer analyzer = createAnalyzer(session, metadata, WarningCollector.NOOP, Optional.of(createTestingQueryExplainer(session, accessControl, metadata)), query); + Statement statement = SQL_PARSER.createStatement(query); + Analysis analysis = analyzer.analyzeSemantic(statement, false); + AccessControlReferences accessControlReferences = analysis.getAccessControlReferences(); + checkAccessPermissions(accessControlReferences, analysis.getViewDefinitionReferences(), query, session.getPreparedStatements(), session.getIdentity(), accessControl, session.getAccessControlContext()); + + Map viewDefinitionsMap = analysis.getViewDefinitionReferences().getViewDefinitions().entrySet().stream() + .collect(Collectors.toMap( + entry -> entry.getKey().toString(), + entry -> entry.getValue().getOriginalSql())); + Map materializedDefinitionsMap = analysis.getViewDefinitionReferences().getMaterializedViewDefinitions().entrySet().stream() + .collect(Collectors.toMap( + entry -> entry.getKey().toString(), + entry -> entry.getValue().getOriginalSql())); + + assertEquals(viewDefinitionsMap, expectedViewDefinitions); + assertEquals(materializedDefinitionsMap, expectedMaterializedViewDefinitions); + }); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java new file mode 100644 index 0000000000000..3d8240cf134b2 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java @@ -0,0 +1,1711 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.block.BlockSerdeUtil; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.SqlTimestampWithTimeZone; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.ExpressionFormatter; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.EnumLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import org.intellij.lang.annotations.Language; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.LocalDate; +import org.joda.time.LocalTime; +import org.testng.annotations.Test; + +import java.math.BigInteger; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; +import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; +import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; +import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +public abstract class AbstractTestExpressionInterpreter +{ + public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "Integer square", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned()); + + public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), + parseTypeSignature(StandardTypes.DOUBLE), + "Returns mean of doubles", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned(), + FunctionKind.AGGREGATE, + Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); + + public static final int TEST_VARCHAR_TYPE_LENGTH = 17; + public static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() + .put("bound_integer", INTEGER) + .put("bound_long", BIGINT) + .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) + .put("bound_varbinary", VarbinaryType.VARBINARY) + .put("bound_double", DOUBLE) + .put("bound_boolean", BOOLEAN) + .put("bound_date", DATE) + .put("bound_time", TIME) + .put("bound_timestamp", TIMESTAMP) + .put("bound_pattern", VARCHAR) + .put("bound_null_string", VARCHAR) + .put("bound_decimal_short", createDecimalType(5, 2)) + .put("bound_decimal_long", createDecimalType(23, 3)) + .put("time", BIGINT) // for testing reserved identifiers + .put("unbound_integer", INTEGER) + .put("unbound_long", BIGINT) + .put("unbound_long2", BIGINT) + .put("unbound_long3", BIGINT) + .put("unbound_string", VARCHAR) + .put("unbound_double", DOUBLE) + .put("unbound_boolean", BOOLEAN) + .put("unbound_date", DATE) + .put("unbound_time", TIME) + .put("unbound_array", new ArrayType(BIGINT)) + .put("unbound_timestamp", TIMESTAMP) + .put("unbound_interval", INTERVAL_DAY_TIME) + .put("unbound_pattern", VARCHAR) + .put("unbound_null_string", VARCHAR) + .build()); + + public static final SqlParser SQL_PARSER = new SqlParser(); + public static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + public static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + public static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); + + @Test + public void testAnd() + { + assertOptimizedEquals("true and false", "false"); + assertOptimizedEquals("false and true", "false"); + assertOptimizedEquals("false and false", "false"); + + assertOptimizedEquals("true and null", "null"); + assertOptimizedEquals("false and null", "false"); + assertOptimizedEquals("null and true", "null"); + assertOptimizedEquals("null and false", "false"); + assertOptimizedEquals("null and null", "null"); + + assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'"); + assertOptimizedEquals("unbound_string='z' and false", "false"); + assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'"); + assertOptimizedEquals("false and unbound_string='z'", "false"); + + assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2"); + assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0"); + } + + @Test + public void testOr() + { + assertOptimizedEquals("true or true", "true"); + assertOptimizedEquals("true or false", "true"); + assertOptimizedEquals("false or true", "true"); + assertOptimizedEquals("false or false", "false"); + + assertOptimizedEquals("true or null", "true"); + assertOptimizedEquals("null or true", "true"); + assertOptimizedEquals("null or null", "null"); + + assertOptimizedEquals("false or null", "null"); + assertOptimizedEquals("null or false", "null"); + + assertOptimizedEquals("bound_string='z' or true", "true"); + assertOptimizedEquals("bound_string='z' or false", "bound_string='z'"); + assertOptimizedEquals("true or bound_string='z'", "true"); + assertOptimizedEquals("false or bound_string='z'", "bound_string='z'"); + + assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2"); + assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0"); + } + + @Test + public void testComparison() + { + assertOptimizedEquals("null = null", "null"); + + assertOptimizedEquals("'a' = 'b'", "false"); + assertOptimizedEquals("'a' = 'a'", "true"); + assertOptimizedEquals("'a' = null", "null"); + assertOptimizedEquals("null = 'a'", "null"); + assertOptimizedEquals("bound_integer = 1234", "true"); + assertOptimizedEquals("bound_integer = 12340000000", "false"); + assertOptimizedEquals("bound_long = BIGINT '1234'", "true"); + assertOptimizedEquals("bound_long = 1234", "true"); + assertOptimizedEquals("bound_double = 12.34", "true"); + assertOptimizedEquals("bound_string = 'hello'", "true"); + assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long"); + + assertOptimizedEquals("10151082135029368 = 10151082135029369", "false"); + + assertOptimizedEquals("bound_varbinary = X'a b'", "true"); + assertOptimizedEquals("bound_varbinary = X'a d'", "false"); + + assertOptimizedEquals("1.1 = 1.1", "true"); + assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true"); + assertOptimizedEquals("bound_decimal_short = 123.45", "true"); + assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true"); + } + + @Test + public void testIsDistinctFrom() + { + assertOptimizedEquals("null is distinct from null", "false"); + + assertOptimizedEquals("3 is distinct from 4", "true"); + assertOptimizedEquals("3 is distinct from BIGINT '4'", "true"); + assertOptimizedEquals("3 is distinct from 4000000000", "true"); + assertOptimizedEquals("3 is distinct from 3", "false"); + assertOptimizedEquals("3 is distinct from null", "true"); + assertOptimizedEquals("null is distinct from 3", "true"); + + assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true"); + + assertOptimizedEquals("1.1 is distinct from 1.1", "false"); + assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false"); + } + + @Test + public void testIsNull() + { + assertOptimizedEquals("null is null", "true"); + assertOptimizedEquals("1 is null", "false"); + assertOptimizedEquals("10000000000 is null", "false"); + assertOptimizedEquals("BIGINT '1' is null", "false"); + assertOptimizedEquals("1.0 is null", "false"); + assertOptimizedEquals("'a' is null", "false"); + assertOptimizedEquals("true is null", "false"); + assertOptimizedEquals("null+1 is null", "true"); + assertOptimizedEquals("unbound_string is null", "unbound_string is null"); + assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null"); + assertOptimizedEquals("1.1 is null", "false"); + assertOptimizedEquals("9876543210.9874561203 is null", "false"); + assertOptimizedEquals("bound_decimal_short is null", "false"); + assertOptimizedEquals("bound_decimal_long is null", "false"); + } + + @Test + public void testIsNotNull() + { + assertOptimizedEquals("null is not null", "false"); + assertOptimizedEquals("1 is not null", "true"); + assertOptimizedEquals("10000000000 is not null", "true"); + assertOptimizedEquals("BIGINT '1' is not null", "true"); + assertOptimizedEquals("1.0 is not null", "true"); + assertOptimizedEquals("'a' is not null", "true"); + assertOptimizedEquals("true is not null", "true"); + assertOptimizedEquals("null+1 is not null", "false"); + assertOptimizedEquals("unbound_string is not null", "unbound_string is not null"); + assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null"); + assertOptimizedEquals("1.1 is not null", "true"); + assertOptimizedEquals("9876543210.9874561203 is not null", "true"); + assertOptimizedEquals("bound_decimal_short is not null", "true"); + assertOptimizedEquals("bound_decimal_long is not null", "true"); + } + + @Test + public void testNullIf() + { + assertOptimizedEquals("nullif(true, true)", "null"); + assertOptimizedEquals("nullif(true, false)", "true"); + assertOptimizedEquals("nullif(null, false)", "null"); + assertOptimizedEquals("nullif(true, null)", "true"); + + assertOptimizedEquals("nullif('a', 'a')", "null"); + assertOptimizedEquals("nullif('a', 'b')", "'a'"); + assertOptimizedEquals("nullif(null, 'b')", "null"); + assertOptimizedEquals("nullif('a', null)", "'a'"); + + assertOptimizedEquals("nullif(1, 1)", "null"); + assertOptimizedEquals("nullif(1, 2)", "1"); + assertOptimizedEquals("nullif(1, BIGINT '2')", "1"); + assertOptimizedEquals("nullif(1, 20000000000)", "1"); + assertOptimizedEquals("nullif(1.0E0, 1)", "null"); + assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null"); + assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0"); + assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null"); + assertOptimizedEquals("nullif(1, 2-1)", "null"); + assertOptimizedEquals("nullif(null, null)", "null"); + assertOptimizedEquals("nullif(1, null)", "1"); + assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)"); + + assertOptimizedEquals("nullif(1.1, 1.2)", "1.1"); + assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null"); + assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null"); + assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + } + + @Test + public void testNegative() + { + assertOptimizedEquals("-(1)", "-1"); + assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'"); + assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)"); + assertOptimizedEquals("-(1+1)", "-2"); + assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'"); + assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null"); + assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)"); + assertOptimizedEquals("-(1.1+1.2)", "-2.3"); + assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))"); + assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90"); + assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))"); + } + + @Test + public void testNot() + { + assertOptimizedEquals("not true", "false"); + assertOptimizedEquals("not false", "true"); + assertOptimizedEquals("not null", "null"); + assertOptimizedEquals("not 1=1", "false"); + assertOptimizedEquals("not 1=BIGINT '1'", "false"); + assertOptimizedEquals("not 1!=1", "true"); + assertOptimizedEquals("not unbound_long=1", "not unbound_long=1"); + assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2"); + } + + @Test + public void testFunctionCall() + { + assertOptimizedEquals("abs(-5)", "5"); + assertOptimizedEquals("abs(-10-5)", "15"); + assertOptimizedEquals("abs(-bound_integer + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233"); + assertOptimizedEquals("abs(-bound_long)", "1234"); + assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)"); + assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)"); + assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))"); + } + + @Test + public void testNonDeterministicFunctionCall() + { + // optimize should do nothing + assertOptimizedEquals("random()", "random()"); + + // evaluate should execute + Object value = evaluate("random()", false); + assertTrue(value instanceof Double); + double randomValue = (double) value; + assertTrue(0 <= randomValue && randomValue < 1); + } + + @Test + public void testCppFunctionCall() + { + METADATA.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); + assertOptimizedEquals("json.test_schema.square(-5)", "json.test_schema.square(-5)"); + } + + @Test + public void testCppAggregateFunctionCall() + { + METADATA.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); + assertOptimizedEquals("json.test_schema.avg(1.0)", "json.test_schema.avg(1.0)"); + } + + @Test + public void testBetween() + { + assertOptimizedEquals("3 between 2 and 4", "true"); + assertOptimizedEquals("2 between 3 and 4", "false"); + assertOptimizedEquals("null between 2 and 4", "null"); + assertOptimizedEquals("3 between null and 4", "null"); + assertOptimizedEquals("3 between 2 and null", "null"); + + assertOptimizedEquals("'cc' between 'b' and 'd'", "true"); + assertOptimizedEquals("'b' between 'cc' and 'd'", "false"); + assertOptimizedEquals("null between 'b' and 'd'", "null"); + assertOptimizedEquals("'cc' between null and 'd'", "null"); + assertOptimizedEquals("'cc' between 'b' and null", "null"); + + assertOptimizedEquals("bound_integer between 1000 and 2000", "true"); + assertOptimizedEquals("bound_integer between 3 and 4", "false"); + assertOptimizedEquals("bound_long between 1000 and 2000", "true"); + assertOptimizedEquals("bound_long between 3 and 4", "false"); + assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true"); + assertOptimizedEquals("bound_string between 'e' and 'i'", "true"); + assertOptimizedEquals("bound_string between 'a' and 'b'", "false"); + + assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001"); + assertOptimizedEquals( + "bound_string between unbound_string and 'bar'", + format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH)); + + assertOptimizedEquals("1.15 between 1.1 and 1.2", "true"); + assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true"); + assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true"); + assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false"); + } + + @Test + public void testExtract() + { + DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey())); + double seconds = dateTime.getMillis() / 1000.0; + + assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001"); + assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8"); + assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34"); + assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234"); + assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22"); + assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4"); + assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5"); + assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7"); + assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9"); + + assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001"); + assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3"); + assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8"); + assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34"); + assertOptimizedEquals("extract (DOW from bound_timestamp)", "2"); + assertOptimizedEquals("extract (DOY from bound_timestamp)", "233"); + assertOptimizedEquals("extract (DAY from bound_timestamp)", "21"); + assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16"); + assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4"); + assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5"); + // todo reenable when cast as timestamp with time zone is implemented + // todo add bound timestamp with time zone + //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0"); + //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0"); + + assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)"); + assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8"); + } + + @Test + public void testIn() + { + assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); + assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); + assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); + + assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); + assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); + assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); + + assertOptimizedEquals("null in (2, null, 3, 5)", "null"); + assertOptimizedEquals("3 in (2, null)", "null"); + + assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); + assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); + + assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); + assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); + + assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); + assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); + assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); + assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); + + assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); + assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); + + assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); + assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); + + assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); + assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); + assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); + } + + @Test + public void testInComplexTypes() + { + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null"); + + assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true"); + assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false"); + assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true"); + assertEvaluatedEquals("ROW(1) IN (null)", "null"); + assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null"); + + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + } + + @Test + public void testCurrentTimestamp() + { + double current = TEST_SESSION.getStartTime() / 1000.0; + assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true"); + double future = current + TimeUnit.MINUTES.toSeconds(1); + assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false"); + } + + @Test + public void testCurrentUser() + throws Exception + { + assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'"); + } + + @Test + public void testCastToString() + { + // integer + assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'"); + assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'"); + + // bigint + assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'"); + assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'"); + assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'"); + + // double + assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'"); + assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'"); + assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'"); + assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'"); + + // boolean + assertOptimizedEquals("cast(true as VARCHAR)", "'true'"); + assertOptimizedEquals("cast(false as VARCHAR)", "'false'"); + + // string + assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'"); + assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'"); + + // null + assertOptimizedEquals("cast(null as VARCHAR)", "null"); + + // decimal + assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'"); + // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); + } + + @Test + public void testCastBigintToBoundedVarchar() + { + assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); + assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); + + try { + evaluate("CAST(12300000000 AS varchar(3))", true); + fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + + try { + evaluate("CAST(-12300000000 AS varchar(3))", true); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + } + + @Test + public void testCastToBoolean() + { + // integer + assertOptimizedEquals("cast(123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0 as BOOLEAN)", "false"); + + // bigint + assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false"); + + // boolean + assertOptimizedEquals("cast(true as BOOLEAN)", "true"); + assertOptimizedEquals("cast(false as BOOLEAN)", "false"); + + // string + assertOptimizedEquals("cast('true' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('false' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('t' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('f' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('1' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('0' as BOOLEAN)", "false"); + + // null + assertOptimizedEquals("cast(null as BOOLEAN)", "null"); + + // double + assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false"); + + // decimal + assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false"); + assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false"); + } + + @Test + public void testCastToBigint() + { + // integer + assertOptimizedEquals("cast(0 as BIGINT)", "0"); + assertOptimizedEquals("cast(123 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123 as BIGINT)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123"); + assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as BIGINT)", "1"); + assertOptimizedEquals("cast(false as BIGINT)", "0"); + + // string + assertOptimizedEquals("cast('123' as BIGINT)", "123"); + assertOptimizedEquals("cast('-123' as BIGINT)", "-123"); + + // null + assertOptimizedEquals("cast(null as BIGINT)", "null"); + + // decimal + assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1"); + assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8"); + assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890"); + assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0"); + } + + @Test + public void testCastToInteger() + { + // integer + assertOptimizedEquals("cast(0 as INTEGER)", "0"); + assertOptimizedEquals("cast(123 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123 as INTEGER)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123"); + assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as INTEGER)", "1"); + assertOptimizedEquals("cast(false as INTEGER)", "0"); + + // string + assertOptimizedEquals("cast('123' as INTEGER)", "123"); + assertOptimizedEquals("cast('-123' as INTEGER)", "-123"); + + // null + assertOptimizedEquals("cast(null as INTEGER)", "null"); + } + + @Test + public void testCastToDouble() + { + // integer + assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0"); + assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0"); + + // double + assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0"); + + // string + assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0"); + + // null + assertOptimizedEquals("cast(null as DOUBLE)", "null"); + + // boolean + assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0"); + assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0"); + + // decimal + assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'"); + assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'"); + assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'"); + assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'"); + } + + @Test + public void testCastToDecimal() + { + // long + assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))"); + + // double + assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // string + assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // null + assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null"); + assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null"); + + // boolean + assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'"); + assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'"); + assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'"); + assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))"); + + // decimal + assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + } + + @Test + public void testCastOptimization() + { + assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)"); + assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string"); + assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer"); + assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))"); + } + + @Test + public void testTryCast() + { + assertOptimizedEquals("try_cast(null as BIGINT)", "null"); + assertOptimizedEquals("try_cast(123 as BIGINT)", "123"); + assertOptimizedEquals("try_cast(null as INTEGER)", "null"); + assertOptimizedEquals("try_cast(123 as INTEGER)", "123"); + assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'"); + assertOptimizedEquals("try_cast('foo' as BIGINT)", "null"); + assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)"); + assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null"); + } + + @Test + public void testReservedWithDoubleQuotes() + { + assertOptimizedEquals("\"time\"", "\"time\""); + } + + @Test + public void testEnumLiteralFormattingWithTypeAndValue() + { + java.util.function.BiFunction createEnumLiteral = (type, value) -> new EnumLiteral(Optional.empty(), type, value); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("color", "RED"), Optional.empty()), "color: RED"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("level", 1), Optional.empty()), "level: 1"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("StatusType", "Active"), Optional.empty()), "StatusType: Active"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("priority", "HIGH PRIORITY"), Optional.empty()), "priority: HIGH PRIORITY"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("lang", "枚举"), Optional.empty()), "lang: 枚举"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("special", "DOLLAR$"), Optional.empty()), "special: DOLLAR$"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("enum_type", "VALUE_1"), Optional.empty()), "enum_type: VALUE_1"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("flag", true), Optional.empty()), "flag: true"); + } + + @Test + public void testSearchCase() + { + assertOptimizedEquals("case " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when false then 10000000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "else unbound_integer " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "else unbound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_integer " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end", + "" + + "case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case " + + "when false then 1234567890.0987654321 " + + "when true then 3.3 " + + "end", + "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testSimpleCase() + { + assertOptimizedEquals("case 1 " + + "when 1 then 32 + 1 " + + "when 1 then 34 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 33 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 33 " + + "else 33 " + + "end", + "33"); + assertOptimizedEquals("case 33 " + + "when null then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "else 3300000000 " + + "end", + "3300000000"); + assertOptimizedEquals("case 33 " + + "when null then 3300000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case true " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else 33 end", + "33"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_long then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case bound_integer " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_integer then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end", + "" + + "case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then unbound_long " + + "else 1 " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then 1 " + + "when unbound_long then 2 " + + "else 1 " + + "end", + "1"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when 33 then 2 " + + "else 0 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "else 2 " + + "end"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 1 then 1 " + + "else unbound_long " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end"); + + assertOptimizedEquals("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "" + + "case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 else 33 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when 123 * 10 + unbound_long then 1 = 1 " + + "else 1 = 2 " + + "end", + "" + + "case bound_long when 1230 + unbound_long then true " + + "else false " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "end", + "" + + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "when 1 then null " + + "when 2 then null " + + "end", + "" + + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedEquals("case true " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + // TODO enabled when DECIMAL is default for literal: +// assertOptimizedEquals("case true " + +// "when false then 1234567890.0987654321 " + +// "when true then 3.3 " + +// "end", +// "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case true " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testCoalesce() + { + assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long"); + assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)"); + assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + } + + @Test + public void testIf() + { + assertOptimizedEquals("IF(2 = 2, 3, 4)", "3"); + assertOptimizedEquals("IF(1 = 2, 3, 4)", "4"); + assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4"); + assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, 4)", "3"); + assertOptimizedEquals("IF(false, 3, 4)", "4"); + assertOptimizedEquals("IF(null, 3, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, null)", "3"); + assertOptimizedEquals("IF(false, 3, null)", "null"); + assertOptimizedEquals("IF(true, null, 4)", "null"); + assertOptimizedEquals("IF(false, null, 4)", "4"); + assertOptimizedEquals("IF(true, null, null)", "null"); + assertOptimizedEquals("IF(false, null, null)", "null"); + + assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0"); + assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0"); + + assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'"); + assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'"); + + assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01"); + assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02"); + assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123"); + assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123"); + } + + @Test + public void testLike() + { + assertOptimizedEquals("'a' LIKE 'a'", "true"); + assertOptimizedEquals("'' LIKE 'a'", "false"); + assertOptimizedEquals("'abc' LIKE 'a'", "false"); + + assertOptimizedEquals("'a' LIKE '_'", "true"); + assertOptimizedEquals("'' LIKE '_'", "false"); + assertOptimizedEquals("'abc' LIKE '_'", "false"); + + assertOptimizedEquals("'a' LIKE '%'", "true"); + assertOptimizedEquals("'' LIKE '%'", "true"); + assertOptimizedEquals("'abc' LIKE '%'", "true"); + + assertOptimizedEquals("'abc' LIKE '___'", "true"); + assertOptimizedEquals("'ab' LIKE '___'", "false"); + assertOptimizedEquals("'abcd' LIKE '___'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc'", "true"); + assertOptimizedEquals("'xyz' LIKE 'abc'", "false"); + assertOptimizedEquals("'abc0' LIKE 'abc'", "false"); + assertOptimizedEquals("'0abc' LIKE 'abc'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE 'abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE 'abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true"); + + assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false"); + + // ensure regex chars are escaped + assertOptimizedEquals("'\' LIKE '\'", "true"); + assertOptimizedEquals("'.*' LIKE '.*'", "true"); + assertOptimizedEquals("'[' LIKE '['", "true"); + assertOptimizedEquals("']' LIKE ']'", "true"); + assertOptimizedEquals("'{' LIKE '{'", "true"); + assertOptimizedEquals("'}' LIKE '}'", "true"); + assertOptimizedEquals("'?' LIKE '?'", "true"); + assertOptimizedEquals("'+' LIKE '+'", "true"); + assertOptimizedEquals("'(' LIKE '('", "true"); + assertOptimizedEquals("')' LIKE ')'", "true"); + assertOptimizedEquals("'|' LIKE '|'", "true"); + assertOptimizedEquals("'^' LIKE '^'", "true"); + assertOptimizedEquals("'$' LIKE '$'", "true"); + + assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true"); + } + + @Test + public void testLikeNullOptimization() + { + assertOptimizedEquals("null LIKE '%'", "null"); + assertOptimizedEquals("'a' LIKE null", "null"); + assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); + assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); + } + + @Test + public void testLikeOptimization() + { + assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)"); + + assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'"); + assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); + assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); + + assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); + assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); + + assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); + assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); + + assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); + } + + @Test + public void testInvalidLike() + { + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'")); + } + + @Test + public void testFailedExpressionOptimization() + { + assertFailedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", + "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)"); + + assertFailedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)"); + + assertFailedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", + "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END"); + + assertFailedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", + "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END"); + + assertFailedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", + "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END"); + + assertFailedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", + "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end"); + + assertFailedMatches("case when unbound_boolean then 1 else 0 / 0 end", + "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end"); + + assertFailedMatches("case when unbound_boolean then 0 / 0 else 1 end", + "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end"); + + assertFailedMatches("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "case true " + + "when unbound_long = BIGINT '1' then 1 " + + "when CAST(fail(8, 'ignored failure message') AS boolean) then 2 else 33 " + + "end"); + + assertFailedMatches("case 1 " + + "when 0 / 0 then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "case 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 2 " + + "else 1 " + + "end"); + + assertFailedMatches("case 1 " + + "when unbound_long then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "" + + "case BIGINT '1' " + + "when unbound_long then 1 " + + "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + + "else 1 " + + "end"); + } + + @Test + public void testArrayConstructor() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowConstructor() + { + optimize("ROW(NULL)"); + optimize("ROW(1)"); + optimize("ROW(unbound_long + 0)"); + optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)"); + optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)"); + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]"); + optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]"); + optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))"); + optimize("ROW(unbound_string, bound_string)"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]"); + optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]"); + optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]"); + } + + @Test + public void testDereference() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowDereference() + { + optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a"); + } + + @Test + public void testRowSubscript() + { + assertOptimizedEquals("ROW (1, 'a', true)[3]", "true"); + assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'"); + } + + @Test + public void testOptimizeDivideByZero() + { + assertThrows(PrestoException.class, () -> optimize("1 / 0")); + } + + @Test + public void testArraySubscriptConstantNegativeIndex() + { + assertThrows(PrestoException.class, () -> optimize("ARRAY [1, 2, 3][-1]")); + } + + @Test + public void testArraySubscriptConstantZeroIndex() + { + assertThrows(PrestoException.class, () -> optimize("ARRAY [1, 2, 3][0]")); + } + + @Test + public void testMapSubscriptMissingKey() + { + assertThrows(PrestoException.class, () -> optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[-1]")); + } + + @Test + public void testMapSubscriptConstantIndexes() + { + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]"); + optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]"); + } + + @Test + public void testLiterals() + { + optimize("date '2013-04-03' + unbound_interval"); + optimize("time '03:04:05.321' + unbound_interval"); + optimize("time '03:04:05.321 UTC' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval"); + + optimize("interval '3' day * unbound_long"); + optimize("interval '3' year * unbound_long"); + } + + @Test + public void testVarbinaryLiteral() + { + assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); + } + + public void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) + { + if (rowExpressionResult instanceof RowExpression) { + // Cannot be completely evaluated into a constant; compare expressions + assertTrue(expressionResult instanceof Expression); + + // It is tricky to check the equivalence of an expression and a row expression. + // We rely on the optimized translator to fill the gap. + RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); + assertRowExpressionEvaluationEquals(translated, rowExpressionResult); + } + else { + // We have constants; directly compare + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + } + } + + public static void assertRowExpressionEvaluationEquals(RowExpression left, RowExpression right) + { + assertTrue(left instanceof RowExpression); + if (left instanceof ConstantExpression) { + if (isRemovableCast(right)) { + assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); + return; + } + assertTrue(right instanceof ConstantExpression); + assertRowExpressionEvaluationEquals(((ConstantExpression) left).getValue(), ((ConstantExpression) left).getValue()); + } + else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { + assertEquals(left, right); + } + else if (left instanceof CallExpression) { + assertTrue(right instanceof CallExpression); + assertCallExpressionEvaluationEquals((CallExpression) left, (CallExpression) right); + } + else if (left instanceof SpecialFormExpression) { + assertTrue(right instanceof SpecialFormExpression); + assertSpecialFormExpressionEvaluationEquals((SpecialFormExpression) left, (SpecialFormExpression) right); + } + else { + assertTrue(left instanceof LambdaDefinitionExpression); + assertTrue(right instanceof LambdaDefinitionExpression); + assertLambdaExpressionEvaluationEquals((LambdaDefinitionExpression) left, (LambdaDefinitionExpression) right); + } + } + + /** + * Assert the evaluation result of two row expressions equivalent + * no matter they are constants or remaining row expressions. + */ + public static void assertRowExpressionEvaluationEquals(Object left, Object right) + { + if (right instanceof RowExpression) { + assertRowExpressionEvaluationEquals((RowExpression) left, (RowExpression) right); + } + else { + // We have constants; directly compare + if (left instanceof Block) { + assertTrue(right instanceof Block); + assertEquals(blockToSlice((Block) left), blockToSlice((Block) right)); + } + else { + assertEquals(left, right); + } + } + } + + private static void assertCallExpressionEvaluationEquals(CallExpression left, CallExpression right) + { + assertEquals(left.getFunctionHandle(), right.getFunctionHandle()); + assertEquals(left.getArguments().size(), right.getArguments().size()); + for (int i = 0; i < left.getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(left.getArguments().get(i), right.getArguments().get(i)); + } + } + + private static void assertSpecialFormExpressionEvaluationEquals(SpecialFormExpression left, SpecialFormExpression right) + { + assertEquals(left.getForm(), right.getForm()); + assertEquals(left.getArguments().size(), right.getArguments().size()); + for (int i = 0; i < left.getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(left.getArguments().get(i), right.getArguments().get(i)); + } + } + + private static void assertLambdaExpressionEvaluationEquals(LambdaDefinitionExpression left, LambdaDefinitionExpression right) + { + assertEquals(left.getArguments(), right.getArguments()); + assertEquals(left.getArgumentTypes(), right.getArgumentTypes()); + assertRowExpressionEvaluationEquals(left.getBody(), right.getBody()); + } + + private static boolean isRemovableCast(Object value) + { + if (value instanceof CallExpression && + new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { + Type targetType = ((CallExpression) value).getType(); + Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); + return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); + } + return false; + } + + public abstract void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected); + + public void assertRoundTrip(String expression) + { + ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); + assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), + SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); + } + + public static RowExpression toRowExpression(Expression expression) + { + return TRANSLATOR.translate(expression, SYMBOL_TYPES); + } + + public abstract void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected); + + public void assertFailedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertOptimizedMatches(actual, expected); + } + + public abstract Object optimize(@Language("SQL") String expression); + + public static Expression expression(String expression) + { + return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + } + + public abstract void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel); + + public static Object symbolConstant(Symbol symbol) + { + switch (symbol.getName().toLowerCase(ENGLISH)) { + case "bound_integer": + return 1234L; + case "bound_long": + return 1234L; + case "bound_string": + return utf8Slice("hello"); + case "bound_double": + return 12.34; + case "bound_date": + return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis(); + case "bound_time": + return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis(); + case "bound_timestamp": + return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(); + case "bound_pattern": + return utf8Slice("%el%"); + case "bound_timestamp_with_timezone": + return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z")); + case "bound_varbinary": + return Slices.wrappedBuffer((byte) 0xab); + case "bound_decimal_short": + return 12345L; + case "bound_decimal_long": + return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123")); + } + return null; + } + + public static Slice blockToSlice(Block block) + { + // This function is strictly for testing use only + SliceOutput sliceOutput = new DynamicSliceOutput(1000); + BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); + return sliceOutput.slice(); + } + + public abstract void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected); + + public abstract Object evaluate(@Language("SQL") String expression, boolean deterministic); +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java index 913e0db26adc0..7340e45b42c1d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/TestExpressionOptimizerManager.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.Properties; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.sql.relational.Expressions.constant; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; @@ -65,6 +66,7 @@ public void setUp() manager = new ExpressionOptimizerManager( pluginNodeManager, METADATA.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)), directory); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java b/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java index 0e964fdade4e4..fe88d6996de10 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestCursorProcessorCompiler.java @@ -30,11 +30,13 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.google.common.collect.ImmutableList; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Random; import java.util.function.Supplier; import java.util.stream.IntStream; @@ -63,13 +65,18 @@ import static java.util.Collections.emptyMap; import static java.util.stream.Collectors.toList; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.expectThrows; public class TestCursorProcessorCompiler { private static final Metadata METADATA = createTestMetadataManager(); private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager(); + // Constants for testing JVM limits + private static final int CONSTANT_POOL_STRESS_PROJECTION_COUNT = 8000; + private static final CallExpression ADD_X_Y = call( ADD.name(), FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), @@ -169,6 +176,97 @@ public void testCompilerWithCSE() checkPageEqual(pageFromCSE, pageFromNoCSE); } + @DataProvider(name = "projectionCounts") + public Object[][] projectionCounts() + { + return new Object[][] { + {1}, + {10}, + {1000}, + {1500}, + {5000}, + {6000} + }; + } + + @Test(dataProvider = "projectionCounts") + public void testProjectionBatching(int projectionCount) + { + PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0); + ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler); + + List projections = IntStream.range(0, projectionCount) + .mapToObj(i -> field(i % 2, BIGINT)) + .collect(toImmutableList()); + + Supplier cursorProcessorSupplier = expressionCompiler.compileCursorProcessor( + SESSION.getSqlFunctionProperties(), + Optional.empty(), + projections, + "testProjectionBatching_" + projectionCount, + false); + + CursorProcessor processor = cursorProcessorSupplier.get(); + assertNotNull(processor, "CursorProcessor should be created successfully for projectionCount = " + projectionCount); + + Page input = createLongBlockPage(2, 1L, 2L, 3L, 4L, 5L); + List types = ImmutableList.of(BIGINT, BIGINT); + PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(toList())); + RecordSet recordSet = new PageRecordSet(types, input); + + processor.process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder); + Page result = pageBuilder.build(); + + assertEquals(result.getChannelCount(), projectionCount, "Mismatch in projected column count"); + assertEquals(result.getPositionCount(), input.getPositionCount(), "Mismatch in row count"); + } + + /** + * NEW NEGATIVE TEST CASE: + * This test demonstrates that while we can handle MethodTooLarge exceptions through batching, + * the JVM constant pool size limit remains a constraint when projections contain many unique constants. + * + * This test creates projections with random constants generated via Java code, which fills up + * the constant pool and should still cause compilation failures even with our batching approach. + * This clearly shows the scope of what we're solving (MethodTooLarge) vs. what remains a JVM constraint. + */ + @Test + public void testConstantPoolLimitStillConstrainsLargeProjections() + { + ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, new PageFunctionCompiler(METADATA, 0)); + + List projectionsWithRandomConstants = createProjectionsWithRandomConstants(CONSTANT_POOL_STRESS_PROJECTION_COUNT); + + expectThrows(RuntimeException.class, () -> { + expressionCompiler.compileCursorProcessor( + SESSION.getSqlFunctionProperties(), + Optional.empty(), + projectionsWithRandomConstants, + "testConstantPoolLimit", + false); + }); + } + + /** + * Helper method to create projections with many unique random constants. + * This is designed to stress the JVM constant pool limit. + */ + private List createProjectionsWithRandomConstants(int count) + { + Random random = new Random(42); + return IntStream.range(0, count) + .mapToObj(i -> { + long randomConstant = random.nextLong(); + return call( + ADD.name(), + FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), + BIGINT, + field(0, BIGINT), + constant(randomConstant, BIGINT)); // Each projection gets a unique constant + }) + .collect(toImmutableList()); + } + private static Page createLongBlockPage(int blockCount, long... values) { Block[] blocks = new Block[blockCount]; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java b/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java index eb486ebf02708..c5383b669af60 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/gen/TestExpressionCompiler.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.log.Logging; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.SqlDecimal; import com.facebook.presto.common.type.SqlTimestampWithTimeZone; @@ -44,7 +45,6 @@ import io.airlift.joni.Regex; import io.airlift.slice.Slice; import io.airlift.slice.Slices; -import io.airlift.units.Duration; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; import org.testng.annotations.AfterClass; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestDeleteAndInsertMergeProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestDeleteAndInsertMergeProcessor.java new file mode 100644 index 0000000000000..dac43482b6338 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestDeleteAndInsertMergeProcessor.java @@ -0,0 +1,244 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.ByteArrayBlock; +import com.facebook.presto.common.block.IntArrayBlock; +import com.facebook.presto.common.block.LongArrayBlock; +import com.facebook.presto.common.block.PageBuilderStatus; +import com.facebook.presto.common.block.RowBlock; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.operator.DeleteAndInsertMergeProcessor; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; +import org.testng.annotations.Test; + +import java.nio.charset.Charset; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.operator.MergeRowChangeProcessor.DEFAULT_CASE_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.DELETE_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.INSERT_OPERATION_NUMBER; +import static com.facebook.presto.spi.ConnectorMergeSink.UPDATE_OPERATION_NUMBER; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestDeleteAndInsertMergeProcessor +{ + @Test + public void testSimpleDeletedRowMerge() + { + // target: ('Dave', 11, 'Devon'), ('Dave', 11, 'Darbyshire') + // source: ('Dave', 11, 'Darbyshire') + // merge: + // MERGE INTO target t USING source s + // ON t.customer = s.customer" + + // WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20" + + // THEN DELETE + // expected: ('Dave', 11, 'Darbyshire') + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + Page inputPage = makePageFromBlocks( + 2, + Optional.empty(), + new Block[] { + makeLongArrayBlock(1, 1), // TransactionId + makeLongArrayBlock(1, 0), // rowId + makeIntArrayBlock(536870912, 536870912)}, // bucket + new Block[] { + makeVarcharArrayBlock("", "Dave"), // customer + makeIntArrayBlock(0, 11), // purchases + makeVarcharArrayBlock("", "Devon"), // address + makeByteArrayBlock(1, 1), // "present" boolean + makeByteArrayBlock(DEFAULT_CASE_OPERATION_NUMBER, DELETE_OPERATION_NUMBER), + makeIntArrayBlock(-1, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(1); + + // The single operation is a delete + assertThat(TINYINT.getLong(outputPage.getBlock(3), 0)).isEqualTo(DELETE_OPERATION_NUMBER); + + // Show that the row to be deleted is rowId 0, e.g. ('Dave', 11, 'Devon') + Block rowIdRow = outputPage.getBlock(4).getBlock(0); + assertThat(INTEGER.getLong(rowIdRow, 1)).isEqualTo(0); + } + + @Test + public void testUpdateAndDeletedMerge() + { + // target: ('Aaron', 11, 'Arches'), ('Bill', 7, 'Buena'), ('Dave', 11, 'Darbyshire'), ('Dave', 11, 'Devon'), ('Ed', 7, 'Etherville') + // source: ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville') + // merge: + // MERGE INTO target t USING source s + // ON t.customer = s.customer" + + // WHEN MATCHED AND t.address <> 'Darbyshire' AND s.purchases * 2 > 20 + // THEN DELETE" + + // WHEN MATCHED" + + // THEN UPDATE SET purchases = s.purchases + t.purchases, address = concat(t.address, '/', s.address)" + + // WHEN NOT MATCHED" + + // THEN INSERT (customer, purchases, address) VALUES(s.customer, s.purchases, s.address) + // expected: ('Aaron', 17, 'Arches/Arches'), ('Bill', 7, 'Buena'), ('Carol', 9, 'Centreville'), ('Dave', 22, 'Darbyshire/Darbyshire'), ('Ed', 14, 'Etherville/Etherville'), ('Fred', 30, 'Franklin') + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + boolean[] rowIdNulls = new boolean[] {false, true, false, false, false}; + Page inputPage = makePageFromBlocks( + 5, + Optional.of(rowIdNulls), + new Block[] { + makeLongArrayBlockWithNulls(rowIdNulls, 5, 2, 1, 2, 2), // TransactionId + makeLongArrayBlockWithNulls(rowIdNulls, 5, 0, 3, 1, 2), // rowId + makeIntArrayBlockWithNulls(rowIdNulls, 5, 536870912, 536870912, 536870912, 536870912)}, // bucket + new Block[] { + // customer + makeVarcharArrayBlock("Aaron", "Carol", "Dave", "Dave", "Ed"), + // purchases + makeIntArrayBlock(17, 9, 11, 22, 14), + // address + makeVarcharArrayBlock("Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshire", "Etherville/Etherville"), + // "present" boolean + makeByteArrayBlock(1, 0, 1, 1, 1), + // operation number: update, insert, delete, update + makeByteArrayBlock(UPDATE_OPERATION_NUMBER, INSERT_OPERATION_NUMBER, DELETE_OPERATION_NUMBER, UPDATE_OPERATION_NUMBER, UPDATE_OPERATION_NUMBER), + makeIntArrayBlock(0, 1, 2, 0, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(8); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); + // Show that the first row has address "Arches" + assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); + } + + @Test + public void testAnotherMergeCase() + { + /* + inputPage: Page[positions=5 + 0:Row[0:Long[2, 1, 2, 2], 1:Long[0, 3, 1, 2], 2:Int[536870912, 536870912, 536870912, 536870912]], + 1:Row[0:VarWidth["Aaron", "Carol", "Dave", "Dave", "Ed"], 1:Int[17, 9, 11, 22, 14], 2:VarWidth["Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshir...", "Etherville/Ethervill..."], 3:Int[1, 2, 0, 1, 1], 4:Int[3, 1, 2, 3, 3]]] +Page[positions=8 0:Dict[VarWidth["Aaron", "Dave", "Dave", "Ed", "Aaron", "Carol", "Dave", "Ed"]], 1:Dict[Int[17, 11, 22, 14, 17, 9, 22, 14]], 2:Dict[VarWidth["Arches/Arches", "Devon", "Darbyshire/Darbyshir...", "Etherville/Ethervill...", "Arches/Arches", "Centreville", "Darbyshire/Darbyshir...", "Etherville/Ethervill..."]], 3:Int[2, 2, 2, 2, 1, 1, 1, 1], 4:Row[0:Dict[Long[2, 1, 2, 2, 2, 2, 2, 2]], 1:Dict[Long[0, 3, 1, 2, 0, 0, 0, 0]], 2:Dict[Int[536870912, 536870912, 536870912, 536870912, 536870912, 536870912, 536870912, 536870912]]]] + Expected row count to be <5>, but was <7>; rows=[[Bill, 7, Buena], [Dave, 11, Devon], [Aaron, 11, Arches], [Aaron, 17, Arches/Arches], [Carol, 9, Centreville], [Dave, 22, Darbyshire/Darbyshire], [Ed, 14, Etherville/Etherville]] + */ + DeleteAndInsertMergeProcessor processor = makeMergeProcessor(); + boolean[] rowIdNulls = new boolean[] {false, true, false, false, false}; + Page inputPage = makePageFromBlocks( + 5, + Optional.of(rowIdNulls), + new Block[] { + makeLongArrayBlockWithNulls(rowIdNulls, 5, 2, 1, 2, 2), // TransactionId + makeLongArrayBlockWithNulls(rowIdNulls, 5, 0, 3, 1, 2), // rowId + makeIntArrayBlockWithNulls(rowIdNulls, 5, 536870912, 536870912, 536870912, 536870912)}, // bucket + new Block[] { + // customer + makeVarcharArrayBlock("Aaron", "Carol", "Dave", "Dave", "Ed"), + // purchases + makeIntArrayBlock(17, 9, 11, 22, 14), + // address + makeVarcharArrayBlock("Arches/Arches", "Centreville", "Devon", "Darbyshire/Darbyshire", "Etherville/Etherville"), + // "present" boolean + makeByteArrayBlock(1, 0, 1, 1, 0), + // operation number: update, insert, delete, update, update + makeByteArrayBlock(3, 1, 2, 3, 3), + makeIntArrayBlock(0, -1, 1, 0, 0)}); + + Page outputPage = processor.transformPage(inputPage); + assertThat(outputPage.getPositionCount()).isEqualTo(8); + RowBlock rowIdBlock = (RowBlock) outputPage.getBlock(4); + assertThat(rowIdBlock.getPositionCount()).isEqualTo(8); + // Show that the first row has address "Arches/Arches" + assertThat(getString(outputPage.getBlock(2), 1)).isEqualTo("Arches/Arches"); + } + + private Page makePageFromBlocks(int positionCount, Optional rowIdNulls, Block[] rowIdBlocks, Block[] mergeCaseBlocks) + { + Block[] pageBlocks = new Block[] { + RowBlock.fromFieldBlocks(positionCount, rowIdNulls, rowIdBlocks), + RowBlock.fromFieldBlocks(positionCount, Optional.empty(), mergeCaseBlocks) + }; + return new Page(pageBlocks); + } + + private DeleteAndInsertMergeProcessor makeMergeProcessor() + { + // CREATE TABLE (customer VARCHAR, purchases INTEGER, address VARCHAR) + List types = ImmutableList.of(VARCHAR, INTEGER, VARCHAR); + + RowType rowIdType = RowType.anonymous(ImmutableList.of(BIGINT, BIGINT, INTEGER)); + return new DeleteAndInsertMergeProcessor(types, rowIdType, 0, 1, ImmutableList.of(0, 1, 2)); + } + + private String getString(Block block, int position) + { + return VARBINARY.getSlice(block, position).toString(Charset.defaultCharset()); + } + + private LongArrayBlock makeLongArrayBlock(long... elements) + { + return new LongArrayBlock(elements.length, Optional.empty(), elements); + } + + private LongArrayBlock makeLongArrayBlockWithNulls(boolean[] nulls, int positionCount, long... elements) + { + assertThat(countNonNull(nulls) + elements.length).isEqualTo(positionCount); + return new LongArrayBlock(elements.length, Optional.of(nulls), elements); + } + + private IntArrayBlock makeIntArrayBlock(int... elements) + { + return new IntArrayBlock(elements.length, Optional.empty(), elements); + } + + private IntArrayBlock makeIntArrayBlockWithNulls(boolean[] nulls, int positionCount, int... elements) + { + assertThat(countNonNull(nulls) + elements.length).isEqualTo(positionCount); + return new IntArrayBlock(elements.length, Optional.of(nulls), elements); + } + + private int countNonNull(boolean[] nulls) + { + int count = 0; + for (int position = 0; position < nulls.length; position++) { + if (nulls[position]) { + count++; + } + } + return count; + } + + private ByteArrayBlock makeByteArrayBlock(int... elements) + { + byte[] bytes = new byte[elements.length]; + for (int index = 0; index < elements.length; index++) { + bytes[index] = (byte) elements[index]; + } + return new ByteArrayBlock(elements.length, Optional.empty(), bytes); + } + + private Block makeVarcharArrayBlock(String... elements) + { + BlockBuilder builder = VARCHAR.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), elements.length); + for (String element : elements) { + VARCHAR.writeSlice(builder, Slices.utf8Slice(element)); + } + return builder.build(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java index 5422304b52935..9789e8bda4652 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLocalExecutionPlanner.java @@ -190,6 +190,7 @@ private LocalExecutionPlan getLocalExecutionPlan(Session session, PlanNode plan, SOURCE_DISTRIBUTION, ImmutableList.of(new PlanNodeId("sourceId")), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index 5cffdb0dcd22d..2e8684cb05da0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.Session; import com.facebook.presto.execution.TaskManagerConfig; import com.facebook.presto.spi.WarningCollector; @@ -23,9 +24,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; import org.testng.annotations.Test; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SystemSessionProperties.ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID; import static com.facebook.presto.SystemSessionProperties.MERGE_AGGREGATIONS_WITH_AND_WITHOUT_FILTER; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -37,8 +39,6 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; -import static io.airlift.units.DataSize.Unit.KILOBYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; public class TestLogicalAddExchangesBelowPartialAggregationOverGroupIdRuleSet extends BasePlanTest diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 13385eb45b2e5..1ff85caca45a9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -15,22 +15,38 @@ import com.facebook.presto.Session; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.execution.TestingPageSourceProvider; import com.facebook.presto.functionNamespace.FunctionNamespaceManagerPlugin; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; +import com.facebook.presto.spi.ConnectorHandleResolver; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.procedure.BaseProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure; +import com.facebook.presto.spi.procedure.DistributedProcedure.Argument; +import com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.transaction.IsolationLevel; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.ExpressionMatcher; @@ -39,13 +55,17 @@ import com.facebook.presto.sql.planner.optimizations.AddLocalExchanges; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.plan.ApplyNode; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestProcedureRegistry; +import com.facebook.presto.testing.TestingHandleResolver; +import com.facebook.presto.testing.TestingMetadata; +import com.facebook.presto.testing.TestingSplitManager; import com.facebook.presto.tests.QueryTemplate; import com.facebook.presto.util.MorePredicates; import com.google.common.collect.ImmutableList; @@ -53,10 +73,15 @@ import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.stream.Collectors; import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_SORT; import static com.facebook.presto.SystemSessionProperties.ENFORCE_FIXED_DISTRIBUTION_FOR_OUTPUT_OPERATOR; @@ -66,8 +91,10 @@ import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.SystemSessionProperties.LEAF_NODE_LIMIT_ENABLED; import static com.facebook.presto.SystemSessionProperties.MAX_LEAF_NODES_IN_PLAN; +import static com.facebook.presto.SystemSessionProperties.NATIVE_EXECUTION_ENABLED; import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; +import static com.facebook.presto.SystemSessionProperties.PREFER_SORT_MERGE_JOIN; import static com.facebook.presto.SystemSessionProperties.PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID; import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; @@ -76,6 +103,7 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; import static com.facebook.presto.common.predicate.Domain.singleValue; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.StandardTypes.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.spi.StandardErrorCode.INVALID_LIMIT_CLAUSE; import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL; @@ -86,6 +114,8 @@ import static com.facebook.presto.spi.plan.JoinType.INNER; import static com.facebook.presto.spi.plan.JoinType.LEFT; import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.SCHEMA; +import static com.facebook.presto.spi.procedure.TableDataRewriteDistributedProcedure.TABLE_NAME; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED; import static com.facebook.presto.sql.TestExpressionInterpreter.AVG_UDAF_CPP; @@ -108,6 +138,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.mergeJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; @@ -129,6 +160,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; @@ -151,8 +183,107 @@ public class TestLogicalPlanner public void setup() { setupJsonFunctionNamespaceManager(this.getQueryRunner()); + + // Register catalog `test` with a distributed procedure `distributed_fun` + this.getQueryRunner().createCatalog("test", + new ConnectorFactory() + { + @Override + public String getName() + { + return "test"; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new TestingHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map config, ConnectorContext context) + { + List arguments = new ArrayList<>(); + arguments.add(new Argument(SCHEMA, VARCHAR)); + arguments.add(new Argument(TABLE_NAME, VARCHAR)); + Set> procedures = new HashSet<>(); + procedures.add(new TableDataRewriteDistributedProcedure("system", "distributed_fun", + arguments, + (session, transactionContext, procedureHandle, fragments, sortOrderIndex) -> null, + (session, transactionContext, procedureHandle, fragments) -> {}, + ignored -> new TestProcedureRegistry.TestProcedureContext())); + + return new Connector() + { + private final ConnectorMetadata metadata = new TestingMetadata(); + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return new ConnectorTransactionHandle() + {}; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return new TestingPageSourceProvider(); + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transaction) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return new TestingSplitManager(ImmutableList.of()); + } + + @Override + public Set getDistributedProcedures() + { + return procedures.stream().filter(DistributedProcedure.class::isInstance) + .map(DistributedProcedure.class::cast) + .collect(Collectors.toSet()); + } + }; + } + }, ImmutableMap.of()); } + @Test + public void testCallDistributedProcedure() + { + Session session = getQueryRunner().getDefaultSession(); + + // Call non-existed distributed procedure + assertPlanFailedWithException("call test.system.no_fun('a', 'b')", session, + format("Distributed procedure not registered: test.system.no_fun", "test", "system", "no_fun")); + + // Call distributed procedure on non-existed target table + assertPlanFailedWithException("call test.system.distributed_fun('tiny', 'notable')", session, + format("Table %s.%s.%s does not exist", session.getCatalog().get(), "tiny", "notable")); + + // Call distributed procedure on partitioned target table + assertDistributedPlan("call test.system.distributed_fun('tiny', 'orders')", + anyTree(node(TableFinishNode.class, + exchange(REMOTE_STREAMING, GATHER, + node(CallDistributedProcedureNode.class, + exchange(LOCAL, GATHER, + tableScan("orders"))))))); + + // Call distributed procedure on unPartitioned target table + assertDistributedPlan("call test.system.distributed_fun('tiny', 'customer')", + anyTree(node(TableFinishNode.class, + exchange(REMOTE_STREAMING, GATHER, + node(CallDistributedProcedureNode.class, + exchange(LOCAL, GATHER, + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("customer")))))))); + } @Test public void testAnalyze() { @@ -526,6 +657,60 @@ public void testJoinWithOrderBySameKey() tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); } + @Test + public void testSortMergeJoin() + { + Session preferSortMergeJoin = Session.builder(noJoinReordering()) + .setSystemProperty(NATIVE_EXECUTION_ENABLED, "true") + .setSystemProperty(PREFER_SORT_MERGE_JOIN, "true") + .setSystemProperty(DISTRIBUTED_SORT, "false") + .build(); + + // Both sides are not sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.custkey = l.partkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_CK", "LINEITEM_PK")), Optional.empty(), + sort( + ImmutableList.of(sort("ORDERS_CK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_CK", "custkey")))), + sort( + ImmutableList.of(sort("LINEITEM_PK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))))))); + + // Left side is sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.orderkey = l.partkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_OK", "LINEITEM_PK")), Optional.empty(), + tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")), + sort( + ImmutableList.of(sort("LINEITEM_PK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))))))); + + // Right side is sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.custkey = l.orderkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_CK", "LINEITEM_OK")), Optional.empty(), + sort( + ImmutableList.of(sort("ORDERS_CK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_CK", "custkey")))), + tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))); + + // Both sides are sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.orderkey = l.orderkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_OK", "LINEITEM_OK")), Optional.empty(), + tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey")), + tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey"))))); + } + @Test public void testUncorrelatedSubqueries() { @@ -1589,6 +1774,45 @@ public void testOffset() .withAlias("row_num", new RowNumberSymbolMatcher()))))); } + @Test + public void testOffsetWithLimit() + { + Session enableOffsetWithConcurrency = Session.builder(this.getQueryRunner().getDefaultSession()) + .setSystemProperty(OFFSET_CLAUSE_ENABLED, "true") + .setSystemProperty("task_concurrency", "2") // task_concurrency > 1 required to add possible local exchanges that fail this test for incorrect AddLocalExchanges + .build(); + + assertPlanWithSession("SELECT totalprice FROM orders ORDER BY totalprice OFFSET 1 LIMIT 512", + enableOffsetWithConcurrency, + false, + any( + strictProject( + ImmutableMap.of("totalprice", new ExpressionMatcher("totalprice")), + limit( + 512, + filter( + "row_num > BIGINT '1'", + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of()), + anyTree( + sort( + ImmutableList.of(sort("totalprice", ASCENDING, LAST)), + any( + tableScan("orders", ImmutableMap.of("totalprice", "totalprice")))))) + .withAlias("row_num", new RowNumberSymbolMatcher())))))); + } + + @Test + public void testRewriteExcludeColumnsFunctionToProjection() + { + assertPlan("SELECT *\n" + + "FROM TABLE(system.builtin.exclude_columns(\n" + + " INPUT => TABLE(orders),\n" + + " COLUMNS => DESCRIPTOR(comment)))\n", + output(tableScan("orders"))); + } + private Session noJoinReordering() { return Session.builder(this.getQueryRunner().getDefaultSession()) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..6d236432e1d15 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; +import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.RowNumberSymbolMatcher; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictOutput; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "test"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new TwoTableArgumentsFunction(), + new DescriptorArgumentFunction(), + new TestingTableFunctions.PassThroughFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "testTVF", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(test.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughVariables(ImmutableSet.of("c1")) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty() + .passThroughVariables(ImmutableSet.of("c3"))) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughVariables(ImmutableSet.of("c2")) + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan("SELECT * FROM TABLE(test.system.two_table_arguments_function(" + + "INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1," + + "INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 " + + "COPARTITION (t1, t2))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("2")))))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.two_scalar_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_scalar_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testPruneTableFunctionColumns() + { + // all table function outputs are referenced with SELECT *, no pruning + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("x", "a", "b"), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols( + ImmutableList.of(ImmutableList.of("a", "b"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'"), "b", expression("BOOLEAN'true'")), values(1))))); + + // no table function outputs are referenced. All pass-through symbols are pruned from the TableFunctionProcessorNode. The unused symbol "b" is pruned from the source values node. + assertPlan("SELECT 'constant' c FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("c"), + strictProject( + ImmutableMap.of("c", expression("VARCHAR'constant'")), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'")), values(1)))))); + } + + @Test + public void testRemoveRedundantTableFunction() + { + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true WHERE false) t(a, b) PRUNE WHEN EMPTY))", + output(values(ImmutableList.of("x", "a", "b")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) PRUNE WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + values(ImmutableList.of("a", "marker_1", "c", "marker_2", "row_number"))))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) PRUNE WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + project( + project( + rowNumber( + builder -> builder.partitionBy(ImmutableList.of()), + project( + ImmutableMap.of("c", expression("INTEGER'2'")), + values(1)) + ).withAlias("input_2_row_number", new RowNumberSymbolMatcher())))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java index 863047690952b..efb9cfdccd6ec 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestingWriterTarget.java @@ -14,9 +14,17 @@ package com.facebook.presto.sql.planner; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.SourceColumn; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.plan.TableWriterNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; public class TestingWriterTarget extends TableWriterNode.WriterTarget @@ -36,6 +44,17 @@ public SchemaTableName getSchemaTableName() return SCHEMA_TABLE_NAME; } + @Override + public Optional> getOutputColumns() + { + return Optional.of( + ImmutableList.of( + new OutputColumnMetadata( + "column", "type", + ImmutableSet.of( + new SourceColumn(QualifiedObjectName.valueOf("catalog.schema.table"), "column"))))); + } + @Override public String toString() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java index fa61a6354e26a..19621f3ce1f97 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/AggregationFunctionMatcher.java @@ -117,8 +117,8 @@ private static boolean verifyAggregationOrderBy(OrderingScheme orderingScheme, O private static boolean isEquivalent(Optional expression, Optional rowExpression) { // Function's argument provided by FunctionCallProvider is SymbolReference that already resolved from symbolAliases. - if (rowExpression.isPresent() && expression.isPresent()) { - checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference"); + if (rowExpression.isPresent() && expression.isPresent() && !(expression.get() instanceof AnySymbolReference)) { + checkArgument(rowExpression.get() instanceof VariableReferenceExpression, "can only process variableReference: " + rowExpression.get()); return expression.get().equals(createSymbolReference(((VariableReferenceExpression) rowExpression.get()))); } return rowExpression.isPresent() == expression.isPresent(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 45deffb292fa7..f090c3e1e897a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -226,11 +226,21 @@ protected void assertDistributedPlan(String sql, PlanMatchPattern pattern) assertDistributedPlan(sql, getQueryRunner().getDefaultSession(), pattern); } + protected void assertNativeDistributedPlan(String sql, PlanMatchPattern pattern) + { + assertNativeDistributedPlan(sql, getQueryRunner().getDefaultSession(), pattern); + } + protected void assertDistributedPlan(String sql, Session session, PlanMatchPattern pattern) { assertPlanWithSession(sql, session, false, pattern); } + protected void assertNativeDistributedPlan(String sql, Session session, PlanMatchPattern pattern) + { + assertPlanWithSession(sql, session, false, true, pattern); + } + protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of( @@ -262,9 +272,14 @@ protected void assertMinimallyOptimizedPlanDoesNotMatch(@Language("SQL") String } protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, PlanMatchPattern pattern) + { + assertPlanWithSession(sql, session, noExchange, false, pattern); + } + + protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, boolean nativeExecutionEnabled, PlanMatchPattern pattern) { queryRunner.inTransaction(session, transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, WarningCollector.NOOP); + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, nativeExecutionEnabled, WarningCollector.NOOP); PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), actualPlan, pattern); return null; }); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java index cea4abb4e0787..a2f535e4d5c2a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/OptimizerAssert.java @@ -23,9 +23,11 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.sql.Optimizer; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.TypeProvider; @@ -49,6 +51,7 @@ import java.util.function.Consumer; import java.util.function.Function; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlanDoesNotMatch; import static com.facebook.presto.transaction.TransactionBuilder.transaction; @@ -177,7 +180,8 @@ private List getMinimalOptimizers() metadata, new ExpressionOptimizerManager( new PluginNodeManager(new InMemoryNodeManager()), - queryRunner.getFunctionAndTypeManager())).rules())); + queryRunner.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class)))).rules())); } private void inTransaction(Function transactionSessionConsumer) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 7ae7d35b94670..7b41ae0b02bf6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -27,6 +27,7 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinDistributionType; @@ -42,9 +43,11 @@ import com.facebook.presto.spi.plan.SemiJoinNode; import com.facebook.presto.spi.plan.SortNode; import com.facebook.presto.spi.plan.SpatialJoinNode; +import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableWriterNode; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.plan.WindowNode.Frame.BoundType; @@ -58,15 +61,14 @@ import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.CallDistributedProcedureNode; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.SequenceNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; @@ -187,6 +189,22 @@ public static PlanMatchPattern indexSource(String expectedTableName) .with(new IndexSourceMatcher(expectedTableName)); } + public static PlanMatchPattern indexSource(String expectedTableName, Map columnReferences) + { + return node(IndexSourceNode.class) + .with(new IndexSourceMatcher(expectedTableName)) + .addColumnReferences(expectedTableName, columnReferences); + } + + public static PlanMatchPattern strictIndexSource(String expectedTableName, Map columnReferences) + { + return node(IndexSourceNode.class) + .with(new IndexSourceMatcher(expectedTableName)) + .withExactAssignedOutputs(columnReferences.values().stream() + .map(columnName -> columnReference(expectedTableName, columnName)) + .collect(toImmutableList())); + } + public static PlanMatchPattern constrainedIndexSource(String expectedTableName, Map constraint, Map columnReferences) { return node(IndexSourceNode.class) @@ -396,6 +414,11 @@ public static PlanMatchPattern strictProject(Map assi .withExactAssignments(assignments.values()); } + public static PlanMatchPattern semiJoin(PlanMatchPattern source, PlanMatchPattern filtering) + { + return node(SemiJoinNode.class, source, filtering); + } + public static PlanMatchPattern semiJoin(String sourceSymbolAlias, String filteringSymbolAlias, String outputAlias, PlanMatchPattern source, PlanMatchPattern filtering) { return semiJoin(sourceSymbolAlias, filteringSymbolAlias, outputAlias, Optional.empty(), source, filtering); @@ -486,13 +509,13 @@ public static PlanMatchPattern spatialJoin(String expectedFilter, PlanMatchPatte public static PlanMatchPattern spatialJoin(String expectedFilter, Optional kdbTree, PlanMatchPattern left, PlanMatchPattern right) { return node(SpatialJoinNode.class, left, right).with( - new SpatialJoinMatcher(SpatialJoinNode.Type.INNER, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), kdbTree)); + new SpatialJoinMatcher(SpatialJoinNode.SpatialJoinType.INNER, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), kdbTree)); } public static PlanMatchPattern spatialLeftJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right) { return node(SpatialJoinNode.class, left, right).with( - new SpatialJoinMatcher(SpatialJoinNode.Type.LEFT, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), Optional.empty())); + new SpatialJoinMatcher(SpatialJoinNode.SpatialJoinType.LEFT, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), Optional.empty())); } public static PlanMatchPattern mergeJoin(JoinType joinType, List> expectedEquiCriteria, Optional filter, PlanMatchPattern left, PlanMatchPattern right) @@ -635,6 +658,11 @@ public static PlanMatchPattern values(Map aliasToIndex) return values(aliasToIndex, Optional.empty(), Optional.empty()); } + public static PlanMatchPattern values(int rowCount) + { + return values(ImmutableList.of(), nCopies(rowCount, ImmutableList.of())); + } + public static PlanMatchPattern values(String... aliases) { return values(ImmutableList.copyOf(aliases)); @@ -670,6 +698,16 @@ public static PlanMatchPattern enforceSingleRow(PlanMatchPattern source) return node(EnforceSingleRowNode.class, source); } + public static PlanMatchPattern callDistributedProcedure(PlanMatchPattern source) + { + return node(CallDistributedProcedureNode.class, source); + } + + public static PlanMatchPattern tableFinish(PlanMatchPattern source) + { + return node(TableFinishNode.class, source); + } + public static PlanMatchPattern tableWriter(List columns, List columnNames, PlanMatchPattern source) { return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames)); @@ -680,6 +718,27 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index f950477c06dd7..8901ea70d9d31 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -493,6 +493,9 @@ else if (expression instanceof TimestampLiteral) { else if (expression instanceof GenericLiteral) { return ((GenericLiteral) expression).getValue(); } + else if (expression instanceof NullLiteral) { + return "null"; + } else { throw new IllegalArgumentException("Unsupported literal expression type: " + expression.getClass().getName()); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpatialJoinMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpatialJoinMatcher.java index 49424bbaea537..b902454fa6fa3 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpatialJoinMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/SpatialJoinMatcher.java @@ -18,7 +18,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.SpatialJoinNode; -import com.facebook.presto.spi.plan.SpatialJoinNode.Type; +import com.facebook.presto.spi.plan.SpatialJoinNode.SpatialJoinType; import com.facebook.presto.sql.tree.Expression; import java.util.Optional; @@ -32,11 +32,11 @@ public class SpatialJoinMatcher implements Matcher { - private final Type type; + private final SpatialJoinType type; private final Expression filter; private final Optional kdbTree; - public SpatialJoinMatcher(Type type, Expression filter, Optional kdbTree) + public SpatialJoinMatcher(SpatialJoinType type, Expression filter, Optional kdbTree) { this.type = type; this.filter = requireNonNull(filter, "filter can not be null"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..c14b68b443867 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,412 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReferences; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + switch (expected.getType()) { + case DescriptorArgumentValue.type: + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + break; + case ScalarArgumentValue.type: + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + break; + default: + if (!(actual instanceof TableArgument) || getMatchResult(symbolAliases, (TableArgumentValue) expected, tableFunctionNode, name).equals(NO_MATCH)) { + return NO_MATCH; + } + } + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + private MatchResult getMatchResult(SymbolAliases symbolAliases, TableArgumentValue expected, TableFunctionNode tableFunctionNode, String name) + { + TableArgumentValue expectedTableArgument = expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + return NO_MATCH; + } + + if (expectedTableArgument.specification().isPresent() != argumentProperties.getSpecification().isPresent()) { + return NO_MATCH; + } + if (!expectedTableArgument.specification() + .map(expectedSpecification -> matchSpecification(argumentProperties.getSpecification().get(), expectedSpecification.getExpectedValue(symbolAliases))) + .orElse(true)) { + return NO_MATCH; + } + Set expectedPassThrough = expectedTableArgument.passThroughVariables().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = toSymbolReferences( + argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(Collectors.toList())) + .stream() + .map(SymbolReference.class::cast) + .collect(Collectors.toSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + return match(symbolAliases); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + String getType(); + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + public static final String type = "Descriptor"; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public String getType() + { + return type; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + public static final String type = "Scalar"; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + + @Override + public String getType() + { + return type; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + private final Set passThroughVariables; + public static final String type = "Table"; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification, Set passThroughVariables) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + this.passThroughVariables = ImmutableSet.copyOf(passThroughVariables); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Set passThroughVariables() + { + return passThroughVariables; + } + + public Optional> specification() + { + return specification; + } + + @Override + public String getType() + { + return type; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + private Set passThroughVariables = ImmutableSet.of(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder passThroughVariables(Set variables) + { + this.passThroughVariables = variables; + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughVariables); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 0000000000000..4891c3eb021dd --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,239 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.QueryPlanner; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final List> passThroughSymbols; + private final List> requiredSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + private final Optional hashSymbol; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + List> passThroughSymbols, + List> requiredSymbols, + Optional> markerSymbols, + Optional> specification, + Optional hashSymbol) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = passThroughSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + List> expectedPassThrough = passThroughSymbols.stream() + .map(list -> list.stream() + .map(symbolAliases::get) + .collect(toImmutableList())) + .collect(toImmutableList()); + List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .map(list -> list.stream() + .map(PassThroughColumn::getOutputVariables) + .map(QueryPlanner::toSymbolReference) + .collect(toImmutableList())) + .collect(toImmutableList()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerVariables().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerVariables().get().entrySet().stream() + .collect(toImmutableMap(entry -> toSymbolReference(entry.getKey()), entry -> toSymbolReference(entry.getValue()))); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!matchSpecification(specification.get().getExpectedValue(symbolAliases), tableFunctionProcessorNode.getSpecification().orElseThrow(NoSuchElementException::new))) { + return NO_MATCH; + } + } + if (hashSymbol.isPresent()) { + if (!hashSymbol.map(symbolAliases::get).equals(tableFunctionProcessorNode.getHashSymbol().map(QueryPlanner::toSymbolReference))) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), toSymbolReference(tableFunctionProcessorNode.getProperOutputs().get(i))); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("requiredSymbols", requiredSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .add("hashSymbol", hashSymbol) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private List> passThroughSymbols = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + private Optional hashSymbol = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(List> passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder hashSymbol(String hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, requiredSymbols, markerSymbols, specification, hashSymbol)); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/UnnestMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/UnnestMatcher.java index ff27f0097ee28..8596d60f3b2e5 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/UnnestMatcher.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/UnnestMatcher.java @@ -17,8 +17,8 @@ import com.facebook.presto.cost.StatsProvider; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableMap; import java.util.Collection; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java new file mode 100644 index 0000000000000..da2ead84c3d8b --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; + +public class TestAddDistinctForSemiJoinBuild + extends BaseRuleTest +{ + @Test + public void testTrigger() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.values(filteringSourceJoinVariable)); + }).matches( + semiJoin( + "sourceJoinVariable", + "filteringSourceJoinVariable", + "semiJoinOutput", + values("sourceJoinVariable"), + aggregation( + singleGroupingSet("filteringSourceJoinVariable"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + values("filteringSourceJoinVariable")))); + } + + @Test + public void testTriggerOverNonQualifiedDistinctAggregation() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + VariableReferenceExpression col1 = p.variable("col1"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.aggregation((a) -> a + .singleGroupingSet(filteringSourceJoinVariable, col1) + .step(AggregationNode.Step.SINGLE) + .source(p.values(filteringSourceJoinVariable, col1)))); + }).matches( + semiJoin( + "sourceJoinVariable", + "filteringSourceJoinVariable", + "semiJoinOutput", + values("sourceJoinVariable"), + aggregation( + singleGroupingSet("filteringSourceJoinVariable"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + aggregation( + singleGroupingSet("filteringSourceJoinVariable", "col1"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + values("filteringSourceJoinVariable", "col1"))))); + } + + @Test + public void testNotTriggerOverDistinct() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.aggregation((a) -> a + .singleGroupingSet(filteringSourceJoinVariable) + .step(AggregationNode.Step.SINGLE) + .source(p.values(filteringSourceJoinVariable)))); + }).doesNotFire(); + } + + @Test + public void testNotTriggerOverDistinctUnderProject() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + VariableReferenceExpression col1 = p.variable("col1"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.project( + assignment(filteringSourceJoinVariable, p.rowExpression("col1")), + p.aggregation((a) -> a + .singleGroupingSet(col1) + .step(AggregationNode.Step.SINGLE) + .source(p.values(col1))))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCombineApproxDistinctFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCombineApproxDistinctFunctions.java new file mode 100644 index 0000000000000..3a58fdfe871ab --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCombineApproxDistinctFunctions.java @@ -0,0 +1,314 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.singletonList; + +public class TestCombineApproxDistinctFunctions + extends BaseRuleTest +{ + @BeforeClass + @Override + public void setUp() + { + tester = new RuleTester(singletonList(new SqlInvokedFunctionsPlugin())); + } + + @Test + public void testBasicApproxDistinct() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col)", AS_DOUBLE)) + .source(p.values(col)); + })) + .doesNotFire(); + } + + @Test + public void testTwoDistinctExpressions() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .matches( + project( + ImmutableMap.of( + "approx_distinct_1", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_2", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 2)))), CAST(0 AS bigint))")), + project( + ImmutableMap.of("transpose_result", expression("array_transpose(set_agg_result)")), + aggregation( + ImmutableMap.of("set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("array_expr"))), + project( + ImmutableMap.of("array_expr", expression("array_constructor(col1, col2)")), + values("col1", "col2")))))); + } + + @Test + public void testMultipleSameTypeExpressions() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + VariableReferenceExpression col3 = p.variable("col3", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_3"), p.rowExpression("approx_distinct(col3)", AS_DOUBLE)) + .source(p.values(col1, col2, col3)); + })) + .matches( + project( + ImmutableMap.of( + "approx_distinct_1", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_2", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 2)))), CAST(0 AS bigint))"), + "approx_distinct_3", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 3)))), CAST(0 AS bigint))")), + project( + ImmutableMap.of("transpose_result", expression("array_transpose(set_agg_result)")), + aggregation( + ImmutableMap.of("set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("array_expr"))), + project( + ImmutableMap.of("array_expr", expression("array_constructor(col1, col2, col3)")), + values("col1", "col2", "col3")))))); + } + + @Test + public void testMixedTypes() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + VariableReferenceExpression col4 = p.variable("col4", BIGINT); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_3"), p.rowExpression("approx_distinct(col3)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_4"), p.rowExpression("approx_distinct(col4)", AS_DOUBLE)) + .source(p.values(col1, col2, col3, col4)); + })) + .matches( + project( + ImmutableMap.of( + "approx_distinct_1", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(varchar_transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_2", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(varchar_transpose_result, 2)))), CAST(0 AS bigint))"), + "approx_distinct_3", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(bigint_transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_4", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(bigint_transpose_result, 2)))), CAST(0 AS bigint))")), + project( + ImmutableMap.of( + "varchar_transpose_result", expression("array_transpose(varchar_set_agg_result)"), + "bigint_transpose_result", expression("array_transpose(bigint_set_agg_result)")), + aggregation( + ImmutableMap.of( + "varchar_set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("varchar_array_expr")), + "bigint_set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("bigint_array_expr"))), + project( + ImmutableMap.of( + "varchar_array_expr", expression("array_constructor(col1, col2)"), + "bigint_array_expr", expression("array_constructor(col3, col4)")), + values("col1", "col2", "col3", "col4")))))); + } + + @Test + public void testDoesNotFire() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col)", AS_DOUBLE)) + .source(p.values(col)); + })) + .doesNotFire(); + } + + @Test + public void testDifferentTypes() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .doesNotFire(); + } + + @Test + public void testWithDuplicate() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col1)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_3"), p.rowExpression("approx_distinct(col2)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .doesNotFire(); + } + + @Test + public void testWithOtherAggregations() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1)", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2)", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)) + .addAggregation(p.variable("count_col3"), p.rowExpression("count(col3)", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)) + .source(p.values(col1, col2, col3)); + })) + .matches( + project( + ImmutableMap.of( + "approx_distinct_1", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_2", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 2)))), CAST(0 AS bigint))"), + "count_col3", expression("count_col3")), + project( + ImmutableMap.of("transpose_result", expression("array_transpose(set_agg_result)")), + aggregation( + ImmutableMap.of( + "set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("array_expr")), + "count_col3", PlanMatchPattern.functionCall("count", ImmutableList.of("col3"))), + project( + ImmutableMap.of("array_expr", expression("array_constructor(col1, col2)")), + values("col1", "col2", "col3")))))); + } + + @Test + public void testTwoArgumentApproxDistinctWithSameError() + { + tester().assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on(p -> p.aggregation(af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + af.globalGrouping() + .addAggregation(p.variable("approx_distinct_1"), p.rowExpression("approx_distinct(col1, 0.01)", AS_DOUBLE)) + .addAggregation(p.variable("approx_distinct_2"), p.rowExpression("approx_distinct(col2, 0.01)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .matches( + project( + ImmutableMap.of( + "approx_distinct_1", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 1)))), CAST(0 AS bigint))"), + "approx_distinct_2", expression("coalesce(cardinality(array_distinct(remove_nulls(element_at(transpose_result, 2)))), CAST(0 AS bigint))")), + project( + ImmutableMap.of("transpose_result", expression("array_transpose(set_agg_result)")), + aggregation( + ImmutableMap.of("set_agg_result", PlanMatchPattern.functionCall("set_agg", ImmutableList.of("array_expr"))), + project( + ImmutableMap.of("array_expr", expression("array_constructor(col1, col2)")), + values("col1", "col2")))))); + } + + @Test + public void testTwoArgumentApproxDistinctDifferentTypesWithSameError() + { + tester() + .assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty( + SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on( + p -> + p.aggregation( + af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + af.globalGrouping() + .addAggregation( + p.variable("approx_distinct_1"), + p.rowExpression("approx_distinct(col1, 0.01)", AS_DOUBLE)) + .addAggregation( + p.variable("approx_distinct_2"), + p.rowExpression("approx_distinct(col2, 0.01)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .doesNotFire(); + } + + @Test + public void testTwoArgumentApproxDistinctWithDuplicates() + { + tester() + .assertThat(new CombineApproxDistinctFunctions(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty( + SystemSessionProperties.OPTIMIZE_MULTIPLE_APPROX_DISTINCT_ON_SAME_TYPE, "true") + .on( + p -> + p.aggregation( + af -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", VARCHAR); + af.globalGrouping() + .addAggregation( + p.variable("approx_distinct_1"), + p.rowExpression("approx_distinct(col1, 0.01)", AS_DOUBLE)) + .addAggregation( + p.variable("approx_distinct_2"), + p.rowExpression("approx_distinct(col1, 0.01)", AS_DOUBLE)) + .addAggregation( + p.variable("approx_distinct_3"), + p.rowExpression("approx_distinct(col2, 0.01)", AS_DOUBLE)) + .source(p.values(col1, col2)); + })) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCrossJoinWithArrayNotContainsToAntiJoin.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCrossJoinWithArrayNotContainsToAntiJoin.java index da8920658463d..194cf489d3998 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCrossJoinWithArrayNotContainsToAntiJoin.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestCrossJoinWithArrayNotContainsToAntiJoin.java @@ -14,17 +14,20 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import static com.facebook.presto.SystemSessionProperties.REWRITE_CROSS_JOIN_ARRAY_NOT_CONTAINS_TO_ANTI_JOIN; @@ -33,10 +36,18 @@ import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.constantExpressions; +import static java.util.Collections.singletonList; public class TestCrossJoinWithArrayNotContainsToAntiJoin extends BaseRuleTest { + @BeforeClass + @Override + public void setUp() + { + tester = new RuleTester(singletonList(new SqlInvokedFunctionsPlugin())); + } + @Test public void testTriggerForBigInt() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java index 47ec9dbc5030c..35559f41e06ad 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestDetermineJoinDistributionType.java @@ -26,13 +26,13 @@ import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; -import com.facebook.presto.sql.planner.plan.UnnestNode; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestGroupInnerJoinsByConnectorRuleSet.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestGroupInnerJoinsByConnectorRuleSet.java index 3e3fbf6089d8d..d29241c1531a0 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestGroupInnerJoinsByConnectorRuleSet.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestGroupInnerJoinsByConnectorRuleSet.java @@ -201,7 +201,7 @@ public void testDPartialPushDownTwoDifferentConnectors() .matches( project( filter( - "a1 = b1 and a1 = c1 and true", + "a1 = b1 and a1 = c1 and true", join( JoinTableScanMatcher.tableScan(CATALOG_SUPPORTING_JOIN_PUSHDOWN, tableHandle1, "a1", "a2", "c1", "c2"), JoinTableScanMatcher.tableScan(LOCAL, tableHandle2, "b1", "b2"))))); @@ -277,7 +277,7 @@ public void testJoinPushDownHappenedWithFilters() .matches( project( filter( - "a1 = a2 and a1 > b1 and true", + "a1 = a2 and a1 > b1 and true", JoinTableScanMatcher.tableScan(CATALOG_SUPPORTING_JOIN_PUSHDOWN, tableHandle, "a1", "a2", "b1")))); } @@ -339,11 +339,11 @@ public void testPushDownWithTwoDifferentConnectors() tableScan(CATALOG_SUPPORTING_JOIN_PUSHDOWN, "b1", "b2"), tableScan(OTHER_CATALOG_SUPPORTING_JOIN_PUSHDOWN, "c1", "c2"), new EquiJoinClause(newBigintVariable("b1"), newBigintVariable("c1"))), - new EquiJoinClause(newBigintVariable("c1"), newBigintVariable("d1")))) + new EquiJoinClause(newBigintVariable("c1"), newBigintVariable("d1")))) .matches( project( filter( - "((a1 = b1 and a1 = d1) and (b1 = c1 and c1 = d1)) and true", + "((a1 = b1 and a1 = d1) and (b1 = c1 and c1 = d1)) and true", join( JoinTableScanMatcher.tableScan(CATALOG_SUPPORTING_JOIN_PUSHDOWN, tableHandle1, "a1", "b1"), JoinTableScanMatcher.tableScan(OTHER_CATALOG_SUPPORTING_JOIN_PUSHDOWN, tableHandle2, "c1", "d1"))))); @@ -352,9 +352,7 @@ public void testPushDownWithTwoDifferentConnectors() private RuleAssert assertGroupInnerJoinsByConnectorRuleSet() { // For testing, we do not wish to push down pulled up predicates - return tester.assertThat(new GroupInnerJoinsByConnectorRuleSet.OnlyJoinRule(tester.getMetadata(), - (plan, session, types, variableAllocator, idAllocator, warningCollector) -> - PlanOptimizerResult.optimizerResult(plan, false)), + return tester.assertThat(new GroupInnerJoinsByConnectorRuleSet.OnlyJoinRule(tester.getMetadata(), (plan, session, types, variableAllocator, idAllocator, warningCollector) -> PlanOptimizerResult.optimizerResult(plan, false)), ImmutableList.of(CATALOG_SUPPORTING_JOIN_PUSHDOWN, OTHER_CATALOG_SUPPORTING_JOIN_PUSHDOWN)); } @@ -473,8 +471,8 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses ConnectorTableHandle connectorHandle = otherTable.getConnectorHandle(); if (connectorId.equals(otherTable.getConnectorId()) && Objects.equals(otherTable.getConnectorId(), this.tableHandle.getConnectorId()) && - Objects.equals(otherTable.getConnectorHandle(), this.tableHandle.getConnectorHandle()) && - Objects.equals(otherTable.getLayout().isPresent(), this.tableHandle.getLayout().isPresent())) { + Objects.equals(otherTable.getConnectorHandle(), this.tableHandle.getConnectorHandle()) && + Objects.equals(otherTable.getLayout().isPresent(), this.tableHandle.getLayout().isPresent())) { return MatchResult.match(SymbolAliases.builder().putAll(Arrays.stream(columns).collect(toMap(identity(), SymbolReference::new))).build()); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java index 657093aec5be2..cdf060f86d390 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLeftJoinWithArrayContainsToEquiJoinCondition.java @@ -14,11 +14,14 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.scalar.sql.SqlInvokedFunctionsPlugin; import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.util.Optional; @@ -32,10 +35,18 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static java.util.Collections.singletonList; public class TestLeftJoinWithArrayContainsToEquiJoinCondition extends BaseRuleTest { + @BeforeClass + @Override + public void setUp() + { + tester = new RuleTester(singletonList(new SqlInvokedFunctionsPlugin())); + } + @Test public void testTriggerForBigIntArrayRightSide() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMaterializedViewRewrite.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMaterializedViewRewrite.java new file mode 100644 index 0000000000000..6295e4b13de6a --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMaterializedViewRewrite.java @@ -0,0 +1,605 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.airlift.units.Duration; +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.metadata.AbstractMockMetadata; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewRefreshType; +import com.facebook.presto.spi.MaterializedViewStaleReadBehavior; +import com.facebook.presto.spi.MaterializedViewStalenessConfig; +import com.facebook.presto.spi.MaterializedViewStatus; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.analyzer.MetadataResolver; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.security.AllowAllAccessControl; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.SystemSessionProperties.MATERIALIZED_VIEW_STALE_READ_BEHAVIOR; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.MaterializedViewStatus.MaterializedViewState.FULLY_MATERIALIZED; +import static com.facebook.presto.spi.MaterializedViewStatus.MaterializedViewState.PARTIALLY_MATERIALIZED; +import static com.facebook.presto.spi.StandardErrorCode.MATERIALIZED_VIEW_STALE; +import static com.facebook.presto.spi.security.ViewSecurity.DEFINER; +import static com.facebook.presto.spi.security.ViewSecurity.INVOKER; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.expectThrows; + +public class TestMaterializedViewRewrite + extends BaseRuleTest +{ + @Override + @BeforeClass + public void setUp() + { + FeaturesConfig featuresConfig = new FeaturesConfig() + .setAllowLegacyMaterializedViewsToggle(true) + .setLegacyMaterializedViews(false); + + Session tempSession = testSessionBuilder().setCatalog("local").setSchema("tiny").build(); + LocalQueryRunner queryRunner = new LocalQueryRunner(tempSession, featuresConfig, new FunctionsConfig()); + + Session session = testSessionBuilder(queryRunner.getMetadata().getSessionPropertyManager()).setCatalog("local").setSchema("tiny").build(); + tester = new RuleTester(ImmutableList.of(), session, queryRunner, new TpchConnectorFactory(1)); + } + @Test + public void testUseFreshDataWhenFullyMaterialized() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + Metadata metadata = new TestingMetadataWithMaterializedViewStatus(true); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches( + project( + ImmutableMap.of("a", expression("data_table_a")), + values("data_table_a"))); + } + + @Test + public void testUseViewQueryWhenNotFullyMaterialized() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + Metadata metadata = new TestingMetadataWithMaterializedViewStatus(false); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches( + project( + ImmutableMap.of("a", expression("view_query_a")), + values("view_query_a"))); + } + + @Test + public void testMultipleOutputVariables() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + Metadata metadata = new TestingMetadataWithMaterializedViewStatus(true); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression outputB = planBuilder.variable("b", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression dataTableB = planBuilder.variable("data_table_b", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + VariableReferenceExpression viewQueryB = planBuilder.variable("view_query_b", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA, dataTableB), + planBuilder.values(viewQueryA, viewQueryB), + ImmutableMap.of(outputA, dataTableA, outputB, dataTableB), + ImmutableMap.of(outputA, viewQueryA, outputB, viewQueryB), + outputA, outputB); + }) + .matches( + project( + ImmutableMap.of( + "a", expression("data_table_a"), + "b", expression("data_table_b")), + values("data_table_a", "data_table_b"))); + } + + @Test + public void testUseViewQueryWhenBaseTableDoesNotExist() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + Metadata metadata = new TestingMetadataWithMissingBaseTable(true); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + VariableReferenceExpression viewQueryB = planBuilder.variable("view_query_b", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.project( + Assignments.builder() + .put(viewQueryA, planBuilder.variable("view_query_b", BIGINT)) + .build(), + planBuilder.values(viewQueryB)), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches( + project( + ImmutableMap.of("a", expression("view_query_a")), + project( + ImmutableMap.of("view_query_a", expression("view_query_b")), + values("view_query_b")))); + } + + @Test + public void testFailWhenStaleAndSessionPropertyIsFail() + { + FeaturesConfig featuresConfig = new FeaturesConfig() + .setAllowLegacyMaterializedViewsToggle(true) + .setLegacyMaterializedViews(false) + .setMaterializedViewStaleReadBehavior(MaterializedViewStaleReadBehavior.FAIL); + + Session tempSession = testSessionBuilder().setCatalog("local").setSchema("tiny").build(); + LocalQueryRunner queryRunner = new LocalQueryRunner(tempSession, featuresConfig, new FunctionsConfig()); + + Session sessionWithFail = testSessionBuilder(queryRunner.getMetadata().getSessionPropertyManager()) + .setCatalog("local") + .setSchema("tiny") + .setSystemProperty(MATERIALIZED_VIEW_STALE_READ_BEHAVIOR, "FAIL") + .build(); + + RuleTester testerWithFail = new RuleTester(ImmutableList.of(), sessionWithFail, queryRunner, new TpchConnectorFactory(1)); + + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + Metadata metadata = new TestingMetadataWithMaterializedViewStatus(false); + + PrestoException exception = expectThrows(PrestoException.class, () -> + testerWithFail.assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches(values("view_query_a"))); + assertEquals(exception.getErrorCode(), MATERIALIZED_VIEW_STALE.toErrorCode()); + } + + @Test + public void testUseDataTableWhenStalenessWithinTolerance() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + // Staleness config with 1-hour window, last fresh 30 minutes ago (within tolerance) + MaterializedViewStalenessConfig stalenessConfig = new MaterializedViewStalenessConfig( + MaterializedViewStaleReadBehavior.FAIL, + new Duration(1, TimeUnit.HOURS)); + + long lastFreshTime = System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(30); + + // Not fully materialized, but within staleness window - should use data table + Metadata metadata = new TestingMetadataWithStalenessConfig(false, stalenessConfig, Optional.of(lastFreshTime)); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches( + project( + ImmutableMap.of("a", expression("data_table_a")), + values("data_table_a"))); + } + + @Test + public void testUseViewQueryWhenStalenessBeyondToleranceWithUseViewQueryBehavior() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + // Staleness config with 1-hour window, USE_VIEW_QUERY behavior + MaterializedViewStalenessConfig stalenessConfig = new MaterializedViewStalenessConfig( + MaterializedViewStaleReadBehavior.USE_VIEW_QUERY, + new Duration(1, TimeUnit.HOURS)); + + // Last fresh 2 hours ago (beyond tolerance) + long lastFreshTime = System.currentTimeMillis() - TimeUnit.HOURS.toMillis(2); + + // Not fully materialized, beyond staleness window - should use view query due to USE_VIEW_QUERY behavior + Metadata metadata = new TestingMetadataWithStalenessConfig(false, stalenessConfig, Optional.of(lastFreshTime)); + + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches( + project( + ImmutableMap.of("a", expression("view_query_a")), + values("view_query_a"))); + } + + @Test + public void testFailWhenStalenessBeyondToleranceWithFailBehavior() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + // Staleness config with 1-hour window, FAIL behavior + MaterializedViewStalenessConfig stalenessConfig = new MaterializedViewStalenessConfig( + MaterializedViewStaleReadBehavior.FAIL, + new Duration(1, TimeUnit.HOURS)); + + // Last fresh 2 hours ago (beyond tolerance) + long lastFreshTime = System.currentTimeMillis() - TimeUnit.HOURS.toMillis(2); + + // Not fully materialized, beyond staleness window - should fail + Metadata metadata = new TestingMetadataWithStalenessConfig(false, stalenessConfig, Optional.of(lastFreshTime)); + + PrestoException exception = expectThrows(PrestoException.class, () -> + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches(values("view_query_a"))); + assertEquals(exception.getErrorCode(), MATERIALIZED_VIEW_STALE.toErrorCode()); + } + + @Test + public void testFailWhenNeverRefreshedWithStalenessConfig() + { + QualifiedObjectName materializedViewName = QualifiedObjectName.valueOf("catalog.schema.mv"); + + // Staleness config with 1-hour window, FAIL behavior + MaterializedViewStalenessConfig stalenessConfig = new MaterializedViewStalenessConfig( + MaterializedViewStaleReadBehavior.FAIL, + new Duration(1, TimeUnit.HOURS)); + + // Never refreshed (no lastFreshTime) - should fail since staleness is beyond any tolerance + Metadata metadata = new TestingMetadataWithStalenessConfig(false, stalenessConfig, Optional.empty()); + + PrestoException exception = expectThrows(PrestoException.class, () -> + tester().assertThat(new MaterializedViewRewrite(metadata, new AllowAllAccessControl())) + .on(planBuilder -> { + VariableReferenceExpression outputA = planBuilder.variable("a", BIGINT); + VariableReferenceExpression dataTableA = planBuilder.variable("data_table_a", BIGINT); + VariableReferenceExpression viewQueryA = planBuilder.variable("view_query_a", BIGINT); + + return planBuilder.materializedViewScan( + materializedViewName, + planBuilder.values(dataTableA), + planBuilder.values(viewQueryA), + ImmutableMap.of(outputA, dataTableA), + ImmutableMap.of(outputA, viewQueryA), + outputA); + }) + .matches(values("view_query_a"))); + assertEquals(exception.getErrorCode(), MATERIALIZED_VIEW_STALE.toErrorCode()); + } + + private static class TestingMetadataWithStalenessConfig + extends AbstractMockMetadata + { + private final boolean isFullyMaterialized; + private final MaterializedViewStalenessConfig stalenessConfig; + private final Optional lastFreshTime; + + public TestingMetadataWithStalenessConfig( + boolean isFullyMaterialized, + MaterializedViewStalenessConfig stalenessConfig, + Optional lastFreshTime) + { + this.isFullyMaterialized = isFullyMaterialized; + this.stalenessConfig = stalenessConfig; + this.lastFreshTime = lastFreshTime; + } + + @Override + public MetadataResolver getMetadataResolver(Session session) + { + return new MaterializedViewTestingMetadataResolverWithStalenessConfig( + super.getMetadataResolver(session), + isFullyMaterialized, + stalenessConfig, + lastFreshTime); + } + } + + private static class MaterializedViewTestingMetadataResolverWithStalenessConfig + implements MetadataResolver + { + private final MetadataResolver delegate; + private final boolean isFullyMaterialized; + private final MaterializedViewStalenessConfig stalenessConfig; + private final Optional lastFreshTime; + + protected MaterializedViewTestingMetadataResolverWithStalenessConfig( + MetadataResolver delegate, + boolean isFullyMaterialized, + MaterializedViewStalenessConfig stalenessConfig, + Optional lastFreshTime) + { + this.delegate = delegate; + this.isFullyMaterialized = isFullyMaterialized; + this.stalenessConfig = stalenessConfig; + this.lastFreshTime = lastFreshTime; + } + + @Override + public boolean catalogExists(String catalogName) + { + return delegate.catalogExists(catalogName); + } + + @Override + public boolean schemaExists(com.facebook.presto.common.CatalogSchemaName schemaName) + { + return delegate.schemaExists(schemaName); + } + + @Override + public Optional getTableHandle(QualifiedObjectName tableName) + { + return delegate.getTableHandle(tableName); + } + + @Override + public List getColumns(TableHandle tableHandle) + { + return delegate.getColumns(tableHandle); + } + + @Override + public Map getColumnHandles(TableHandle tableHandle) + { + return delegate.getColumnHandles(tableHandle); + } + + @Override + public Optional getView(QualifiedObjectName viewName) + { + return delegate.getView(viewName); + } + + @Override + public Optional getMaterializedView(QualifiedObjectName viewName) + { + return Optional.of(new MaterializedViewDefinition( + "SELECT * FROM base_table", + "schema", + "mv", + ImmutableList.of(new SchemaTableName("schema", "base_table")), + Optional.of("test_owner"), + Optional.of(DEFINER), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.of(stalenessConfig), + Optional.of(MaterializedViewRefreshType.FULL))); + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName materializedViewName, TupleDomain baseQueryDomain) + { + return new MaterializedViewStatus( + isFullyMaterialized ? FULLY_MATERIALIZED : PARTIALLY_MATERIALIZED, + ImmutableMap.of(), + lastFreshTime); + } + } + + private static class TestingMetadataWithMaterializedViewStatus + extends AbstractMockMetadata + { + private final boolean isFullyMaterialized; + + public TestingMetadataWithMaterializedViewStatus(boolean isFullyMaterialized) + { + this.isFullyMaterialized = isFullyMaterialized; + } + + @Override + public MetadataResolver getMetadataResolver(Session session) + { + return new MaterializedViewTestingMetadataResolver(super.getMetadataResolver(session), isFullyMaterialized, false); + } + } + + private static class TestingMetadataWithMissingBaseTable + extends AbstractMockMetadata + { + private final boolean isFullyMaterialized; + + public TestingMetadataWithMissingBaseTable(boolean isFullyMaterialized) + { + this.isFullyMaterialized = isFullyMaterialized; + } + + @Override + public MetadataResolver getMetadataResolver(Session session) + { + return new MaterializedViewTestingMetadataResolver(super.getMetadataResolver(session), isFullyMaterialized, true); + } + } + + private static class MaterializedViewTestingMetadataResolver + implements MetadataResolver + { + private final MetadataResolver delegate; + private boolean isFullyMaterialized; + private boolean baseTableMissing; + + protected MaterializedViewTestingMetadataResolver(MetadataResolver delegate, boolean isFullyMaterialized, boolean baseTableMissing) + { + this.delegate = delegate; + this.isFullyMaterialized = isFullyMaterialized; + this.baseTableMissing = baseTableMissing; + } + + @Override + public boolean catalogExists(String catalogName) + { + return delegate.catalogExists(catalogName); + } + + @Override + public boolean schemaExists(com.facebook.presto.common.CatalogSchemaName schemaName) + { + return delegate.schemaExists(schemaName); + } + + @Override + public Optional getTableHandle(QualifiedObjectName tableName) + { + if (baseTableMissing) { + return Optional.empty(); + } + return delegate.getTableHandle(tableName); + } + + @Override + public List getColumns(TableHandle tableHandle) + { + return delegate.getColumns(tableHandle); + } + + @Override + public Map getColumnHandles(TableHandle tableHandle) + { + return delegate.getColumnHandles(tableHandle); + } + + @Override + public Optional getView(QualifiedObjectName viewName) + { + return delegate.getView(viewName); + } + + @Override + public Optional getMaterializedView(QualifiedObjectName viewName) + { + return Optional.of(new MaterializedViewDefinition( + "SELECT * FROM base_table", + "schema", + "mv", + ImmutableList.of(new SchemaTableName("schema", "base_table")), + Optional.of("test_owner"), + Optional.of(baseTableMissing ? INVOKER : DEFINER), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + Optional.empty())); + } + + @Override + public MaterializedViewStatus getMaterializedViewStatus(QualifiedObjectName materializedViewName, TupleDomain baseQueryDomain) + { + return new MaterializedViewStatus(isFullyMaterialized ? FULLY_MATERIALIZED : PARTIALLY_MATERIALIZED); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java new file mode 100644 index 0000000000000..b1d6beda57658 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java @@ -0,0 +1,247 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.SystemSessionProperties.REWRITE_MIN_MAX_BY_TO_TOP_N; +import static com.facebook.presto.common.block.MethodHandleUtil.compose; +import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.topNRowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle; + +public class TestMinMaxByToWindowFunction + extends BaseRuleTest +{ + private static final MethodHandle KEY_NATIVE_EQUALS = getOperatorMethodHandle(OperatorType.EQUAL, BIGINT, BIGINT); + private static final MethodHandle KEY_BLOCK_EQUALS = compose(KEY_NATIVE_EQUALS, nativeValueGetter(BIGINT), nativeValueGetter(BIGINT)); + private static final MethodHandle KEY_NATIVE_HASH_CODE = getOperatorMethodHandle(OperatorType.HASH_CODE, BIGINT); + private static final MethodHandle KEY_BLOCK_HASH_CODE = compose(KEY_NATIVE_HASH_CODE, nativeValueGetter(BIGINT)); + + @Test + public void testMaxByOnly() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMaxAndMaxBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("max(ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinByOnly() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("min_by(a, ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", ASC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinAndMinBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("min_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("min(ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", ASC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinAndMaxBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("min(ds)")) + .source( + p.values(ds, a, id))); + }).doesNotFire(); + } + + @Test + public void testMaxByOnlyNotOnMap() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", VARCHAR); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .source( + p.values(ds, a, id))); + }).doesNotFire(); + } + + @Test + public void testMaxByOnBothMapNonMap() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("max_by(b, ds)")) + .source( + p.values(ds, a, b, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "b", "id"))))); + } + + @Test + public void testMaxByArray() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new ArrayType(BIGINT)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java new file mode 100644 index 0000000000000..54a5c35c05845 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneMergeSourceColumns.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.List; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneMergeSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneInputColumn() + { + tester().assertThat(new PruneMergeSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression mergeRow = p.variable("merge_row"); + VariableReferenceExpression rowId = p.variable("row_id"); + VariableReferenceExpression partialRows = p.variable("partial_rows"); + VariableReferenceExpression fragment = p.variable("fragment"); + List mergeProcessorProjectedSymbols = ImmutableList.of(mergeRow, rowId); + return p.merge( + new SchemaTableName("schema", "table"), + p.values(a, mergeRow, rowId), + mergeProcessorProjectedSymbols, + ImmutableList.of(partialRows, fragment)); + }) + .matches( + node( + MergeWriterNode.class, + strictProject( + ImmutableMap.of( + "row_id", PlanMatchPattern.expression("row_id"), + "merge_row", PlanMatchPattern.expression("merge_row")), + values("a", "merge_row", "row_id")))); + } + + @Test + public void testDoNotPruneRowId() + { + tester().assertThat(new PruneMergeSourceColumns()) + .on(p -> { + VariableReferenceExpression mergeRow = p.variable("merge_row"); + VariableReferenceExpression rowId = p.variable("row_id"); + VariableReferenceExpression partialRows = p.variable("partial_rows"); + VariableReferenceExpression fragment = p.variable("fragment"); + List mergeProcessorProjectedSymbols = ImmutableList.of(mergeRow, rowId); + return p.merge( + new SchemaTableName("schema", "table"), + p.values(mergeRow, rowId), + mergeProcessorProjectedSymbols, + ImmutableList.of(partialRows, fragment)); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..bcae22ae6c623 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorColumns + extends BaseRuleTest +{ + @Test + public void testDoNotPruneProperOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("p")) + .source(p.values(p.variable("x")))))) + .doesNotFire(); + } + + @Test + public void testPrunePassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + } + + @Test + public void testReferencedPassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression x = p.variable("x"); + VariableReferenceExpression y = p.variable("y"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(y, y).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(x, y) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of("y", expression("y"), "b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("x", "y")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("b"))), + values("a", "b")))); + } + + @Test + public void testAllPassThroughOutputsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + } + + @Test + public void testNoSource() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper"))))) + .doesNotFire(); + } + + @Test + public void testMultipleTableArguments() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.project( + Assignments.builder().put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper")) + .passThroughSpecifications( + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(a, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(b, true))), + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false)))) + .source(p.values(a, b, c, d)))); + }) + .matches(project( + ImmutableMap.of("b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of(), ImmutableList.of("b"), ImmutableList.of())), + values("a", "b", "c", "d")))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..68f56d320e396 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneUnreferencedSymbol() + { + // symbols 'a', 'b', 'c', 'd', 'hash', and 'marker' are used by the node. + // symbol 'unreferenced' is pruned out. Also, the mapping for this symbol is removed from marker mappings + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false)))) + .requiredSymbols(ImmutableList.of(ImmutableList.of(b))) + .markerSymbols(ImmutableMap.of( + a, marker, + b, marker, + c, marker, + d, marker, + unreferenced, marker)) + .specification(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_FIRST)))))) + .hashSymbol(hash) + .source(p.values(a, b, c, d, unreferenced, hash, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"))) + .markerSymbols(ImmutableMap.of( + "a", "marker", + "b", "marker", + "c", "marker", + "d", "marker")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_FIRST))) + .hashSymbol("hash"), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "hash", expression("hash"), + "marker", expression("marker")), + values("a", "b", "c", "d", "unreferenced", "hash", "marker")))); + } + + @Test + public void testPruneUnusedMarkerSymbol() + { + // symbol 'unreferenced' is pruned out because the node does not use it. + // also, the mapping for this symbol is removed from marker mappings. + // because the marker symbol 'marker' is no longer used, it is pruned out too. + // note: currently a marker symbol cannot become unused because the function + // must use at least one symbol from each source. it might change in the future. + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of(unreferenced, marker)) + .source(p.values(unreferenced, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of()), + project( + ImmutableMap.of(), + values("unreferenced", "marker")))); + } + + @Test + public void testMultipleSources() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + // the third argument provides symbols 'e', 'f', and 'unreferenced'. those symbols are mapped to common marker symbol 'marker3' + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression marker1 = p.variable("marker1"); + VariableReferenceExpression marker2 = p.variable("marker2"); + VariableReferenceExpression marker3 = p.variable("marker3"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true)))) + .requiredSymbols(ImmutableList.of( + ImmutableList.of(b), + ImmutableList.of(d), + ImmutableList.of(f))) + .markerSymbols(ImmutableMap.of( + a, marker1, + b, marker1, + c, marker2, + d, marker2, + e, marker3, + f, marker3, + unreferenced, marker3)) + .source(p.values(a, b, c, d, e, f, marker1, marker2, marker3, unreferenced))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"), ImmutableList.of("c"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"), ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "a", "marker1", + "b", "marker1", + "c", "marker2", + "d", "marker2", + "e", "marker3", + "f", "marker3")), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "e", expression("e"), + "f", expression("f"), + "marker1", expression("marker1"), + "marker2", expression("marker2"), + "marker3", expression("marker3")), + values("a", "b", "c", "d", "e", "f", "marker1", "marker2", "marker3", "unreferenced")))); + } + + @Test + public void allSymbolsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(a))) + .markerSymbols(ImmutableMap.of(a, marker)) + .source(p.values(a, marker))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java index be66d0b00210b..f49b796fb67d9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughExchange.java @@ -22,8 +22,10 @@ import com.google.common.collect.ImmutableList; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION; import static com.facebook.presto.common.function.OperatorType.MULTIPLY; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.spi.plan.ProjectNode.Locality.REMOTE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; @@ -192,4 +194,46 @@ public void testOrderingColumnsArePreserved() ).withNumberOfOutputColumns(3) .withExactOutputs("a_times_5", "b_times_5", "h_times_5")); } + + @Test + public void testDoesNotFireWithSkipProjectionPushdownThroughExchangeForRemoteProjection() + { + tester().assertThat(new PushProjectionThroughExchange()) + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression x = p.variable("x"); + return p.project( + p.exchange(e -> e + .addSource(p.values(a)) + .addInputsSet(a) + .singleDistributionPartitioningScheme(a)), + Assignments.builder() + .put(x, p.binaryOperation(MULTIPLY, a, constant(5L, BIGINT))) + .build(), + REMOTE); + }) + .doesNotFire(); + } + + @Test + public void testFiresWithoutSkipProjectionPushdownThroughExchangeForRemoteProjectionWhenProjectionIsNotRemote() + { + tester().assertThat(new PushProjectionThroughExchange()) + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression x = p.variable("x"); + return p.project( + assignment(x, p.binaryOperation(MULTIPLY, a, constant(5L, BIGINT))), + p.exchange(e -> e + .addSource(p.values(a)) + .addInputsSet(a) + .singleDistributionPartitioningScheme(a))); + }) + .matches( + exchange( + project( + values(ImmutableList.of("a"))))); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRandomizeSourceKeyInSemiJoin.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRandomizeSourceKeyInSemiJoin.java new file mode 100644 index 0000000000000..79194af51a813 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRandomizeSourceKeyInSemiJoin.java @@ -0,0 +1,161 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRandomizeSourceKeyInSemiJoin + extends BaseRuleTest +{ + @Test + public void testSemiJoinWithSupportedTypes() + { + tester().assertThat(new RandomizeSourceKeyInSemiJoin(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, "ALWAYS") + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") + .on(p -> { + return p.semiJoin( + p.variable("left_k1"), + p.variable("right_k1"), + p.variable("semi_output", BOOLEAN), + Optional.empty(), + Optional.empty(), + p.values(p.variable("left_k1")), + p.values(p.variable("right_k1"))); + }) + .matches( + project( + ImmutableMap.of("left_k1", expression("left_k1"), "semi_output", expression("semi_join_output_randomized OR if(left_k1 IS NULL, CAST(null AS boolean), false)")), + semiJoin( + "randomized_left", "cast_right", "semi_join_output_randomized", + project( + ImmutableMap.of("left_k1", expression("left_k1"), "randomized_left", expression("coalesce(cast(left_k1 as varchar), 'l' || cast(random(100) as varchar))")), + values("left_k1")), + project( + ImmutableMap.of("right_k1", expression("right_k1"), "cast_right", expression("cast(right_k1 as varchar)")), + values("right_k1"))))); + } + + @Test + public void testSemiJoinWithIntegerType() + { + tester().assertThat(new RandomizeSourceKeyInSemiJoin(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, "ALWAYS") + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") + .on(p -> { + return p.semiJoin( + p.variable("left_k1", INTEGER), + p.variable("right_k1", INTEGER), + p.variable("semi_output", BOOLEAN), + Optional.empty(), + Optional.empty(), + p.values(p.variable("left_k1", INTEGER)), + p.values(p.variable("right_k1", INTEGER))); + }) + .matches( + project( + ImmutableMap.of("left_k1", expression("left_k1"), "semi_output", expression("semi_join_output_randomized OR if(left_k1 IS NULL, CAST(null AS boolean), false)")), + semiJoin( + "randomized_left", "cast_right", "semi_join_output_randomized", + project( + ImmutableMap.of("left_k1", expression("left_k1"), "randomized_left", expression("coalesce(cast(left_k1 as varchar), 'l' || cast(random(100) as varchar))")), + values("left_k1")), + project( + ImmutableMap.of("right_k1", expression("right_k1"), "cast_right", expression("cast(right_k1 as varchar)")), + values("right_k1"))))); + } + + @Test + public void testSemiJoinWithUnsupportedType() + { + // VARCHAR type is not supported - rule should not fire + tester().assertThat(new RandomizeSourceKeyInSemiJoin(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, "ALWAYS") + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") + .on(p -> { + return p.semiJoin( + p.variable("left_k1", VARCHAR), + p.variable("right_k1", VARCHAR), + p.variable("semi_output", BOOLEAN), + Optional.empty(), + Optional.empty(), + p.values(p.variable("left_k1", VARCHAR)), + p.values(p.variable("right_k1", VARCHAR))); + }) + .doesNotFire(); + } + + @Test + public void testSemiJoinWithReplicatedDistribution() + { + // Replicated distribution should not be optimized + tester().assertThat(new RandomizeSourceKeyInSemiJoin(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, "ALWAYS") + .on(p -> { + return new SemiJoinNode( + Optional.empty(), + new PlanNodeId("semi"), + Optional.empty(), + p.values(p.variable("left_k1")), + p.values(p.variable("right_k1")), + p.variable("left_k1"), + p.variable("right_k1"), + p.variable("semi_output", BOOLEAN), + Optional.empty(), + Optional.empty(), + Optional.of(SemiJoinNode.DistributionType.REPLICATED), + ImmutableMap.of()); + }) + .doesNotFire(); + } + + @Test + public void testDisabledWhenStrategyNotAlways() + { + // Rule should not fire when strategy is not ALWAYS + tester().assertThat(new RandomizeSourceKeyInSemiJoin(getMetadata().getFunctionAndTypeManager())) + .setSystemProperty(RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY, "DISABLED") + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, "PARTITIONED") + .on(p -> { + p.variable("left_k1", BIGINT); + p.variable("right_k1", BIGINT); + return p.semiJoin( + p.variable("left_k1"), + p.variable("right_k1"), + p.variable("semi_output", BOOLEAN), + Optional.empty(), + Optional.empty(), + p.values(p.variable("left_k1")), + p.values(p.variable("right_k1"))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java index 88b0c4a771f56..e133149c1713d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveMapCastRule.java @@ -72,4 +72,44 @@ public void testElementAtCast() ImmutableMap.of("a", expression("element_at(feature, try_cast(key as integer))")), values("feature", "key"))); } + + @Test + public void testMapSubSet() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("map_subset(cast(feature as map), array[key])")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("cast(map_subset(feature, array[try_cast(key as integer)]) as map)")), + values("feature", "key"))); + } + + @Test + public void testMapSubSetConstantArray() + { + tester().assertThat( + ImmutableSet.>builder().addAll(new SimplifyRowExpressions(getMetadata(), getExpressionManager()).rules()).addAll(new RemoveMapCastRule(getFunctionManager()).rules()).build()) + .setSystemProperty(REMOVE_MAP_CAST, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", DOUBLE); + VariableReferenceExpression feature = p.variable("feature", createMapType(getFunctionManager(), INTEGER, DOUBLE)); + VariableReferenceExpression key = p.variable("key", BIGINT); + return p.project( + assignment(a, p.rowExpression("map_subset(cast(feature as map), array[cast(1 as bigint)])")), + p.values(feature, key)); + }) + .matches( + project( + ImmutableMap.of("a", expression("cast(map_subset(feature, array[1]) as map)")), + values("feature", "key"))); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java new file mode 100644 index 0000000000000..86b6cc74e74af --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunctionProcessor.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveRedundantTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testRemoveTableFunction() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .matches(values("proper", "pass_through")); + } + + @Test + public void testDoNotRemoveKeepWhenEmpty() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .doesNotFire(); + } + + @Test + public void testDoNotRemoveNonEmptyInput() + { + tester().assertThat(new RemoveRedundantTableFunctionProcessor()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(5, passThrough))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java new file mode 100644 index 0000000000000..8cb9014afafa2 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReplaceConditionalApproxDistinct.java @@ -0,0 +1,215 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestReplaceConditionalApproxDistinct + extends BaseRuleTest +{ + @Test + public void testReplaceConditionalConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, 'constant')")), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", + functionCall("arbitrary", ImmutableList.of("expression"))), + SINGLE, + project( + ImmutableMap.of( + "original", expression("if(a > b, 'constant')"), + "expression", expression("if(a > b, 1, NULL)")), + values("a", "b"))))); + } + + @Test + public void testReplaceConditionalErrorBounds() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression bounds = p.variable("bounds", DOUBLE); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original, bounds)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, 'constant')"), + bounds, p.rowExpression("0.0040625", ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE)), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", + functionCall("arbitrary", ImmutableList.of("expression"))), + SINGLE, + project( + ImmutableMap.of( + "original", expression("if(a > b, 'constant')"), + "expression", expression("if(a > b, 1, NULL)")), + values("a", "b"))))); + } + + @Test + public void testReplaceMultipleConditionalConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original1 = p.variable("original1", BOOLEAN); + VariableReferenceExpression original2 = p.variable("original2", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output1"), + p.rowExpression("approx_distinct(original1)")) + .addAggregation( + p.variable("output2"), + p.rowExpression("approx_distinct(original2)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original1, p.rowExpression("if(a > b, 'constant')"), + original2, p.rowExpression("if(a < b, NULL, 'constant')")), + p.values(a, b)))); + }) + .matches( + project( + ImmutableMap.of( + "output1", expression("coalesce(intermediate1, 0)"), + "output2", expression("coalesce(intermediate2, 0)")), + aggregation( + ImmutableMap.of( + "intermediate1", functionCall("arbitrary", ImmutableList.of("expression1")), + "intermediate2", functionCall("arbitrary", ImmutableList.of("expression2"))), + SINGLE, + project( + ImmutableMap.of( + "original1", expression("if(a > b, 'constant')"), + "original2", expression("if(a < b, NULL, 'constant')"), + "expression1", expression("if(a > b, 1, NULL)"), + "expression2", expression("if(a < b, NULL, 1)")), + values("a", "b"))))); + } + + @Test + public void testDontReplaceConstant() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression input = p.variable("input", VARCHAR); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(input)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment(input, p.rowExpression("'constant'")), + p.values()))); + }).doesNotFire(); + } + + @Test + public void testDontReplaceVariable() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression input = p.variable("input", VARCHAR); + VariableReferenceExpression nonconstant = p.variable("nonconstant", VARCHAR); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(input)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment(input, p.rowExpression("nonconstant")), + p.values(nonconstant)))); + }).doesNotFire(); + } + + @Test + public void testDontReplaceConditionalVariable() + { + tester().assertThat(new ReplaceConditionalApproxDistinct(getFunctionManager())) + .on(p -> { + VariableReferenceExpression original = p.variable("original", BOOLEAN); + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression nonconstant = p.variable("nonconstant", BIGINT); + return p.aggregation((builder) -> builder + .addAggregation( + p.variable("output"), + p.rowExpression("approx_distinct(original)")) + .globalGrouping() + .step(AggregationNode.Step.SINGLE) + .source( + p.project( + p.assignment( + original, p.rowExpression("if(a > b, nonconstant)")), + p.values(a, b, nonconstant)))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java new file mode 100644 index 0000000000000..bfdcea0b8ca5f --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteExcludeColumnsFunctionToProjection.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.operator.table.ExcludeColumns.ExcludeColumnsFunctionHandle; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRewriteExcludeColumnsFunctionToProjection + extends BaseRuleTest +{ + @Test + public void rewriteExcludeColumnsFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression y = p.variable("y", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("exclude_columns") + .properOutputs(x, y) + .pruneWhenEmpty() + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .connectorHandle(new ExcludeColumnsFunctionHandle()) + .source(p.values(a, b, c))); + }) + .matches(PlanMatchPattern.strictProject( + ImmutableMap.of( + "x", expression("b"), + "y", expression("c")), + values("a", "b", "c"))); + } + + @Test + public void doNotRewriteOtherFunction() + { + tester().assertThat(new RewriteExcludeColumnsFunctionToProjection()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BOOLEAN); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression c = p.variable("c", SMALLINT); + return p.tableFunctionProcessor( + builder -> builder + .name("testing_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(b, c))) + .source(p.values(a, b, c))); + }).doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteRowExpressions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteRowExpressions.java new file mode 100644 index 0000000000000..bb29b7774a4ac --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRewriteRowExpressions.java @@ -0,0 +1,257 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.SystemSessionProperties.EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE; +import static com.facebook.presto.common.function.OperatorType.ADD; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.expressions.ExpressionOptimizerManager.DEFAULT_EXPRESSION_OPTIMIZER_NAME; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestRewriteRowExpressions + extends BaseRuleTest +{ + private static final MetadataManager METADATA = createTestMetadataManager(); + + private FunctionAndTypeManager functionAndTypeManager; + private ExpressionOptimizerManager expressionOptimizerManager; + private RewriteRowExpressions optimizer; + private RuleTester ruleTesterWithOptimizer; + + private static RowExpression ifExpression(RowExpression condition, long trueValue, long falseValue) + { + return new SpecialFormExpression(IF, BIGINT, ImmutableList.of(condition, constant(trueValue, BIGINT), constant(falseValue, BIGINT))); + } + + @BeforeClass + @Override + public void setUp() + { + super.setUp(); + functionAndTypeManager = createTestFunctionAndTypeManager(); + InMemoryNodeManager nodeManager = new InMemoryNodeManager(); + expressionOptimizerManager = new ExpressionOptimizerManager( + new PluginNodeManager(nodeManager), + METADATA.getFunctionAndTypeManager(), + new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); + optimizer = new RewriteRowExpressions(expressionOptimizerManager); + + // Create a RuleTester with the session property enabled + ruleTesterWithOptimizer = new RuleTester( + ImmutableList.of(), + ImmutableMap.of(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, DEFAULT_EXPRESSION_OPTIMIZER_NAME), + Optional.empty()); + } + + @AfterClass(alwaysRun = true) + public void cleanUp() + { + expressionOptimizerManager = null; + optimizer = null; + if (ruleTesterWithOptimizer != null) { + ruleTesterWithOptimizer.close(); + ruleTesterWithOptimizer = null; + } + } + + @Test + public void testIsRewriterEnabledWithEmptyOptimizerName() + { + Session sessionWithEmptyOptimizer = testSessionBuilder() + .setSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, "") + .build(); + assertFalse(optimizer.isRewriterEnabled(sessionWithEmptyOptimizer)); + } + + @Test + public void testIsRewriterEnabledWithValidOptimizerName() + { + Session sessionWithOptimizer = testSessionBuilder() + .setSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, DEFAULT_EXPRESSION_OPTIMIZER_NAME) + .build(); + assertTrue(optimizer.isRewriterEnabled(sessionWithOptimizer)); + } + + @Test + public void testRewriteWithDefaultOptimizer() + { + Session session = testSessionBuilder() + .setSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, DEFAULT_EXPRESSION_OPTIMIZER_NAME) + .build(); + + RowExpression expression = constant(1L, BIGINT); + RowExpression rewritten = RewriteRowExpressions.rewrite(expression, session, expressionOptimizerManager); + assertEquals(rewritten, expression); + } + + @Test + public void testRewriteIfConstantOptimization() + { + Session session = testSessionBuilder() + .setSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, DEFAULT_EXPRESSION_OPTIMIZER_NAME) + .build(); + + RowExpression ifExpressionTrue = ifExpression(constant(true, BOOLEAN), 1L, 2L); + RowExpression rewrittenTrue = RewriteRowExpressions.rewrite(ifExpressionTrue, session, expressionOptimizerManager); + assertEquals(rewrittenTrue, constant(1L, BIGINT)); + + RowExpression ifExpressionFalse = ifExpression(constant(false, BOOLEAN), 1L, 2L); + RowExpression rewrittenFalse = RewriteRowExpressions.rewrite(ifExpressionFalse, session, expressionOptimizerManager); + assertEquals(rewrittenFalse, constant(2L, BIGINT)); + + RowExpression ifExpressionNull = ifExpression(constant(null, BOOLEAN), 1L, 2L); + RowExpression rewrittenNull = RewriteRowExpressions.rewrite(ifExpressionNull, session, expressionOptimizerManager); + assertEquals(rewrittenNull, constant(2L, BIGINT)); + } + + @Test + public void testRewriteWithArithmeticExpression() + { + Session session = testSessionBuilder() + .setSystemProperty(EXPRESSION_OPTIMIZER_IN_ROW_EXPRESSION_REWRITE, DEFAULT_EXPRESSION_OPTIMIZER_NAME) + .build(); + + FunctionHandle addHandle = functionAndTypeManager.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)); + RowExpression addExpression = new com.facebook.presto.spi.relation.CallExpression( + ADD.name(), + addHandle, + BIGINT, + ImmutableList.of(constant(1L, BIGINT), constant(2L, BIGINT))); + + RowExpression rewritten = RewriteRowExpressions.rewrite(addExpression, session, expressionOptimizerManager); + assertEquals(rewritten, constant(3L, BIGINT)); + } + + @Test + public void testFilterRuleDoesNotFireWhenDisabled() + { + tester().assertThat(optimizer.filterRowExpressionRewriteRule()) + .on(p -> p.filter(p.rowExpression("1 + 1 = 2"), p.values())) + .doesNotFire(); + } + + @Test + public void testProjectRuleDoesNotFireWhenDisabled() + { + tester().assertThat(optimizer.projectRowExpressionRewriteRule()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + return p.project( + assignment(a, p.rowExpression("1 + 2")), + p.values()); + }) + .doesNotFire(); + } + + @Test + public void testFilterRuleRewritesConstantExpression() + { + ruleTesterWithOptimizer.assertThat( + new RewriteRowExpressions(ruleTesterWithOptimizer.getExpressionManager()).filterRowExpressionRewriteRule()) + .on(p -> p.filter(p.rowExpression("1 + 1 = 2"), p.values())) + .matches( + filter("true", + values())); + } + + @Test + public void testProjectRuleRewritesConstantArithmetic() + { + ruleTesterWithOptimizer.assertThat( + new RewriteRowExpressions(ruleTesterWithOptimizer.getExpressionManager()).projectRowExpressionRewriteRule()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + return p.project( + assignment(a, p.rowExpression("1 + 2")), + p.values()); + }) + .matches( + project(ImmutableMap.of("a", expression("BIGINT'3'")), + values())); + } + + @Test + public void testRuleSetReturnsAllRules() + { + Set> rules = optimizer.rules(); + assertNotNull(rules); + // RowExpressionRewriteRuleSet returns 10 rules + assertEquals(rules.size(), 10); + } + + @Test + public void testFilterRuleRewritesIfExpression() + { + ruleTesterWithOptimizer.assertThat( + new RewriteRowExpressions(ruleTesterWithOptimizer.getExpressionManager()).filterRowExpressionRewriteRule()) + .on(p -> p.filter(p.rowExpression("IF(true, true, false)"), p.values())) + .matches( + filter("true", + values())); + } + + @Test + public void testProjectRuleRewritesIfExpression() + { + ruleTesterWithOptimizer.assertThat( + new RewriteRowExpressions(ruleTesterWithOptimizer.getExpressionManager()).projectRowExpressionRewriteRule()) + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + return p.project( + assignment(a, p.rowExpression("IF(true, BIGINT '1', BIGINT '2')")), + p.values()); + }) + .matches( + project(ImmutableMap.of("a", expression("BIGINT'1'")), + values())); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java index 7817b311053df..78ff684386afc 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyRowExpressions.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.tree.Expression; @@ -38,6 +39,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; @@ -185,7 +187,7 @@ private static void assertSimplifies(String expression, String rowExpressionExpe Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression)); InMemoryNodeManager nodeManager = new InMemoryNodeManager(); - ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager()); + ExpressionOptimizerManager expressionOptimizerManager = new ExpressionOptimizerManager(new PluginNodeManager(nodeManager), METADATA.getFunctionAndTypeManager(), new JsonCodecRowExpressionSerde(jsonCodec(RowExpression.class))); TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(METADATA); RowExpression actualRowExpression = translator.translate(actualExpression, TypeProvider.viewOf(TYPES)); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java new file mode 100644 index 0000000000000..b6fb48a904e81 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestTransformTableFunctionToTableFunctionProcessor.java @@ -0,0 +1,1404 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; + +public class TestTransformTableFunctionToTableFunctionProcessor + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.variable("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(true, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + VariableReferenceExpression h = p.variable("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(h, DESC_NULLS_FIRST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"), ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"), ImmutableList.of("h"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = input_3_row_number OR " + + "(combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST)) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR " + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM c) " + + "AND (" + + " input_2_row_number = input_1_row_number OR" + + " (input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR" + + " input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d)" + + " AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (combined_partition_column_1_2 IS DISTINCT FROM e) " + + "AND (" + + " combined_row_number_1_2 = input_3_row_number OR" + + " (combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR" + + " input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1'))"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(g, DESC_NULLS_FIRST)))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("g"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, null)"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + JoinType.LEFT, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = combined_row_number_3_4 OR " + + "(combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR " + + "combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (e IS DISTINCT FROM f) " + + "AND ( " + + "input_3_row_number = input_4_row_number OR " + + "(input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR " + + "input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))), + window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + window(builder -> builder + .specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST)) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())), + values("f", "g")))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + JoinType.INNER, + ImmutableList.of(), + Optional.of("combined_row_number_2_3 = input_1_row_number OR " + + "(combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR " + + "input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM e) " + + "AND ( " + + "input_2_row_number = input_3_row_number OR " + + "(input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c", TINYINT); + VariableReferenceExpression cCoerced = p.variable("c_coerced", INTEGER); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e", INTEGER); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, p.rowExpression("c")) + .put(d, p.rowExpression("d")) + .put(cCoerced, p.rowExpression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c_coerced IS DISTINCT FROM e) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND NOT (d IS DISTINCT FROM f) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new TransformTableFunctionToTableFunctionProcessor(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index ba17757136a3e..bd8f1e9f113c8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -14,19 +14,23 @@ package com.facebook.presto.sql.planner.iterative.rule.test; import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.IndexHandle; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.connector.RowChangeParadigm; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; @@ -38,6 +42,7 @@ import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.IndexJoinNode; import com.facebook.presto.spi.plan.IndexSourceNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinDistributionType; @@ -45,6 +50,7 @@ import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MaterializedViewScanNode; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; @@ -60,8 +66,11 @@ import com.facebook.presto.spi.plan.TableFinishNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.plan.TableWriterNode; +import com.facebook.presto.spi.plan.TableWriterNode.MergeParadigmAndTypes; +import com.facebook.presto.spi.plan.TableWriterNode.MergeTarget; import com.facebook.presto.spi.plan.TopNNode; import com.facebook.presto.spi.plan.UnionNode; +import com.facebook.presto.spi.plan.UnnestNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.CallExpression; @@ -79,13 +88,14 @@ import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.GroupIdNode; -import com.facebook.presto.sql.planner.plan.IndexJoinNode; import com.facebook.presto.sql.planner.plan.LateralJoinNode; +import com.facebook.presto.sql.planner.plan.MergeWriterNode; import com.facebook.presto.sql.planner.plan.OffsetNode; import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; -import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; @@ -111,6 +121,7 @@ import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.spi.plan.ExchangeEncoding.COLUMNAR; @@ -299,6 +310,11 @@ public ProjectNode project(PlanNode source, Assignments assignments) return new ProjectNode(idAllocator.getNextId(), source, assignments); } + public ProjectNode project(PlanNode source, Assignments assignments, ProjectNode.Locality locality) + { + return new ProjectNode(Optional.empty(), idAllocator.getNextId(), source, assignments, locality); + } + public ProjectNode project(Assignments assignments, PlanNode source) { return new ProjectNode(idAllocator.getNextId(), source, assignments); @@ -590,7 +606,7 @@ public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode del deleteSource.getSourceLocation(), idAllocator.getNextId(), deleteSource, - deleteRowId, + Optional.of(deleteRowId), ImmutableList.of(deleteRowId), Optional.empty())) .addInputsSet(deleteRowId) @@ -601,6 +617,34 @@ public TableFinishNode tableDelete(SchemaTableName schemaTableName, PlanNode del Optional.empty(), Optional.empty()); } + public MergeWriterNode merge( + SchemaTableName schemaTableName, + PlanNode mergeSource, + List inputSymbols, + List outputSymbols) + { + return new MergeWriterNode( + mergeSource.getSourceLocation(), + idAllocator.getNextId(), + mergeSource, + mergeTarget(schemaTableName), + inputSymbols, + outputSymbols); + } + + private MergeTarget mergeTarget(SchemaTableName schemaTableName) + { + return new MergeTarget( + new TableHandle( + new ConnectorId("testConnector"), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()), + Optional.empty(), + schemaTableName, + new MergeParadigmAndTypes(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW, ImmutableList.of(), INTEGER)); + } + public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) { return exchange(builder -> builder.type(ExchangeNode.Type.GATHER) @@ -845,7 +889,16 @@ public JoinNode join( return new JoinNode(Optional.empty(), idAllocator.getNextId(), type, left, right, criteria, outputVariables, filter, leftHashVariable, rightHashVariable, distributionType, dynamicFilters); } - public PlanNode indexJoin(JoinType type, TableScanNode probe, TableScanNode index) + public PlanNode indexJoin(JoinType type, PlanNode probe, PlanNode index) + { + return indexJoin(type, probe, index, emptyList(), Optional.empty()); + } + + public PlanNode indexJoin(JoinType type, + PlanNode probe, + PlanNode index, + List criteria, + Optional filter) { return new IndexJoinNode( Optional.empty(), @@ -853,10 +906,11 @@ public PlanNode indexJoin(JoinType type, TableScanNode probe, TableScanNode inde type, probe, index, - emptyList(), + criteria, + filter, Optional.empty(), Optional.empty(), - Optional.empty()); + index.getOutputVariables()); } public CteProducerNode cteProducerNode(String ctename, @@ -961,6 +1015,32 @@ public WindowNode window(DataOrganizationSpecification specification, Map properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(new ConnectorId("connector_id"), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + + public TableFunctionProcessorNode tableFunctionProcessor(Consumer consumer) + { + TableFunctionProcessorBuilder tableFunctionProcessorBuilder = new TableFunctionProcessorBuilder(); + consumer.accept(tableFunctionProcessorBuilder); + return tableFunctionProcessorBuilder.build(idAllocator); + } + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, VariableReferenceExpression rowNumberVariable, PlanNode source) { return new RowNumberNode( @@ -1078,4 +1158,23 @@ public GroupIdNode groupId(List> groupingSets, aggregationArguments, groupIdSymbol); } + + public MaterializedViewScanNode materializedViewScan( + QualifiedObjectName materializedViewName, + PlanNode dataTablePlan, + PlanNode viewQueryPlan, + Map dataTableMappings, + Map viewQueryMappings, + VariableReferenceExpression... outputVariables) + { + return new MaterializedViewScanNode( + Optional.empty(), + idAllocator.getNextId(), + dataTablePlan, + viewQueryPlan, + materializedViewName, + dataTableMappings, + viewQueryMappings, + ImmutableList.copyOf(outputVariables)); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java new file mode 100644 index 0000000000000..404831b10f0ef --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class TableFunctionProcessorBuilder +{ + private String name; + private List properOutputs = ImmutableList.of(); + private Optional source = Optional.empty(); + private boolean pruneWhenEmpty; + private List passThroughSpecifications = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional specification = Optional.empty(); + private Set prePartitioned = ImmutableSet.of(); + private int preSorted; + private Optional hashSymbol = Optional.empty(); + private ConnectorTableFunctionHandle connectorHandle = new ConnectorTableFunctionHandle() {}; + + public TableFunctionProcessorBuilder() {} + + public TableFunctionProcessorBuilder name(String name) + { + this.name = name; + return this; + } + + public TableFunctionProcessorBuilder properOutputs(VariableReferenceExpression... properOutputs) + { + this.properOutputs = ImmutableList.copyOf(properOutputs); + return this; + } + + public TableFunctionProcessorBuilder source(PlanNode source) + { + this.source = Optional.of(source); + return this; + } + + public TableFunctionProcessorBuilder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public TableFunctionProcessorBuilder passThroughSpecifications(PassThroughSpecification... passThroughSpecifications) + { + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + return this; + } + + public TableFunctionProcessorBuilder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public TableFunctionProcessorBuilder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public TableFunctionProcessorBuilder specification(DataOrganizationSpecification specification) + { + this.specification = Optional.of(specification); + return this; + } + + public TableFunctionProcessorBuilder prePartitioned(Set prePartitioned) + { + this.prePartitioned = prePartitioned; + return this; + } + + public TableFunctionProcessorBuilder preSorted(int preSorted) + { + this.preSorted = preSorted; + return this; + } + + public TableFunctionProcessorBuilder hashSymbol(VariableReferenceExpression hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public TableFunctionProcessorBuilder connectorHandle(ConnectorTableFunctionHandle connectorHandle) + { + this.connectorHandle = connectorHandle; + return this; + } + + public TableFunctionProcessorNode build(PlanNodeIdAllocator idAllocator) + { + return new TableFunctionProcessorNode( + idAllocator.getNextId(), + name, + properOutputs, + source, + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + markerSymbols, + specification, + prePartitioned, + preSorted, + hashSymbol, + new TableFunctionHandle(new ConnectorId("connector_id"), connectorHandle, TestingTransactionHandle.create())); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlans.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlans.java index aca3982dfad43..edb1b3bdcca3f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlans.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlans.java @@ -38,6 +38,7 @@ import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.SystemSessionProperties.PARTITIONING_PRECISION_STRATEGY; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; +import static com.facebook.presto.SystemSessionProperties.TABLE_SCAN_SHUFFLE_STRATEGY; import static com.facebook.presto.SystemSessionProperties.TASK_CONCURRENCY; import static com.facebook.presto.SystemSessionProperties.USE_STREAMING_EXCHANGE_FOR_MARK_DISTINCT; import static com.facebook.presto.execution.QueryManagerConfig.ExchangeMaterializationStrategy.ALL; @@ -383,8 +384,8 @@ public void testJoinExactlyPartitioned() " AND orders.orderstatus = t.orderstatus", anyTree( join(INNER, ImmutableList.of( - equiJoinClause("ORDERKEY_LEFT", "ORDERKEY_RIGHT"), - equiJoinClause("orderstatus", "ORDERSTATUS_RIGHT")), + equiJoinClause("ORDERKEY_LEFT", "ORDERKEY_RIGHT"), + equiJoinClause("orderstatus", "ORDERSTATUS_RIGHT")), exchange(REMOTE_STREAMING, REPARTITION, anyTree( aggregation( @@ -522,4 +523,22 @@ void assertExactDistributedPlan(String sql, PlanMatchPattern pattern) .build(), pattern); } + + @Test + public void testShuffleAboveTableScanAlwaysEnabled() + { + Session session = TestingSession.testSessionBuilder().setCatalog("local").setSchema("tiny").setSystemProperty(TABLE_SCAN_SHUFFLE_STRATEGY, "ALWAYS_ENABLED").build(); + + // When ALWAYS_ENABLED, a round robin exchange should be added above the table scan + assertDistributedPlan("SELECT nationkey FROM nation", session, anyTree(exchange(REMOTE_STREAMING, ExchangeNode.Type.GATHER, exchange(REMOTE_STREAMING, ExchangeNode.Type.REPARTITION, tableScan("nation"))))); + } + + @Test + public void testShuffleAboveTableScanDisabled() + { + Session session = TestingSession.testSessionBuilder().setCatalog("local").setSchema("tiny").setSystemProperty(TABLE_SCAN_SHUFFLE_STRATEGY, "DISABLED").build(); + + // When DISABLED, no extra round robin exchange should be added + assertDistributedPlan("SELECT nationkey FROM nation", session, anyTree(exchange(REMOTE_STREAMING, ExchangeNode.Type.GATHER, tableScan("nation")))); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java new file mode 100644 index 0000000000000..c8c59bd9ebc28 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java @@ -0,0 +1,894 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.scalar.CombineHashFunction; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerContext; +import com.facebook.presto.spi.sql.planner.ExpressionOptimizerFactory; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.facebook.presto.type.BigintOperators; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.facebook.presto.SystemSessionProperties.REMOTE_FUNCTIONS_ENABLED; +import static com.facebook.presto.SystemSessionProperties.REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM; +import static com.facebook.presto.SystemSessionProperties.SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser.parseFunctionDefinitions; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.PYTHON; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +/** + * These are plan tests similar to what we have for other optimizers (e.g. {@link com.facebook.presto.sql.planner.TestPredicatePushdown}) + * They test that the plan for a query after the optimizer runs is as expected. + * These are separate from {@link TestAddExchanges} because those are unit tests for + * how layouts get chosen. + *

    + * Key behavior tested: When CPP functions are used with system tables, the filter containing + * the CPP function is preserved above the exchange (not pushed down) to ensure the filter + * executes in a different fragment from the system table scan. This validates the fragment + * boundary between CPP function evaluation and system table access. + */ +public class TestAddExchangesPlansWithFunctions + extends BasePlanTest +{ + private static final String NO_OP_OPTIMIZER = "no-op-optimizer"; + + public TestAddExchangesPlansWithFunctions() + { + super(TestAddExchangesPlansWithFunctions::createTestQueryRunner); + } + + private static final SqlInvokedFunction CPP_FOO = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_foo"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "cpp_foo(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction CPP_BAZ = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_baz"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "cpp_baz(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction JAVA_BAR = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_bar"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "java_bar(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction JAVA_FEE = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_fee"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "java_fee(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction NOT = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "not"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BOOLEAN))), + parseTypeSignature(StandardTypes.BOOLEAN), + "not(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction CPP_ARRAY_CONSTRUCTOR = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "array_constructor"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT)), new Parameter("y", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature("array(bigint)"), + "array_constructor(x, y)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + // External/Remote functions using PYTHON language mapped to THRIFT implementation type for testing REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM + private static final SqlInvokedFunction REMOTE_FOO = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "remote_foo"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "remote_foo(x)", + RoutineCharacteristics.builder().setLanguage(PYTHON).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction REMOTE_BAR = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "remote_bar"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "remote_bar(x)", + RoutineCharacteristics.builder().setLanguage(PYTHON).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction REMOTE_BAZ = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "remote_baz"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "remote_baz(x)", + RoutineCharacteristics.builder().setLanguage(PYTHON).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static LocalQueryRunner createTestQueryRunner() + { + LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty("expression_optimizer_name", NO_OP_OPTIMIZER) + .build(), + new FeaturesConfig(), + new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest")); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(), ImmutableMap.of()); + queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace( + "dummy", + new InMemoryFunctionNamespaceManager( + "dummy", + new SqlFunctionExecutors( + ImmutableMap.of( + CPP, FunctionImplementationType.CPP, + JAVA, FunctionImplementationType.JAVA, + PYTHON, FunctionImplementationType.THRIFT), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("cpp,python"))); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FOO, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_BAZ, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_BAR, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FEE, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(NOT, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_ARRAY_CONSTRUCTOR, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(REMOTE_FOO, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(REMOTE_BAR, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(REMOTE_BAZ, true); + parseFunctionDefinitions(BigintOperators.class).stream() + .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction) + .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true)); + parseFunctionDefinitions(CombineHashFunction.class).stream() + .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction) + .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true)); + queryRunner.getExpressionManager().addExpressionOptimizerFactory(new NoOpExpressionOptimizerFactory()); + queryRunner.getExpressionManager().loadExpressionOptimizerFactory(NO_OP_OPTIMIZER, NO_OP_OPTIMIZER, ImmutableMap.of()); + return queryRunner; + } + + public static SqlInvokedFunction convertToSqlInvokedFunction(SqlScalarFunction scalarFunction) + { + QualifiedObjectName functionName = new QualifiedObjectName("dummy", "unittest", scalarFunction.getSignature().getName().getObjectName()); + TypeSignature returnType = scalarFunction.getSignature().getReturnType(); + RoutineCharacteristics characteristics = RoutineCharacteristics.builder() + .setLanguage(RoutineCharacteristics.Language.JAVA) // Assuming JAVA as the language + .setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC) + .setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT) + .build(); + + // Convert scalar function arguments to SqlInvokedFunction parameters + ImmutableList parameters = scalarFunction.getSignature().getArgumentTypes().stream() + .map(type -> new Parameter(type.toString(), TypeSignature.parseTypeSignature(type.toString()))) + .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf)); + + // Create the SqlInvokedFunction + return new SqlInvokedFunction( + functionName, + parameters, + returnType, + scalarFunction.getSignature().getName().toString(), // Using the function name as the body for simplicity + characteristics, + "", // Empty description + notVersioned()); + } + + @Test + public void testFilterWithCppFunctionDoesNotGetPushedIntoSystemTableScan() + { + // java_fee and java_bar are java functions, they are both pushed down past the exchange + assertNativeDistributedPlan("SELECT java_fee(ordinal_position) FROM information_schema.columns WHERE java_bar(ordinal_position) = 1", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("java_fee", expression("java_fee(ordinal_position)")), + filter("java_bar(ordinal_position) = BIGINT'1'", + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))); + // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan + // The filter is preserved above the exchange to prove that the filter is not in the same fragment as the system table scan + assertNativeDistributedPlan("SELECT cpp_baz(ordinal_position) FROM information_schema.columns WHERE cpp_foo(ordinal_position) = 1", + anyTree( + project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")), + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))); + } + + @Test + public void testJoinWithCppFunctionDoesNotGetPushedIntoSystemTableScan() + { + // java_bar is a java function, it is pushed down past the exchange + assertNativeDistributedPlan( + "SELECT c1.table_name FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE java_bar(c1.ordinal_position) = 1", + anyTree( + exchange( + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project( + filter("java_bar(ordinal_position) = BIGINT'1'", + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project( + filter("java_bar(ordinal_position_4) = BIGINT'1'", + tableScan("columns", ImmutableMap.of( + "ordinal_position_4", "ordinal_position")))))))))); + + // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan + assertNativeDistributedPlan( + "SELECT cpp_baz(c1.ordinal_position) FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE cpp_foo(c1.ordinal_position) = 1", + output( + exchange( + project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")), + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")), + anyTree( + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project( + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position")))))), + anyTree( + filter("cpp_foo(ordinal_position_4) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project( + tableScan("columns", ImmutableMap.of( + "ordinal_position_4", "ordinal_position"))))))))))); + } + + @Test + public void testMixedFunctionTypesInComplexPredicates() + { + // Test AND condition with mixed Java and CPP functions + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 AND cpp_foo(ordinal_position) > 0", + anyTree( + filter("java_bar(ordinal_position) = BIGINT'1' AND cpp_foo(ordinal_position) > BIGINT'0'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Test OR condition with mixed functions - entire predicate should be evaluated after exchange + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 OR cpp_foo(ordinal_position) = 2", + anyTree( + filter("java_bar(ordinal_position) = BIGINT'1' OR cpp_foo(ordinal_position) = BIGINT'2'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testNestedFunctionCalls() + { + // CPP function nested inside Java function - should not push down + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(cpp_foo(ordinal_position)) = 1", + anyTree( + filter("java_bar(cpp_foo(ordinal_position)) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Java function nested inside CPP function - should not push down + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(ordinal_position)) = 1", + anyTree( + filter("cpp_foo(java_bar(ordinal_position)) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Multiple levels of nesting + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(cpp_foo(ordinal_position))) = 1", + anyTree( + filter("cpp_foo(java_bar(cpp_foo(ordinal_position))) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testMixedSystemAndRegularTables() + { + // System table with CPP function joined with regular table + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c JOIN nation n ON c.ordinal_position = n.nationkey WHERE cpp_foo(c.ordinal_position) = 1", + output( + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "nationkey")), + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("ordinal_position", expression("ordinal_position")), + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))), + anyTree( + project(ImmutableMap.of("nationkey", expression("nationkey")), + filter( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))))); + + // Regular table with CPP function (should work normally without extra exchange) + assertNativeDistributedPlan( + "SELECT * FROM nation WHERE cpp_foo(nationkey) = 1", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + filter("cpp_foo(nationkey) = BIGINT'1'", + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testAggregationsWithMixedFunctions() + { + // Aggregation with CPP function in GROUP BY + assertNativeDistributedPlan( + "SELECT DISTINCT cpp_foo(ordinal_position) FROM information_schema.columns", + anyTree( + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Aggregation with Java function in GROUP BY - can be pushed down + assertNativeDistributedPlan( + "SELECT DISTINCT java_bar(ordinal_position) FROM information_schema.columns", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("java_bar", expression("java_bar(ordinal_position)")), + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testComplexPredicateWithMultipleFunctions() + { + // Complex predicate with multiple CPP and Java functions + // Since the predicate contains CPP functions (cpp_foo, baz), the exchange is inserted before the system table scan + // The RemoveRedundantExchanges rule removes the inner exchange that was added by ExtractIneligiblePredicatesFromSystemTableScans + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE (cpp_foo(ordinal_position) > 0 AND java_bar(ordinal_position) < 100) OR cpp_baz(ordinal_position) = 50", + anyTree( + filter( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testProjectionWithMixedFunctions() + { + // Projection with both Java and CPP functions + assertNativeDistributedPlan( + "SELECT java_bar(ordinal_position) as java_result, cpp_foo(ordinal_position) as cpp_result FROM information_schema.columns", + anyTree( + project(ImmutableMap.of( + "java_result", expression("java_bar(ordinal_position)"), + "cpp_result", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testCaseStatementsWithCppFunctions() + { + // CASE statement with CPP function in condition + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CASE WHEN cpp_foo(ordinal_position) > 0 THEN 'positive' ELSE 'negative' END FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // CASE statement with CPP function in result + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CASE WHEN ordinal_position > 0 THEN cpp_foo(ordinal_position) ELSE 0 END FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testBuiltinFunctionWithExplicitNamespace() + { + // Test that built-in functions with explicit namespace are handled correctly + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE presto.default.length(table_name) > 10", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + filter("length(table_name) > BIGINT'10'", + tableScan("columns", ImmutableMap.of("table_name", "table_name")))))); + } + + @Test(enabled = false) // TODO: Window functions are resolved with namespace which causes issues in tests + public void testWindowFunctionsWithCppFunctions() + { + // Window function with CPP function in partition by + assertNativeDistributedPlan( + "SELECT row_number() OVER (PARTITION BY cpp_foo(ordinal_position)) FROM information_schema.columns", + anyTree( + exchange( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))))); + + // Window function with CPP function in order by + assertNativeDistributedPlan( + "SELECT row_number() OVER (ORDER BY cpp_foo(ordinal_position)) FROM information_schema.columns", + anyTree( + exchange( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))))); + } + + @Test + public void testMultipleSystemTableJoins() + { + // Multiple system tables with CPP functions + // This test verifies that when joining two system tables with a CPP function comparison, + // an exchange is added between the table scan and the join to ensure CPP functions + // execute in a separate fragment from system table access + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c1 " + + "JOIN information_schema.columns c2 ON cpp_foo(c1.ordinal_position) = cpp_foo(c2.ordinal_position)", + anyTree( + exchange( + join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "foo_4")), + exchange( + project(ImmutableMap.of("cpp_foo", expression("cpp_foo")), + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))), + anyTree( + exchange( + project(ImmutableMap.of("foo_4", expression("foo_4")), + project(ImmutableMap.of("foo_4", expression("cpp_foo(ordinal_position_4)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position_4", "ordinal_position"))))))))))); + } + + @Test + public void testInPredicateWithCppFunction() + { + // IN predicate with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IN (1, 2, 3)", + anyTree( + filter("cpp_foo(ordinal_position) IN (BIGINT'1', BIGINT'2', BIGINT'3')", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testBetweenPredicateWithCppFunction() + { + // BETWEEN predicate with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) BETWEEN 1 AND 10", + anyTree( + filter("cpp_foo(ordinal_position) BETWEEN BIGINT'1' AND BIGINT'10'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testNullHandlingWithCppFunctions() + { + // IS NULL check with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NULL", + anyTree( + filter("cpp_foo(ordinal_position) IS NULL", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // COALESCE with CPP function + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT COALESCE(cpp_foo(ordinal_position), 0) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testUnionWithCppFunctions() + { + // UNION ALL with CPP functions from system tables + assertNativeDistributedPlan( + "SELECT cpp_foo(ordinal_position) FROM information_schema.columns " + + "UNION ALL SELECT cpp_foo(nationkey) FROM nation", + output( + exchange( + anyTree( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))), + anyTree( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testExistsSubqueryWithCppFunction() + { + // EXISTS subquery with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c WHERE EXISTS (SELECT 1 FROM nation n WHERE cpp_foo(c.ordinal_position) = n.nationkey)", + anyTree( + join( + anyTree( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))), + anyTree( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testLimitWithCppFunction() + { + // LIMIT with CPP function in ORDER BY + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns ORDER BY cpp_foo(ordinal_position) LIMIT 10", + output( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))))); + } + + @Test + public void testCastOperationsWithCppFunctions() + { + // CAST operations with CPP functions + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CAST(cpp_foo(ordinal_position) AS VARCHAR) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testArrayConstructorWithCppFunction() + { + // Array constructor with CPP function + assertNativeDistributedPlan( + "SELECT ARRAY[cpp_foo(ordinal_position), cpp_baz(ordinal_position)] FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testRowConstructorWithCppFunction() + { + // ROW constructor with CPP function + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT ROW(cpp_foo(ordinal_position), table_name) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))); + } + + @Test + public void testIsNotNullWithCppFunction() + { + // IS NOT NULL check with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NOT NULL", + anyTree( + filter("cpp_foo(ordinal_position) IS NOT NULL", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testComplexJoinWithMultipleCppFunctions() + { + // Complex join with multiple CPP functions in different positions + // The filters are pushed into FilterProject nodes and the join happens on the expression cpp_foo(c1.ordinal_position) + assertNativeDistributedPlan( + "SELECT c1.table_name, n.name FROM information_schema.columns c1 " + + "JOIN nation n ON cpp_foo(c1.ordinal_position) = n.nationkey " + + "WHERE cpp_baz(c1.ordinal_position) > 0 AND cpp_foo(n.nationkey) < 100", + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "nationkey")), + project(ImmutableMap.of("table_name", expression("table_name"), "cpp_foo", expression("cpp_foo")), + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + filter("cpp_baz(ordinal_position) > BIGINT'0' AND cpp_foo(cpp_foo(ordinal_position)) < BIGINT'100'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))), + anyTree( + project(ImmutableMap.of("nationkey", expression("nationkey"), + "name", expression("name")), + filter("cpp_foo(nationkey) < BIGINT'100'", + tableScan("nation", ImmutableMap.of( + "nationkey", "nationkey", + "name", "name")))))))); + } + + @Test + public void testSystemTableFilterWithOutputVariableMismatch() + { + assertNativeDistributedPlan( + "SELECT table_name FROM information_schema.columns WHERE cpp_foo(ordinal_position) > 5", + output( + project(ImmutableMap.of("table_name", expression("table_name")), + filter("cpp_foo(ordinal_position) > BIGINT'5'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name"))))))); + } + + @Test + public void testSystemTableFilterWithMultipleColumnsAndPartialSelection() + { + assertNativeDistributedPlan( + "SELECT table_schema, table_name FROM information_schema.columns " + + "WHERE cpp_foo(ordinal_position) > 0 AND cpp_baz(ordinal_position) < 100", + output( + project(ImmutableMap.of("table_schema", expression("table_schema"), + "table_name", expression("table_name")), + filter("cpp_foo(ordinal_position) > BIGINT'0' AND cpp_baz(ordinal_position) < BIGINT'100'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_schema", "table_schema", + "table_name", "table_name"))))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithExactMatch() + { + // Test that REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM with exact function name + // causes round-robin exchanges to be added before and after the remote project + // Note: The function must be an external function (isExternalExecution() = true) for this feature to work + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "dummy.unittest.remote_foo") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, REPARTITION, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithRegexWildcard() + { + // Test that regex pattern with wildcard matches function names + // remote_foo matches the pattern "remote_.*" + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "dummy.unittest.remote_.*") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, REPARTITION, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithNonMatchingRegex() + { + // Test that when the regex doesn't match the function name, + // no extra round-robin exchanges are added + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "nonmatching_.*") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithEmptyString() + { + // Test that empty string means the feature is disabled (no extra exchanges) + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithMultipleFunctions() + { + // Test regex that matches multiple function names using OR pattern + // Both remote_foo and remote_baz should trigger the exchange insertion + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey), remote_baz(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "dummy.unittest.remote_(foo|baz)") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, REPARTITION, + project(ImmutableMap.of( + "remote_foo", expression("remote_foo(nationkey)"), + "remote_baz", expression("remote_baz(nationkey)")), + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithPartialMatch() + { + // Test that regex requires full match (not partial) by using anchored pattern + // "remote_f" should NOT match "remote_foo" because matches() requires full string match + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "dummy.unittest.remote_f") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testRemoteFunctionNamesForFixedParallelismWithComplexRegex() + { + // Test complex regex pattern with character classes + assertNativeDistributedPlanWithSession( + "SELECT remote_foo(nationkey) FROM nation", + testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .setSystemProperty(REMOTE_FUNCTION_NAMES_FOR_FIXED_PARALLELISM, "dummy.unittest.remote_[a-z]+") + .setSystemProperty(REMOTE_FUNCTIONS_ENABLED, "true") + .setSystemProperty(SKIP_PUSHDOWN_THROUGH_EXCHANGE_FOR_REMOTE_PROJECTION, "true") + .build(), + anyTree( + exchange(REMOTE_STREAMING, REPARTITION, + project(ImmutableMap.of("remote_foo", expression("remote_foo(nationkey)")), + exchange(REMOTE_STREAMING, REPARTITION, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } + + private void assertNativeDistributedPlanWithSession(String sql, com.facebook.presto.Session session, PlanMatchPattern pattern) + { + assertDistributedPlan(sql, session, pattern); + } + + private static class NoOpExpressionOptimizerFactory + implements ExpressionOptimizerFactory + { + @Override + public ExpressionOptimizer createOptimizer(Map config, ExpressionOptimizerContext context) + { + return new NoOpExpressionOptimizer(); + } + + @Override + public String getName() + { + return NO_OP_OPTIMIZER; + } + } + + private static class NoOpExpressionOptimizer + implements ExpressionOptimizer + { + @Override + public RowExpression optimize(RowExpression expression, Level level, ConnectorSession session, Function variableResolver) + { + return expression; + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java new file mode 100644 index 0000000000000..c978a895ada42 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestApproxDistinctOptimizer.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anySymbol; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; + +public class TestApproxDistinctOptimizer + extends BasePlanTest +{ + @Test + public void testReplacesConditionalApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = 1, 1)) FROM nation", + output( + project( + ImmutableMap.of("output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation")))))))); + } + + @Test + public void testReplacesConditionalApproxDistinctGrouped() + { + assertPlan("SELECT APPROX_DISTINCT(IF(nationkey = nationkey, 1)) FROM nation group by nationkey", + output( + project( + ImmutableMap.of("output", expression("coalesce(intermediate, 0)")), + aggregation( + ImmutableMap.of("intermediate", functionCall("arbitrary", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("arbitrary", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation")))))))); + } + + @Test + public void testDontReplaceConstantApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT('constant') FROM nation", + output( + aggregation( + ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("approx_distinct", false, ImmutableList.of(anySymbol()))), + AggregationNode.Step.PARTIAL, + anyTree( + tableScan("nation"))))))); + } + + @Test + public void testDontReplaceVariableApproxDistinct() + { + assertPlan("SELECT APPROX_DISTINCT(nationkey) FROM nation", + output( + aggregation( + ImmutableMap.of("final", functionCall("approx_distinct", ImmutableList.of("partial"))), + AggregationNode.Step.FINAL, + anyTree( + aggregation( + ImmutableMap.of("partial", functionCall("approx_distinct", ImmutableList.of("nationkey"))), + AggregationNode.Step.PARTIAL, + tableScan("nation", ImmutableMap.of("nationkey", "nationkey"))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java index 62b96f0973778..1f6a7f46d539a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestConnectorOptimization.java @@ -60,6 +60,7 @@ import java.util.Set; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.SystemSessionProperties.ENABLE_EMPTY_CONNECTOR_OPTIMIZER; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.expressions.LogicalRowExpressions.and; @@ -223,6 +224,346 @@ public void testAddFilterToTableScan() TypeProvider.viewOf(ImmutableMap.of("a", BIGINT, "b", BIGINT))); } + @Test + public void testEmptyConnectorOptimization() + { + PlanNode plan = output(values("a", "b"), "a"); + ConnectorId emptyConnectorId = new ConnectorId("$internal$ApplyConnectorOptimization_EMPTY_CONNECTOR"); + ConnectorPlanOptimizer emptyConnectorOptimizer = createEmptyConnectorOptimizer(emptyConnectorId); + Session session = Session.builder(TEST_SESSION).setSystemProperty(ENABLE_EMPTY_CONNECTOR_OPTIMIZER, "true").build(); + + PlanNode actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer)), session); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.filter( + "true", + PlanMatchPattern.values("a", "b")))); + + plan = output( + union( + values("a", "b"), + values("a", "b"), + values("a", "b")), + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer)), session); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.filter( + "true", + PlanMatchPattern.union( + PlanMatchPattern.values("a", "b"), + PlanMatchPattern.values("a", "b"), + PlanMatchPattern.values("a", "b"))))); + + plan = output( + union( + values("a", "b"), + tableScan("cat1", "a", "b")), + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer), + new ConnectorId("cat1"), ImmutableSet.of(noop())), session); + + assertEquals(actual, plan); + + plan = output( + filter(values("a", "b"), TRUE_CONSTANT), + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer)), session); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.filter( + "true", + PlanMatchPattern.filter( + "true", + PlanMatchPattern.values("a", "b"))))); + + plan = output( + union( + filter(values("a", "b"), TRUE_CONSTANT), + union( + values("a", "b"), + filter(values("a", "b"), TRUE_CONSTANT))), + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer)), session); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.filter( + "true", + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + PlanMatchPattern.values("a", "b")), + PlanMatchPattern.union( + PlanMatchPattern.values("a", "b"), + PlanMatchPattern.filter( + "true", + PlanMatchPattern.values("a", "b"))))))); + + plan = output(tableScan("cat1", "a", "b"), "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("tpch"), ImmutableSet.of(emptyConnectorOptimizer), + new ConnectorId("cat1"), ImmutableSet.of(noop())), session); + + assertEquals(actual, plan); + } + + @Test + public void testMultipleConnectorOptimization() + { + PlanNode plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + PlanNode actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))))); + + ConnectorPlanOptimizer crossConnectorUnionOptimizer = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(crossConnectorUnionOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))))); + + plan = output( + union( + filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT), + filter(tableScan("cat2", "a", "b"), TRUE_CONSTANT), + filter(tableScan("cat3", "a", "b"), TRUE_CONSTANT)), + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer12 = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer12), + new ConnectorId("cat3"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b")), + SimpleTableScanMatcher.tableScan("cat3", TRUE_CONSTANT)))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), + filter(tableScan("cat1", "a", "b"), TRUE_CONSTANT)), + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer12, filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + SimpleTableScanMatcher.tableScan("cat1", TRUE_CONSTANT), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))), + SimpleTableScanMatcher.tableScan("cat1", TRUE_CONSTANT)))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), // This union only contains supported connectors + tableScan("cat4", "a", "b")), // cat4 in separate part of plan + "a"); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(crossConnectorUnionOptimizer))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))), + SimpleTableScanMatcher.tableScan("cat4", "a", "b")))); + + plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b")), + "a"); + + ConnectorPlanOptimizer singleConnectorOptimizer1 = addFilterToTableScan(TRUE_CONSTANT); + ConnectorPlanOptimizer singleConnectorOptimizer2 = noop(); + ConnectorPlanOptimizer multiConnectorOptimizer13 = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat3"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(singleConnectorOptimizer1, multiConnectorOptimizer13), + new ConnectorId("cat2"), ImmutableSet.of(singleConnectorOptimizer2), + new ConnectorId("cat3"), ImmutableSet.of(singleConnectorOptimizer1))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b"))))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat3", "a", "b")), // This inner union has exactly cat1, cat3 + union( + tableScan("cat2", "a", "b"), + tableScan("cat4", "a", "b"))), // This inner union has cat2, cat4 + "a"); + + ConnectorPlanOptimizer exactMatchOptimizer = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat3"))); // Only supports cat1 and cat3 + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(exactMatchOptimizer), + new ConnectorId("cat2"), ImmutableSet.of(filterPushdown()), + new ConnectorId("cat4"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b"))), + PlanMatchPattern.union( + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + SimpleTableScanMatcher.tableScan("cat4", "a", "b"))))); + + plan = output( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b")), + "a"); + + ConnectorPlanOptimizer partialCoverageOptimizer = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); // Only supports cat1, cat2 + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(partialCoverageOptimizer), + new ConnectorId("cat3"), ImmutableSet.of(filterPushdown()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + SimpleTableScanMatcher.tableScan("cat1", "a", "b"), + SimpleTableScanMatcher.tableScan("cat2", "a", "b"), + SimpleTableScanMatcher.tableScan("cat3", "a", "b")))); + + plan = output( + union( + union( + tableScan("cat1", "a", "b"), + tableScan("cat2", "a", "b")), // This inner union has exactly cat1, cat2 + union( + tableScan("cat2", "a", "b"), + tableScan("cat3", "a", "b"))), // This inner union has exactly cat2, cat3 + "a"); + + ConnectorPlanOptimizer multiConnectorOptimizer12v2 = createMultiConnectorOptimizer( + ImmutableList.of(new ConnectorId("cat1"), new ConnectorId("cat2"))); + ConnectorPlanOptimizer multiConnectorOptimizer23 = createCrossConnectorUnionOptimizer( + ImmutableList.of(new ConnectorId("cat2"), new ConnectorId("cat3"))); + + actual = optimize(plan, ImmutableMap.of( + new ConnectorId("cat1"), ImmutableSet.of(multiConnectorOptimizer12v2), + new ConnectorId("cat2"), ImmutableSet.of(multiConnectorOptimizer23), + new ConnectorId("cat3"), ImmutableSet.of(noop()))); + + assertPlanMatch( + actual, + PlanMatchPattern.output( + PlanMatchPattern.union( + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat1", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b"))), + PlanMatchPattern.union( + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat2", "a", "b")), + PlanMatchPattern.filter( + "true", + SimpleTableScanMatcher.tableScan("cat3", "a", "b")))))); + } + private TableScanNode tableScan(String connectorName, String... columnNames) { return PLAN_BUILDER.tableScan( @@ -288,6 +629,12 @@ private static PlanNode optimize(PlanNode plan, Map> optimizers, Session session) + { + ApplyConnectorOptimization optimizer = new ApplyConnectorOptimization(() -> optimizers); + return optimizer.optimize(plan, session, TypeProvider.empty(), new VariableAllocator(), new PlanNodeIdAllocator(), WarningCollector.NOOP).getPlanNode(); + } + private static ConnectorPlanOptimizer filterPushdown() { return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan.accept(new TestFilterPushdownVisitor(), null); @@ -303,6 +650,60 @@ private static ConnectorPlanOptimizer noop() return (maxSubplan, session, variableAllocator, idAllocator) -> maxSubplan; } + private static ConnectorPlanOptimizer createMultiConnectorOptimizer(java.util.List supportedConnectors) + { + return new ConnectorPlanOptimizer() + { + @Override + public PlanNode optimize(PlanNode maxSubplan, com.facebook.presto.spi.ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) + { + return maxSubplan.accept(new TestMultiConnectorOptimizationVisitor(supportedConnectors, idAllocator), null); + } + + @Override + public java.util.List getSupportedConnectorIds() + { + return supportedConnectors; + } + }; + } + + private static ConnectorPlanOptimizer createCrossConnectorUnionOptimizer(java.util.List supportedConnectors) + { + return new ConnectorPlanOptimizer() + { + @Override + public PlanNode optimize(PlanNode maxSubplan, com.facebook.presto.spi.ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) + { + return maxSubplan.accept(new TestCrossConnectorUnionVisitor(supportedConnectors, idAllocator), null); + } + + @Override + public java.util.List getSupportedConnectorIds() + { + return supportedConnectors; + } + }; + } + + private static ConnectorPlanOptimizer createEmptyConnectorOptimizer(ConnectorId emptyConnectorId) + { + return new ConnectorPlanOptimizer() + { + @Override + public PlanNode optimize(PlanNode maxSubplan, com.facebook.presto.spi.ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator) + { + return new FilterNode(Optional.empty(), idAllocator.getNextId(), maxSubplan, TRUE_CONSTANT); + } + + @Override + public java.util.List getSupportedConnectorIds() + { + return ImmutableList.of(emptyConnectorId); + } + }; + } + private static class TestPlanOptimizationVisitor extends PlanVisitor { @@ -404,6 +805,88 @@ public PlanNode visitTableScan(TableScanNode node, Void context) } } + /** + * Multi-connector visitor that adds filters to table scans from supported connectors + */ + private static class TestMultiConnectorOptimizationVisitor + extends TestPlanOptimizationVisitor + { + private final java.util.List supportedConnectors; + private final PlanNodeIdAllocator idAllocator; + + TestMultiConnectorOptimizationVisitor(java.util.List supportedConnectors, PlanNodeIdAllocator idAllocator) + { + this.supportedConnectors = supportedConnectors; + this.idAllocator = idAllocator; + } + + @Override + public PlanNode visitTableScan(TableScanNode node, Void context) + { + if (supportedConnectors.contains(node.getTable().getConnectorId())) { + return new FilterNode(Optional.empty(), idAllocator.getNextId(), node, TRUE_CONSTANT); + } + return node; + } + } + + /** + * Multi-connector visitor that optimizes unions across different connectors + */ + private static class TestCrossConnectorUnionVisitor + extends TestPlanOptimizationVisitor + { + private final java.util.List supportedConnectors; + private final PlanNodeIdAllocator idAllocator; + + TestCrossConnectorUnionVisitor(java.util.List supportedConnectors, PlanNodeIdAllocator idAllocator) + { + this.supportedConnectors = supportedConnectors; + this.idAllocator = idAllocator; + } + + @Override + public PlanNode visitUnion(UnionNode node, Void context) + { + Set foundConnectors = new java.util.HashSet<>(); + boolean hasMultipleConnectors = false; + + for (PlanNode source : node.getSources()) { + if (source instanceof TableScanNode) { + ConnectorId connectorId = ((TableScanNode) source).getTable().getConnectorId(); + if (supportedConnectors.contains(connectorId)) { + foundConnectors.add(connectorId); + if (foundConnectors.size() > 1) { + hasMultipleConnectors = true; + break; + } + } + } + } + + if (hasMultipleConnectors) { + ImmutableList.Builder newSources = ImmutableList.builder(); + for (PlanNode source : node.getSources()) { + if (source instanceof TableScanNode) { + TableScanNode tableScan = (TableScanNode) source; + if (supportedConnectors.contains(tableScan.getTable().getConnectorId())) { + newSources.add(new FilterNode(Optional.empty(), idAllocator.getNextId(), tableScan, TRUE_CONSTANT)); + } + else { + newSources.add(source); + } + } + else { + newSources.add(source.accept(this, context)); + } + } + return node.replaceChildren(newSources.build()); + } + + return super.visitUnion(node, context); + } + } + /** * A simplified table scan matcher for multiple-connector support. * The goal is to test plan structural matching rather than table scan details diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java index 82e180b0d78e3..0b639975c4815 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestFullOuterJoinWithCoalesce.java @@ -91,7 +91,7 @@ public void testDuplicateJoinClause() ImmutableMap.of("ts", expression("coalesce(t, s)")), join( FULL, - ImmutableList.of(equiJoinClause("t", "s"), equiJoinClause("t", "s")), + ImmutableList.of(equiJoinClause("t", "s")), exchange(REMOTE_STREAMING, REPARTITION, anyTree(values(ImmutableList.of("t")))), exchange(LOCAL, GATHER, anyTree(values(ImmutableList.of("s"))))))), exchange(LOCAL, GATHER, anyTree(values(ImmutableList.of("r")))))))); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java index db0600d2d6380..edbeeb3e0cdc2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestLocalProperties.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.GroupingProperty; import com.facebook.presto.spi.LocalProperty; import com.facebook.presto.spi.SortingProperty; +import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.testing.TestingMetadata.TestingColumnHandle; import com.fasterxml.jackson.core.JsonParser; @@ -78,6 +79,10 @@ public void testConstantProcessing() input = ImmutableList.of(constant("a"), constant("b")); assertEquals(stripLeadingConstants(input), ImmutableList.of()); assertEquals(extractLeadingConstants(input), ImmutableSet.of("a", "b")); + + input = ImmutableList.of(unique("a")); + assertEquals(stripLeadingConstants(input), ImmutableList.of(unique("a"))); + assertEquals(extractLeadingConstants(input), ImmutableSet.of()); } @Test @@ -138,6 +143,14 @@ public void testTranslate() map = ImmutableMap.of("a", "a1", "b", "b1", "c", "c1"); input = ImmutableList.of(grouped("a"), constant("b"), grouped("c")); assertEquals(LocalProperties.translate(input, translateWithMap(map)), ImmutableList.of(grouped("a1"), constant("b1"), grouped("c1"))); + + map = ImmutableMap.of(); + input = ImmutableList.of(unique("a")); + assertEquals(LocalProperties.translate(input, translateWithMap(map)), ImmutableList.of()); + + map = ImmutableMap.of("a", "a1"); + input = ImmutableList.of(unique("a")); + assertEquals(LocalProperties.translate(input, translateWithMap(map)), ImmutableList.of(unique("a1"))); } private static Function> translateWithMap(Map translateMap) @@ -177,6 +190,35 @@ public void testNormalizeOverlappingSymbol() assertNormalizeAndFlatten( localProperties, grouped("a")); + + localProperties = builder() + .unique("a") + .sorted("a", SortOrder.ASC_NULLS_FIRST) + .constant("a") + .build(); + assertNormalize( + localProperties, + Optional.of(unique("a")), + Optional.empty(), + Optional.empty()); + assertNormalizeAndFlatten( + localProperties, + unique("a")); + + localProperties = builder() + .grouped("a") + .unique("a") + .constant("a") + .build(); + assertNormalize( + localProperties, + Optional.of(grouped("a")), + Optional.of(unique("a")), + Optional.empty()); + assertNormalizeAndFlatten( + localProperties, + grouped("a"), + unique("a")); } @Test @@ -780,6 +822,11 @@ private static GroupingProperty grouped(String... columns) return new GroupingProperty<>(Arrays.asList(columns)); } + private static UniqueProperty unique(String column) + { + return new UniqueProperty<>(column); + } + private static SortingProperty sorted(String column, SortOrder order) { return new SortingProperty<>(column, order); @@ -812,6 +859,12 @@ public Builder constant(String column) return this; } + public Builder unique(String column) + { + properties.add(new UniqueProperty<>(column)); + return this; + } + public List> build() { return new ArrayList<>(properties); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java index 7d61a00af47ef..568e3c6ee38a2 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestMergePartialAggregationsWithFilter.java @@ -87,6 +87,37 @@ public void testOptimizationApplied() false); } + @Test + public void testOptimizationAppliedAllHasMask() + { + assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey", + enableOptimization(), + anyTree( + aggregation( + singleGroupingSet("partkey"), + ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")), + Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"), + "maskPartialSum2", expression("IF(expr2, partialSum, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "expr", "expr2"), + ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))), + ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_or", expression("expr or expr2")), + project( + ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")), + tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))), + false); + } + @Test public void testOptimizationDisabled() { @@ -188,6 +219,57 @@ public void testAggregationsMultipleLevel() false); } + @Test + public void testAggregationsMultipleLevelAllAggWithMask() + { + assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey", + enableOptimization(), + anyTree( + aggregation( + singleGroupingSet("partkey"), + ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")), + Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"), + "maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "expr_2", "expr_2_g10"), + ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))), + ImmutableMap.of(new Symbol("partialAvg"), new Symbol("expr_2_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_2_or", expression("expr_2 or expr_2_g10")), + project( + ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")), + aggregation( + singleGroupingSet("partkey", "suppkey"), + ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.FINAL, + project( + ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"), + "maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")), + anyTree( + aggregation( + singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"), + ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))), + ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")), + Optional.empty(), + AggregationNode.Step.PARTIAL, + project( + ImmutableMap.of("expr_or", expression("expr or expr_g10")), + project( + ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")), + tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))), + false); + } + @Test public void testGlobalOptimization() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java index 6a1eff9da903d..c8db155f677d7 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPruneUnreferencedOutputs.java @@ -14,9 +14,14 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.IndexJoinNode; +import com.facebook.presto.spi.plan.JoinType; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.WindowNode; @@ -25,25 +30,34 @@ import com.facebook.presto.sql.planner.assertions.OptimizerAssert; import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; import java.util.Optional; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.RANGE; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.except; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.indexJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.intersect; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictIndexSource; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; public class TestPruneUnreferencedOutputs extends BaseRuleTest @@ -140,6 +154,47 @@ public void testExceptNodePruning() values("regionkey_6"))))); } + @Test + public void testIndexJoinNodePruning() + { + assertRuleApplication() + .on(p -> + p.output(ImmutableList.of("totoalprice"), ImmutableList.of(p.variable("totoalprice")), + p.indexJoin(JoinType.LEFT, + p.tableScan( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("lineitem", TINY_SCALE_FACTOR), + TestingTransactionHandle.create(), + Optional.empty()), + ImmutableList.of(p.variable("partkey"), p.variable("suppkey")), + ImmutableMap.of( + p.variable("partkey", BIGINT), new TpchColumnHandle("partkey", BIGINT), + p.variable("suppkey", BIGINT), new TpchColumnHandle("suppkey", BIGINT))), + p.indexSource( + new TableHandle( + new ConnectorId("local"), + new TpchTableHandle("orders", TINY_SCALE_FACTOR), + TestingTransactionHandle.create(), + Optional.empty()), + ImmutableSet.of(p.variable("custkey"), p.variable("orderkey"), p.variable("orderstatus", VARCHAR)), + ImmutableList.of(p.variable("custkey"), p.variable("orderkey"), p.variable("orderstatus", VARCHAR), p.variable("totalprice", DOUBLE)), + ImmutableMap.of( + p.variable("custkey", BIGINT), new TpchColumnHandle("custkey", BIGINT), + p.variable("orderkey", BIGINT), new TpchColumnHandle("orderkey", BIGINT), + p.variable("totalprice", DOUBLE), new TpchColumnHandle("totalprice", DOUBLE), + p.variable("orderstatus", VARCHAR), new TpchColumnHandle("orderstatus", VARCHAR)), + TupleDomain.all()), + ImmutableList.of(new IndexJoinNode.EquiJoinClause(p.variable("partkey", BIGINT), p.variable("orderkey", BIGINT))), + Optional.of(p.rowExpression("custkey BETWEEN suppkey AND 20"))))) + .matches( + output( + indexJoin( + strictTableScan("lineitem", ImmutableMap.of("partkey", "partkey", "suppkey", "suppkey")), + strictIndexSource("orders", + ImmutableMap.of("custkey", "custkey", "orderkey", "orderkey", "orderstatus", "orderstatus", "totalprice", "totalprice"))))); + } + private OptimizerAssert assertRuleApplication() { RuleTester tester = tester(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java new file mode 100644 index 0000000000000..129189f6048d9 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestSortedExchangeRule.java @@ -0,0 +1,393 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.SORTED_EXCHANGE_ENABLED; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_LAST; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +/** + * Tests for the SortedExchangeRule optimizer which pushes sort operations + * down to exchange nodes for distributed queries. + */ +public class TestSortedExchangeRule + extends BasePlanTest +{ + @Test + public void testOptimizationDisabled() + { + Session disabledSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "false") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(disabledSession, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col = builder.variable("col"); + + // Create: Sort -> Exchange -> Values + PlanNode exchange = builder.exchange(ex -> ex + .addSource(builder.values(col)) + .addInputsSet(col) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col), ImmutableList.of(col)) + .scope(REMOTE_STREAMING)); + + PlanNode plan = builder.sort(ImmutableList.of(col), exchange); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + disabledSession, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // Plan should remain unchanged - still has SortNode + assertTrue(result.getPlanNode() instanceof SortNode); + SortNode sortNode = (SortNode) result.getPlanNode(); + assertTrue(sortNode.getSource() instanceof ExchangeNode); + ExchangeNode exchangeNode = (ExchangeNode) sortNode.getSource(); + assertFalse(exchangeNode.getOrderingScheme().isPresent()); + } + + @Test + public void testPushSortToRemoteRepartitionExchange() + { + Session enabledSession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(enabledSession, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col = builder.variable("col"); + + // Create: Sort -> Exchange -> Values + PlanNode exchange = builder.exchange(ex -> ex + .addSource(builder.values(col)) + .addInputsSet(col) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col), ImmutableList.of(col)) + .scope(REMOTE_STREAMING)); + + PlanNode plan = builder.sort(ImmutableList.of(col), exchange); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + enabledSession, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // SortNode should be removed and ordering moved to ExchangeNode + assertTrue(result.getPlanNode() instanceof ExchangeNode); + ExchangeNode exchangeNode = (ExchangeNode) result.getPlanNode(); + assertTrue(exchangeNode.getOrderingScheme().isPresent()); + + assertEquals(exchangeNode.getOrderingScheme().get().getOrderByVariables().size(), 1); + assertEquals(exchangeNode.getOrderingScheme().get().getOrderBy().get(0).getSortOrder(), ASC_NULLS_FIRST); + } + + @Test + public void testMultipleOrderingColumns() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(session, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col1 = builder.variable("col1"); + VariableReferenceExpression col2 = builder.variable("col2"); + + // Create: Sort(col1 ASC, col2 DESC) -> Exchange -> Values + PlanNode exchange = builder.exchange(ex -> ex + .addSource(builder.values(col1, col2)) + .addInputsSet(col1, col2) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col1, col2), ImmutableList.of(col1, col2)) + .scope(REMOTE_STREAMING)); + + // Manually create a SortNode with custom ordering + OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of( + new Ordering(col1, ASC_NULLS_FIRST), + new Ordering(col2, DESC_NULLS_LAST))); + PlanNode plan = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + exchange, + orderingScheme, + false, + ImmutableList.of()); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + session, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // Verify ordering moved to exchange + assertTrue(result.getPlanNode() instanceof ExchangeNode); + ExchangeNode exchangeNode = (ExchangeNode) result.getPlanNode(); + assertTrue(exchangeNode.getOrderingScheme().isPresent()); + + assertEquals(exchangeNode.getOrderingScheme().get().getOrderByVariables().size(), 2); + assertEquals(exchangeNode.getOrderingScheme().get().getOrderBy().get(0).getSortOrder(), ASC_NULLS_FIRST); + assertEquals(exchangeNode.getOrderingScheme().get().getOrderBy().get(1).getSortOrder(), DESC_NULLS_LAST); + } + + @Test + public void testDoesNotOptimizeNonRemoteExchange() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(session, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col = builder.variable("col"); + + // Create: Sort -> LOCAL Exchange -> Values (not REMOTE) + PlanNode exchange = builder.exchange(ex -> ex + .addSource(builder.values(col)) + .addInputsSet(col) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col), ImmutableList.of(col)) + .scope(ExchangeNode.Scope.LOCAL)); + + PlanNode plan = builder.sort(ImmutableList.of(col), exchange); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + session, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // Plan should remain unchanged since exchange is not remote + assertTrue(result.getPlanNode() instanceof SortNode); + } + + @Test + public void testDoesNotOptimizeReplicatedExchange() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(session, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col = builder.variable("col"); + + // Create: Sort -> REPLICATE Exchange -> Values + PlanNode exchange = ExchangeNode.replicatedExchange( + idAllocator.getNextId(), + ExchangeNode.Scope.REMOTE_STREAMING, + builder.values(col)); + + PlanNode plan = builder.sort(ImmutableList.of(col), exchange); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + session, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // Plan should remain unchanged since exchange is REPLICATE type + assertTrue(result.getPlanNode() instanceof SortNode); + SortNode sortNode = (SortNode) result.getPlanNode(); + assertTrue(sortNode.getSource() instanceof ExchangeNode); + ExchangeNode exchangeNode = (ExchangeNode) sortNode.getSource(); + assertEquals(exchangeNode.getType(), ExchangeNode.Type.REPLICATE); + } + + @Test + public void testNestedSortExchangePattern() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(session, idAllocator, getQueryRunner().getMetadata()); + VariableReferenceExpression col1 = builder.variable("col1"); + VariableReferenceExpression col2 = builder.variable("col2"); + + // Create nested pattern: Sort -> Project -> Exchange -> Project -> Sort -> Exchange -> Values + // Bottom layer: Exchange2 -> Values + PlanNode values = builder.values(col1, col2); + PlanNode exchange2 = builder.exchange(ex -> ex + .addSource(values) + .addInputsSet(col1, col2) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col1, col2), ImmutableList.of(col1, col2)) + .scope(REMOTE_STREAMING)); + + // Second layer: Sort2 on exchange2 + PlanNode sort2 = builder.sort(ImmutableList.of(col1), exchange2); + + // Third layer: Project on sort2 (some other nodes) + PlanNode project1 = builder.project(builder.assignment(col1, col1, col2, col2), sort2); + + // Fourth layer: Exchange1 on project1 + PlanNode exchange1 = builder.exchange(ex -> ex + .addSource(project1) + .addInputsSet(col1, col2) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(col2, col1), ImmutableList.of(col1, col2)) + .scope(REMOTE_STREAMING)); + + // Fifth layer: Project on exchange1 (some other nodes) + PlanNode project2 = builder.project(builder.assignment(col1, col1, col2, col2), exchange1); + + // Top layer: Sort1 on project2 + PlanNode plan = builder.sort(ImmutableList.of(col2), project2); + + // Run the optimizer + // IMPORTANT: Only Sort2 can be optimized (its immediate child is Exchange2) + // Sort1 CANNOT be optimized (its immediate child is Project, not Exchange1) + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + + PlanOptimizerResult result = rule.optimize( + plan, + session, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // Expected result after optimization: + // Sort1 -> Project -> Exchange1 (NO ordering) -> Project -> Exchange2 (WITH ordering) -> Values + // + // Sort1 remains because its immediate child is Project, not Exchange1 + // Sort2 is removed and its ordering is pushed to Exchange2 + + PlanNode finalPlan = result.getPlanNode(); + + // The root SHOULD still be a SortNode (Sort1 cannot be optimized) + assertTrue(finalPlan instanceof SortNode, "Top-level node should still be SortNode because its immediate child is Project, not Exchange"); + SortNode sort1 = (SortNode) finalPlan; + assertEquals(sort1.getOrderingScheme().getOrderByVariables().get(0), col2); + + // Navigate through to find exchanges + boolean foundExchange1 = false; + boolean foundExchange2WithOrdering = false; + + PlanNode current = finalPlan.getSources().get(0); // Skip Sort1, start at Project + while (current != null) { + if (current instanceof ExchangeNode) { + ExchangeNode ex = (ExchangeNode) current; + if (!foundExchange1) { + // First exchange we encounter should be Exchange1 (should NOT have ordering) + foundExchange1 = true; + assertFalse(ex.getOrderingScheme().isPresent(), + "Exchange1 should NOT have ordering (Sort1's immediate child was Project, not Exchange1)"); + } + else if (!foundExchange2WithOrdering && ex.getOrderingScheme().isPresent()) { + // Second exchange with ordering should be Exchange2 + foundExchange2WithOrdering = true; + assertEquals(ex.getOrderingScheme().get().getOrderByVariables().size(), 1); + assertEquals(ex.getOrderingScheme().get().getOrderByVariables().get(0), col1, + "Exchange2 should have ordering from Sort2 on col1"); + } + } + + // Move to next node + if (current.getSources().isEmpty()) { + break; + } + current = current.getSources().get(0); + } + + assertTrue(foundExchange1, "Should have found Exchange1"); + assertTrue(foundExchange2WithOrdering, "Exchange2 should have ordering (Sort2's immediate child was Exchange2)"); + } + + @Test + public void testDoesNotOptimizeWhenOrderingVariablesNotInSourceOutput() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(SORTED_EXCHANGE_ENABLED, "true") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder builder = new PlanBuilder(session, idAllocator, getQueryRunner().getMetadata()); + + // Source produces: [a, b] + VariableReferenceExpression varA = builder.variable("a"); + VariableReferenceExpression varB = builder.variable("b"); + PlanNode source = builder.values(varA, varB); + + // Exchange outputs: [x, y] (different variables!) + VariableReferenceExpression varX = builder.variable("x"); + VariableReferenceExpression varY = builder.variable("y"); + PlanNode exchange = builder.exchange(ex -> ex + .addSource(source) + .addInputsSet(varA, varB) // Source produces [a, b] + .fixedHashDistributionPartitioningScheme(ImmutableList.of(varX, varY), ImmutableList.of(varX, varY)) // Exchange outputs [x, y] + .scope(REMOTE_STREAMING)); + + // Try to sort by varX, which is NOT in the source output [a, b] + // This should NOT be optimized because the ordering variable doesn't exist in source + PlanNode plan = builder.sort(ImmutableList.of(varX), exchange); + + SortedExchangeRule rule = new SortedExchangeRule(true); // true for testing + VariableAllocator variableAllocator = new VariableAllocator(builder.getTypes().allVariables()); + PlanOptimizerResult result = rule.optimize( + plan, + session, + builder.getTypes(), + variableAllocator, + idAllocator, + WarningCollector.NOOP); + + // The sort should NOT be pushed down because varX is not in source output [a, b] + assertTrue(result.getPlanNode() instanceof SortNode, "Sort should remain because ordering variable not in source output"); + SortNode sortNode = (SortNode) result.getPlanNode(); + assertTrue(sortNode.getSource() instanceof ExchangeNode, "Sort's child should still be Exchange"); + ExchangeNode exchangeNode = (ExchangeNode) sortNode.getSource(); + assertFalse(exchangeNode.getOrderingScheme().isPresent(), "Exchange should NOT have ordering scheme"); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java index 5e28a77689e8c..baa687541059c 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticsWriterNode.java @@ -16,8 +16,10 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; @@ -30,6 +32,7 @@ import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.ColumnStatisticType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.testing.TestingHandleResolver; import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; @@ -39,6 +42,7 @@ import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; +import com.google.inject.Scopes; import org.testng.annotations.Test; import java.util.Optional; @@ -127,6 +131,9 @@ private JsonCodec getJsonCodec() FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + binder.bind(FeaturesConfig.class).toInstance(new FeaturesConfig()); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); newSetBinder(binder, Type.class); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java index fa5a8c6db51dc..f71b0e889649f 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/plan/TestWindowNode.java @@ -16,9 +16,11 @@ import com.facebook.airlift.bootstrap.Bootstrap; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.JsonModule; +import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.presto.common.block.SortOrder; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.server.SliceDeserializer; @@ -159,10 +161,12 @@ private JsonCodec getJsonCodec() SqlParser sqlParser = new SqlParser(); FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); binder.install(new JsonModule()); + binder.install(new ThriftCodecModule()); binder.install(new HandleJsonModule()); + configBinder(binder).bindConfig(FeaturesConfig.class); + binder.bind(ConnectorManager.class).toProvider(() -> null); binder.bind(SqlParser.class).toInstance(sqlParser); binder.bind(TypeManager.class).toInstance(functionAndTypeManager); - configBinder(binder).bindConfig(FeaturesConfig.class); newSetBinder(binder, Type.class); jsonBinder(binder).addSerializerBinding(Slice.class).to(SliceSerializer.class); jsonBinder(binder).addDeserializerBinding(Slice.class).to(SliceDeserializer.class); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java index e6a33c46e5c4b..076de6a3f3252 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/planPrinter/TestPlanPrinter.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.planPrinter; import com.facebook.airlift.stats.Distribution; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.RuntimeMetric; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.predicate.Domain; @@ -47,7 +48,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.List; @@ -101,6 +101,7 @@ private String domainToPrintedScan(VariableReferenceExpression variable, ColumnH SOURCE_DISTRIBUTION, ImmutableList.of(scanNode.getId()), new PartitioningScheme(Partitioning.create(SOURCE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)), + Optional.empty(), StageExecutionDescriptor.ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), @@ -307,6 +308,14 @@ private static StageExecutionStats createStageStats(int stageId, int stageExecut 0, 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, 0, 0, 0, diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java new file mode 100644 index 0000000000000..106b526a6b75e --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java @@ -0,0 +1,474 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.TestingColumnHandle; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.function.Function; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestCheckNoIneligibleFunctionsInCoordinatorFragments + extends BasePlanTest +{ + // CPP function for testing (similar to TestAddExchangesPlansWithFunctions) + private static final SqlInvokedFunction CPP_FUNC = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_func"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))), + parseTypeSignature(StandardTypes.VARCHAR), + "cpp_func(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + // JAVA function for testing + private static final SqlInvokedFunction JAVA_FUNC = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_func"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))), + parseTypeSignature(StandardTypes.VARCHAR), + "java_func(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + public TestCheckNoIneligibleFunctionsInCoordinatorFragments() + { + super(TestCheckNoIneligibleFunctionsInCoordinatorFragments::createTestQueryRunner); + } + + private static LocalQueryRunner createTestQueryRunner() + { + LocalQueryRunner queryRunner = new LocalQueryRunner( + testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .build(), + new FeaturesConfig().setNativeExecutionEnabled(true), + new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest")); + + queryRunner.createCatalog("local", new TpchConnectorFactory(), ImmutableMap.of()); + + // Add function namespace with both CPP and JAVA functions + queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace( + "dummy", + new InMemoryFunctionNamespaceManager( + "dummy", + new SqlFunctionExecutors( + ImmutableMap.of( + CPP, FunctionImplementationType.CPP, + JAVA, FunctionImplementationType.JAVA), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("CPP,JAVA"))); + + // Register the functions + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FUNC, false); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FUNC, false); + + return queryRunner; + } + + @Test + public void testSystemTableScanWithJavaFunctionPasses() + { + // System table scan with Java function in same fragment should pass + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // Create a system table scan - using proper system connector ID + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode tableScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Java function (using our registered java_func) + return p.project( + assignment(result, p.rowExpression("java_func(col)")), + tableScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testSystemTableScanWithCppFunctionInProjectFails() + { + // System table scan with C++ function in same fragment should fail + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // C++ function (using our registered cpp_func) + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + systemScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testSystemTableScanWithCppFunctionInFilterFails() + { + // System table scan with C++ function in filter should fail + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Filter with C++ function + return p.filter( + p.rowExpression("cpp_func(col) = 'test'"), + systemScan); + }); + } + + @Test + public void testSystemTableScanWithCppFunctionSeparatedByExchangePasses() + { + // System table scan and C++ function separated by exchange should pass + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Exchange creates fragment boundary + PartitioningScheme partitioningScheme = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col)); + + ExchangeNode exchange = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme, + ImmutableList.of(systemScan), + ImmutableList.of(ImmutableList.of(col)), + false, + Optional.empty()); + + // C++ function in different fragment + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + exchange); + }); + } + + @Test + public void testRegularTableScanWithCppFunctionPasses() + { + // Regular table scan with C++ function should pass (no system table) + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // Regular table scan (not system) + TableHandle regularTableHandle = new TableHandle( + new ConnectorId("local"), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode regularScan = p.tableScan( + regularTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // C++ function + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + regularScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testMultipleFragmentsWithCppFunctionInSystemFragment() + { + // Complex plan where CPP function is in same fragment as system table scan (should fail) + validatePlan( + p -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + + // Fragment 1: System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col1), + ImmutableMap.of(col1, new TestingColumnHandle("col1"))); + + // Convert to numeric for join (using CPP function - this should fail) + PlanNode project1 = p.project( + assignment(col2, p.rowExpression("cast(cpp_func(col1) as bigint)")), + systemScan); + + // Fragment 2: Regular values with computation + PlanNode values = p.values(col3); + + // Exchange to separate fragments + PartitioningScheme partitioningScheme1 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col2)); + + ExchangeNode exchange1 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme1, + ImmutableList.of(project1), + ImmutableList.of(ImmutableList.of(col2)), + false, + Optional.empty()); + + PartitioningScheme partitioningScheme2 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col3)); + + ExchangeNode exchange2 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme2, + ImmutableList.of(values), + ImmutableList.of(ImmutableList.of(col3)), + false, + Optional.empty()); + + // Join the results + return p.join( + JoinType.INNER, + exchange1, + exchange2, + p.rowExpression("col2 = col3")); + }); + } + + @Test + public void testMultipleFragmentsWithExchange() + { + // Complex plan with multiple fragments properly separated (Java function - should pass) + validatePlan( + p -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + + // Fragment 1: System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col1), + ImmutableMap.of(col1, new TestingColumnHandle("col1"))); + + // Convert to numeric for join (using Java function) + PlanNode project1 = p.project( + assignment(col2, p.rowExpression("cast(java_func(col1) as bigint)")), + systemScan); + + // Fragment 2: Regular values with computation + PlanNode values = p.values(col3); + + // Exchange to separate fragments + PartitioningScheme partitioningScheme1 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col2)); + + ExchangeNode exchange1 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme1, + ImmutableList.of(project1), + ImmutableList.of(ImmutableList.of(col2)), + false, + Optional.empty()); + + PartitioningScheme partitioningScheme2 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col3)); + + ExchangeNode exchange2 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme2, + ImmutableList.of(values), + ImmutableList.of(ImmutableList.of(col3)), + false, + Optional.empty()); + + // Join the results + return p.join( + JoinType.INNER, + exchange1, + exchange2, + p.rowExpression("col2 = col3")); + }); + } + + @Test + public void testFilterAndProjectWithSystemTable() + { + // Test filter and project both with Java functions on system table + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression len = p.variable("len", BIGINT); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Filter with Java function + PlanNode filtered = p.filter( + p.rowExpression("java_func(col) = 'test'"), + systemScan); + + // Project with Java function + return p.project( + assignment(len, p.rowExpression("cast(java_func(col) as bigint)")), + filtered); + }); + } + + private void validatePlan(Function planProvider) + { + Session session = testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Metadata metadata = getQueryRunner().getMetadata(); + PlanBuilder builder = new PlanBuilder(TEST_SESSION, idAllocator, metadata); + PlanNode planNode = planProvider.apply(builder); + + getQueryRunner().inTransaction(session, transactionSession -> { + new CheckNoIneligibleFunctionsInCoordinatorFragments().validate(planNode, transactionSession, metadata, WarningCollector.NOOP); + return null; + }); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java index ec59272a234fd..40e8878487fe6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/query/QueryAssertions.java @@ -75,9 +75,14 @@ public QueryRunner getQueryRunner() } public void assertFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) + { + assertFails(runner.getDefaultSession(), sql, expectedMessageRegExp); + } + + public void assertFails(Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { try { - runner.execute(runner.getDefaultSession(), sql).toTestTypes(); + runner.execute(session, sql).toTestTypes(); fail(format("Expected query to fail: %s", sql)); } catch (RuntimeException exception) { @@ -153,7 +158,7 @@ public void close() runner.close(); } - protected void executeExclusively(Runnable executionBlock) + public void executeExclusively(Runnable executionBlock) { runner.getExclusiveLock().lock(); try { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestFunctionResolution.java b/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestFunctionResolution.java index 25455849be2d7..1074f7ecbca82 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestFunctionResolution.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/relational/TestFunctionResolution.java @@ -13,10 +13,21 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.StandardFunctionResolution; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -25,19 +36,29 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.spi.function.FunctionImplementationType.THRIFT; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.SQL; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; public class TestFunctionResolution { + private static final RoutineCharacteristics.Language JAVA = new RoutineCharacteristics.Language("java"); + private FunctionResolution functionResolution; + private FunctionAndTypeManager functionAndTypeManager; @BeforeClass public void setup() { - FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + functionAndTypeManager = createTestFunctionAndTypeManager(); functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); } @@ -72,5 +93,57 @@ public void testStandardFunctionResolution() // BuiltInFunction assertEquals(standardFunctionResolution.notFunction(), standardFunctionResolution.lookupBuiltInFunction("not", ImmutableList.of(BOOLEAN))); assertEquals(standardFunctionResolution.countFunction(), standardFunctionResolution.lookupBuiltInFunction("count", ImmutableList.of())); + + // full qualified name + assertEquals(standardFunctionResolution.notFunction(), standardFunctionResolution.lookupFunction("presto", "default", "not", ImmutableList.of(BOOLEAN))); + assertEquals(standardFunctionResolution.countFunction(), standardFunctionResolution.lookupFunction("presto", "default", "count", ImmutableList.of())); + + // lookup cast + assertTrue(standardFunctionResolution.isCastFunction(standardFunctionResolution.lookupCast("CAST", BIGINT, VARCHAR))); + } + + @Test + public void testLookupFunctionWithNonDefaultSchemaAndCatalog() + { + StandardFunctionResolution standardFunctionResolution = functionResolution; + + functionAndTypeManager.addFunctionNamespace( + "custom_catalog", + new InMemoryFunctionNamespaceManager( + "custom_catalog", + new SqlFunctionExecutors( + ImmutableMap.of( + SQL, FunctionImplementationType.SQL, + JAVA, THRIFT), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql,java"))); + + QualifiedObjectName customNotFunction = QualifiedObjectName.valueOf("custom_catalog", "custom_schema", "custom_not"); + SqlInvokedFunction notFunc = new SqlInvokedFunction( + customNotFunction, + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BOOLEAN))), + parseTypeSignature(StandardTypes.BOOLEAN), + "custom_not(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + QualifiedObjectName customCountFunction = QualifiedObjectName.valueOf("custom_catalog", "custom_schema", "custom_count"); + SqlInvokedFunction countFunc = new SqlInvokedFunction( + customCountFunction, + ImmutableList.of(), + parseTypeSignature(StandardTypes.BIGINT), + "custom_count()", + RoutineCharacteristics.builder().setLanguage(JAVA).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + functionAndTypeManager.createFunction(notFunc, true); + functionAndTypeManager.createFunction(countFunc, true); + + assertEquals(notFunc.getFunctionId().getFunctionName(), + functionAndTypeManager.getFunctionMetadata(standardFunctionResolution.lookupFunction("custom_catalog", "custom_schema", "custom_not", ImmutableList.of(BOOLEAN))).getName()); + assertEquals(countFunc.getFunctionId().getFunctionName(), + functionAndTypeManager.getFunctionMetadata(standardFunctionResolution.lookupFunction("custom_catalog", "custom_schema", "custom_count", ImmutableList.of())).getName()); } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/tdigest/TestTDigest.java b/presto-main-base/src/test/java/com/facebook/presto/tdigest/TestTDigest.java index 773be051a5df5..04320aaa52fa8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/tdigest/TestTDigest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/tdigest/TestTDigest.java @@ -397,6 +397,16 @@ public void testGeometricDistribution() } } + @Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = ".*Cannot serialize t-digest with NaN mean value.*") + public void testSerializationWithNaNMean() + { + double[] means = {1.0, Double.NaN, 3.0}; + double[] weights = {1.0, 1.0, 1.0}; + TDigest tDigest = createTDigest(means, weights, STANDARD_COMPRESSION_FACTOR, 1.0, 3.0, 6.0, 3); + + tDigest.serialize(); + } + @Test(enabled = false) public void testPoissonDistribution() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManager.java b/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManager.java index 8fa3f5d3178a1..dccda22b6a342 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManager.java +++ b/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManager.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.transaction; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.connector.informationSchema.InformationSchemaConnector; import com.facebook.presto.connector.system.SystemConnector; @@ -30,7 +31,6 @@ import com.facebook.presto.tpch.TpchConnectorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManagerConfig.java index b6c3f5a8b1f80..257adaeae1049 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/transaction/TestTransactionManagerConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.transaction; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main-base/src/test/java/com/facebook/presto/ttl/TestNodeTtlFetcherManagerConfig.java b/presto-main-base/src/test/java/com/facebook/presto/ttl/TestNodeTtlFetcherManagerConfig.java index 1fe4aeaa004f3..a2e42c0670239 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/ttl/TestNodeTtlFetcherManagerConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/ttl/TestNodeTtlFetcherManagerConfig.java @@ -14,9 +14,9 @@ package com.facebook.presto.ttl; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.facebook.presto.ttl.nodettlfetchermanagers.NodeTtlFetcherManagerConfig; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java b/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java index dfc25091c6ae0..e0727d18c2b35 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/AbstractTestType.java @@ -24,6 +24,7 @@ import com.facebook.presto.common.type.UnknownType; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FunctionsConfig; import com.google.common.collect.ImmutableMap; @@ -67,6 +68,7 @@ public abstract class AbstractTestType private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); protected static final FunctionAndTypeManager functionAndTypeManager = new FunctionAndTypeManager( createTestTransactionManager(), + new TableFunctionRegistry(), blockEncodingSerde, new FeaturesConfig(), new FunctionsConfig(), diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/TestBuiltInTypeRegistry.java b/presto-main-base/src/test/java/com/facebook/presto/type/TestBuiltInTypeRegistry.java index b5a4b8f49f3d5..0c626e5f0fc3d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/TestBuiltInTypeRegistry.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/TestBuiltInTypeRegistry.java @@ -13,11 +13,11 @@ */ package com.facebook.presto.type; -import com.facebook.presto.UnknownTypeException; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.OperatorNotFoundException; +import com.facebook.presto.spi.type.UnknownTypeException; import com.google.common.collect.ImmutableSet; import org.testng.annotations.Test; diff --git a/presto-main-base/src/test/java/com/facebook/presto/type/TestMapOperators.java b/presto-main-base/src/test/java/com/facebook/presto/type/TestMapOperators.java index 0690252130ce8..c00ad1b5baf86 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/type/TestMapOperators.java +++ b/presto-main-base/src/test/java/com/facebook/presto/type/TestMapOperators.java @@ -527,16 +527,13 @@ public void testJsonToMap() .put("k8", "[null]") .build()); - // These two tests verifies that partial json cast preserves input order - // The second test should never happen in real life because valid json in presto requires natural key ordering. - // However, it is added to make sure that the order in the first test is not a coincidence. assertFunction("CAST(JSON '{\"k1\": {\"1klmnopq\":1, \"2klmnopq\":2, \"3klmnopq\":3, \"4klmnopq\":4, \"5klmnopq\":5, \"6klmnopq\":6, \"7klmnopq\":7}}' AS MAP)", mapType(VARCHAR, JSON), ImmutableMap.of("k1", "{\"1klmnopq\":1,\"2klmnopq\":2,\"3klmnopq\":3,\"4klmnopq\":4,\"5klmnopq\":5,\"6klmnopq\":6,\"7klmnopq\":7}")); + assertFunction("CAST(unchecked_to_json('{\"k1\": {\"7klmnopq\":7, \"6klmnopq\":6, \"5klmnopq\":5, \"4klmnopq\":4, \"3klmnopq\":3, \"2klmnopq\":2, \"1klmnopq\":1}}') AS MAP)", mapType(VARCHAR, JSON), - ImmutableMap.of("k1", "{\"7klmnopq\":7,\"6klmnopq\":6,\"5klmnopq\":5,\"4klmnopq\":4,\"3klmnopq\":3,\"2klmnopq\":2,\"1klmnopq\":1}")); - + ImmutableMap.of("k1", "{\"1klmnopq\":1,\"2klmnopq\":2,\"3klmnopq\":3,\"4klmnopq\":4,\"5klmnopq\":5,\"6klmnopq\":6,\"7klmnopq\":7}")); // nested array/map assertFunction("CAST(JSON '{\"1\": [1, 2], \"2\": [3, null], \"3\": [], \"5\": [null, null], \"8\": null}' AS MAP>)", mapType(BIGINT, new ArrayType(BIGINT)), diff --git a/presto-main-base/src/test/java/com/facebook/presto/util/MockPowerOfTwo.java b/presto-main-base/src/test/java/com/facebook/presto/util/MockPowerOfTwo.java index 43c8a78e10979..a6f180e0d6504 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/util/MockPowerOfTwo.java +++ b/presto-main-base/src/test/java/com/facebook/presto/util/MockPowerOfTwo.java @@ -13,9 +13,8 @@ */ package com.facebook.presto.util; -import io.airlift.units.MinDataSize; - -import javax.validation.Payload; +import com.facebook.airlift.units.MinDataSize; +import jakarta.validation.Payload; import java.lang.annotation.Annotation; diff --git a/presto-main-base/src/test/java/com/facebook/presto/util/RetryAnalyzer.java b/presto-main-base/src/test/java/com/facebook/presto/util/RetryAnalyzer.java new file mode 100644 index 0000000000000..53ea4cd606309 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/util/RetryAnalyzer.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import org.testng.IRetryAnalyzer; +import org.testng.ITestResult; + +import java.lang.reflect.Method; + +import static com.facebook.presto.util.RetryCount.DEFAULT_RETRY_COUNT; + +public class RetryAnalyzer + implements IRetryAnalyzer +{ + private int retryCount; + private int maxRetryCount; + + @Override + public boolean retry(ITestResult result) + { + if (maxRetryCount == 0) { + Method method = result.getMethod().getConstructorOrMethod().getMethod(); + RetryCount annotation = method.getAnnotation(RetryCount.class); + maxRetryCount = (annotation != null) ? annotation.value() : DEFAULT_RETRY_COUNT; + } + return retryCount++ < maxRetryCount; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/util/RetryCount.java b/presto-main-base/src/test/java/com/facebook/presto/util/RetryCount.java new file mode 100644 index 0000000000000..47c12ed47402c --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/util/RetryCount.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.util; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +@Retention(RetentionPolicy.RUNTIME) +public @interface RetryCount +{ + int DEFAULT_RETRY_COUNT = 3; + int value() default DEFAULT_RETRY_COUNT; +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java b/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java index e04f3a67bdabb..c88db53d01a26 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/util/TestGraphvizPrinter.java @@ -200,6 +200,7 @@ private static PlanFragment createTestPlanFragment(int id, PlanNode node) SOURCE_DISTRIBUTION, ImmutableList.of(TEST_TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + Optional.empty(), ungroupedExecution(), false, Optional.of(StatsAndCosts.empty()), diff --git a/presto-main-tests/pom.xml b/presto-main-tests/pom.xml new file mode 100644 index 0000000000000..d624bf686fb21 --- /dev/null +++ b/presto-main-tests/pom.xml @@ -0,0 +1,91 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.297-edge10.1-SNAPSHOT + + + presto-main-tests + presto-main-tests + Presto Main Tests + + + ${project.parent.basedir} + true + 4g + + + + + com.google.guava + guava + + + + com.google.inject + guice + test + + + + com.facebook.airlift + log-manager + test + + + + com.facebook.airlift + json + test + + + + com.facebook.airlift + jaxrs + test + + + + com.facebook.airlift + node + test + + + + com.facebook.airlift + http-server + test + + + + com.facebook.airlift + bootstrap + test + + + + com.facebook.airlift + stats + test + + + + com.facebook.presto + presto-common + + + + org.testng + testng + + + + com.facebook.airlift + units + test + + + diff --git a/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArrayExcept.java b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArrayExcept.java new file mode 100644 index 0000000000000..c2ccbf8ecddbf --- /dev/null +++ b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArrayExcept.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.RowType; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public interface AbstractTestArrayExcept + extends TestFunctions +{ + @Test + default void testBasic() + { + assertFunction("array_except(ARRAY[1, 5, 3], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); + assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 3], ARRAY[5])", new ArrayType(BIGINT), ImmutableList.of(1L, 3L)); + assertFunction("array_except(ARRAY[CAST('x' as VARCHAR), 'y', 'z'], ARRAY['x'])", new ArrayType(VARCHAR), ImmutableList.of("y", "z")); + assertFunction("array_except(ARRAY[true, false, null], ARRAY[true])", new ArrayType(BOOLEAN), asList(false, null)); + assertFunction("array_except(ARRAY[1.1E0, 5.4E0, 3.9E0], ARRAY[5, 5.4E0])", new ArrayType(DOUBLE), ImmutableList.of(1.1, 3.9)); + } + + @Test + default void testDuplicates() + { + assertFunction("array_except(ARRAY[1, 5, 3, 5, 1], ARRAY[3])", new ArrayType(INTEGER), ImmutableList.of(1, 5)); + assertFunction("array_except(ARRAY[CAST(1 as BIGINT), 5, 5, 3, 3, 3, 1], ARRAY[3, 5])", new ArrayType(BIGINT), ImmutableList.of(1L)); + assertFunction("array_except(ARRAY[CAST('x' as VARCHAR), 'x', 'y', 'z'], ARRAY['x', 'y', 'x'])", new ArrayType(VARCHAR), ImmutableList.of("z")); + assertFunction("array_except(ARRAY[true, false, null, true, false, null], ARRAY[true, true, true])", new ArrayType(BOOLEAN), asList(false, null)); + } + + @Test + default void testIndeterminateRows() + { + // test unsupported + assertFunction( + "array_except(ARRAY[(123, 'abc'), (123, NULL)], ARRAY[(123, 'abc'), (123, NULL)])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of()); + assertFunction( + "array_except(ARRAY[(NULL, 'abc'), (123, null), (123, 'abc')], ARRAY[(456, 'def'),(NULL, 'abc')])", + new ArrayType(RowType.anonymous(ImmutableList.of(INTEGER, createVarcharType(3)))), + ImmutableList.of(asList(123, null), asList(123, "abc"))); + } + + @Test + default void testIndeterminateArrays() + { + assertFunction( + "array_except(ARRAY[ARRAY[123, 456], ARRAY[123, NULL]], ARRAY[ARRAY[123, 456], ARRAY[123, NULL]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of()); + assertFunction( + "array_except(ARRAY[ARRAY[NULL, 456], ARRAY[123, null], ARRAY[123, 456]], ARRAY[ARRAY[456, 456],ARRAY[NULL, 456]])", + new ArrayType(new ArrayType(INTEGER)), + ImmutableList.of(asList(123, null), asList(123, 456))); + } +} diff --git a/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArraySort.java b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArraySort.java new file mode 100644 index 0000000000000..9a3740d61ed27 --- /dev/null +++ b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/AbstractTestArraySort.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.tests.operator.scalar; + +import com.facebook.presto.common.type.ArrayType; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static java.util.Arrays.asList; + +public interface AbstractTestArraySort + extends TestFunctions +{ + @Test + default void testArraySort() + { + assertFunction("array_sort(ARRAY [5, 20, null, 5, 3, 50]) ", new ArrayType(INTEGER), + asList(3, 5, 5, 20, 50, null)); + assertFunction("array_sort(sequence(-4, 3))", new ArrayType(BIGINT), + asList(-4L, -3L, -2L, -1L, 0L, 1L, 2L, 3L)); + assertFunction("array_sort(reverse(sequence(-4, 3)))", new ArrayType(BIGINT), + asList(-4L, -3L, -2L, -1L, 0L, 1L, 2L, 3L)); + assertFunction("repeat(1,4)", new ArrayType(INTEGER), asList(1, 1, 1, 1)); + assertFunction("cast(array[] as array)", new ArrayType(INTEGER), asList()); + } + + @Test + default void testArraySortVarchar() + { + assertFunction("array_sort(array['x', 'a', 'a', 'a', 'a', 'm', 'j', 'p'])", + new ArrayType(createVarcharType(1)), ImmutableList.of("a", "a", "a", "a", "j", "m", "p", "x")); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadataUpdaterProvider.java b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/TestFunctions.java similarity index 51% rename from presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadataUpdaterProvider.java rename to presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/TestFunctions.java index e39cf0add1f1c..e72644201ef3b 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/connector/ConnectorMetadataUpdaterProvider.java +++ b/presto-main-tests/src/main/java/com/facebook/presto/tests/operator/scalar/TestFunctions.java @@ -11,21 +11,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.spi.connector; +package com.facebook.presto.tests.operator.scalar; -/** - * Every connector that support sending metadata updates to coordinator will - * have a SINGLETON provider instance that implements this interface. - *

    - * This provider will be used to create a new instance of ConnectorMetadataUpdater - * for every table writer operator in a Task. - */ -public interface ConnectorMetadataUpdaterProvider +import com.facebook.presto.common.type.Type; + +public interface TestFunctions { /** - * Create and return metadata updater that handles metadata requests/responses. - * - * @return metadata updater + * Asserts that the projection, representing a SQL expression comprising a scalar function call, returns the + * expected value of the expected type. + */ + void assertFunction(String projection, Type expectedType, Object expected); + + /** + * Asserts that the projection is not supported and that it fails with the expected error message. */ - ConnectorMetadataUpdater getMetadataUpdater(); + void assertNotSupported(String projection, String message); } diff --git a/presto-main/etc/config.properties b/presto-main/etc/config.properties index a120422c4b6f1..0ee18a1e4df0f 100644 --- a/presto-main/etc/config.properties +++ b/presto-main/etc/config.properties @@ -51,7 +51,8 @@ plugin.bundles=\ ../presto-node-ttl-fetchers/pom.xml,\ ../presto-hive-function-namespace/pom.xml,\ ../presto-delta/pom.xml,\ - ../presto-hudi/pom.xml + ../presto-hudi/pom.xml, \ + ../presto-sql-helpers/presto-sql-invoked-functions-plugin/pom.xml presto.version=testversion node-scheduler.include-coordinator=true diff --git a/presto-main/etc/jvm.config b/presto-main/etc/jvm.config index 311a1e242079d..1f7dfb2b10503 100644 --- a/presto-main/etc/jvm.config +++ b/presto-main/etc/jvm.config @@ -6,5 +6,20 @@ # # Required for Java 17 runtime -#--add-opens=java.base/java.lang=ALL-UNNAMED -#--add-opens=java.base/java.lang.reflect=ALL-UNNAMED +--add-opens=java.base/java.io=ALL-UNNAMED +--add-opens=java.base/java.lang=ALL-UNNAMED +--add-opens=java.base/java.lang.ref=ALL-UNNAMED +--add-opens=java.base/java.lang.reflect=ALL-UNNAMED +--add-opens=java.base/java.net=ALL-UNNAMED +--add-opens=java.base/java.nio=ALL-UNNAMED +--add-opens=java.base/java.security=ALL-UNNAMED +--add-opens=java.base/javax.security.auth=ALL-UNNAMED +--add-opens=java.base/javax.security.auth.login=ALL-UNNAMED +--add-opens=java.base/java.text=ALL-UNNAMED +--add-opens=java.base/java.util=ALL-UNNAMED +--add-opens=java.base/java.util.concurrent=ALL-UNNAMED +--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED +--add-opens=java.base/java.util.regex=ALL-UNNAMED +--add-opens=java.base/jdk.internal.loader=ALL-UNNAMED +--add-opens=java.base/sun.security.action=ALL-UNNAMED +--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED diff --git a/presto-main/etc/regex-map.txt b/presto-main/etc/regex-map.txt new file mode 100644 index 0000000000000..b61daa86bcc42 --- /dev/null +++ b/presto-main/etc/regex-map.txt @@ -0,0 +1,3 @@ +user=.* +internal=coordinator +admin=su.* \ No newline at end of file diff --git a/presto-main/pom.xml b/presto-main/pom.xml index afa7c07504b83..893d3743234ff 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-main @@ -13,6 +13,8 @@ ${project.parent.basedir} + 17 + true @@ -21,6 +23,11 @@ presto-main-base + + com.facebook.presto + presto-built-in-worker-function-tools + + com.facebook.airlift jmx-http @@ -51,6 +58,11 @@ presto-common + + com.facebook.presto + presto-function-namespace-managers-common + + io.jsonwebtoken jjwt-api @@ -67,22 +79,22 @@ - javax.ws.rs - javax.ws.rs-api + jakarta.ws.rs + jakarta.ws.rs-api - com.facebook.drift + com.facebook.airlift.drift drift-codec-utils - com.facebook.drift + com.facebook.airlift.drift drift-server - com.facebook.drift + com.facebook.airlift.drift drift-transport-spi @@ -92,8 +104,8 @@ - javax.annotation - javax.annotation-api + jakarta.annotation + jakarta.annotation-api @@ -137,102 +149,130 @@ - com.facebook.airlift.discovery + com.facebook.airlift discovery-server - com.facebook.drift + com.facebook.airlift.drift drift-codec + com.google.guava guava + - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations + - com.facebook.drift + com.facebook.airlift.drift drift-client + - com.facebook.drift + com.facebook.airlift.drift drift-transport-netty + + + io.netty + netty-buffer + + - com.facebook.drift + com.facebook.airlift.drift drift-api + com.facebook.airlift http-client + joda-time joda-time + com.facebook.airlift concurrent + io.airlift slice + com.facebook.presto presto-analyzer + - io.airlift + com.facebook.airlift units + com.facebook.presto presto-client + - com.facebook.drift + com.facebook.airlift.drift drift-protocol + com.facebook.airlift stats + com.facebook.presto presto-memory-context + - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api + - javax.servlet - javax.servlet-api + jakarta.servlet + jakarta.servlet-api + com.facebook.presto presto-spi + com.facebook.airlift trace-token + com.facebook.airlift http-server + com.facebook.airlift event + com.facebook.airlift bootstrap + com.facebook.presto presto-parser @@ -244,8 +284,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -258,6 +298,105 @@ fastutil + + io.projectreactor.netty + reactor-netty-core + + + + io.projectreactor.netty + reactor-netty-http + + + + io.projectreactor + reactor-core + 3.8.0-M2 + + + + org.reactivestreams + reactive-streams + 1.0.4 + + + + io.netty + netty-codec-http + + + + io.netty + netty-transport-classes-epoll + + + + io.netty + netty-handler + + + + javax.inject + javax.inject + + + + + io.netty + netty-transport-native-epoll + ${dep.netty.version} + linux-x86_64 + runtime + + + + + io.netty + netty-tcnative-boringssl-static + runtime + + + + org.apache.commons + commons-lang3 + + + + com.nimbusds + nimbus-jose-jwt + + + + com.nimbusds + oauth2-oidc-sdk + + + org.ow2.asm + asm + + + + + + com.fasterxml.jackson.core + jackson-core + + + + io.airlift + aircompressor + + + + net.jodah + failsafe + + + + net.minidev + json-smart + + com.facebook.presto @@ -311,11 +450,35 @@ io.netty netty-common - + com.squareup.okhttp3 mockwebserver test + + + + org.hamcrest + hamcrest-core + + + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + postgresql + test + + + + com.github.luben + zstd-jni @@ -343,6 +506,24 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + + com.facebook.airlift.drift:drift-protocol + io.netty:netty-buffer + javax.inject:javax.inject + com.squareup.okhttp3:okhttp + com.squareup.okhttp3:okhttp-urlconnection + + + com.squareup.okhttp3:okhttp + com.squareup.okhttp3:okhttp-urlconnection + + + diff --git a/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQuery.java b/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQuery.java index eb65385df75fa..1e964f597e52d 100644 --- a/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQuery.java +++ b/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQuery.java @@ -14,6 +14,7 @@ package com.facebook.presto.dispatcher; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.common.ErrorCode; import com.facebook.presto.event.QueryMonitor; @@ -34,7 +35,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -253,7 +253,7 @@ public DispatchInfo getDispatchInfo() return DispatchInfo.failed(failureInfo, queryInfo.getQueryStats().getElapsedTime(), queryInfo.getQueryStats().getWaitingForPrerequisitesTime(), queryInfo.getQueryStats().getQueuedTime()); } if (dispatched) { - return DispatchInfo.dispatched(new LocalCoordinatorLocation(), queryInfo.getQueryStats().getElapsedTime(), queryInfo.getQueryStats().getWaitingForPrerequisitesTime(), queryInfo.getQueryStats().getQueuedTime()); + return DispatchInfo.dispatched(queryInfo.getQueryStats().getElapsedTime(), queryInfo.getQueryStats().getWaitingForPrerequisitesTime(), queryInfo.getQueryStats().getQueuedTime()); } if (queryInfo.getState() == QUEUED) { return DispatchInfo.queued(queryInfo.getQueryStats().getElapsedTime(), queryInfo.getQueryStats().getWaitingForPrerequisitesTime(), queryInfo.getQueryStats().getQueuedTime()); @@ -279,6 +279,12 @@ public long getCreateTimeInMillis() return stateMachine.getCreateTimeInMillis(); } + @Override + public Duration getQueuedTime() + { + return stateMachine.getQueuedTime(); + } + @Override public long getExecutionStartTimeInMillis() { diff --git a/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQueryFactory.java b/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQueryFactory.java index b1340d306a82b..a5a4b671b10d4 100644 --- a/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQueryFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/dispatcher/LocalDispatchQueryFactory.java @@ -36,8 +36,7 @@ import com.facebook.presto.transaction.TransactionManager; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Optional; import java.util.function.Consumer; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java index fc28963d525ec..100da1a7bec06 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/MemoryRevokingScheduler.java @@ -29,10 +29,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; diff --git a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java index e0f089adde85c..92f1ba0fdef65 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/SqlQueryManager.java @@ -15,6 +15,8 @@ import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.ExceededCpuLimitException; import com.facebook.presto.ExceededIntermediateWrittenBytesException; import com.facebook.presto.ExceededOutputSizeLimitException; @@ -37,17 +39,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Flatten; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.util.List; import java.util.NoSuchElementException; import java.util.Objects; @@ -59,6 +58,7 @@ import java.util.function.Consumer; import static com.facebook.airlift.concurrent.Threads.threadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; import static com.facebook.presto.SystemSessionProperties.getQueryMaxCpuTime; import static com.facebook.presto.SystemSessionProperties.getQueryMaxOutputPositions; import static com.facebook.presto.SystemSessionProperties.getQueryMaxOutputSize; @@ -75,7 +75,6 @@ import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Futures.immediateFailedFuture; -import static io.airlift.units.DataSize.Unit.BYTE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -252,6 +251,13 @@ public QueryInfo getFullQueryInfo(QueryId queryId) return queryTracker.getQuery(queryId).getQueryInfo(); } + @Override + public long getDurationUntilExpirationInMillis(QueryId queryId) + throws NoSuchElementException + { + return queryTracker.getQuery(queryId).getDurationUntilExpirationInMillis(); + } + @Override public Session getQuerySession(QueryId queryId) throws NoSuchElementException diff --git a/presto-main/src/main/java/com/facebook/presto/failureDetector/FailureDetectorConfig.java b/presto-main/src/main/java/com/facebook/presto/failureDetector/FailureDetectorConfig.java index 4b349bdcfb25b..ddcfd14f8a28f 100644 --- a/presto-main/src/main/java/com/facebook/presto/failureDetector/FailureDetectorConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/failureDetector/FailureDetectorConfig.java @@ -15,13 +15,12 @@ import com.facebook.airlift.configuration.Config; import com.facebook.airlift.configuration.ConfigDescription; -import io.airlift.units.Duration; -import io.airlift.units.MinDuration; - -import javax.validation.constraints.DecimalMax; -import javax.validation.constraints.DecimalMin; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import jakarta.validation.constraints.DecimalMax; +import jakarta.validation.constraints.DecimalMin; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; import java.util.concurrent.TimeUnit; diff --git a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java index eed134f386bd7..f023979ca13b2 100644 --- a/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java +++ b/presto-main/src/main/java/com/facebook/presto/failureDetector/HeartbeatFailureDetector.java @@ -24,6 +24,7 @@ import com.facebook.airlift.node.NodeInfo; import com.facebook.airlift.stats.DecayCounter; import com.facebook.airlift.stats.ExponentialDecay; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.FailureInfo; import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.spi.HostAddress; @@ -32,17 +33,15 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.Nullable; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.net.ConnectException; import java.net.SocketTimeoutException; import java.net.URI; diff --git a/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java b/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java index ebb98a92f22c7..2134cd160e32f 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/ClusterMemoryManager.java @@ -18,6 +18,8 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.LocationFactory; import com.facebook.presto.execution.QueryExecution; import com.facebook.presto.execution.QueryIdGenerator; @@ -45,16 +47,13 @@ import com.google.common.collect.Maps; import com.google.common.collect.Streams; import com.google.common.io.Closer; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.JmxException; import org.weakref.jmx.MBeanExporter; import org.weakref.jmx.Managed; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; - import java.io.IOException; import java.util.ArrayList; import java.util.Comparator; @@ -72,6 +71,8 @@ import java.util.function.Supplier; import java.util.stream.Stream; +import static com.facebook.airlift.units.DataSize.succinctBytes; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.ExceededMemoryLimitException.exceededGlobalTotalLimit; import static com.facebook.presto.ExceededMemoryLimitException.exceededGlobalUserLimit; import static com.facebook.presto.SystemSessionProperties.RESOURCE_OVERCOMMIT; @@ -94,8 +95,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.MoreCollectors.toOptional; import static com.google.common.collect.Sets.difference; -import static io.airlift.units.DataSize.succinctBytes; -import static io.airlift.units.Duration.nanosSince; import static java.lang.Math.min; import static java.lang.String.format; import static java.util.AbstractMap.SimpleEntry; diff --git a/presto-main/src/main/java/com/facebook/presto/memory/HighMemoryTaskKiller.java b/presto-main/src/main/java/com/facebook/presto/memory/HighMemoryTaskKiller.java index 35558380087aa..c382bd7d28ac7 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/HighMemoryTaskKiller.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/HighMemoryTaskKiller.java @@ -15,6 +15,8 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.GarbageCollectionNotificationInfo; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.SqlTask; import com.facebook.presto.execution.SqlTaskManager; import com.facebook.presto.execution.TaskInfo; @@ -25,12 +27,10 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.collect.ListMultimap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; import javax.management.JMException; import javax.management.Notification; import javax.management.NotificationListener; diff --git a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryResource.java b/presto-main/src/main/java/com/facebook/presto/memory/MemoryResource.java similarity index 86% rename from presto-main-base/src/main/java/com/facebook/presto/memory/MemoryResource.java rename to presto-main/src/main/java/com/facebook/presto/memory/MemoryResource.java index 3ee64e711a299..8b1b6f27e71ca 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/memory/MemoryResource.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/MemoryResource.java @@ -15,24 +15,23 @@ import com.facebook.presto.execution.TaskManager; import com.facebook.presto.spi.memory.MemoryPoolInfo; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.core.Response; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Response; import static com.facebook.presto.PrestoMediaTypes.APPLICATION_JACKSON_SMILE; import static com.facebook.presto.memory.LocalMemoryManager.GENERAL_POOL; import static com.facebook.presto.memory.LocalMemoryManager.RESERVED_POOL; import static com.facebook.presto.server.security.RoleType.INTERNAL; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; /** * Manages memory pools on this worker node diff --git a/presto-main/src/main/java/com/facebook/presto/memory/RemoteNodeMemory.java b/presto-main/src/main/java/com/facebook/presto/memory/RemoteNodeMemory.java index 7dc7de3f9c2f3..4167aee6f845a 100644 --- a/presto-main/src/main/java/com/facebook/presto/memory/RemoteNodeMemory.java +++ b/presto-main/src/main/java/com/facebook/presto/memory/RemoteNodeMemory.java @@ -22,14 +22,13 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.server.smile.BaseResponse; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.Optional; @@ -40,12 +39,12 @@ import static com.facebook.airlift.http.client.HttpStatus.OK; import static com.facebook.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; import static com.facebook.airlift.http.client.Request.Builder.preparePost; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders; import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; import static com.facebook.presto.server.smile.SmileBodyGenerator.smileBodyGenerator; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.Duration.nanosSince; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java index 5174739471910..9797f3f4fdf2e 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/DiscoveryNodeManager.java @@ -27,8 +27,10 @@ import com.facebook.presto.server.InternalCommunicationConfig.CommunicationProtocol; import com.facebook.presto.server.thrift.ThriftServerInfoClient; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.NodeLoadMetrics; import com.facebook.presto.spi.NodePoolType; import com.facebook.presto.spi.NodeState; +import com.facebook.presto.spi.NodeStats; import com.facebook.presto.statusservice.NodeStatusService; import com.google.common.base.Splitter; import com.google.common.collect.HashMultimap; @@ -40,14 +42,13 @@ import com.google.common.collect.SetMultimap; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; - import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; @@ -94,7 +95,7 @@ public final class DiscoveryNodeManager private final FailureDetector failureDetector; private final Optional nodeStatusService; private final NodeVersion expectedNodeVersion; - private final ConcurrentHashMap nodeStates = new ConcurrentHashMap<>(); + private final ConcurrentHashMap nodeStats = new ConcurrentHashMap<>(); private final HttpClient httpClient; private final DriftClient driftClient; private final ScheduledExecutorService nodeStateUpdateExecutor; @@ -103,6 +104,7 @@ public final class DiscoveryNodeManager private final InternalNode currentNode; private final CommunicationProtocol protocol; private final boolean isMemoizeDeadNodesEnabled; + private final InternalCommunicationConfig internalCommunicationConfig; @GuardedBy("this") private SetMultimap activeNodesByConnectorId; @@ -154,6 +156,7 @@ public DiscoveryNodeManager( this.nodeStateUpdateExecutor = newSingleThreadScheduledExecutor(threadsNamed("node-state-poller-%s")); this.nodeStateEventExecutor = newCachedThreadPool(threadsNamed("node-state-events-%s")); this.httpsRequired = internalCommunicationConfig.isHttpsRequired(); + this.internalCommunicationConfig = requireNonNull(internalCommunicationConfig, "internalCommunicationConfig is null"); this.currentNode = findCurrentNode( serviceSelector.selectAllServices(), @@ -212,6 +215,7 @@ private static NodePoolType getPoolType(ServiceDescriptor service) @PostConstruct public void startPollingNodeStates() { + long pollingIntervalMillis = internalCommunicationConfig.getNodeDiscoveryPollingIntervalMillis(); nodeStateUpdateExecutor.scheduleWithFixedDelay(() -> { try { pollWorkers(); @@ -219,7 +223,7 @@ public void startPollingNodeStates() catch (Exception e) { log.error(e, "Error polling state of nodes"); } - }, 5, 5, TimeUnit.SECONDS); + }, pollingIntervalMillis, pollingIntervalMillis, TimeUnit.MILLISECONDS); pollWorkers(); } @@ -237,20 +241,20 @@ private void pollWorkers() // Remove nodes that don't exist anymore // Make a copy to materialize the set difference - Set deadNodes = difference(nodeStates.keySet(), aliveNodeIds).immutableCopy(); - nodeStates.keySet().removeAll(deadNodes); + Set deadNodes = difference(nodeStats.keySet(), aliveNodeIds).immutableCopy(); + nodeStats.keySet().removeAll(deadNodes); // Add new nodes for (InternalNode node : aliveNodes) { switch (protocol) { case HTTP: - nodeStates.putIfAbsent(node.getNodeIdentifier(), - new HttpRemoteNodeState(httpClient, uriBuilderFrom(node.getInternalUri()).appendPath("/v1/info/state").build())); + nodeStats.putIfAbsent(node.getNodeIdentifier(), + new HttpRemoteNodeStats(httpClient, uriBuilderFrom(node.getInternalUri()).appendPath("/v1/info/stats").build(), internalCommunicationConfig.getNodeStatsRefreshIntervalMillis())); break; case THRIFT: if (node.getThriftPort().isPresent()) { - nodeStates.put(node.getNodeIdentifier(), - new ThriftRemoteNodeState(driftClient, uriBuilderFrom(node.getInternalUri()).scheme("thrift").port(node.getThriftPort().getAsInt()).build())); + nodeStats.put(node.getNodeIdentifier(), + new ThriftRemoteNodeStats(driftClient, uriBuilderFrom(node.getInternalUri()).scheme("thrift").port(node.getThriftPort().getAsInt()).build(), internalCommunicationConfig.getNodeStatsRefreshIntervalMillis())); } else { // thrift port has not yet been populated; ignore the node for now @@ -260,7 +264,7 @@ private void pollWorkers() } // Schedule refresh - nodeStates.values().forEach(RemoteNodeState::asyncRefresh); + nodeStats.values().forEach(RemoteNodeStats::asyncRefresh); // update indexes refreshNodesInternal(); @@ -439,10 +443,19 @@ private NodeState getNodeState(InternalNode node) private boolean isNodeShuttingDown(String nodeId) { - Optional remoteNodeState = nodeStates.containsKey(nodeId) - ? nodeStates.get(nodeId).getNodeState() + Optional remoteNodeStats = nodeStats.containsKey(nodeId) + ? nodeStats.get(nodeId).getNodeStats() + : Optional.empty(); + return remoteNodeStats.isPresent() && remoteNodeStats.get().getNodeState() == SHUTTING_DOWN; + } + + @Override + public Optional getNodeLoadMetrics(String nodeId) + { + Optional remoteNodeStats = nodeStats.containsKey(nodeId) + ? nodeStats.get(nodeId).getNodeStats() : Optional.empty(); - return remoteNodeState.isPresent() && remoteNodeState.get() == SHUTTING_DOWN; + return remoteNodeStats.flatMap(NodeStats::getLoadMetrics); } @Override @@ -630,20 +643,16 @@ private static boolean isCoordinatorSidecar(ServiceDescriptor service) * Resource Manager -> All Nodes * Catalog Server -> All Nodes * Worker -> Resource Managers or Catalog Servers + * Sidecar -> Resource Managers or Catalog Servers * * @return Predicate to filter Service Descriptor for Nodes */ private Predicate filterRelevantNodes() { - if (currentNode.isCoordinator() || currentNode.isResourceManager() || currentNode.isCatalogServer() || currentNode.isCoordinatorSidecar()) { - // Allowing coordinator node in the list of services, even if it's not allowed by nodeStatusService with currentNode check - return service -> - !nodeStatusService.isPresent() - || nodeStatusService.get().isAllowed(service.getLocation()) - || isCatalogServer(service) - || isCoordinatorSidecar(service); + if (currentNode.isCoordinator() || currentNode.isResourceManager() || currentNode.isCatalogServer()) { + return service -> !nodeStatusService.isPresent() || nodeStatusService.get().isAllowed(service.getLocation()); } - return service -> isResourceManager(service) || isCatalogServer(service) || isCoordinatorSidecar(service); + return service -> isResourceManager(service) || isCatalogServer(service); } } diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeState.java b/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeState.java index 82eedad9f0a6e..8276fa0976895 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeState.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeState.java @@ -18,13 +18,12 @@ import com.facebook.airlift.http.client.HttpClient.HttpResponseFuture; import com.facebook.airlift.http.client.Request; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.spi.NodeState; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; -import io.airlift.units.Duration; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; import java.net.URI; import java.util.Optional; @@ -36,12 +35,12 @@ import static com.facebook.airlift.http.client.HttpStatus.OK; import static com.facebook.airlift.http.client.Request.Builder.prepareGet; import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.google.common.net.MediaType.JSON_UTF_8; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.Duration.nanosSince; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; @ThreadSafe public class HttpRemoteNodeState diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeStats.java b/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeStats.java new file mode 100644 index 0000000000000..5f99a378c4d64 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/metadata/HttpRemoteNodeStats.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.airlift.http.client.FullJsonResponseHandler.JsonResponse; +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.HttpClient.HttpResponseFuture; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.spi.NodeStats; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.errorprone.annotations.ThreadSafe; +import jakarta.annotation.Nullable; + +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static com.facebook.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; +import static com.facebook.airlift.http.client.HttpStatus.OK; +import static com.facebook.airlift.http.client.Request.Builder.prepareGet; +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.airlift.units.Duration.nanosSince; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +@ThreadSafe +public class HttpRemoteNodeStats + implements RemoteNodeStats +{ + private static final Logger log = Logger.get(HttpRemoteNodeStats.class); + + private final HttpClient httpClient; + private final URI stateInfoUri; + private final long refreshIntervalMillis; + private final AtomicReference> nodeStats = new AtomicReference<>(Optional.empty()); + private final AtomicReference> future = new AtomicReference<>(); + private final AtomicLong lastUpdateNanos = new AtomicLong(); + private final AtomicLong lastWarningLogged = new AtomicLong(); + + public HttpRemoteNodeStats(HttpClient httpClient, URI stateInfoUri, long refreshIntervalMillis) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.stateInfoUri = requireNonNull(stateInfoUri, "stateInfoUri is null"); + this.refreshIntervalMillis = refreshIntervalMillis; + } + + @Override + public Optional getNodeStats() + { + return nodeStats.get(); + } + + @Override + public synchronized void asyncRefresh() + { + Duration sinceUpdate = nanosSince(lastUpdateNanos.get()); + if (nanosSince(lastWarningLogged.get()).toMillis() > 1_000 && + sinceUpdate.toMillis() > 10_000 && + future.get() != null) { + log.warn("Node state update request to %s has not returned in %s", stateInfoUri, sinceUpdate.toString(SECONDS)); + lastWarningLogged.set(System.nanoTime()); + } + if (sinceUpdate.toMillis() > refreshIntervalMillis && future.get() == null) { + Request request = prepareGet() + .setUri(stateInfoUri) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .build(); + HttpResponseFuture> responseFuture = httpClient.executeAsync(request, createFullJsonResponseHandler(jsonCodec(NodeStats.class))); + future.compareAndSet(null, responseFuture); + + Futures.addCallback(responseFuture, new FutureCallback>() + { + @Override + public void onSuccess(@Nullable JsonResponse result) + { + lastUpdateNanos.set(System.nanoTime()); + future.compareAndSet(responseFuture, null); + if (result != null) { + if (result.hasValue()) { + nodeStats.set(Optional.ofNullable(result.getValue())); + } + if (result.getStatusCode() != OK.code()) { + log.warn("Error fetching node stats from %s returned status %d", stateInfoUri, result.getStatusCode()); + return; + } + } + else { + log.warn("Node statistics endpoint %s returned null response, using cached statistics", stateInfoUri); + } + } + + @Override + public void onFailure(Throwable t) + { + log.error("Error fetching node stats from %s: %s", stateInfoUri, t.getMessage()); + lastUpdateNanos.set(System.nanoTime()); + future.compareAndSet(responseFuture, null); + } + }, directExecutor()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/operator/HttpRpcShuffleClient.java b/presto-main/src/main/java/com/facebook/presto/operator/HttpRpcShuffleClient.java index 86eda2434c247..0f3f772c3d6f4 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/HttpRpcShuffleClient.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/HttpRpcShuffleClient.java @@ -20,17 +20,16 @@ import com.facebook.airlift.http.client.ResponseHandler; import com.facebook.airlift.http.client.ResponseTooLargeException; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.operator.PageBufferClient.PagesResponse; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.page.SerializedPage; import com.google.common.collect.ImmutableList; import com.google.common.net.MediaType; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.InputStreamSliceInput; import io.airlift.slice.SliceInput; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.ThreadSafe; import java.io.BufferedReader; import java.io.IOException; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/ThriftRpcShuffleClient.java b/presto-main/src/main/java/com/facebook/presto/operator/ThriftRpcShuffleClient.java index d06004b3dd891..794171111ee8a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/ThriftRpcShuffleClient.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/ThriftRpcShuffleClient.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator; +import com.facebook.airlift.units.DataSize; import com.facebook.drift.client.DriftClient; import com.facebook.drift.transport.client.MessageTooLargeException; import com.facebook.presto.execution.TaskId; @@ -22,9 +23,7 @@ import com.facebook.presto.server.thrift.ThriftTaskClient; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.net.URI; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedClusterStatsResource.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedClusterStatsResource.java index 6d158b610de20..e4de63b1eea3b 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedClusterStatsResource.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedClusterStatsResource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.QueryState; import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.metadata.InternalNodeManager; @@ -22,15 +23,13 @@ import com.facebook.presto.spi.NodeState; import com.facebook.presto.spi.memory.MemoryPoolId; import com.facebook.presto.spi.memory.MemoryPoolInfo; -import io.airlift.units.Duration; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import java.util.Map; import java.util.function.Supplier; @@ -51,6 +50,7 @@ public class DistributedClusterStatsResource private final ResourceManagerClusterStateProvider clusterStateProvider; private final InternalNodeManager internalNodeManager; private final Supplier clusterStatsSupplier; + private final String clusterTag; @Inject public DistributedClusterStatsResource( @@ -62,7 +62,9 @@ public DistributedClusterStatsResource( this.isIncludeCoordinator = requireNonNull(nodeSchedulerConfig, "nodeSchedulerConfig is null").isIncludeCoordinator(); this.clusterStateProvider = requireNonNull(clusterStateProvider, "nodeStateManager is null"); this.internalNodeManager = requireNonNull(internalNodeManager, "internalNodeManager is null"); - Duration expirationDuration = requireNonNull(serverConfig, "serverConfig is null").getClusterStatsExpirationDuration(); + ServerConfig config = requireNonNull(serverConfig, "serverConfig is null"); + this.clusterTag = config.getClusterTag(); + Duration expirationDuration = config.getClusterStatsExpirationDuration(); this.clusterStatsSupplier = expirationDuration.getValue() > 0 ? memoizeWithExpiration(this::calculateClusterStats, expirationDuration.toMillis(), MILLISECONDS) : this::calculateClusterStats; } @@ -127,7 +129,8 @@ else if (query.getState() == QueryState.RUNNING) { totalInputRows, totalInputBytes, totalCpuTimeSecs, - clusterStateProvider.getAdjustedQueueSize()); + clusterStateProvider.getAdjustedQueueSize(), + clusterTag); } @GET diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryInfoResource.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryInfoResource.java index 92fea39a12802..7cfe73f592171 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryInfoResource.java @@ -26,22 +26,21 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.io.IOException; import java.net.URI; @@ -56,8 +55,8 @@ import static com.facebook.airlift.http.client.Request.Builder.prepareGet; import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.facebook.presto.server.security.RoleType.USER; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/v1/queryState") @RolesAllowed({USER, ADMIN}) diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryResource.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryResource.java index e16df0c404ba8..d6e9d94c87143 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryResource.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedQueryResource.java @@ -18,23 +18,22 @@ import com.facebook.presto.spi.QueryId; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DELETE; -import javax.ws.rs.GET; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.util.ArrayList; import java.util.Comparator; @@ -46,11 +45,11 @@ import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.facebook.presto.server.security.RoleType.USER; import static com.google.common.base.MoreObjects.firstNonNull; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/v1/query") @RolesAllowed({USER, ADMIN}) diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedResourceGroupInfoResource.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedResourceGroupInfoResource.java index 46ec732b7e1bf..c4f017f2ad5cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedResourceGroupInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedResourceGroupInfoResource.java @@ -27,21 +27,20 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.Encoded; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.Encoded; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.io.IOException; import java.net.URI; @@ -57,8 +56,8 @@ import static com.facebook.airlift.http.client.Request.Builder.prepareGet; import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.google.common.base.Strings.isNullOrEmpty; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/v1/resourceGroupState") @RolesAllowed(ADMIN) diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedTaskInfoResource.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedTaskInfoResource.java index 23bcf0f8b8619..fe9cd102d5a33 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedTaskInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/DistributedTaskInfoResource.java @@ -16,29 +16,28 @@ import com.facebook.presto.execution.TaskId; import com.facebook.presto.server.BasicQueryInfo; import com.facebook.presto.spi.QueryId; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.util.Optional; import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.facebook.presto.server.security.RoleType.USER; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/v1/taskInfo") @RolesAllowed({USER, ADMIN}) diff --git a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java similarity index 98% rename from presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java rename to presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java index 08b7ab06d82d4..bed088a412cf5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerClusterStatusSender.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.Duration; import com.facebook.drift.client.DriftClient; import com.facebook.presto.execution.ManagedQueryExecution; import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; @@ -24,11 +25,9 @@ import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.QueryId; import com.facebook.presto.util.PeriodicTaskExecutor; -import io.airlift.units.Duration; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import java.util.List; import java.util.Map; diff --git a/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerProxy.java b/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerProxy.java index d04197fc395c8..aa915fe1fb998 100644 --- a/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerProxy.java +++ b/presto-main/src/main/java/com/facebook/presto/resourcemanager/ResourceManagerProxy.java @@ -17,18 +17,17 @@ import com.facebook.airlift.http.client.HeaderName; import com.facebook.airlift.http.client.HttpClient; import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.units.Duration; import com.facebook.presto.metadata.InternalNodeManager; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; -import io.airlift.units.Duration; - -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.core.Response; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.core.Response; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -48,14 +47,14 @@ import static com.google.common.net.HttpHeaders.USER_AGENT; import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR; import static com.google.common.util.concurrent.Futures.transform; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; +import static jakarta.ws.rs.core.Response.Status.GATEWAY_TIMEOUT; +import static jakarta.ws.rs.core.Response.status; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Collections.list; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN; -import static javax.ws.rs.core.Response.Status.GATEWAY_TIMEOUT; -import static javax.ws.rs.core.Response.status; @SuppressWarnings("UnstableApiUsage") public class ResourceManagerProxy diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java similarity index 90% rename from presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java rename to presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java index 3e307e91152f8..47249f5c9b090 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportForwardFilter.java @@ -13,13 +13,13 @@ */ package com.facebook.presto.server; -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.FilterConfig; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.regex.Pattern; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java similarity index 94% rename from presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java rename to presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java index ba47ea184508c..d904d725e4621 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java +++ b/presto-main/src/main/java/com/facebook/presto/server/AsyncPageTransportServlet.java @@ -16,6 +16,8 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskManager; import com.facebook.presto.execution.buffer.BufferInfo; @@ -27,21 +29,18 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.AsyncContext; -import javax.servlet.AsyncEvent; -import javax.servlet.AsyncListener; -import javax.servlet.ServletOutputStream; -import javax.servlet.http.HttpServlet; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import java.io.IOException; import java.util.Enumeration; import java.util.List; @@ -64,12 +63,12 @@ import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static jakarta.servlet.http.HttpServletResponse.SC_BAD_REQUEST; +import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; +import static jakarta.servlet.http.HttpServletResponse.SC_NO_CONTENT; import static java.lang.Long.parseLong; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static javax.servlet.http.HttpServletResponse.SC_BAD_REQUEST; -import static javax.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR; -import static javax.servlet.http.HttpServletResponse.SC_NO_CONTENT; @RolesAllowed(INTERNAL) public class AsyncPageTransportServlet diff --git a/presto-main/src/main/java/com/facebook/presto/server/CatalogServerModule.java b/presto-main/src/main/java/com/facebook/presto/server/CatalogServerModule.java index 170617d2c00e7..3149443421ec5 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/CatalogServerModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/CatalogServerModule.java @@ -29,8 +29,7 @@ import com.google.inject.Binder; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; diff --git a/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java b/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java index b9e0d59afd52d..c703212f79c89 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ClusterStatsResource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.Duration; import com.facebook.drift.annotations.ThriftConstructor; import com.facebook.drift.annotations.ThriftField; import com.facebook.drift.annotations.ThriftStruct; @@ -28,23 +29,21 @@ import com.facebook.presto.ttl.clusterttlprovidermanagers.ClusterTtlProviderManager; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import io.airlift.units.Duration; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.util.Iterator; @@ -59,10 +58,10 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Suppliers.memoizeWithExpiration; import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static jakarta.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static javax.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; @Path("/v1/cluster") @RolesAllowed({ADMIN, USER}) @@ -77,6 +76,7 @@ public class ClusterStatsResource private final InternalResourceGroupManager internalResourceGroupManager; private final ClusterTtlProviderManager clusterTtlProviderManager; private final Supplier clusterStatsSupplier; + private final String clusterTag; @Inject public ClusterStatsResource( @@ -97,7 +97,9 @@ public ClusterStatsResource( this.proxyHelper = requireNonNull(proxyHelper, "internalNodeManager is null"); this.internalResourceGroupManager = requireNonNull(internalResourceGroupManager, "internalResourceGroupManager is null"); this.clusterTtlProviderManager = requireNonNull(clusterTtlProviderManager, "clusterTtlProvider is null"); - Duration expirationDuration = requireNonNull(serverConfig, "serverConfig is null").getClusterStatsExpirationDuration(); + ServerConfig config = requireNonNull(serverConfig, "serverConfig is null"); + this.clusterTag = config.getClusterTag(); + Duration expirationDuration = config.getClusterStatsExpirationDuration(); this.clusterStatsSupplier = expirationDuration.getValue() > 0 ? memoizeWithExpiration(this::calculateClusterStats, expirationDuration.toMillis(), MILLISECONDS) : this::calculateClusterStats; } @@ -171,7 +173,8 @@ else if (query.getState() == QueryState.RUNNING) { totalInputRows, totalInputBytes, totalCpuTimeSecs, - internalResourceGroupManager.getQueriesQueuedOnInternal()); + internalResourceGroupManager.getQueriesQueuedOnInternal(), + clusterTag); } @GET @@ -239,6 +242,8 @@ public static class ClusterStats private final long totalCpuTimeSecs; private final long adjustedQueueSize; + private final String clusterTag; + @JsonCreator @ThriftConstructor public ClusterStats( @@ -252,7 +257,8 @@ public ClusterStats( @JsonProperty("totalInputRows") long totalInputRows, @JsonProperty("totalInputBytes") long totalInputBytes, @JsonProperty("totalCpuTimeSecs") long totalCpuTimeSecs, - @JsonProperty("adjustedQueueSize") long adjustedQueueSize) + @JsonProperty("adjustedQueueSize") long adjustedQueueSize, + @JsonProperty("clusterTag") String clusterTag) { this.runningQueries = runningQueries; this.blockedQueries = blockedQueries; @@ -265,6 +271,7 @@ public ClusterStats( this.totalInputBytes = totalInputBytes; this.totalCpuTimeSecs = totalCpuTimeSecs; this.adjustedQueueSize = adjustedQueueSize; + this.clusterTag = clusterTag; } @JsonProperty @@ -343,5 +350,12 @@ public long getAdjustedQueueSize() { return adjustedQueueSize; } + + @JsonProperty + @ThriftField(12) + public String getClusterTag() + { + return clusterTag; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java index a3e50ab1fc588..c201aa36d7771 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/CoordinatorModule.java @@ -16,7 +16,9 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.discovery.server.EmbeddedDiscoveryModule; +import com.facebook.airlift.http.client.HttpClient; import com.facebook.airlift.http.server.HttpServerBinder.HttpResourceBinding; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.QueryResults; import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.CostCalculator.EstimatedExchanges; @@ -82,7 +84,11 @@ import com.facebook.presto.server.protocol.QueryBlockingRateLimiter; import com.facebook.presto.server.protocol.QueuedStatementResource; import com.facebook.presto.server.protocol.RetryCircuitBreaker; +import com.facebook.presto.server.remotetask.HttpClientConnectionPoolStats; +import com.facebook.presto.server.remotetask.HttpClientStats; import com.facebook.presto.server.remotetask.HttpRemoteTaskFactory; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClient; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig; import com.facebook.presto.server.remotetask.RemoteTaskStats; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.security.SelectedRole; @@ -101,11 +107,9 @@ import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.multibindings.MapBinder; -import io.airlift.units.Duration; - -import javax.annotation.PreDestroy; -import javax.inject.Inject; -import javax.inject.Singleton; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; import java.util.List; import java.util.concurrent.ExecutorService; @@ -142,7 +146,7 @@ public class CoordinatorModule { private static final String DEFAULT_WEBUI_CSP = "default-src 'self'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " + - "font-src 'self' https://fonts.gstatic.com; frame-ancestors 'self'; img-src http: https: data:"; + "font-src 'self' https://fonts.gstatic.com; frame-ancestors 'self'; img-src 'self' data:; form-action 'self'"; public static HttpResourceBinding webUIBinder(Binder binder, String path, String classPathResourceBase) { @@ -208,6 +212,10 @@ protected void setup(Binder binder) binder.bind(QueryBlockingRateLimiter.class).in(Scopes.SINGLETON); newExporter(binder).export(QueryBlockingRateLimiter.class).withGeneratedName(); + // retry configuration + configBinder(binder).bindConfig(RetryConfig.class); + binder.bind(RetryUrlValidator.class).in(Scopes.SINGLETON); + binder.bind(LocalQueryProvider.class).in(Scopes.SINGLETON); binder.bind(ExecutingQueryResponseProvider.class).to(LocalExecutingQueryResponseProvider.class).in(Scopes.SINGLETON); @@ -268,13 +276,24 @@ protected void setup(Binder binder) binder.bind(RemoteTaskStats.class).in(Scopes.SINGLETON); newExporter(binder).export(RemoteTaskStats.class).withGeneratedName(); - httpClientBinder(binder).bindHttpClient("scheduler", ForScheduler.class) - .withTracing() - .withFilter(GenerateTraceTokenRequestFilter.class) - .withConfigDefaults(config -> { - config.setRequestTimeout(new Duration(10, SECONDS)); - config.setMaxConnectionsPerServer(250); - }); + ReactorNettyHttpClientConfig reactorNettyHttpClientConfig = buildConfigObject(ReactorNettyHttpClientConfig.class); + if (reactorNettyHttpClientConfig.isReactorNettyHttpClientEnabled()) { + binder.bind(ReactorNettyHttpClient.class).in(Scopes.SINGLETON); + binder.bind(HttpClientStats.class).in(Scopes.SINGLETON); + newExporter(binder).export(HttpClientStats.class).withGeneratedName(); + binder.bind(HttpClientConnectionPoolStats.class).in(Scopes.SINGLETON); + newExporter(binder).export(HttpClientConnectionPoolStats.class).withGeneratedName(); + binder.bind(HttpClient.class).annotatedWith(ForScheduler.class).to(ReactorNettyHttpClient.class); + } + else { + httpClientBinder(binder).bindHttpClient("scheduler", ForScheduler.class) + .withTracing() + .withFilter(GenerateTraceTokenRequestFilter.class) + .withConfigDefaults(config -> { + config.setRequestTimeout(new Duration(10, SECONDS)); + config.setMaxConnectionsPerServer(250); + }); + } binder.bind(ScheduledExecutorService.class).annotatedWith(ForScheduler.class) .toInstance(newSingleThreadScheduledExecutor(threadsNamed("stage-scheduler"))); diff --git a/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java b/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java index 6fc8b9efbd667..2826ff4a4bf41 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java +++ b/presto-main/src/main/java/com/facebook/presto/server/HttpRequestSessionContext.java @@ -14,6 +14,8 @@ package com.facebook.presto.server; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session.ResourceEstimateBuilder; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.transaction.TransactionId; @@ -37,14 +39,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.Status; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.Status; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; @@ -95,11 +94,13 @@ public final class HttpRequestSessionContext private static final Splitter DOT_SPLITTER = Splitter.on('.'); private static final JsonCodec SQL_FUNCTION_ID_JSON_CODEC = jsonCodec(SqlFunctionId.class); private static final JsonCodec SQL_INVOKED_FUNCTION_JSON_CODEC = jsonCodec(SqlInvokedFunction.class); - private static final String X509_ATTRIBUTE = "javax.servlet.request.X509Certificate"; + private static final String X509_ATTRIBUTE = "jakarta.servlet.request.X509Certificate"; private final String catalog; private final String schema; + private final String sqlText; + private final Identity identity; private final Optional authorizedIdentity; private final List certificates; @@ -129,7 +130,7 @@ public final class HttpRequestSessionContext public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOptions sqlParserOptions) { - this(servletRequest, sqlParserOptions, NoopTracerProvider.NOOP_TRACER_PROVIDER, Optional.empty()); + this(servletRequest, sqlParserOptions, NoopTracerProvider.NOOP_TRACER_PROVIDER, Optional.empty(), ""); } /** @@ -139,25 +140,21 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOpt * @param sessionPropertyManager is used to provide with some default session values. In some scenarios we need * those default values even before session for a query is created. This is how we can get it at this * session context creation stage. + * @param sqlText query string * @throws WebApplicationException */ - public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOptions sqlParserOptions, TracerProvider tracerProvider, Optional sessionPropertyManager) + public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOptions sqlParserOptions, TracerProvider tracerProvider, Optional sessionPropertyManager, String sqlText) throws WebApplicationException { catalog = trimEmptyToNull(servletRequest.getHeader(PRESTO_CATALOG)); schema = trimEmptyToNull(servletRequest.getHeader(PRESTO_SCHEMA)); + this.sqlText = requireNonNull(sqlText, "sqlText is null"); + assertRequest((catalog != null) || (schema == null), "Schema is set but catalog is not"); String user = trimEmptyToNull(servletRequest.getHeader(PRESTO_USER)); assertRequest(user != null, "User must be set"); - identity = new Identity( - user, - Optional.ofNullable(servletRequest.getUserPrincipal()), - parseRoleHeaders(servletRequest), - parseExtraCredentials(servletRequest), - ImmutableMap.of(), - Optional.empty(), - Optional.empty()); + authorizedIdentity = authorizedIdentity(servletRequest); X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE); @@ -168,6 +165,16 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOpt certificates = ImmutableList.of(); } + identity = new Identity( + user, + Optional.ofNullable(servletRequest.getUserPrincipal()), + parseRoleHeaders(servletRequest), + parseExtraCredentials(servletRequest), + ImmutableMap.of(), + Optional.empty(), + Optional.empty(), + certificates); + source = servletRequest.getHeader(PRESTO_SOURCE); userAgent = servletRequest.getHeader(USER_AGENT); remoteUserAddress = !isNullOrEmpty(servletRequest.getHeader(X_FORWARDED_FOR)) ? servletRequest.getHeader(X_FORWARDED_FOR) : servletRequest.getRemoteAddr(); @@ -432,6 +439,12 @@ public String getSchema() return schema; } + @Override + public String getSqlText() + { + return sqlText; + } + @Override public String getSource() { diff --git a/presto-main/src/main/java/com/facebook/presto/server/HttpServerModule.java b/presto-main/src/main/java/com/facebook/presto/server/HttpServerModule.java index fff5e5b2446f4..02a47b2ae0ec6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/HttpServerModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/HttpServerModule.java @@ -31,9 +31,8 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.servlet.Filter; -import javax.servlet.Servlet; +import jakarta.servlet.Filter; +import jakarta.servlet.Servlet; import java.util.List; import java.util.Set; diff --git a/presto-main/src/main/java/com/facebook/presto/server/InternalAuthenticationManager.java b/presto-main/src/main/java/com/facebook/presto/server/InternalAuthenticationManager.java index 9bbc42cc4613e..e98a24cf08a39 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/InternalAuthenticationManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/InternalAuthenticationManager.java @@ -22,9 +22,8 @@ import io.jsonwebtoken.JwtException; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; - -import javax.inject.Inject; -import javax.ws.rs.container.ContainerRequestContext; +import jakarta.inject.Inject; +import jakarta.ws.rs.container.ContainerRequestContext; import java.security.Principal; import java.time.ZonedDateTime; diff --git a/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationModule.java b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationModule.java index c6ac179c2d85f..fc0ef9ea8abe3 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/InternalCommunicationModule.java @@ -16,6 +16,7 @@ import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.http.client.HttpClientConfig; import com.facebook.airlift.http.client.spnego.KerberosConfig; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig; import com.facebook.presto.server.security.InternalAuthenticationFilter; import com.google.inject.Binder; import com.google.inject.Module; @@ -53,6 +54,16 @@ protected void setup(Binder binder) } }); + configBinder(binder).bindConfigGlobalDefaults(ReactorNettyHttpClientConfig.class, config -> { + config.setHttpsEnabled(internalCommunicationConfig.isHttpsRequired()); + config.setKeyStorePath(internalCommunicationConfig.getKeyStorePath()); + config.setKeyStorePassword(internalCommunicationConfig.getKeyStorePassword()); + config.setTrustStorePath(internalCommunicationConfig.getTrustStorePath()); + if (internalCommunicationConfig.getIncludedCipherSuites().isPresent()) { + config.setCipherSuites(internalCommunicationConfig.getIncludedCipherSuites().get()); + } + }); + install(installModuleIf(InternalCommunicationConfig.class, InternalCommunicationConfig::isKerberosEnabled, kerberosInternalCommunicationModule())); binder.bind(InternalAuthenticationManager.class); httpClientBinder(binder).bindGlobalFilter(InternalAuthenticationManager.class); diff --git a/presto-main/src/main/java/com/facebook/presto/server/NodeResource.java b/presto-main/src/main/java/com/facebook/presto/server/NodeResource.java index eb4546ba64d3e..08731cd4c0b1f 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/NodeResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/NodeResource.java @@ -15,11 +15,10 @@ import com.facebook.presto.failureDetector.HeartbeatFailureDetector; import com.google.common.collect.Maps; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.GET; -import javax.ws.rs.Path; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; import java.util.Collection; diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index fac8bc79e06b5..14c66a6a95483 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -33,6 +33,7 @@ import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.ClientRequestFilterModule; +import com.facebook.presto.builtin.tools.WorkerFunctionRegistryTool; import com.facebook.presto.dispatcher.QueryPrerequisitesManager; import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule; import com.facebook.presto.eventlistener.EventListenerManager; @@ -48,12 +49,16 @@ import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.metadata.StaticCatalogStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStore; +import com.facebook.presto.metadata.StaticTypeManagerStore; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.security.AccessControlModule; import com.facebook.presto.server.security.PasswordAuthenticatorManager; import com.facebook.presto.server.security.PrestoAuthenticatorManager; +import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.server.security.ServerSecurityModule; +import com.facebook.presto.server.security.oauth2.OAuth2Client; +import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; import com.facebook.presto.sql.parser.SqlParserOptions; @@ -85,6 +90,7 @@ import static com.facebook.airlift.json.JsonBinder.jsonBinder; import static com.facebook.presto.server.PrestoSystemRequirements.verifyJvmRequirements; import static com.facebook.presto.server.PrestoSystemRequirements.verifySystemTimeIsReasonable; +import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.OAUTH2; import static com.google.common.base.Strings.nullToEmpty; import static java.util.Objects.requireNonNull; @@ -175,8 +181,14 @@ public void run() injector.getInstance(DriftServer.class)); injector.getInstance(StaticFunctionNamespaceStore.class).loadFunctionNamespaceManagers(); + injector.getInstance(StaticTypeManagerStore.class).loadTypeManagers(); injector.getInstance(SessionPropertyDefaults.class).loadConfigurationManager(); injector.getInstance(ResourceGroupManager.class).loadConfigurationManager(); + if (injector.getInstance(FeaturesConfig.class).isBuiltInSidecarFunctionsEnabled()) { + List functions = injector.getInstance(WorkerFunctionRegistryTool.class).getWorkerFunctions(); + injector.getInstance(FunctionAndTypeManager.class).registerWorkerFunctions(functions); + } + if (!serverConfig.isResourceManager()) { injector.getInstance(AccessControlManager.class).loadSystemAccessControl(); } @@ -191,7 +203,6 @@ public void run() injector.getInstance(NodeStatusNotificationManager.class).loadNodeStatusNotificationProvider(); injector.getInstance(GracefulShutdownHandler.class).loadNodeStatusNotification(); injector.getInstance(SessionPropertyManager.class).loadSessionPropertyProviders(); - injector.getInstance(FunctionAndTypeManager.class).loadTypeManagers(); PlanCheckerProviderManager planCheckerProviderManager = injector.getInstance(PlanCheckerProviderManager.class); InternalNodeManager nodeManager = injector.getInstance(DiscoveryNodeManager.class); NodeInfo nodeInfo = injector.getInstance(NodeInfo.class); @@ -201,8 +212,16 @@ public void run() injector.getInstance(ClientRequestFilterManager.class).loadClientRequestFilters(); injector.getInstance(ExpressionOptimizerManager.class).loadExpressionOptimizerFactories(); + injector.getInstance(FunctionAndTypeManager.class) + .getBuiltInPluginFunctionNamespaceManager().triggerConflictCheckWithBuiltInFunctions(); + startAssociatedProcesses(injector); + SecurityConfig securityConfig = injector.getInstance(SecurityConfig.class); + if (securityConfig.getAuthenticationTypes().contains(OAUTH2)) { + injector.getInstance(OAuth2Client.class).load(); + } + injector.getInstance(Announcer.class).start(); log.info("======== SERVER STARTED ========"); diff --git a/presto-main/src/main/java/com/facebook/presto/server/QueryResource.java b/presto-main/src/main/java/com/facebook/presto/server/QueryResource.java index 7ffdcdb8208a4..c9d6b272ed01e 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/QueryResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/QueryResource.java @@ -25,25 +25,24 @@ import com.facebook.presto.spi.QueryId; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DELETE; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.Status; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.util.ArrayList; @@ -65,13 +64,13 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.NO_CONTENT; +import static jakarta.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; import static java.lang.String.format; import static java.util.Comparator.comparing; import static java.util.Comparator.comparingInt; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -import static javax.ws.rs.core.Response.Status.NO_CONTENT; -import static javax.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; /** * Manage queries scheduled on this node diff --git a/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfoResource.java index 128db56a4d89f..6cbf436f8f04b 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/QueryStateInfoResource.java @@ -23,24 +23,23 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.google.re2j.Pattern; import io.airlift.slice.Slices; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.util.Arrays; @@ -59,11 +58,11 @@ import static com.google.common.base.Strings.isNullOrEmpty; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; +import static jakarta.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; -import static javax.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; @Path("/v1/queryState") @RolesAllowed({ADMIN, USER}) diff --git a/presto-main/src/main/java/com/facebook/presto/server/RequestErrorTracker.java b/presto-main/src/main/java/com/facebook/presto/server/RequestErrorTracker.java index 961a39c209410..a5e40792cdc23 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/RequestErrorTracker.java +++ b/presto-main/src/main/java/com/facebook/presto/server/RequestErrorTracker.java @@ -15,6 +15,7 @@ import com.facebook.airlift.event.client.ServiceUnavailableException; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.TaskId; import com.facebook.presto.server.remotetask.Backoff; import com.facebook.presto.spi.ErrorCodeSupplier; @@ -23,9 +24,7 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFutureTask; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.io.EOFException; import java.net.SocketException; diff --git a/presto-main/src/main/java/com/facebook/presto/server/RequestHelpers.java b/presto-main/src/main/java/com/facebook/presto/server/RequestHelpers.java index 65e1165de99cc..61645ac8c65d3 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/RequestHelpers.java +++ b/presto-main/src/main/java/com/facebook/presto/server/RequestHelpers.java @@ -19,7 +19,7 @@ import static com.facebook.presto.PrestoMediaTypes.APPLICATION_JACKSON_SMILE; import static com.google.common.net.HttpHeaders.ACCEPT; import static com.google.common.net.MediaType.JSON_UTF_8; -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; public class RequestHelpers { diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java index 9df49004cd158..892e73be39e0c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceGroupStateInfoResource.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.server; +import com.facebook.airlift.units.Duration; import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; @@ -20,26 +21,24 @@ import com.facebook.presto.spi.resourceGroups.ResourceGroupId; import com.google.common.base.Supplier; import com.google.common.base.Suppliers; -import io.airlift.units.Duration; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.Encoded; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.Encoded; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.io.UnsupportedEncodingException; import java.net.URI; @@ -60,11 +59,11 @@ import static com.google.common.base.Suppliers.memoizeWithExpiration; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; +import static jakarta.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; -import static javax.ws.rs.core.Response.Status.SERVICE_UNAVAILABLE; @Path("/v1/resourceGroupState") @RolesAllowed(ADMIN) diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java similarity index 90% rename from presto-main-base/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java rename to presto-main/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java index ec34f7d91bada..660429d1dc606 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerHeartbeatResource.java @@ -16,14 +16,13 @@ import com.facebook.presto.resourcemanager.ForResourceManager; import com.facebook.presto.resourcemanager.ResourceManagerClusterStateProvider; import com.google.common.util.concurrent.ListeningExecutorService; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; - +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; @Path("/v1/heartbeat") public class ResourceManagerHeartbeatResource diff --git a/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerModule.java b/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerModule.java index 553e7ea7c7dac..6109c99d70113 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ResourceManagerModule.java @@ -15,6 +15,8 @@ import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.discovery.server.EmbeddedDiscoveryModule; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.dispatcher.NoOpQueryManager; import com.facebook.presto.execution.QueryIdGenerator; import com.facebook.presto.execution.QueryInfo; @@ -38,10 +40,7 @@ import com.google.inject.Binder; import com.google.inject.Provides; import com.google.inject.Scopes; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import static com.facebook.airlift.configuration.ConditionalModule.installModuleIf; import static com.facebook.airlift.discovery.client.DiscoveryBinder.discoveryBinder; diff --git a/presto-main/src/main/java/com/facebook/presto/server/RetryUrlValidator.java b/presto-main/src/main/java/com/facebook/presto/server/RetryUrlValidator.java new file mode 100644 index 0000000000000..7b8eb1b9d1485 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/RetryUrlValidator.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.airlift.log.Logger; + +import javax.inject.Inject; + +import java.net.URI; +import java.util.Set; + +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class RetryUrlValidator +{ + private static final Logger log = Logger.get(RetryUrlValidator.class); + public static final String RETRY_PATH = "/v1/statement/queued/retry"; + + private final RetryConfig retryConfig; + + @Inject + public RetryUrlValidator(RetryConfig retryConfig) + { + this.retryConfig = requireNonNull(retryConfig, "retryConfig is null"); + } + + public boolean isValidRetryUrl(URI retryUrl, String currentServerHost) + { + requireNonNull(retryUrl, "retryUrl is null"); + if (!retryConfig.isRetryEnabled()) { + return false; + } + + try { + // Check protocol + if (retryConfig.isRequireHttps() && !"https".equalsIgnoreCase(retryUrl.getScheme())) { + log.debug("Retry URL rejected - not HTTPS: %s", retryUrl); + return false; + } + + // Check path + if (!retryUrl.getPath().startsWith(RETRY_PATH)) { + log.debug("Retry URL rejected - invalid path: %s", retryUrl); + return false; + } + + if (retryUrl.getRawQuery() != null) { + log.debug("Retry URL rejected - parameters present: %s", retryUrl); + return false; + } + + // Check domain allowlist + if (!isDomainAllowed(retryUrl.getHost(), currentServerHost)) { + log.debug("Retry URL rejected - domain not allowed: %s", retryUrl.getHost()); + return false; + } + + return true; + } + catch (Exception e) { + log.debug(e, "Invalid retry URL: %s", retryUrl); + return false; + } + } + + private boolean isDomainAllowed(String host, String currentServerHost) + { + Set allowedDomains = retryConfig.getAllowedRetryDomains(); + String lowerHost = host.toLowerCase(ENGLISH); + + // If no domains are configured, only allow same domain as current server + if (allowedDomains.isEmpty()) { + if (currentServerHost == null) { + // Fallback to original behavior if current host not provided + log.warn("Current server host not provided, cannot restrict to same domain"); + return false; + } + return lowerHost.equals(currentServerHost.toLowerCase(ENGLISH)); + } + + for (String allowedDomain : allowedDomains) { + if (allowedDomain.startsWith("*.")) { + // Wildcard domain + String suffix = allowedDomain.substring(1); + if (lowerHost.endsWith(suffix)) { + return true; + } + } + else if (lowerHost.equals(allowedDomain)) { + // Exact match + return true; + } + } + + return false; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java b/presto-main/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java similarity index 96% rename from presto-main-base/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java rename to presto-main/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java index ef57b2ea1b8f7..65cf3233d16f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java +++ b/presto-main/src/main/java/com/facebook/presto/server/SerializedPageWriteListener.java @@ -16,10 +16,9 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.page.SerializedPage; import io.airlift.slice.SliceOutput; - -import javax.servlet.AsyncContext; -import javax.servlet.ServletOutputStream; -import javax.servlet.WriteListener; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.WriteListener; import java.io.IOException; import java.util.ArrayDeque; diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java index 6006b49f298b4..8c686288caef8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerInfoResource.java @@ -19,32 +19,32 @@ import com.facebook.presto.execution.resourceGroups.ResourceGroupManager; import com.facebook.presto.metadata.StaticCatalogStore; import com.facebook.presto.spi.NodeState; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.GET; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Response; +import com.facebook.presto.spi.NodeStats; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; import java.util.Optional; import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_BINARY; import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_COMPACT; import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_FB_COMPACT; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.facebook.presto.spi.NodeState.ACTIVE; import static com.facebook.presto.spi.NodeState.INACTIVE; import static com.facebook.presto.spi.NodeState.SHUTTING_DOWN; -import static io.airlift.units.Duration.nanosSince; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; @Path("/v1/info") public class ServerInfoResource @@ -129,6 +129,16 @@ else if (!nodeResourceStatusProvider.hasResources() || !resourceGroupManager.isC } } + @GET + @Path("stats") + @Produces({APPLICATION_JSON, APPLICATION_THRIFT_BINARY, APPLICATION_THRIFT_COMPACT, APPLICATION_THRIFT_FB_COMPACT}) + @RolesAllowed(ADMIN) + public NodeStats getServerStats() + { + NodeStats stats = new NodeStats(getServerState(), null); + return stats; + } + @GET @Path("coordinator") @Produces(TEXT_PLAIN) diff --git a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java index 7c9f8742ccea7..9a5654c408e3e 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ServerMainModule.java @@ -16,11 +16,16 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.discovery.client.ServiceAnnouncement; +import com.facebook.airlift.http.client.HttpClient; import com.facebook.airlift.http.server.TheServlet; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; import com.facebook.airlift.json.JsonObjectMapperProvider; import com.facebook.airlift.stats.GcMonitor; import com.facebook.airlift.stats.JmxGcMonitor; import com.facebook.airlift.stats.PauseMeter; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.drift.client.ExceptionClassification; import com.facebook.drift.client.address.AddressSelector; import com.facebook.drift.codec.utils.DefaultThriftCodecsModule; @@ -30,6 +35,10 @@ import com.facebook.presto.PagesIndexPageSorter; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.builtin.tools.ForNativeFunctionRegistryInfo; +import com.facebook.presto.builtin.tools.NativeSidecarFunctionRegistryTool; +import com.facebook.presto.builtin.tools.NativeSidecarRegistryToolConfig; +import com.facebook.presto.builtin.tools.WorkerFunctionRegistryTool; import com.facebook.presto.catalogserver.CatalogServerClient; import com.facebook.presto.catalogserver.RandomCatalogServerAddressSelector; import com.facebook.presto.catalogserver.RemoteMetadataManager; @@ -41,8 +50,8 @@ import com.facebook.presto.common.block.BlockEncodingSerde; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorCodecManager; import com.facebook.presto.connector.ConnectorManager; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.connector.system.SystemConnectorModule; import com.facebook.presto.cost.FilterStatsCalculator; import com.facebook.presto.cost.HistoryBasedOptimizationConfig; @@ -75,8 +84,11 @@ import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; import com.facebook.presto.execution.scheduler.NodeSchedulerExporter; import com.facebook.presto.execution.scheduler.TableWriteInfo; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterOverloadPolicyModule; +import com.facebook.presto.execution.scheduler.clusterOverload.ClusterResourceChecker; import com.facebook.presto.execution.scheduler.nodeSelection.NodeSelectionStats; import com.facebook.presto.execution.scheduler.nodeSelection.SimpleTtlNodeSelectorConfig; +import com.facebook.presto.functionNamespace.JsonBasedUdfFunctionMetadata; import com.facebook.presto.index.IndexManager; import com.facebook.presto.memory.LocalMemoryManager; import com.facebook.presto.memory.LocalMemoryManagerExporter; @@ -87,24 +99,27 @@ import com.facebook.presto.memory.NodeMemoryConfig; import com.facebook.presto.memory.ReservedSystemMemoryConfig; import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; import com.facebook.presto.metadata.CatalogManager; import com.facebook.presto.metadata.ColumnPropertyManager; -import com.facebook.presto.metadata.ConnectorMetadataUpdaterManager; import com.facebook.presto.metadata.DiscoveryNodeManager; import com.facebook.presto.metadata.ForNodeManager; import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.metadata.MaterializedViewPropertyManager; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.metadata.SchemaPropertyManager; import com.facebook.presto.metadata.SessionPropertyManager; -import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.SessionPropertyProviderConfig; import com.facebook.presto.metadata.StaticCatalogStore; import com.facebook.presto.metadata.StaticCatalogStoreConfig; import com.facebook.presto.metadata.StaticFunctionNamespaceStore; import com.facebook.presto.metadata.StaticFunctionNamespaceStoreConfig; +import com.facebook.presto.metadata.StaticTypeManagerStore; +import com.facebook.presto.metadata.StaticTypeManagerStoreConfig; +import com.facebook.presto.metadata.TableFunctionRegistry; import com.facebook.presto.metadata.TablePropertyManager; import com.facebook.presto.nodeManager.PluginNodeManager; import com.facebook.presto.operator.ExchangeClientConfig; @@ -139,11 +154,11 @@ import com.facebook.presto.resourcemanager.ResourceManagerConfig; import com.facebook.presto.resourcemanager.ResourceManagerInconsistentException; import com.facebook.presto.resourcemanager.ResourceManagerResourceGroupService; +import com.facebook.presto.server.remotetask.DecompressionFilter; import com.facebook.presto.server.remotetask.HttpLocationFactory; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig; import com.facebook.presto.server.thrift.FixedAddressSelector; -import com.facebook.presto.server.thrift.MetadataUpdatesCodec; -import com.facebook.presto.server.thrift.SplitCodec; -import com.facebook.presto.server.thrift.TableWriteInfoCodec; +import com.facebook.presto.server.thrift.HandleThriftModule; import com.facebook.presto.server.thrift.ThriftServerInfoClient; import com.facebook.presto.server.thrift.ThriftServerInfoService; import com.facebook.presto.server.thrift.ThriftTaskClient; @@ -151,19 +166,20 @@ import com.facebook.presto.server.thrift.ThriftTaskUpdateRequestBodyReader; import com.facebook.presto.sessionpropertyproviders.JavaWorkerSessionPropertyProvider; import com.facebook.presto.sessionpropertyproviders.NativeWorkerSessionPropertyProvider; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; import com.facebook.presto.spi.ConnectorSplit; -import com.facebook.presto.spi.ConnectorTypeSerde; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PageIndexerFactory; import com.facebook.presto.spi.PageSorter; +import com.facebook.presto.spi.RowExpressionSerde; import com.facebook.presto.spi.analyzer.ViewDefinition; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.plan.SimplePlanFragment; import com.facebook.presto.spi.plan.SimplePlanFragmentSerde; +import com.facebook.presto.spi.procedure.ProcedureRegistry; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.DomainTranslator; import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; import com.facebook.presto.spiller.FileSingleStreamSpillerFactory; @@ -204,6 +220,7 @@ import com.facebook.presto.sql.analyzer.QueryExplainer; import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager; import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.expressions.JsonCodecRowExpressionSerde; import com.facebook.presto.sql.gen.ExpressionCompiler; import com.facebook.presto.sql.gen.JoinCompiler; import com.facebook.presto.sql.gen.JoinFilterFunctionCompiler; @@ -246,15 +263,13 @@ import com.google.inject.TypeLiteral; import com.google.inject.multibindings.MapBinder; import io.airlift.slice.Slice; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.PreDestroy; -import javax.inject.Singleton; -import javax.servlet.Filter; -import javax.servlet.Servlet; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Singleton; +import jakarta.servlet.Filter; +import jakarta.servlet.Servlet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; @@ -270,8 +285,10 @@ import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodec.listJsonCodec; import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; import static com.facebook.airlift.json.smile.SmileCodecBinder.smileCodecBinder; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.drift.client.ExceptionClassification.HostStatus.DOWN; import static com.facebook.drift.client.ExceptionClassification.HostStatus.NORMAL; import static com.facebook.drift.client.guice.DriftClientBinder.driftClientBinder; @@ -286,7 +303,6 @@ import static com.google.inject.multibindings.MapBinder.newMapBinder; import static com.google.inject.multibindings.Multibinder.newSetBinder; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newFixedThreadPool; @@ -326,6 +342,7 @@ else if (serverConfig.isCoordinator()) { install(new InternalCommunicationModule()); + configBinder(binder).bindConfig(ServerConfig.class); configBinder(binder).bindConfig(FeaturesConfig.class); configBinder(binder).bindConfig(FunctionsConfig.class); configBinder(binder).bindConfig(JavaFeaturesConfig.class); @@ -371,6 +388,7 @@ else if (serverConfig.isCoordinator()) { // expression manager binder.bind(ExpressionOptimizerManager.class).in(Scopes.SINGLETON); + binder.bind(RowExpressionSerde.class).to(JsonCodecRowExpressionSerde.class).in(Scopes.SINGLETON); // schema properties binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); @@ -378,6 +396,9 @@ else if (serverConfig.isCoordinator()) { // table properties binder.bind(TablePropertyManager.class).in(Scopes.SINGLETON); + // materialized view properties + binder.bind(MaterializedViewPropertyManager.class).in(Scopes.SINGLETON); + // column properties binder.bind(ColumnPropertyManager.class).in(Scopes.SINGLETON); @@ -398,6 +419,11 @@ else if (serverConfig.isCoordinator()) { .withAddressSelector(((addressSelectorBinder, annotation, prefix) -> addressSelectorBinder.bind(AddressSelector.class).annotatedWith(annotation).to(FixedAddressSelector.class))); + binder.bind(new TypeLiteral>>>() {}) + .toInstance(new JsonCodecFactory().mapJsonCodec(String.class, listJsonCodec(JsonBasedUdfFunctionMetadata.class))); + httpClientBinder(binder).bindHttpClient("native-function-registry", ForNativeFunctionRegistryInfo.class); + configBinder(binder).bindConfig(NativeSidecarRegistryToolConfig.class); + // node scheduler // TODO: remove from NodePartitioningManager and move to CoordinatorModule configBinder(binder).bindConfig(NodeSchedulerConfig.class); @@ -423,6 +449,7 @@ else if (serverConfig.isCoordinator()) { // task execution jaxrsBinder(binder).bind(TaskResource.class); jaxrsBinder(binder).bind(ThriftTaskUpdateRequestBodyReader.class); + jaxrsBinder(binder).bind(DecompressionFilter.class); newExporter(binder).export(TaskResource.class).withGeneratedName(); jaxrsBinder(binder).bind(TaskExecutorResource.class); @@ -430,11 +457,13 @@ else if (serverConfig.isCoordinator()) { binder.bind(TaskManagementExecutor.class).in(Scopes.SINGLETON); install(new DefaultThriftCodecsModule()); + // handle resolve for thrift + binder.install(new HandleThriftModule()); + thriftCodecBinder(binder).bindCustomThriftCodec(SqlInvokedFunctionCodec.class); thriftCodecBinder(binder).bindCustomThriftCodec(SqlFunctionIdCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(MetadataUpdatesCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(SplitCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(TableWriteInfoCodec.class); + + binder.bind(ConnectorCodecManager.class).in(Scopes.SINGLETON); jsonCodecBinder(binder).bindListJsonCodec(TaskMemoryReservationSummary.class); binder.bind(SqlTaskManager.class).in(Scopes.SINGLETON); @@ -547,6 +576,7 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon binder.bind(PageFunctionCompiler.class).in(Scopes.SINGLETON); newExporter(binder).export(PageFunctionCompiler.class).withGeneratedName(); configBinder(binder).bindConfig(TaskManagerConfig.class); + configBinder(binder).bindConfig(ReactorNettyHttpClientConfig.class); binder.bind(IndexJoinLookupStats.class).in(Scopes.SINGLETON); newExporter(binder).export(IndexJoinLookupStats.class).withGeneratedName(); binder.bind(AsyncHttpExecutionMBean.class).in(Scopes.SINGLETON); @@ -568,8 +598,8 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonCodecBinder(binder).bindJsonCodec(TableCommitContext.class); jsonCodecBinder(binder).bindJsonCodec(SqlInvokedFunction.class); jsonCodecBinder(binder).bindJsonCodec(TaskSource.class); - jsonCodecBinder(binder).bindJsonCodec(Split.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); @@ -623,18 +653,6 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon binder.bind(PageSourceManager.class).in(Scopes.SINGLETON); binder.bind(PageSourceProvider.class).to(PageSourceManager.class).in(Scopes.SINGLETON); - // connector distributed metadata manager - binder.bind(ConnectorMetadataUpdaterManager.class).in(Scopes.SINGLETON); - - // connector metadata update handle serde manager - binder.bind(ConnectorTypeSerdeManager.class).in(Scopes.SINGLETON); - - // connector metadata update handle json serde - binder.bind(new TypeLiteral>() {}) - .annotatedWith(ForJsonMetadataUpdateHandle.class) - .to(ConnectorMetadataUpdateHandleJsonSerde.class) - .in(Scopes.SINGLETON); - // page sink provider binder.bind(PageSinkManager.class).in(Scopes.SINGLETON); binder.bind(PageSinkProvider.class).to(PageSinkManager.class).in(Scopes.SINGLETON); @@ -644,8 +662,14 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon configBinder(binder).bindConfig(StaticCatalogStoreConfig.class); binder.bind(StaticFunctionNamespaceStore.class).in(Scopes.SINGLETON); configBinder(binder).bindConfig(StaticFunctionNamespaceStoreConfig.class); + binder.bind(StaticTypeManagerStore.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(StaticTypeManagerStoreConfig.class); + configBinder(binder).bindConfig(SessionPropertyProviderConfig.class); binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); binder.bind(MetadataManager.class).in(Scopes.SINGLETON); + binder.bind(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); + binder.bind(ProcedureRegistry.class).to(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); if (serverConfig.isCatalogServerEnabled() && serverConfig.isCoordinator()) { binder.bind(RemoteMetadataManager.class).in(Scopes.SINGLETON); @@ -703,6 +727,10 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon // system connector binder.install(new SystemConnectorModule()); + // ClusterOverload policy module + binder.install(new ClusterOverloadPolicyModule()); + newExporter(binder).export(ClusterResourceChecker.class).withGeneratedName(); + // splits jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); jsonCodecBinder(binder).bindJsonCodec(ConnectorSplit.class); @@ -717,10 +745,6 @@ public ListeningExecutorService createResourceManagerExecutor(ResourceManagerCon jsonBinder(binder).addDeserializerBinding(FunctionCall.class).to(FunctionCallDeserializer.class); thriftCodecBinder(binder).bindThriftCodec(TaskUpdateRequest.class); - // metadata updates - jsonCodecBinder(binder).bindJsonCodec(MetadataUpdates.class); - smileCodecBinder(binder).bindSmileCodec(MetadataUpdates.class); - // split monitor binder.bind(SplitMonitor.class).in(Scopes.SINGLETON); @@ -905,6 +929,22 @@ public static FragmentResultCacheManager createFragmentResultCacheManager(FileFr return new NoOpFragmentResultCacheManager(); } + @Provides + @Singleton + public WorkerFunctionRegistryTool provideWorkerFunctionRegistryTool( + NativeSidecarRegistryToolConfig config, + @ForNativeFunctionRegistryInfo HttpClient httpClient, + JsonCodec>> nativeFunctionSignatureMapJsonCodec, + NodeManager nodeManager) + { + return new NativeSidecarFunctionRegistryTool( + httpClient, + nativeFunctionSignatureMapJsonCodec, + nodeManager, + config.getNativeSidecarRegistryToolNumRetries(), + config.getNativeSidecarRegistryToolRetryDelayMs()); + } + public static class ExecutorCleanup { private final List executors; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/StageResource.java b/presto-main/src/main/java/com/facebook/presto/server/StageResource.java similarity index 88% rename from presto-main-base/src/main/java/com/facebook/presto/server/StageResource.java rename to presto-main/src/main/java/com/facebook/presto/server/StageResource.java index 2463f50fd1a3f..0944593bd189f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/StageResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/StageResource.java @@ -15,12 +15,11 @@ import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.StageId; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.DELETE; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; import static com.facebook.presto.server.security.RoleType.USER; import static java.util.Objects.requireNonNull; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/StatusResource.java b/presto-main/src/main/java/com/facebook/presto/server/StatusResource.java similarity index 90% rename from presto-main-base/src/main/java/com/facebook/presto/server/StatusResource.java rename to presto-main/src/main/java/com/facebook/presto/server/StatusResource.java index 241cfabaf09b2..98f17b1b281e9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/StatusResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/StatusResource.java @@ -17,22 +17,21 @@ import com.facebook.presto.client.NodeVersion; import com.facebook.presto.memory.LocalMemoryManager; import com.sun.management.OperatingSystemMXBean; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.GET; -import javax.ws.rs.HEAD; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.Response; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HEAD; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Response; import java.lang.management.ManagementFactory; import java.lang.management.MemoryMXBean; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.server.security.RoleType.INTERNAL; -import static io.airlift.units.Duration.nanosSince; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; @Path("/v1/status") @RolesAllowed(INTERNAL) diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/TaskExecutorResource.java b/presto-main/src/main/java/com/facebook/presto/server/TaskExecutorResource.java similarity index 86% rename from presto-main-base/src/main/java/com/facebook/presto/server/TaskExecutorResource.java rename to presto-main/src/main/java/com/facebook/presto/server/TaskExecutorResource.java index c82bc19941f3e..bac43cf00596a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/TaskExecutorResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/TaskExecutorResource.java @@ -14,13 +14,12 @@ package com.facebook.presto.server; import com.facebook.presto.execution.executor.TaskExecutor; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; import static com.facebook.presto.server.security.RoleType.ADMIN; import static java.util.Objects.requireNonNull; diff --git a/presto-main/src/main/java/com/facebook/presto/server/TaskInfoResource.java b/presto-main/src/main/java/com/facebook/presto/server/TaskInfoResource.java index ff6ca3990979c..7ee90d4386a08 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/TaskInfoResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/TaskInfoResource.java @@ -23,32 +23,31 @@ import com.facebook.presto.resourcemanager.ResourceManagerProxy; import com.facebook.presto.spi.QueryId; import com.google.inject.Inject; - -import javax.annotation.security.RolesAllowed; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.NotFoundException; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.NotFoundException; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.net.UnknownHostException; import java.util.Optional; import static com.google.common.base.Preconditions.checkState; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("v1/taskInfo") @RolesAllowed("ADMIN") diff --git a/presto-main/src/main/java/com/facebook/presto/server/TaskResource.java b/presto-main/src/main/java/com/facebook/presto/server/TaskResource.java index dc0c288d4cb37..f98d3424e500a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/TaskResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/TaskResource.java @@ -16,8 +16,8 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.json.Codec; import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.execution.TaskId; import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.execution.TaskManager; @@ -25,34 +25,30 @@ import com.facebook.presto.execution.TaskStatus; import com.facebook.presto.execution.buffer.OutputBufferInfo; import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.sql.planner.PlanFragment; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.DELETE; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HEAD; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.HttpHeaders; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HEAD; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import java.util.List; import java.util.concurrent.Executor; @@ -68,16 +64,14 @@ import static com.facebook.presto.client.PrestoHeaders.PRESTO_BUFFER_REMAINING_BYTES; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; -import static com.facebook.presto.server.TaskResourceUtils.convertToThriftTaskInfo; -import static com.facebook.presto.server.TaskResourceUtils.isThriftAcceptable; import static com.facebook.presto.server.security.RoleType.INTERNAL; import static com.facebook.presto.util.TaskUtils.randomizeWaitTime; import static com.google.common.collect.Iterables.transform; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; /** * Manages tasks on this worker node @@ -93,8 +87,6 @@ public class TaskResource private final Executor responseExecutor; private final ScheduledExecutorService timeoutExecutor; private final Codec planFragmentCodec; - private final HandleResolver handleResolver; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; @Inject public TaskResource( @@ -102,17 +94,13 @@ public TaskResource( SessionPropertyManager sessionPropertyManager, @ForAsyncRpc BoundedExecutor responseExecutor, @ForAsyncRpc ScheduledExecutorService timeoutExecutor, - JsonCodec planFragmentJsonCodec, - HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager) + JsonCodec planFragmentJsonCodec) { this.taskManager = requireNonNull(taskManager, "taskManager is null"); this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null"); this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); this.timeoutExecutor = requireNonNull(timeoutExecutor, "timeoutExecutor is null"); this.planFragmentCodec = planFragmentJsonCodec; - this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); - this.connectorTypeSerdeManager = requireNonNull(connectorTypeSerdeManager, "connectorTypeSerdeManager is null"); } @GET @@ -160,22 +148,16 @@ public void getTaskInfo( @HeaderParam(PRESTO_CURRENT_STATE) TaskState currentState, @HeaderParam(PRESTO_MAX_WAIT) Duration maxWait, @Context UriInfo uriInfo, - @Context HttpHeaders httpHeaders, @Suspended AsyncResponse asyncResponse) { requireNonNull(taskId, "taskId is null"); - boolean isThriftRequest = isThriftAcceptable(httpHeaders); - if (currentState == null || maxWait == null) { TaskInfo taskInfo = taskManager.getTaskInfo(taskId); if (shouldSummarize(uriInfo)) { taskInfo = taskInfo.summarize(); } - if (isThriftRequest) { - taskInfo = convertToThriftTaskInfo(taskInfo, connectorTypeSerdeManager, handleResolver); - } asyncResponse.resume(taskInfo); return; } @@ -191,13 +173,6 @@ public void getTaskInfo( futureTaskInfo = Futures.transform(futureTaskInfo, TaskInfo::summarize, directExecutor()); } - if (isThriftRequest) { - futureTaskInfo = Futures.transform( - futureTaskInfo, - taskInfo -> convertToThriftTaskInfo(taskInfo, connectorTypeSerdeManager, handleResolver), - directExecutor()); - } - // For hard timeout, add an additional time to max wait for thread scheduling contention and GC Duration timeout = new Duration(waitTime.toMillis() + ADDITIONAL_WAIT_TIME.toMillis(), MILLISECONDS); bindAsyncResponse(asyncResponse, futureTaskInfo, responseExecutor) @@ -239,16 +214,6 @@ public void getTaskStatus( .withTimeout(timeout); } - @POST - @Path("{taskId}/metadataresults") - @Consumes({APPLICATION_JSON, APPLICATION_JACKSON_SMILE}) - public Response updateMetadataResults(@PathParam("taskId") TaskId taskId, MetadataUpdates metadataUpdates, @Context UriInfo uriInfo) - { - requireNonNull(metadataUpdates, "metadataUpdates is null"); - taskManager.updateMetadataResults(taskId, metadataUpdates); - return Response.ok().build(); - } - @DELETE @Path("{taskId}") @Consumes({APPLICATION_JSON, APPLICATION_JACKSON_SMILE, APPLICATION_THRIFT_BINARY, APPLICATION_THRIFT_COMPACT, APPLICATION_THRIFT_FB_COMPACT}) @@ -273,10 +238,6 @@ public TaskInfo deleteTask( taskInfo = taskInfo.summarize(); } - if (isThriftAcceptable(httpHeaders)) { - taskInfo = convertToThriftTaskInfo(taskInfo, connectorTypeSerdeManager, handleResolver); - } - return taskInfo; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ThreadResource.java b/presto-main/src/main/java/com/facebook/presto/server/ThreadResource.java similarity index 97% rename from presto-main-base/src/main/java/com/facebook/presto/server/ThreadResource.java rename to presto-main/src/main/java/com/facebook/presto/server/ThreadResource.java index 38dd30ff0ff6e..8de7423dbbaa5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ThreadResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ThreadResource.java @@ -20,12 +20,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.Ordering; - -import javax.annotation.security.RolesAllowed; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.Produces; -import javax.ws.rs.core.MediaType; +import jakarta.annotation.security.RolesAllowed; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; import java.lang.management.ManagementFactory; import java.lang.management.ThreadInfo; diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ThrowableMapper.java b/presto-main/src/main/java/com/facebook/presto/server/ThrowableMapper.java similarity index 81% rename from presto-main-base/src/main/java/com/facebook/presto/server/ThrowableMapper.java rename to presto-main/src/main/java/com/facebook/presto/server/ThrowableMapper.java index 02103191cab01..ba9c15f52ae1b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ThrowableMapper.java +++ b/presto-main/src/main/java/com/facebook/presto/server/ThrowableMapper.java @@ -15,17 +15,16 @@ import com.facebook.airlift.log.Logger; import com.google.common.base.Throwables; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.ResponseBuilder; +import jakarta.ws.rs.ext.ExceptionMapper; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.ResponseBuilder; -import javax.ws.rs.ext.ExceptionMapper; - -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN; public class ThrowableMapper implements ExceptionMapper diff --git a/presto-main/src/main/java/com/facebook/presto/server/WebUiResource.java b/presto-main/src/main/java/com/facebook/presto/server/WebUiResource.java new file mode 100644 index 0000000000000..a2e908d47aea7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/WebUiResource.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server; + +import com.facebook.presto.server.security.oauth2.OAuthWebUiCookie; +import jakarta.annotation.security.RolesAllowed; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; + +import java.util.Optional; + +import static com.facebook.presto.server.security.RoleType.ADMIN; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getLastURLParameter; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; + +@Path("/") +@RolesAllowed(ADMIN) +public class WebUiResource +{ + public static final String UI_ENDPOINT = "/"; + + @GET + public Response redirectIndexHtml( + @HeaderParam(X_FORWARDED_PROTO) String proto, + @Context UriInfo uriInfo) + { + if (isNullOrEmpty(proto)) { + proto = uriInfo.getRequestUri().getScheme(); + } + Optional lastURL = getLastURLParameter(uriInfo.getQueryParameters()); + if (lastURL.isPresent()) { + return Response + .seeOther(uriInfo.getRequestUriBuilder().scheme(proto).uri(lastURL.get()).build()) + .build(); + } + + return Response + .temporaryRedirect(uriInfo.getRequestUriBuilder().scheme(proto).path("/ui/").replaceQuery("").build()) + .build(); + } + + @GET + @Path("/logout") + public Response logout( + @HeaderParam(X_FORWARDED_PROTO) String proto, + @Context UriInfo uriInfo) + { + if (isNullOrEmpty(proto)) { + proto = uriInfo.getRequestUri().getScheme(); + } + return Response + .temporaryRedirect(uriInfo.getBaseUriBuilder().scheme(proto).path("/ui/logout.html").build()) + .cookie(OAuthWebUiCookie.delete()) + .build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/WorkerModule.java b/presto-main/src/main/java/com/facebook/presto/server/WorkerModule.java index 1227ebe411251..8b007302291a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/WorkerModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/WorkerModule.java @@ -26,8 +26,7 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import static com.google.common.reflect.Reflection.newProxy; diff --git a/presto-main/src/main/java/com/facebook/presto/server/WorkerResource.java b/presto-main/src/main/java/com/facebook/presto/server/WorkerResource.java index 410150698ec01..63cfec0e0cbac 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/WorkerResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/WorkerResource.java @@ -19,14 +19,13 @@ import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.spi.NodeState; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.GET; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Response; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; import java.io.IOException; import java.io.InputStream; @@ -36,10 +35,10 @@ import static com.facebook.airlift.http.client.Request.Builder.prepareGet; import static com.facebook.presto.server.security.RoleType.ADMIN; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/v1/worker") @RolesAllowed(ADMIN) diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java similarity index 80% rename from presto-main-base/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java rename to presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java index e9de9c13b2727..40f1bf49e2484 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingQueryResponseProvider.java @@ -13,16 +13,17 @@ */ package com.facebook.presto.server.protocol; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.dispatcher.DispatchInfo; import com.facebook.presto.spi.QueryId; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import java.net.URI; import java.util.Optional; +import java.util.OptionalLong; public interface ExecutingQueryResponseProvider { @@ -47,6 +48,9 @@ public interface ExecutingQueryResponseProvider * @param compressionEnabled enable compression * @param nestedDataSerializationEnabled enable nested data serialization * @param binaryResults generate results in binary format, rather than JSON + * @param retryUrl optional retry URL for cross-cluster retry + * @param retryExpirationEpochTime optional retry expiration time + * @param isRetryQuery true if this query is already a retry query * @return the ExecutingStatement's Response, if available */ Optional> waitForExecutingResponse( @@ -60,5 +64,9 @@ Optional> waitForExecutingResponse( DataSize targetResultSize, boolean compressionEnabled, boolean nestedDataSerializationEnabled, - boolean binaryResults); + boolean binaryResults, + long durationUntilExpirationMs, + Optional retryUrl, + OptionalLong retryExpirationEpochTime, + boolean isRetryQuery); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingStatementResource.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingStatementResource.java index 950769a1a7904..0973ed96f63e0 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingStatementResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/ExecutingStatementResource.java @@ -15,35 +15,36 @@ import com.facebook.airlift.concurrent.BoundedExecutor; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.QueryResults; +import com.facebook.presto.execution.QueryManager; import com.facebook.presto.server.ForStatementResource; import com.facebook.presto.server.ServerConfig; import com.facebook.presto.spi.QueryId; import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.DELETE; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; - import static com.facebook.airlift.http.server.AsyncResponseHandler.bindAsyncResponse; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREFIX_URL; import static com.facebook.presto.server.protocol.QueryResourceUtil.abortIfPrefixUrlInvalid; import static com.facebook.presto.server.protocol.QueryResourceUtil.toResponse; @@ -53,7 +54,6 @@ import static com.google.common.util.concurrent.Futures.transform; import static com.google.common.util.concurrent.Futures.transformAsync; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.SECONDS; @@ -68,6 +68,7 @@ public class ExecutingStatementResource private final BoundedExecutor responseExecutor; private final LocalQueryProvider queryProvider; + private final QueryManager queryManager; private final boolean compressionEnabled; private final boolean nestedDataSerializationEnabled; private final QueryBlockingRateLimiter queryRateLimiter; @@ -76,11 +77,13 @@ public class ExecutingStatementResource public ExecutingStatementResource( @ForStatementResource BoundedExecutor responseExecutor, LocalQueryProvider queryProvider, + QueryManager queryManager, ServerConfig serverConfig, QueryBlockingRateLimiter queryRateLimiter) { this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); this.queryProvider = requireNonNull(queryProvider, "queryProvider is null"); + this.queryManager = requireNonNull(queryManager, "queryManager is null"); this.compressionEnabled = requireNonNull(serverConfig, "serverConfig is null").isQueryResultsCompressionEnabled(); this.nestedDataSerializationEnabled = requireNonNull(serverConfig, "serverConfig is null").isNestedDataSerializationEnabled(); this.queryRateLimiter = requireNonNull(queryRateLimiter, "queryRateLimiter is null"); @@ -132,9 +135,10 @@ public void getQueryResults( return query.waitForResults(token, uriInfo, effectiveFinalProto, wait, effectiveFinalTargetResultSize, binaryResults); }, responseExecutor); + long durationUntilExpirationMs = queryManager.getDurationUntilExpirationInMillis(queryId); ListenableFuture queryResultsFuture = transform( waitForResultsAsync, - results -> toResponse(query, results, xPrestoPrefixUrl, compressionEnabled, nestedDataSerializationEnabled), + results -> toResponse(query, results, xPrestoPrefixUrl, compressionEnabled, nestedDataSerializationEnabled, durationUntilExpirationMs), directExecutor()); bindAsyncResponse(asyncResponse, queryResultsFuture, responseExecutor); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalExecutingQueryResponseProvider.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalExecutingQueryResponseProvider.java index e0d7dc0c8d855..bca006f05c9ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalExecutingQueryResponseProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalExecutingQueryResponseProvider.java @@ -13,19 +13,21 @@ */ package com.facebook.presto.server.protocol; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.dispatcher.DispatchInfo; import com.facebook.presto.spi.QueryId; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.inject.Inject; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriInfo; +import jakarta.inject.Inject; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import java.net.URI; import java.util.Optional; +import java.util.OptionalLong; +import static com.facebook.presto.server.protocol.QueryResourceUtil.toResponse; import static com.google.common.util.concurrent.Futures.transform; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static java.util.Objects.requireNonNull; @@ -53,18 +55,22 @@ public Optional> waitForExecutingResponse( DataSize targetResultSize, boolean compressionEnabled, boolean nestedDataSerializationEnabled, - boolean binaryResults) + boolean binaryResults, + long durationUntilExpirationMs, + Optional retryUrl, + OptionalLong retryExpirationEpochTime, + boolean isRetryQuery) { Query query; try { - query = queryProvider.getQuery(queryId, slug); + query = queryProvider.getQuery(queryId, slug, retryUrl, retryExpirationEpochTime, isRetryQuery); } catch (WebApplicationException e) { return Optional.empty(); } return Optional.of(transform( query.waitForResults(0, uriInfo, scheme, maxWait, targetResultSize, binaryResults), - results -> QueryResourceUtil.toResponse(query, results, xPrestoPrefixUrl, compressionEnabled, nestedDataSerializationEnabled), + results -> toResponse(query, results, xPrestoPrefixUrl, compressionEnabled, nestedDataSerializationEnabled, durationUntilExpirationMs), directExecutor())); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalQueryProvider.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalQueryProvider.java index 1bbfbb6e7c12e..e0c7fb38c1e2a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalQueryProvider.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/LocalQueryProvider.java @@ -22,28 +22,31 @@ import com.facebook.presto.operator.ExchangeClient; import com.facebook.presto.operator.ExchangeClientSupplier; import com.facebook.presto.server.ForStatementResource; +import com.facebook.presto.server.RetryConfig; import com.facebook.presto.spi.QueryId; import com.facebook.presto.transaction.TransactionManager; - -import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import javax.inject.Inject; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.Status; - +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.Status; + +import java.net.URI; import java.util.Map.Entry; import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.OptionalLong; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; public class LocalQueryProvider { @@ -56,6 +59,7 @@ public class LocalQueryProvider private final BoundedExecutor responseExecutor; private final ScheduledExecutorService timeoutExecutor; private final RetryCircuitBreaker retryCircuitBreaker; + private final RetryConfig retryConfig; private final ConcurrentMap queries = new ConcurrentHashMap<>(); private final ScheduledExecutorService queryPurger = newSingleThreadScheduledExecutor(threadsNamed("execution-query-purger")); @@ -68,7 +72,8 @@ public LocalQueryProvider( BlockEncodingSerde blockEncodingSerde, @ForStatementResource BoundedExecutor responseExecutor, @ForStatementResource ScheduledExecutorService timeoutExecutor, - RetryCircuitBreaker retryCircuitBreaker) + RetryCircuitBreaker retryCircuitBreaker, + RetryConfig retryConfig) { this.queryManager = requireNonNull(queryManager, "queryManager is null"); this.transactionManager = requireNonNull(transactionManager, "transactionManager is null"); @@ -77,6 +82,7 @@ public LocalQueryProvider( this.responseExecutor = requireNonNull(responseExecutor, "responseExecutor is null"); this.timeoutExecutor = requireNonNull(timeoutExecutor, "timeoutExecutor is null"); this.retryCircuitBreaker = requireNonNull(retryCircuitBreaker, "retryCircuitBreaker is null"); + this.retryConfig = requireNonNull(retryConfig, "retryConfig is null"); } @PostConstruct @@ -112,6 +118,16 @@ public void stop() } public Query getQuery(QueryId queryId, String slug) + { + return getQuery(queryId, slug, Optional.empty(), OptionalLong.empty(), false); + } + + public Query getQuery(QueryId queryId, String slug, Optional retryUrl, OptionalLong retryExpirationEpochTime) + { + return getQuery(queryId, slug, retryUrl, retryExpirationEpochTime, false); + } + + public Query getQuery(QueryId queryId, String slug, Optional retryUrl, OptionalLong retryExpirationEpochTime, boolean isRetryQuery) { Query query = queries.get(queryId); if (query != null) { @@ -144,7 +160,11 @@ public Query getQuery(QueryId queryId, String slug) responseExecutor, timeoutExecutor, blockEncodingSerde, - retryCircuitBreaker); + retryCircuitBreaker, + retryConfig, + retryUrl, + retryExpirationEpochTime, + isRetryQuery); }); return query; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/Query.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/Query.java similarity index 87% rename from presto-main-base/src/main/java/com/facebook/presto/server/protocol/Query.java rename to presto-main/src/main/java/com/facebook/presto/server/protocol/Query.java index d9c398636df14..b0545547e9a0f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/Query.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/Query.java @@ -14,6 +14,8 @@ package com.facebook.presto.server.protocol; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.client.Column; import com.facebook.presto.client.FailureInfo; @@ -34,6 +36,7 @@ import com.facebook.presto.execution.StageInfo; import com.facebook.presto.execution.buffer.PagesSerdeFactory; import com.facebook.presto.operator.ExchangeClient; +import com.facebook.presto.server.RetryConfig; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.function.SqlFunctionId; @@ -49,18 +52,16 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.google.common.io.BaseEncoding; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import io.airlift.slice.DynamicSliceOutput; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; +import jakarta.ws.rs.core.UriInfo; import java.net.URI; import java.util.Base64; @@ -85,16 +86,17 @@ import static com.facebook.presto.SystemSessionProperties.useHistoryBasedPlanStatisticsEnabled; import static com.facebook.presto.execution.QueryState.FAILED; import static com.facebook.presto.execution.QueryState.WAITING_FOR_PREREQUISITES; -import static com.facebook.presto.server.protocol.QueryResourceUtil.toStatementStats; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static com.facebook.presto.spi.page.PagesSerdeUtil.writeSerializedPage; import static com.facebook.presto.util.Failures.toFailure; +import static com.facebook.presto.util.QueryInfoUtils.toStatementStats; import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static java.lang.String.format; +import static java.lang.System.currentTimeMillis; import static java.util.Objects.requireNonNull; @ThreadSafe @@ -111,6 +113,9 @@ class Query private final QueryId queryId; private final Session session; private final String slug; + private final Optional retryUrl; + private final OptionalLong retryExpirationEpochTime; + private final boolean isRetryQuery; @GuardedBy("this") private final ExchangeClient exchangeClient; @@ -120,6 +125,7 @@ class Query private final PagesSerde serde; private final RetryCircuitBreaker retryCircuitBreaker; + private final RetryConfig retryConfig; @GuardedBy("this") private OptionalLong nextToken = OptionalLong.of(0); @@ -184,9 +190,26 @@ public static Query create( Executor dataProcessorExecutor, ScheduledExecutorService timeoutExecutor, BlockEncodingSerde blockEncodingSerde, - RetryCircuitBreaker retryCircuitBreaker) - { - Query result = new Query(session, slug, queryManager, transactionManager, exchangeClient, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde, retryCircuitBreaker); + RetryCircuitBreaker retryCircuitBreaker, + RetryConfig retryConfig, + Optional retryUrl, + OptionalLong retryExpirationEpochTime, + boolean isRetryQuery) + { + Query result = new Query( + session, + slug, + retryUrl, + retryExpirationEpochTime, + isRetryQuery, + queryManager, + transactionManager, + exchangeClient, + dataProcessorExecutor, + timeoutExecutor, + blockEncodingSerde, + retryCircuitBreaker, + retryConfig); result.queryManager.addOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); @@ -203,16 +226,22 @@ public static Query create( private Query( Session session, String slug, + Optional retryUrl, + OptionalLong retryExpirationEpochTime, + boolean isRetryQuery, QueryManager queryManager, TransactionManager transactionManager, ExchangeClient exchangeClient, Executor resultsProcessorExecutor, ScheduledExecutorService timeoutExecutor, BlockEncodingSerde blockEncodingSerde, - RetryCircuitBreaker retryCircuitBreaker) + RetryCircuitBreaker retryCircuitBreaker, + RetryConfig retryConfig) { requireNonNull(session, "session is null"); requireNonNull(slug, "slug is null"); + requireNonNull(retryUrl, "retryUrl is null"); + requireNonNull(retryExpirationEpochTime, "retryExpirationEpochTime is null"); requireNonNull(queryManager, "queryManager is null"); requireNonNull(transactionManager, "transactionManager is null"); requireNonNull(exchangeClient, "exchangeClient is null"); @@ -220,6 +249,7 @@ private Query( requireNonNull(timeoutExecutor, "timeoutExecutor is null"); requireNonNull(blockEncodingSerde, "serde is null"); requireNonNull(retryCircuitBreaker, "retryCircuitBreaker is null"); + requireNonNull(retryConfig, "retryConfig is null"); this.queryManager = queryManager; this.transactionManager = transactionManager; @@ -227,12 +257,16 @@ private Query( this.queryId = session.getQueryId(); this.session = session; this.slug = slug; + this.retryUrl = retryUrl; + this.retryExpirationEpochTime = retryExpirationEpochTime; + this.isRetryQuery = isRetryQuery; this.exchangeClient = exchangeClient; this.resultsProcessorExecutor = resultsProcessorExecutor; this.timeoutExecutor = timeoutExecutor; this.serde = new PagesSerdeFactory(blockEncodingSerde, getExchangeCompressionCodec(session), isExchangeChecksumEnabled(session)).createPagesSerde(); this.retryCircuitBreaker = retryCircuitBreaker; + this.retryConfig = retryConfig; } public void cancel() @@ -427,11 +461,13 @@ else if (queryManager.getQueryRetryCount(queryId) == 1 && retryQueryWithHistoryB // build a new query with next uri // we expect failed nodes have been removed from discovery server upon query failure + URI nextUri = createRetryUri(scheme, uriInfo); + return new QueryResults( queryId.toString(), queryResults.getInfoUri(), queryResults.getPartialCancelUri(), - createRetryUri(scheme, uriInfo), + nextUri, queryResults.getColumns(), null, null, @@ -486,7 +522,8 @@ private synchronized QueryResults getNextResult(long token, UriInfo uriInfo, Str DynamicSliceOutput sliceOutput = new DynamicSliceOutput(1000); writeSerializedPage(sliceOutput, serializedPage); - String encodedPage = BASE64_ENCODER.encodeToString(sliceOutput.slice().byteArray()); + byte[] binaryResultArray = sliceOutput.slice().byteArray(); + String encodedPage = BaseEncoding.base64().encode(binaryResultArray, 0, sliceOutput.size()); pages.add(encodedPage); } if (rows > 0) { @@ -526,7 +563,7 @@ private synchronized QueryResults getNextResult(long token, UriInfo uriInfo, Str // TODO: figure out a better way to do this // grab the update count for non-queries - if ((data != null) && (queryInfo.getUpdateType() != null) && (updateCount == null) && + if ((data != null) && (queryInfo.getUpdateInfo() != null) && (updateCount == null) && (columns.size() == 1) && (columns.get(0).getType().equals(StandardTypes.BIGINT))) { Iterator> iterator = data.iterator(); if (iterator.hasNext()) { @@ -597,7 +634,7 @@ private synchronized QueryResults getNextResult(long token, UriInfo uriInfo, Str toStatementStats(queryInfo), toQueryError(queryInfo), queryInfo.getWarnings(), - queryInfo.getUpdateType(), + queryInfo.getUpdateInfo() != null ? queryInfo.getUpdateInfo().getUpdateType() : null, updateCount); // cache the new result @@ -669,6 +706,20 @@ private synchronized URI createNextResultsUri(String scheme, UriInfo uriInfo, lo private synchronized URI createRetryUri(String scheme, UriInfo uriInfo) { + // Check if we have external retry URL information + if (retryUrl.isPresent()) { + // Check if the retry URL has not expired + long currentTime = currentTimeMillis(); + if (currentTime < retryExpirationEpochTime.getAsLong()) { + return retryUrl.get(); + } + else { + log.warn("Retry URL for query %s has expired. Current time: %d, Expiration: %d", + queryId, currentTime, retryExpirationEpochTime.getAsLong()); + } + } + + // Use the default retry mechanism UriBuilder uri = uriInfo.getBaseUriBuilder() .scheme(scheme) .replacePath("/v1/statement/queued/retry") @@ -714,8 +765,28 @@ private boolean retryConditionsMet(QueryResults queryResults) } if (!retryQueryWithHistoryBasedOptimizationEnabled(session)) { - if (!queryResults.getError().isRetriable()) { - return false; + // Check if cross-cluster retry is attempted + if (retryUrl.isPresent()) { + // Check if this is already a retry query - prevent cross-cluster retry chains + if (isRetryQuery) { + log.debug("Query %s is already a retry query, preventing cross-cluster retry chain", queryId); + return false; + } + + // For cross-cluster retry, only allow specific error codes + int errorCode = queryResults.getError().getErrorCode(); + if (!retryConfig.getCrossClusterRetryErrorCodes().contains(errorCode)) { + log.debug("Query %s error code %d is not allowed for cross-cluster retry. Allowed codes: %s", + queryId, errorCode, retryConfig.getCrossClusterRetryErrorCodes()); + return false; + } + } + else { + // For same-cluster retry, use the normal retriable flag + if (!queryResults.getError().isRetriable()) { + log.debug("Query %s error code %s is not retriable", queryId, queryResults.getError().getErrorName()); + return false; + } } // check if we have exceeded the global limit diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java similarity index 74% rename from presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java rename to presto-main/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java index adb47d813e8ef..5cca337690a3f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/QueryResourceUtil.java @@ -15,34 +15,26 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.Column; import com.facebook.presto.client.QueryError; import com.facebook.presto.client.QueryResults; -import com.facebook.presto.client.StageStats; import com.facebook.presto.client.StatementStats; import com.facebook.presto.common.type.NamedTypeSignature; import com.facebook.presto.common.type.ParameterKind; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.common.type.TypeSignatureParameter; -import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryState; -import com.facebook.presto.execution.QueryStats; -import com.facebook.presto.execution.StageExecutionInfo; -import com.facebook.presto.execution.StageExecutionStats; -import com.facebook.presto.execution.StageInfo; -import com.facebook.presto.execution.TaskInfo; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; -import io.airlift.units.Duration; - -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.CacheControl; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; +import jakarta.ws.rs.core.UriInfo; import java.io.UnsupportedEncodingException; import java.net.URI; @@ -52,7 +44,6 @@ import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -82,12 +73,12 @@ import static com.facebook.presto.execution.QueryState.WAITING_FOR_PREREQUISITES; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Strings.isNullOrEmpty; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; import static java.lang.String.format; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.stream.Collectors.toList; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; public final class QueryResourceUtil { @@ -100,7 +91,7 @@ public final class QueryResourceUtil private QueryResourceUtil() {} - public static Response toResponse(Query query, QueryResults queryResults, boolean compressionEnabled) + public static Response toResponse(Query query, QueryResults queryResults, boolean compressionEnabled, long durationUntilExpirationMs) { Response.ResponseBuilder response = Response.ok(queryResults); @@ -158,10 +149,18 @@ public static Response toResponse(Query query, QueryResults queryResults, boolea response.header(PRESTO_REMOVED_SESSION_FUNCTION, urlEncode(SQL_FUNCTION_ID_JSON_CODEC.toJson(signature))); } + response.cacheControl(getCacheControlMaxAge(durationUntilExpirationMs)); + return response.build(); } - public static Response toResponse(Query query, QueryResults queryResults, String xPrestoPrefixUri, boolean compressionEnabled, boolean nestedDataSerializationEnabled) + public static Response toResponse( + Query query, + QueryResults queryResults, + String xPrestoPrefixUri, + boolean compressionEnabled, + boolean nestedDataSerializationEnabled, + long durationUntilExpirationMs) { Iterable> queryResultsData = queryResults.getData(); if (nestedDataSerializationEnabled) { @@ -181,7 +180,12 @@ public static Response toResponse(Query query, QueryResults queryResults, String queryResults.getUpdateType(), queryResults.getUpdateCount()); - return toResponse(query, resultsClone, compressionEnabled); + return toResponse(query, resultsClone, compressionEnabled, durationUntilExpirationMs); + } + + public static CacheControl getCacheControlMaxAge(long durationUntilExpirationMs) + { + return CacheControl.valueOf("max-age=" + MILLISECONDS.toSeconds(durationUntilExpirationMs)); } public static void abortIfPrefixUrlInvalid(String xPrestoPrefixUrl) @@ -216,40 +220,6 @@ public static URI prependUri(URI backendUri, String xPrestoPrefixUrl) return backendUri; } - public static StatementStats toStatementStats(QueryInfo queryInfo) - { - QueryStats queryStats = queryInfo.getQueryStats(); - StageInfo outputStage = queryInfo.getOutputStage().orElse(null); - - Set globalUniqueNodes = new HashSet<>(); - StageStats rootStageStats = toStageStats(outputStage, globalUniqueNodes); - - return StatementStats.builder() - .setState(queryInfo.getState().toString()) - .setWaitingForPrerequisites(queryInfo.getState() == QueryState.WAITING_FOR_PREREQUISITES) - .setQueued(queryInfo.getState() == QueryState.QUEUED) - .setScheduled(queryInfo.isScheduled()) - .setNodes(globalUniqueNodes.size()) - .setTotalSplits(queryStats.getTotalDrivers()) - .setQueuedSplits(queryStats.getQueuedDrivers()) - .setRunningSplits(queryStats.getRunningDrivers() + queryStats.getBlockedDrivers()) - .setCompletedSplits(queryStats.getCompletedDrivers()) - .setCpuTimeMillis(queryStats.getTotalCpuTime().toMillis()) - .setWallTimeMillis(queryStats.getTotalScheduledTime().toMillis()) - .setWaitingForPrerequisitesTimeMillis(queryStats.getWaitingForPrerequisitesTime().toMillis()) - .setQueuedTimeMillis(queryStats.getQueuedTime().toMillis()) - .setElapsedTimeMillis(queryStats.getElapsedTime().toMillis()) - .setProcessedRows(queryStats.getRawInputPositions()) - .setProcessedBytes(queryStats.getRawInputDataSize().toBytes()) - .setPeakMemoryBytes(queryStats.getPeakUserMemoryReservation().toBytes()) - .setPeakTotalMemoryBytes(queryStats.getPeakTotalMemoryReservation().toBytes()) - .setPeakTaskTotalMemoryBytes(queryStats.getPeakTaskTotalMemory().toBytes()) - .setSpilledBytes(queryStats.getSpilledDataSize().toBytes()) - .setRootStage(rootStageStats) - .setRuntimeStats(queryStats.getRuntimeStats()) - .build(); - } - private static String urlEncode(String value) { try { @@ -260,57 +230,6 @@ private static String urlEncode(String value) } } - private static StageStats toStageStats(StageInfo stageInfo, Set globalUniqueNodeIds) - { - if (stageInfo == null) { - return null; - } - - StageExecutionInfo currentStageExecutionInfo = stageInfo.getLatestAttemptExecutionInfo(); - StageExecutionStats stageExecutionStats = currentStageExecutionInfo.getStats(); - - // Store current stage details into a builder - StageStats.Builder builder = StageStats.builder() - .setStageId(String.valueOf(stageInfo.getStageId().getId())) - .setState(currentStageExecutionInfo.getState().toString()) - .setDone(currentStageExecutionInfo.getState().isDone()) - .setTotalSplits(stageExecutionStats.getTotalDrivers()) - .setQueuedSplits(stageExecutionStats.getQueuedDrivers()) - .setRunningSplits(stageExecutionStats.getRunningDrivers() + stageExecutionStats.getBlockedDrivers()) - .setCompletedSplits(stageExecutionStats.getCompletedDrivers()) - .setCpuTimeMillis(stageExecutionStats.getTotalCpuTime().toMillis()) - .setWallTimeMillis(stageExecutionStats.getTotalScheduledTime().toMillis()) - .setProcessedRows(stageExecutionStats.getRawInputPositions()) - .setProcessedBytes(stageExecutionStats.getRawInputDataSizeInBytes()) - .setNodes(countStageAndAddGlobalUniqueNodes(currentStageExecutionInfo.getTasks(), globalUniqueNodeIds)); - - // Recurse into child stages to create their StageStats - List subStages = stageInfo.getSubStages(); - if (subStages.isEmpty()) { - builder.setSubStages(ImmutableList.of()); - } - else { - ImmutableList.Builder subStagesBuilder = ImmutableList.builderWithExpectedSize(subStages.size()); - for (StageInfo subStage : subStages) { - subStagesBuilder.add(toStageStats(subStage, globalUniqueNodeIds)); - } - builder.setSubStages(subStagesBuilder.build()); - } - - return builder.build(); - } - - private static int countStageAndAddGlobalUniqueNodes(List tasks, Set globalUniqueNodes) - { - Set stageUniqueNodes = Sets.newHashSetWithExpectedSize(tasks.size()); - for (TaskInfo task : tasks) { - String nodeId = task.getNodeId(); - stageUniqueNodes.add(nodeId); - globalUniqueNodes.add(nodeId); - } - return stageUniqueNodes.size(); - } - /** * Problem: As the type of data defined in QueryResult is `Iterable>`, * when jackson serialize a data with nested data structure in the response, the nested object won't diff --git a/presto-main/src/main/java/com/facebook/presto/server/protocol/QueuedStatementResource.java b/presto-main/src/main/java/com/facebook/presto/server/protocol/QueuedStatementResource.java index 82b4b55cbb459..8eb1baf8d18cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/protocol/QueuedStatementResource.java +++ b/presto-main/src/main/java/com/facebook/presto/server/protocol/QueuedStatementResource.java @@ -15,6 +15,8 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.TimeStat; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.QueryError; import com.facebook.presto.client.QueryResults; import com.facebook.presto.common.ErrorCode; @@ -24,6 +26,7 @@ import com.facebook.presto.execution.ExecutionFailureInfo; import com.facebook.presto.metadata.SessionPropertyManager; import com.facebook.presto.server.HttpRequestSessionContext; +import com.facebook.presto.server.RetryUrlValidator; import com.facebook.presto.server.ServerConfig; import com.facebook.presto.server.SessionContext; import com.facebook.presto.spi.PrestoException; @@ -33,38 +36,39 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.PreDestroy; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.PUT; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.CacheControl; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.Status; +import jakarta.ws.rs.core.UriInfo; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.DELETE; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.POST; -import javax.ws.rs.PUT; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.container.AsyncResponse; -import javax.ws.rs.container.Suspended; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.UriInfo; - +import java.io.UnsupportedEncodingException; import java.net.URI; +import java.net.URLDecoder; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.OptionalLong; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -73,10 +77,13 @@ import static com.facebook.airlift.concurrent.MoreFutures.addTimeout; import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.airlift.http.server.AsyncResponseHandler.bindAsyncResponse; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREFIX_URL; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_RETRY_QUERY; import static com.facebook.presto.server.protocol.QueryResourceUtil.NO_DURATION; import static com.facebook.presto.server.protocol.QueryResourceUtil.abortIfPrefixUrlInvalid; import static com.facebook.presto.server.protocol.QueryResourceUtil.createQueuedQueryResults; +import static com.facebook.presto.server.protocol.QueryResourceUtil.getCacheControlMaxAge; import static com.facebook.presto.server.protocol.QueryResourceUtil.getQueuedUri; import static com.facebook.presto.server.protocol.QueryResourceUtil.getScheme; import static com.facebook.presto.server.security.RoleType.USER; @@ -90,18 +97,21 @@ import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.Futures.transformAsync; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.CONFLICT; +import static jakarta.ws.rs.core.Response.Status.NOT_FOUND; +import static java.lang.Boolean.parseBoolean; +import static java.lang.String.format; +import static java.lang.System.currentTimeMillis; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.UUID.randomUUID; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; -import static javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -import static javax.ws.rs.core.Response.Status.CONFLICT; -import static javax.ws.rs.core.Response.Status.NOT_FOUND; @Path("/") @RolesAllowed(USER) @@ -129,6 +139,7 @@ public class QueuedStatementResource private final SessionPropertyManager sessionPropertyManager; // We may need some system default session property values at early query stage even before session is created. private final QueryBlockingRateLimiter queryRateLimiter; + private final RetryUrlValidator retryUrlValidator; @Inject public QueuedStatementResource( @@ -139,7 +150,8 @@ public QueuedStatementResource( ServerConfig serverConfig, TracerProviderManager tracerProviderManager, SessionPropertyManager sessionPropertyManager, - QueryBlockingRateLimiter queryRateLimiter) + QueryBlockingRateLimiter queryRateLimiter, + RetryUrlValidator retryUrlValidator) { this.dispatchManager = requireNonNull(dispatchManager, "dispatchManager is null"); this.executingQueryResponseProvider = requireNonNull(executingQueryResponseProvider, "executingQueryResponseProvider is null"); @@ -153,6 +165,7 @@ public QueuedStatementResource( this.sessionPropertyManager = sessionPropertyManager; this.queryRateLimiter = requireNonNull(queryRateLimiter, "queryRateLimiter is null"); + this.retryUrlValidator = requireNonNull(retryUrlValidator, "retryUrlValidator is null"); queryPurger.scheduleWithFixedDelay( () -> { @@ -191,8 +204,8 @@ public void stop() * @param statement The statement or sql query string submitted * @param xForwardedProto Forwarded protocol (http or https) * @param servletRequest The http request - * @param uriInfo {@link javax.ws.rs.core.UriInfo} - * @return {@link javax.ws.rs.core.Response} HTTP response code + * @param uriInfo {@link jakarta.ws.rs.core.UriInfo} + * @return {@link jakarta.ws.rs.core.Response} HTTP response code */ @POST @Path("/v1/statement") @@ -200,6 +213,9 @@ public void stop() public Response postStatement( String statement, @DefaultValue("false") @QueryParam("binaryResults") boolean binaryResults, + @QueryParam("retryUrl") String retryUrlString, + @QueryParam("retryExpirationInSeconds") Long retryExpirationInSeconds, + @HeaderParam(PRESTO_RETRY_QUERY) String isRetryQueryHeader, @HeaderParam(X_FORWARDED_PROTO) String xForwardedProto, @HeaderParam(PRESTO_PREFIX_URL) String xPrestoPrefixUrl, @Context HttpServletRequest servletRequest, @@ -211,17 +227,61 @@ public Response postStatement( abortIfPrefixUrlInvalid(xPrestoPrefixUrl); + // Parse retry query header + boolean isRetryQuery = parseBoolean(isRetryQueryHeader); + + // Validate retry URL if provided + Optional retryUrl = Optional.empty(); + OptionalLong retryExpirationEpochTime = OptionalLong.empty(); + if ((retryUrlString != null && !retryUrlString.isEmpty()) || retryExpirationInSeconds != null) { + if (retryUrlString == null || retryUrlString.isEmpty() || retryExpirationInSeconds == null || retryExpirationInSeconds < 1) { + throw badRequest(BAD_REQUEST, format("Invalid retry parameters: retryUrl=%s, retryExpiration=%s", retryUrlString, retryExpirationInSeconds)); + } + retryUrl = Optional.of(getRetryUrl(retryUrlString)); + retryExpirationEpochTime = OptionalLong.of(currentTimeMillis() + SECONDS.toMillis(retryExpirationInSeconds)); + String currentHost = uriInfo.getBaseUri().getHost(); + if (!retryUrlValidator.isValidRetryUrl(retryUrl.get(), currentHost)) { + throw badRequest(BAD_REQUEST, "Invalid retry URL"); + } + } + // TODO: For future cases we may want to start tracing from client. Then continuation of tracing // will be needed instead of creating a new trace here. SessionContext sessionContext = new HttpRequestSessionContext( servletRequest, sqlParserOptions, tracerProviderManager.getTracerProvider(), - Optional.of(sessionPropertyManager)); - Query query = new Query(statement, sessionContext, dispatchManager, executingQueryResponseProvider, 0); + Optional.of(sessionPropertyManager), + statement); + QueryId newQueryId = dispatchManager.createQueryId(); + Query query = new Query( + statement, + sessionContext, + dispatchManager, + executingQueryResponseProvider, + 0, + newQueryId, + createSlug(), + isRetryQuery, + retryUrl, + retryExpirationEpochTime); + queries.put(query.getQueryId(), query); - return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled).build(); + return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled) + .cacheControl(query.getDefaultCacheControl()) + .build(); + } + + private static URI getRetryUrl(String urlEncodedUrl) + { + try { + String decodedUrl = URLDecoder.decode(urlEncodedUrl, UTF_8.toString()); + return URI.create(decodedUrl); + } + catch (UnsupportedEncodingException | IllegalArgumentException e) { + throw badRequest(BAD_REQUEST, "Retry URL invalid"); + } } /** @@ -237,8 +297,8 @@ public Response postStatement( * @param slug Pre-minted slug to protect this query * @param xForwardedProto Forwarded protocol (http or https) * @param servletRequest The http request - * @param uriInfo {@link javax.ws.rs.core.UriInfo} - * @return {@link javax.ws.rs.core.Response} HTTP response code + * @param uriInfo {@link jakarta.ws.rs.core.UriInfo} + * @return {@link jakarta.ws.rs.core.Response} HTTP response code */ @PUT @Path("/v1/statement/{queryId}") @@ -265,7 +325,8 @@ public Response putStatement( servletRequest, sqlParserOptions, tracerProviderManager.getTracerProvider(), - Optional.of(sessionPropertyManager)); + Optional.of(sessionPropertyManager), + statement); Query attemptedQuery = new Query(statement, sessionContext, dispatchManager, executingQueryResponseProvider, 0, queryId, slug); Query query = queries.computeIfAbsent(queryId, unused -> attemptedQuery); @@ -273,15 +334,17 @@ public Response putStatement( throw badRequest(CONFLICT, "Query already exists"); } - return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled).build(); + return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled) + .cacheControl(query.getDefaultCacheControl()) + .build(); } /** * HTTP endpoint for re-processing a failed query * @param queryId Query Identifier of the query to be retried * @param xForwardedProto Forwarded protocol (http or https) - * @param uriInfo {@link javax.ws.rs.core.UriInfo} - * @return {@link javax.ws.rs.core.Response} HTTP response code + * @param uriInfo {@link jakarta.ws.rs.core.UriInfo} + * @return {@link jakarta.ws.rs.core.Response} HTTP response code */ @GET @Path("/v1/statement/queued/retry/{queryId}") @@ -302,6 +365,12 @@ public Response retryFailedQuery( throw new PrestoException(RETRY_QUERY_NOT_FOUND, "failed to find the query to retry with ID " + queryId); } + if (dispatchManager.isQueryPresent(queryId) && + dispatchManager.getQueryInfo(queryId).getFailureInfo() == null && + !failedQuery.isRetryQuery()) { + throw badRequest(CONFLICT, "Query with ID " + queryId + " has not failed and cannot be retried"); + } + int retryCount = failedQuery.getRetryCount() + 1; Query query = new Query( "-- retry query " + queryId + "; attempt: " + retryCount + "\n" + failedQuery.getQuery(), @@ -312,6 +381,11 @@ public Response retryFailedQuery( retriedQueries.putIfAbsent(queryId, query); synchronized (retriedQueries.get(queryId)) { + // Retry queries should never be processed except through this endpoint + if (failedQuery.isRetryQuery() && failedQuery.getLastToken() != 0) { + throw badRequest(CONFLICT, "Query with ID " + queryId + " has already been processed and cannot be retried"); + } + if (retriedQueries.get(queryId).getQueryId().equals(query.getQueryId())) { queries.put(query.getQueryId(), query); } @@ -322,7 +396,9 @@ public Response retryFailedQuery( } } - return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled).build(); + return withCompressionConfiguration(Response.ok(query.getInitialQueryResults(uriInfo, xForwardedProto, xPrestoPrefixUrl, binaryResults)), compressionEnabled) + .cacheControl(query.getDefaultCacheControl()) + .build(); } /** @@ -332,7 +408,7 @@ public Response retryFailedQuery( * @param slug Unique security token generated for each query that controls access to that query's results * @param maxWait Time to wait for the query to be dispatched * @param xForwardedProto Forwarded protocol (http or https) - * @param uriInfo {@link javax.ws.rs.core.UriInfo} + * @param uriInfo {@link jakarta.ws.rs.core.UriInfo} * @param asyncResponse */ @GET @@ -352,6 +428,14 @@ public void getStatus( abortIfPrefixUrlInvalid(xPrestoPrefixUrl); Query query = getQuery(queryId, slug); + + if (query.isRetryQuery()) { + throw badRequest( + CONFLICT, + format("Query with ID %s is a retry query and cannot be polled directly. Use /v1/statement/queued/retry/%s endpoint to retry it.", + queryId, queryId)); + } + ListenableFuture acquirePermitAsync = queryRateLimiter.acquire(queryId); ListenableFuture waitForDispatchedAsync = transformAsync( acquirePermitAsync, @@ -379,7 +463,7 @@ public void getStatus( * @param queryId Query Identifier of query to be canceled * @param token Monotonically increasing token that identifies the next batch of query results * @param slug Unique security token generated for each query that controls access to that query's results - * @return {@link javax.ws.rs.core.Response} HTTP response code + * @return {@link jakarta.ws.rs.core.Response} HTTP response code */ @DELETE @Path("/v1/statement/queued/{queryId}/{token}") @@ -433,8 +517,14 @@ private static Response.ResponseBuilder withCompressionConfiguration(Response.Re return builder; } + private static String createSlug() + { + return "x" + randomUUID().toString().toLowerCase(ENGLISH).replace("-", ""); + } + private static final class Query { + private static final int CACHE_CONTROL_MAX_AGE_SEC = 60; private final String query; private final SessionContext sessionContext; private final DispatchManager dispatchManager; @@ -443,13 +533,17 @@ private static final class Query private final String slug; private final AtomicLong lastToken = new AtomicLong(); private final int retryCount; + private final long expirationTime; + private final boolean isRetryQuery; + private final Optional retryUrl; + private final OptionalLong retryExpirationEpochTime; @GuardedBy("this") private ListenableFuture querySubmissionFuture; public Query(String query, SessionContext sessionContext, DispatchManager dispatchManager, ExecutingQueryResponseProvider executingQueryResponseProvider, int retryCount) { - this(query, sessionContext, dispatchManager, executingQueryResponseProvider, retryCount, dispatchManager.createQueryId(), createSlug()); + this(query, sessionContext, dispatchManager, executingQueryResponseProvider, retryCount, dispatchManager.createQueryId(), createSlug(), false, Optional.empty(), OptionalLong.empty()); } public Query( @@ -460,6 +554,21 @@ public Query( int retryCount, QueryId queryId, String slug) + { + this(query, sessionContext, dispatchManager, executingQueryResponseProvider, retryCount, queryId, slug, false, Optional.empty(), OptionalLong.empty()); + } + + public Query( + String query, + SessionContext sessionContext, + DispatchManager dispatchManager, + ExecutingQueryResponseProvider executingQueryResponseProvider, + int retryCount, + QueryId queryId, + String slug, + boolean isRetryQuery, + Optional retryUrl, + OptionalLong retryExpirationEpochTime) { this.query = requireNonNull(query, "query is null"); this.sessionContext = requireNonNull(sessionContext, "sessionContext is null"); @@ -468,6 +577,10 @@ public Query( this.retryCount = retryCount; this.queryId = requireNonNull(queryId, "queryId is null"); this.slug = requireNonNull(slug, "slug is null"); + this.expirationTime = currentTimeMillis() + SECONDS.toMillis(CACHE_CONTROL_MAX_AGE_SEC); + this.isRetryQuery = isRetryQuery; + this.retryUrl = requireNonNull(retryUrl, "retryUrl is null"); + this.retryExpirationEpochTime = requireNonNull(retryExpirationEpochTime, "retryExpirationEpochTime is null"); } /** @@ -510,6 +623,14 @@ public long getLastToken() return lastToken.get(); } + /** + * Returns whether or not this query was stored as a retry query + */ + public boolean isRetryQuery() + { + return isRetryQuery; + } + /** * Returns the retry attempt of the query */ @@ -518,6 +639,15 @@ public int getRetryCount() return retryCount; } + /** + * Returns a cache control with the default max age value + */ + public CacheControl getDefaultCacheControl() + { + long maxAgeMillis = Math.max(0, expirationTime - currentTimeMillis()); + return CacheControl.valueOf("max-age=" + MILLISECONDS.toSeconds(maxAgeMillis)); + } + /** * Checks whether the query has been processed by the dispatchManager */ @@ -546,7 +676,7 @@ private ListenableFuture waitForDispatched() /** * Returns a placeholder for query results for the client to poll - * @param uriInfo {@link javax.ws.rs.core.UriInfo} + * @param uriInfo {@link jakarta.ws.rs.core.UriInfo} * @param xForwardedProto Forwarded protocol (http or https) * @return {@link com.facebook.presto.client.QueryResults} */ @@ -591,7 +721,10 @@ public ListenableFuture toResponse( xPrestoPrefixUrl, DispatchInfo.waitingForPrerequisites(NO_DURATION, NO_DURATION), binaryResults); - return immediateFuture(withCompressionConfiguration(Response.ok(queryResults), compressionEnabled).build()); + + return immediateFuture(withCompressionConfiguration(Response.ok(queryResults), compressionEnabled) + .cacheControl(getDefaultCacheControl()) + .build()); } } @@ -602,6 +735,7 @@ public ListenableFuture toResponse( .status(NOT_FOUND) .build())); } + long durationUntilExpirationMs = dispatchManager.getDurationUntilExpirationInMillis(queryId); if (waitForDispatched().isDone()) { Optional> executingQueryResponse = executingQueryResponseProvider.waitForExecutingResponse( @@ -615,7 +749,11 @@ public ListenableFuture toResponse( TARGET_RESULT_SIZE, compressionEnabled, nestedDataSerializationEnabled, - binaryResults); + binaryResults, + durationUntilExpirationMs, + retryUrl, + retryExpirationEpochTime, + isRetryQuery); if (executingQueryResponse.isPresent()) { return executingQueryResponse.get(); @@ -624,6 +762,7 @@ public ListenableFuture toResponse( return immediateFuture(withCompressionConfiguration(Response.ok( createQueryResults(token + 1, uriInfo, xForwardedProto, xPrestoPrefixUrl, dispatchInfo.get(), binaryResults)), compressionEnabled) + .cacheControl(getCacheControlMaxAge(durationUntilExpirationMs)) .build()); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcher.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcher.java deleted file mode 100644 index dea0a2b4839ef..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcher.java +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.remotetask; - -import com.facebook.airlift.concurrent.SetThreadName; -import com.facebook.airlift.http.client.HttpClient; -import com.facebook.airlift.http.client.Request; -import com.facebook.airlift.http.client.ResponseHandler; -import com.facebook.airlift.http.client.thrift.ThriftRequestUtils; -import com.facebook.airlift.http.client.thrift.ThriftResponseHandler; -import com.facebook.airlift.json.Codec; -import com.facebook.airlift.json.JsonCodec; -import com.facebook.airlift.json.smile.SmileCodec; -import com.facebook.airlift.log.Logger; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.execution.StateMachine; -import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskStatus; -import com.facebook.presto.server.RequestErrorTracker; -import com.facebook.presto.server.SimpleHttpResponseCallback; -import com.facebook.presto.server.SimpleHttpResponseHandler; -import com.facebook.presto.server.smile.BaseResponse; -import com.facebook.presto.server.thrift.ThriftHttpResponseHandler; -import com.facebook.presto.spi.HostAddress; -import com.facebook.presto.spi.PrestoException; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; - -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; - -import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static com.facebook.airlift.http.client.Request.Builder.prepareGet; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; -import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker; -import static com.facebook.presto.server.RequestHelpers.getBinaryTransportBuilder; -import static com.facebook.presto.server.RequestHelpers.getJsonTransportBuilder; -import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; -import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; -import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; -import static com.facebook.presto.util.Failures.REMOTE_TASK_MISMATCH_ERROR; -import static io.airlift.units.Duration.nanosSince; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -class ContinuousTaskStatusFetcher - implements SimpleHttpResponseCallback -{ - private static final Logger log = Logger.get(ContinuousTaskStatusFetcher.class); - - private final TaskId taskId; - private final Consumer onFail; - private final StateMachine taskStatus; - private final Codec taskStatusCodec; - - private final Duration refreshMaxWait; - private final Executor executor; - private final HttpClient httpClient; - private final RequestErrorTracker errorTracker; - private final RemoteTaskStats stats; - private final boolean binaryTransportEnabled; - private final boolean thriftTransportEnabled; - private final Protocol thriftProtocol; - - private final AtomicLong currentRequestStartNanos = new AtomicLong(); - - @GuardedBy("this") - private boolean running; - - @GuardedBy("this") - private ListenableFuture> future; - - public ContinuousTaskStatusFetcher( - Consumer onFail, - TaskId taskId, - TaskStatus initialTaskStatus, - Duration refreshMaxWait, - Codec taskStatusCodec, - Executor executor, - HttpClient httpClient, - Duration maxErrorDuration, - ScheduledExecutorService errorScheduledExecutor, - RemoteTaskStats stats, - boolean binaryTransportEnabled, - boolean thriftTransportEnabled, - Protocol thriftProtocol) - { - requireNonNull(initialTaskStatus, "initialTaskStatus is null"); - - this.taskId = requireNonNull(taskId, "taskId is null"); - this.onFail = requireNonNull(onFail, "onFail is null"); - this.taskStatus = new StateMachine<>("task-" + taskId, executor, initialTaskStatus); - - this.refreshMaxWait = requireNonNull(refreshMaxWait, "refreshMaxWait is null"); - this.taskStatusCodec = requireNonNull(taskStatusCodec, "taskStatusCodec is null"); - - this.executor = requireNonNull(executor, "executor is null"); - this.httpClient = requireNonNull(httpClient, "httpClient is null"); - - this.errorTracker = taskRequestErrorTracker(taskId, initialTaskStatus.getSelf(), maxErrorDuration, errorScheduledExecutor, "getting task status"); - this.stats = requireNonNull(stats, "stats is null"); - this.binaryTransportEnabled = binaryTransportEnabled; - this.thriftTransportEnabled = thriftTransportEnabled; - this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null"); - } - - public synchronized void start() - { - if (running) { - // already running - return; - } - running = true; - scheduleNextRequest(); - } - - public synchronized void stop() - { - running = false; - if (future != null) { - // do not terminate if the request is already running to avoid closing pooled connections - future.cancel(false); - future = null; - } - } - - private synchronized void scheduleNextRequest() - { - // stopped or done? - TaskStatus taskStatus = getTaskStatus(); - if (!running || taskStatus.getState().isDone()) { - return; - } - - // outstanding request? - if (future != null && !future.isDone()) { - // this should never happen - log.error("Can not reschedule update because an update is already running"); - return; - } - - // if throttled due to error, asynchronously wait for timeout and try again - ListenableFuture errorRateLimit = errorTracker.acquireRequestPermit(); - if (!errorRateLimit.isDone()) { - errorRateLimit.addListener(this::scheduleNextRequest, executor); - return; - } - - Request.Builder requestBuilder; - ResponseHandler responseHandler; - if (thriftTransportEnabled) { - requestBuilder = ThriftRequestUtils.prepareThriftGet(thriftProtocol); - responseHandler = new ThriftResponseHandler(unwrapThriftCodec(taskStatusCodec)); - } - else if (binaryTransportEnabled) { - requestBuilder = getBinaryTransportBuilder(prepareGet()); - responseHandler = createFullSmileResponseHandler((SmileCodec) taskStatusCodec); - } - else { - requestBuilder = getJsonTransportBuilder(prepareGet()); - responseHandler = createAdaptingJsonResponseHandler((JsonCodec) taskStatusCodec); - } - - Request request = requestBuilder.setUri(uriBuilderFrom(taskStatus.getSelf()).appendPath("status").build()) - .setHeader(PRESTO_CURRENT_STATE, taskStatus.getState().toString()) - .setHeader(PRESTO_MAX_WAIT, refreshMaxWait.toString()) - .build(); - - errorTracker.startRequest(); - future = httpClient.executeAsync(request, responseHandler); - currentRequestStartNanos.set(System.nanoTime()); - FutureCallback callback; - if (thriftTransportEnabled) { - callback = new ThriftHttpResponseHandler(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR); - } - else { - callback = new SimpleHttpResponseHandler<>(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR); - } - - Futures.addCallback( - future, - callback, - executor); - } - - TaskStatus getTaskStatus() - { - return taskStatus.get(); - } - - @Override - public void success(TaskStatus value) - { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - updateTaskStatus(value); - errorTracker.requestSucceeded(); - } - finally { - scheduleNextRequest(); - } - } - } - - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - try { - // if task not already done, record error - TaskStatus taskStatus = getTaskStatus(); - if (!taskStatus.getState().isDone()) { - errorTracker.requestFailed(cause); - } - } - catch (Error e) { - onFail.accept(e); - throw e; - } - catch (RuntimeException e) { - onFail.accept(e); - } - finally { - scheduleNextRequest(); - } - } - } - - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("ContinuousTaskStatusFetcher-%s", taskId)) { - updateStats(currentRequestStartNanos.get()); - onFail.accept(cause); - } - } - - void updateTaskStatus(TaskStatus newValue) - { - // change to new value if old value is not changed and new value has a newer version - AtomicBoolean taskMismatch = new AtomicBoolean(); - taskStatus.setIf(newValue, oldValue -> { - // did the task instance id change - boolean isEmpty = oldValue.getTaskInstanceIdLeastSignificantBits() == 0 && oldValue.getTaskInstanceIdMostSignificantBits() == 0; - if (!isEmpty && - !(oldValue.getTaskInstanceIdLeastSignificantBits() == newValue.getTaskInstanceIdLeastSignificantBits() && - oldValue.getTaskInstanceIdMostSignificantBits() == newValue.getTaskInstanceIdMostSignificantBits())) { - taskMismatch.set(true); - return false; - } - - if (oldValue.getState().isDone()) { - // never update if the task has reached a terminal state - return false; - } - if (newValue.getVersion() < oldValue.getVersion()) { - // don't update to an older version (same version is ok) - return false; - } - return true; - }); - - if (taskMismatch.get()) { - // This will also set the task status to FAILED state directly. - // Additionally, this will issue a DELETE for the task to the worker. - // While sending the DELETE is not required, it is preferred because a task was created by the previous request. - onFail.accept(new PrestoException(REMOTE_TASK_MISMATCH, format("%s (%s)", REMOTE_TASK_MISMATCH_ERROR, HostAddress.fromUri(getTaskStatus().getSelf())))); - } - } - - public synchronized boolean isRunning() - { - return running; - } - - /** - * Listener is always notified asynchronously using a dedicated notification thread pool so, care should - * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is - * possible notifications are observed out of order due to the asynchronous execution. - */ - public void addStateChangeListener(StateMachine.StateChangeListener stateChangeListener) - { - taskStatus.addStateChangeListener(stateChangeListener); - } - - private void updateStats(long currentRequestStartNanos) - { - stats.statusRoundTripMillis(nanosSince(currentRequestStartNanos).toMillis()); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcherWithEventLoop.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcherWithEventLoop.java index c7d1a3658dc4d..5207ee7a1850a 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcherWithEventLoop.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ContinuousTaskStatusFetcherWithEventLoop.java @@ -22,6 +22,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.execution.StateMachine; import com.facebook.presto.execution.TaskId; @@ -36,7 +37,6 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import io.netty.channel.EventLoop; import java.util.concurrent.atomic.AtomicBoolean; @@ -44,6 +44,7 @@ import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static com.facebook.airlift.http.client.Request.Builder.prepareGet; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker; @@ -56,7 +57,6 @@ import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; import static com.facebook.presto.util.Failures.REMOTE_TASK_MISMATCH_ERROR; import static com.google.common.base.Verify.verify; -import static io.airlift.units.Duration.nanosSince; import static java.lang.String.format; import static java.util.Objects.requireNonNull; diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/DecompressionFilter.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/DecompressionFilter.java new file mode 100644 index 0000000000000..54ee336358776 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/DecompressionFilter.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.spi.PrestoException; +import com.github.luben.zstd.ZstdInputStream; +import jakarta.annotation.Priority; +import jakarta.ws.rs.Priorities; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.ext.Provider; + +import java.io.IOException; +import java.io.InputStream; + +import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static java.lang.String.format; + +@Provider +@Priority(Priorities.ENTITY_CODER) +public class DecompressionFilter + implements ContainerRequestFilter +{ + private static final Logger log = Logger.get(DecompressionFilter.class); + + @Override + public void filter(ContainerRequestContext containerRequestContext) + throws IOException + { + String contentEncoding = containerRequestContext.getHeaderString("Content-Encoding"); + + if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) { + InputStream originalStream = containerRequestContext.getEntityStream(); + InputStream decompressedStream; + + if (contentEncoding.equalsIgnoreCase("zstd")) { + decompressedStream = new ZstdInputStream(originalStream); + } + else { + throw new PrestoException(NOT_SUPPORTED, format("Unsupported Content-Encoding: '%s'. Only zstd compression is supported.", contentEncoding)); + } + + containerRequestContext.setEntityStream(decompressedStream); + containerRequestContext.getHeaders().remove("Content-Encoding"); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientConnectionPoolStats.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientConnectionPoolStats.java new file mode 100644 index 0000000000000..91513980c3ba4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientConnectionPoolStats.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.stats.DistributionStat; +import com.google.inject.Inject; +import com.google.inject.Singleton; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; +import reactor.netty.resources.ConnectionPoolMetrics; +import reactor.netty.resources.ConnectionProvider; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +@Singleton +public class HttpClientConnectionPoolStats + implements ConnectionProvider.MeterRegistrar +{ + private final ConcurrentHashMap poolMetrics = new ConcurrentHashMap<>(); + + private final DistributionStat activeConnections = new DistributionStat(); + private final DistributionStat totalConnections = new DistributionStat(); + private final DistributionStat idleConnections = new DistributionStat(); + private final DistributionStat pendingAcquires = new DistributionStat(); + private final DistributionStat maxConnections = new DistributionStat(); + private final DistributionStat maxPendingAcquires = new DistributionStat(); + + @Inject + public HttpClientConnectionPoolStats() + { + scheduleStatsExport(); + } + + @Override + public void registerMetrics(String poolName, String id, SocketAddress remoteAddress, ConnectionPoolMetrics metrics) + { + poolMetrics.put(createPoolKey(poolName, remoteAddress), metrics); + } + + private static String createPoolKey(String poolName, SocketAddress remoteAddress) + { + return poolName + ":" + formatSocketAddress(remoteAddress); + } + + private static String formatSocketAddress(SocketAddress socketAddress) + { + if (socketAddress != null) { + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress address = (InetSocketAddress) socketAddress; + return address.getHostString().replace(".", "_"); + } + else { + return socketAddress.toString().replace(".", "_"); + } + } + return "UNKNOWN"; + } + + private void scheduleStatsExport() + { + Executors.newSingleThreadScheduledExecutor() + .scheduleAtFixedRate( + () -> { + for (ConnectionPoolMetrics metrics : poolMetrics.values()) { + activeConnections.add(metrics.acquiredSize()); + totalConnections.add(metrics.allocatedSize()); + idleConnections.add(metrics.idleSize()); + pendingAcquires.add(metrics.pendingAcquireSize()); + maxConnections.add(metrics.maxAllocatedSize()); + maxPendingAcquires.add(metrics.maxPendingAcquireSize()); + } + }, + 0, + 1, + TimeUnit.SECONDS); + } + + @Managed + @Nested + public DistributionStat getActiveConnections() + { + return activeConnections; + } + + @Managed + @Nested + public DistributionStat getTotalConnections() + { + return totalConnections; + } + + @Managed + @Nested + public DistributionStat getIdleConnections() + { + return idleConnections; + } + + @Managed + @Nested + public DistributionStat getPendingAcquires() + { + return pendingAcquires; + } + + @Managed + @Nested + public DistributionStat getMaxConnections() + { + return maxConnections; + } + + @Managed + @Nested + public DistributionStat getMaxPendingAcquires() + { + return maxPendingAcquires; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientStats.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientStats.java new file mode 100644 index 0000000000000..6d5af54e80636 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpClientStats.java @@ -0,0 +1,382 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.stats.CounterStat; +import com.facebook.airlift.stats.DistributionStat; +import com.facebook.airlift.stats.TimeStat; +import com.google.inject.Singleton; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; +import reactor.netty.http.client.ContextAwareHttpClientMetricsRecorder; +import reactor.util.context.ContextView; + +import java.net.SocketAddress; +import java.time.Duration; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +@Singleton +public class HttpClientStats + extends ContextAwareHttpClientMetricsRecorder +{ + // HTTP level metrics + private final TimeStat responseTime = new TimeStat(); + private final TimeStat dataReceivedTime = new TimeStat(); + private final TimeStat dataSentTime = new TimeStat(); + private final CounterStat errorsCount = new CounterStat(); + private final CounterStat bytesReceived = new CounterStat(); + private final CounterStat bytesSent = new CounterStat(); + private final DistributionStat payloadSize = new DistributionStat(); + + // Channel level metrics + private final TimeStat connectTime = new TimeStat(); + private final TimeStat tlsHandshakeTime = new TimeStat(); + private final TimeStat resolveAddressTime = new TimeStat(); + private final CounterStat channelErrorsCount = new CounterStat(); + private final CounterStat channelBytesReceived = new CounterStat(); + private final CounterStat channelBytesSent = new CounterStat(); + private final CounterStat connectionsOpened = new CounterStat(); + private final CounterStat connectionsClosed = new CounterStat(); + + // HTTP level metrics recording + + /** + * Records the time that is spent in consuming incoming data + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + * @param method The HTTP method + * @param status The HTTP status + * @param time The time in nanoseconds that is spent in consuming incoming data + */ + @Override + public void recordDataReceivedTime( + ContextView contextView, + SocketAddress remoteAddress, + String uri, + String method, + String status, + Duration time) + { + dataReceivedTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Records the time that is spent in sending outgoing data + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + * @param method The HTTP method + * @param time The time in nanoseconds that is spent in sending outgoing data + */ + @Override + public void recordDataSentTime( + ContextView contextView, + SocketAddress remoteAddress, + String uri, + String method, + Duration time) + { + dataSentTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Records the total time for the request/response + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + * @param method The HTTP method + * @param status The HTTP status + * @param time The total time in nanoseconds for the request/response + */ + @Override + public void recordResponseTime( + ContextView contextView, + SocketAddress remoteAddress, + String uri, + String method, + String status, + Duration time) + { + responseTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Increments the number of the errors that are occurred + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + */ + @Override + public void incrementErrorsCount(ContextView contextView, SocketAddress remoteAddress, String uri) + { + errorsCount.update(1); + } + + /** + * Records the amount of the data that is received, in bytes + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + * @param bytes The amount of the data that is received, in bytes + */ + @Override + public void recordDataReceived(ContextView contextView, SocketAddress remoteAddress, String uri, long bytes) + { + bytesReceived.update(bytes); + } + + /** + * Records the amount of the data that is sent, in bytes + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux + * @param remoteAddress The remote peer + * @param uri The requested URI + * @param bytes The amount of the data that is sent, in bytes + */ + @Override + public void recordDataSent(ContextView contextView, SocketAddress remoteAddress, String uri, long bytes) + { + bytesSent.update(bytes); + payloadSize.add(bytes); + } + + // Channel level metrics recording + + /** + * Increments the number of the errors that are occurred + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux pipeline + * @param remoteAddress The remote peer + */ + @Override + public void incrementErrorsCount(ContextView contextView, SocketAddress remoteAddress) + { + channelErrorsCount.update(1); + } + + /** + * Records the time that is spent for connecting to the remote address Relevant only when on the + * client + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux pipeline + * @param remoteAddress The remote peer + * @param time The time in nanoseconds that is spent for connecting to the remote address + * @param status The status of the operation + */ + @Override + public void recordConnectTime( + ContextView contextView, + SocketAddress remoteAddress, + Duration time, + String status) + { + connectTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Records the amount of the data that is received, in bytes + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux pipeline + * @param remoteAddress The remote peer + * @param bytes The amount of the data that is received, in bytes + */ + @Override + public void recordDataReceived(ContextView contextView, SocketAddress remoteAddress, long bytes) + { + channelBytesReceived.update(bytes); + } + + /** + * Records the amount of the data that is sent, in bytes + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux pipeline + * @param remoteAddress The remote peer + * @param bytes The amount of the data that is sent, in bytes + */ + @Override + public void recordDataSent(ContextView contextView, SocketAddress remoteAddress, long bytes) + { + channelBytesSent.update(bytes); + } + + /** + * Records the time that is spent for TLS handshake + * + * @param contextView The current {@link ContextView} associated with the Mono/Flux pipeline + * @param remoteAddress The remote peer + * @param time The time in nanoseconds that is spent for TLS handshake + * @param status The status of the operation + */ + @Override + public void recordTlsHandshakeTime( + ContextView contextView, + SocketAddress remoteAddress, + Duration time, + String status) + { + tlsHandshakeTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Records the time that is spent for resolving the remote address Relevant only when on the + * client + * + * @param remoteAddress The remote peer + * @param time the time in nanoseconds that is spent for resolving to the remote address + * @param status the status of the operation + */ + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) + { + resolveAddressTime.add(time.toMillis(), MILLISECONDS); + } + + /** + * Records a just accepted server connection + * + * @param localAddress the server local address + * @since 1.0.15 + */ + @Override + public void recordServerConnectionOpened(SocketAddress localAddress) + { + connectionsOpened.update(1); + } + + /** + * Records a just disconnected server connection + * + * @param localAddress the server local address + * @since 1.0.15 + */ + @Override + public void recordServerConnectionClosed(SocketAddress localAddress) + { + connectionsClosed.update(1); + } + + // JMX exposed metrics + + @Managed + @Nested + public TimeStat getResponseTime() + { + return responseTime; + } + + @Managed + @Nested + public TimeStat getDataReceivedTime() + { + return dataReceivedTime; + } + + @Managed + @Nested + public TimeStat getDataSentTime() + { + return dataSentTime; + } + + @Managed + @Nested + public CounterStat getErrorsCount() + { + return errorsCount; + } + + @Managed + @Nested + public CounterStat getBytesReceived() + { + return bytesReceived; + } + + @Managed + @Nested + public CounterStat getBytesSent() + { + return bytesSent; + } + + @Managed + @Nested + public DistributionStat getPayloadSize() + { + return payloadSize; + } + + @Managed + @Nested + public TimeStat getConnectTime() + { + return connectTime; + } + + @Managed + @Nested + public TimeStat getTlsHandshakeTime() + { + return tlsHandshakeTime; + } + + @Managed + @Nested + public TimeStat getResolveAddressTime() + { + return resolveAddressTime; + } + + @Managed + @Nested + public CounterStat getChannelErrorsCount() + { + return channelErrorsCount; + } + + @Managed + @Nested + public CounterStat getChannelBytesReceived() + { + return channelBytesReceived; + } + + @Managed + @Nested + public CounterStat getChannelBytesSent() + { + return channelBytesSent; + } + + @Managed + @Nested + public CounterStat getConnectionsOpened() + { + return connectionsOpened; + } + + @Managed + @Nested + public CounterStat getConnectionsClosed() + { + return connectionsClosed; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java index ff80ac7a260b2..3efc20bfd99bb 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpLocationFactory.java @@ -23,8 +23,7 @@ import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.InternalCommunicationConfig.CommunicationProtocol; import com.facebook.presto.spi.QueryId; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.net.URI; import java.util.OptionalInt; diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java deleted file mode 100644 index 7a1518e1f468c..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTask.java +++ /dev/null @@ -1,1315 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.remotetask; - -import com.facebook.airlift.concurrent.SetThreadName; -import com.facebook.airlift.http.client.HttpClient; -import com.facebook.airlift.http.client.HttpUriBuilder; -import com.facebook.airlift.http.client.Request; -import com.facebook.airlift.http.client.ResponseHandler; -import com.facebook.airlift.http.client.StatusResponseHandler.StatusResponse; -import com.facebook.airlift.http.client.thrift.ThriftRequestUtils; -import com.facebook.airlift.http.client.thrift.ThriftResponse; -import com.facebook.airlift.http.client.thrift.ThriftResponseHandler; -import com.facebook.airlift.json.Codec; -import com.facebook.airlift.json.JsonCodec; -import com.facebook.airlift.json.smile.SmileCodec; -import com.facebook.airlift.log.Logger; -import com.facebook.airlift.stats.DecayCounter; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; -import com.facebook.presto.execution.FutureStateChange; -import com.facebook.presto.execution.Lifespan; -import com.facebook.presto.execution.NodeTaskMap.NodeStatsTracker; -import com.facebook.presto.execution.PartitionedSplitsInfo; -import com.facebook.presto.execution.QueryManager; -import com.facebook.presto.execution.RemoteTask; -import com.facebook.presto.execution.ScheduledSplit; -import com.facebook.presto.execution.SchedulerStatsTracker; -import com.facebook.presto.execution.StateMachine.StateChangeListener; -import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskInfo; -import com.facebook.presto.execution.TaskSource; -import com.facebook.presto.execution.TaskState; -import com.facebook.presto.execution.TaskStatus; -import com.facebook.presto.execution.buffer.BufferInfo; -import com.facebook.presto.execution.buffer.OutputBuffers; -import com.facebook.presto.execution.buffer.PageBufferInfo; -import com.facebook.presto.execution.scheduler.TableWriteInfo; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; -import com.facebook.presto.metadata.Split; -import com.facebook.presto.operator.TaskStats; -import com.facebook.presto.server.RequestErrorTracker; -import com.facebook.presto.server.SimpleHttpResponseCallback; -import com.facebook.presto.server.SimpleHttpResponseHandler; -import com.facebook.presto.server.TaskUpdateRequest; -import com.facebook.presto.server.smile.BaseResponse; -import com.facebook.presto.server.thrift.ThriftHttpResponseHandler; -import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.SplitWeight; -import com.facebook.presto.spi.plan.PlanNode; -import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.sql.planner.PlanFragment; -import com.google.common.base.Ticker; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Multimap; -import com.google.common.collect.ObjectArrays; -import com.google.common.collect.SetMultimap; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import com.sun.management.ThreadMXBean; -import io.airlift.units.Duration; -import it.unimi.dsi.fastutil.longs.LongArrayList; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; - -import java.lang.management.ManagementFactory; -import java.net.URI; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Optional; -import java.util.OptionalLong; -import java.util.Set; -import java.util.concurrent.Executor; -import java.util.concurrent.Future; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadLocalRandom; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Stream; - -import static com.facebook.airlift.http.client.HttpStatus.NO_CONTENT; -import static com.facebook.airlift.http.client.HttpStatus.OK; -import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static com.facebook.airlift.http.client.Request.Builder.prepareDelete; -import static com.facebook.airlift.http.client.Request.Builder.preparePost; -import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; -import static com.facebook.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; -import static com.facebook.presto.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; -import static com.facebook.presto.execution.TaskInfo.createInitialTask; -import static com.facebook.presto.execution.TaskState.ABORTED; -import static com.facebook.presto.execution.TaskState.FAILED; -import static com.facebook.presto.execution.TaskStatus.failWith; -import static com.facebook.presto.server.RequestErrorTracker.isExpectedError; -import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker; -import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders; -import static com.facebook.presto.server.RequestHelpers.setTaskInfoAcceptTypeHeaders; -import static com.facebook.presto.server.RequestHelpers.setTaskUpdateRequestContentTypeHeaders; -import static com.facebook.presto.server.TaskResourceUtils.convertFromThriftTaskInfo; -import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; -import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; -import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec; -import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_TASK_UPDATE_SIZE_LIMIT; -import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; -import static com.facebook.presto.util.Failures.toFailure; -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.util.concurrent.Futures.addCallback; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static java.lang.Math.addExact; -import static java.lang.String.format; -import static java.lang.System.currentTimeMillis; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.NANOSECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; - -public final class HttpRemoteTask - implements RemoteTask -{ - private static final Logger log = Logger.get(HttpRemoteTask.class); - private static final double UPDATE_WITHOUT_PLAN_STATS_SAMPLE_RATE = 0.01; - private static final ThreadMXBean THREAD_MX_BEAN = (ThreadMXBean) ManagementFactory.getThreadMXBean(); - - private final TaskId taskId; - private final URI taskLocation; - private final URI remoteTaskLocation; - - private final Session session; - private final String nodeId; - private final PlanFragment planFragment; - - private final Set tableScanPlanNodeIds; - private final Set remoteSourcePlanNodeIds; - - private final AtomicLong nextSplitId = new AtomicLong(); - - private final Duration maxErrorDuration; - private final RemoteTaskStats stats; - private final TaskInfoFetcher taskInfoFetcher; - private final ContinuousTaskStatusFetcher taskStatusFetcher; - - @GuardedBy("this") - private final LongArrayList taskUpdateTimeline = new LongArrayList(); - @GuardedBy("this") - private Future currentRequest; - @GuardedBy("this") - private long currentRequestStartNanos; - @GuardedBy("this") - private long currentRequestLastTaskUpdate; - - @GuardedBy("this") - private final SetMultimap pendingSplits = HashMultimap.create(); - @GuardedBy("this") - private volatile int pendingSourceSplitCount; - @GuardedBy("this") - private volatile long pendingSourceSplitsWeight; - @GuardedBy("this") - private final SetMultimap pendingNoMoreSplitsForLifespan = HashMultimap.create(); - @GuardedBy("this") - // The keys of this map represent all plan nodes that have "no more splits". - // The boolean value of each entry represents whether the "no more splits" notification is pending delivery to workers. - private final Map noMoreSplits = new HashMap<>(); - @GuardedBy("this") - private final AtomicReference outputBuffers = new AtomicReference<>(); - private final FutureStateChange whenSplitQueueHasSpace = new FutureStateChange<>(); - @GuardedBy("this") - private boolean splitQueueHasSpace = true; - @GuardedBy("this") - private OptionalLong whenSplitQueueHasSpaceThreshold = OptionalLong.empty(); - - private final boolean summarizeTaskInfo; - - private final HttpClient httpClient; - private final Executor executor; - private final ScheduledExecutorService errorScheduledExecutor; - - private final Codec taskInfoCodec; - //Json codec required for TaskUpdateRequest endpoint which uses JSON and returns a TaskInfo - private final Codec taskInfoJsonCodec; - private final Codec taskUpdateRequestCodec; - private final Codec taskInfoResponseCodec; - private final Codec planFragmentCodec; - - private final RequestErrorTracker updateErrorTracker; - - private final AtomicBoolean needsUpdate = new AtomicBoolean(true); - private final AtomicBoolean sendPlan = new AtomicBoolean(true); - - private final NodeStatsTracker nodeStatsTracker; - - private final AtomicBoolean started = new AtomicBoolean(false); - private final AtomicBoolean aborting = new AtomicBoolean(false); - - private final boolean binaryTransportEnabled; - private final boolean thriftTransportEnabled; - private final boolean taskInfoThriftTransportEnabled; - private final boolean taskUpdateRequestThriftSerdeEnabled; - private final boolean taskInfoResponseThriftSerdeEnabled; - private final Protocol thriftProtocol; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; - private final HandleResolver handleResolver; - private final long maxTaskUpdateSizeInBytes; - private final int maxUnacknowledgedSplits; - - private final TableWriteInfo tableWriteInfo; - - private final DecayCounter taskUpdateRequestSize; - private final boolean taskUpdateSizeTrackingEnabled; - private final SchedulerStatsTracker schedulerStatsTracker; - - public HttpRemoteTask( - Session session, - TaskId taskId, - String nodeId, - URI location, - URI remoteLocation, - PlanFragment planFragment, - Multimap initialSplits, - OutputBuffers outputBuffers, - HttpClient httpClient, - Executor executor, - ScheduledExecutorService updateScheduledExecutor, - ScheduledExecutorService errorScheduledExecutor, - Duration maxErrorDuration, - Duration taskStatusRefreshMaxWait, - Duration taskInfoRefreshMaxWait, - Duration taskInfoUpdateInterval, - boolean summarizeTaskInfo, - Codec taskStatusCodec, - Codec taskInfoCodec, - Codec taskInfoJsonCodec, - Codec taskUpdateRequestCodec, - Codec taskInfoResponseCodec, - Codec planFragmentCodec, - Codec metadataUpdatesCodec, - NodeStatsTracker nodeStatsTracker, - RemoteTaskStats stats, - boolean binaryTransportEnabled, - boolean thriftTransportEnabled, - boolean taskInfoThriftTransportEnabled, - boolean taskUpdateRequestThriftSerdeEnabled, - boolean taskInfoResponseThriftSerdeEnabled, - Protocol thriftProtocol, - TableWriteInfo tableWriteInfo, - long maxTaskUpdateSizeInBytes, - MetadataManager metadataManager, - QueryManager queryManager, - DecayCounter taskUpdateRequestSize, - boolean taskUpdateSizeTrackingEnabled, - HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - SchedulerStatsTracker schedulerStatsTracker) - { - requireNonNull(session, "session is null"); - requireNonNull(taskId, "taskId is null"); - requireNonNull(nodeId, "nodeId is null"); - requireNonNull(location, "location is null"); - requireNonNull(remoteLocation, "remoteLocation is null"); - requireNonNull(planFragment, "planFragment is null"); - requireNonNull(outputBuffers, "outputBuffers is null"); - requireNonNull(httpClient, "httpClient is null"); - requireNonNull(executor, "executor is null"); - requireNonNull(taskStatusCodec, "taskStatusCodec is null"); - requireNonNull(taskInfoCodec, "taskInfoCodec is null"); - requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null"); - requireNonNull(planFragmentCodec, "planFragmentCodec is null"); - requireNonNull(nodeStatsTracker, "nodeStatsTracker is null"); - requireNonNull(maxErrorDuration, "maxErrorDuration is null"); - requireNonNull(stats, "stats is null"); - requireNonNull(taskInfoRefreshMaxWait, "taskInfoRefreshMaxWait is null"); - requireNonNull(tableWriteInfo, "tableWriteInfo is null"); - requireNonNull(metadataManager, "metadataManager is null"); - requireNonNull(queryManager, "queryManager is null"); - requireNonNull(thriftProtocol, "thriftProtocol is null"); - requireNonNull(handleResolver, "handleResolver is null"); - requireNonNull(connectorTypeSerdeManager, "connectorTypeSerdeManager is null"); - requireNonNull(taskUpdateRequestSize, "taskUpdateRequestSize cannot be null"); - requireNonNull(schedulerStatsTracker, "schedulerStatsTracker is null"); - - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - this.taskId = taskId; - this.taskLocation = location; - this.remoteTaskLocation = remoteLocation; - this.session = session; - this.nodeId = nodeId; - this.planFragment = planFragment; - this.outputBuffers.set(outputBuffers); - this.httpClient = httpClient; - this.executor = executor; - this.errorScheduledExecutor = errorScheduledExecutor; - this.summarizeTaskInfo = summarizeTaskInfo; - this.taskInfoCodec = taskInfoCodec; - this.taskInfoJsonCodec = taskInfoJsonCodec; - this.taskUpdateRequestCodec = taskUpdateRequestCodec; - this.taskInfoResponseCodec = taskInfoResponseCodec; - this.planFragmentCodec = planFragmentCodec; - this.updateErrorTracker = taskRequestErrorTracker(taskId, location, maxErrorDuration, errorScheduledExecutor, "updating task"); - this.nodeStatsTracker = requireNonNull(nodeStatsTracker, "nodeStatsTracker is null"); - this.maxErrorDuration = maxErrorDuration; - this.stats = stats; - this.binaryTransportEnabled = binaryTransportEnabled; - this.thriftTransportEnabled = thriftTransportEnabled; - this.taskInfoThriftTransportEnabled = taskInfoThriftTransportEnabled; - this.taskUpdateRequestThriftSerdeEnabled = taskUpdateRequestThriftSerdeEnabled; - this.taskInfoResponseThriftSerdeEnabled = taskInfoResponseThriftSerdeEnabled; - this.thriftProtocol = thriftProtocol; - this.connectorTypeSerdeManager = connectorTypeSerdeManager; - this.handleResolver = handleResolver; - this.tableWriteInfo = tableWriteInfo; - this.maxTaskUpdateSizeInBytes = maxTaskUpdateSizeInBytes; - this.maxUnacknowledgedSplits = getMaxUnacknowledgedSplitsPerTask(session); - checkArgument(maxUnacknowledgedSplits > 0, "maxUnacknowledgedSplits must be > 0, found: %s", maxUnacknowledgedSplits); - - this.tableScanPlanNodeIds = ImmutableSet.copyOf(planFragment.getTableScanSchedulingOrder()); - this.remoteSourcePlanNodeIds = planFragment.getRemoteSourceNodes().stream() - .map(PlanNode::getId) - .collect(toImmutableSet()); - this.taskUpdateRequestSize = taskUpdateRequestSize; - this.taskUpdateSizeTrackingEnabled = taskUpdateSizeTrackingEnabled; - this.schedulerStatsTracker = schedulerStatsTracker; - - for (Entry entry : requireNonNull(initialSplits, "initialSplits is null").entries()) { - ScheduledSplit scheduledSplit = new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getKey(), entry.getValue()); - pendingSplits.put(entry.getKey(), scheduledSplit); - } - int pendingSourceSplitCount = 0; - long pendingSourceSplitsWeight = 0; - for (PlanNodeId planNodeId : planFragment.getTableScanSchedulingOrder()) { - Collection tableScanSplits = initialSplits.get(planNodeId); - if (tableScanSplits != null && !tableScanSplits.isEmpty()) { - pendingSourceSplitCount += tableScanSplits.size(); - pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, SplitWeight.rawValueSum(tableScanSplits, Split::getSplitWeight)); - } - } - this.pendingSourceSplitCount = pendingSourceSplitCount; - this.pendingSourceSplitsWeight = pendingSourceSplitsWeight; - - List bufferStates = outputBuffers.getBuffers() - .keySet().stream() - .map(outputId -> new BufferInfo(outputId, false, 0, 0, PageBufferInfo.empty())) - .collect(toImmutableList()); - - TaskInfo initialTask = createInitialTask(taskId, location, bufferStates, new TaskStats(currentTimeMillis(), 0), nodeId); - - this.taskStatusFetcher = new ContinuousTaskStatusFetcher( - this::failTask, - taskId, - initialTask.getTaskStatus(), - taskStatusRefreshMaxWait, - taskStatusCodec, - executor, - httpClient, - maxErrorDuration, - errorScheduledExecutor, - stats, - binaryTransportEnabled, - thriftTransportEnabled, - thriftProtocol); - - this.taskInfoFetcher = new TaskInfoFetcher( - this::failTask, - initialTask, - httpClient, - taskInfoUpdateInterval, - taskInfoRefreshMaxWait, - taskInfoCodec, - metadataUpdatesCodec, - maxErrorDuration, - summarizeTaskInfo, - executor, - updateScheduledExecutor, - errorScheduledExecutor, - stats, - binaryTransportEnabled, - taskInfoThriftTransportEnabled, - session, - metadataManager, - queryManager, - handleResolver, - connectorTypeSerdeManager, - thriftProtocol); - - taskStatusFetcher.addStateChangeListener(newStatus -> { - TaskState state = newStatus.getState(); - if (state.isDone()) { - cleanUpTask(); - } - else { - updateTaskStats(); - updateSplitQueueSpace(); - } - }); - - updateTaskStats(); - updateSplitQueueSpace(); - } - } - - @Override - public PlanFragment getPlanFragment() - { - return planFragment; - } - - @Override - public TaskId getTaskId() - { - return taskId; - } - - @Override - public String getNodeId() - { - return nodeId; - } - - @Override - public TaskInfo getTaskInfo() - { - return taskInfoFetcher.getTaskInfo(); - } - - @Override - public TaskStatus getTaskStatus() - { - return taskStatusFetcher.getTaskStatus(); - } - - @Override - public URI getRemoteTaskLocation() - { - return remoteTaskLocation; - } - - @Override - public void start() - { - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - // to start we just need to trigger an update - started.set(true); - scheduleUpdate(); - - taskStatusFetcher.start(); - taskInfoFetcher.start(); - } - } - - @Override - public synchronized void addSplits(Multimap splitsBySource) - { - requireNonNull(splitsBySource, "splitsBySource is null"); - - // only add pending split if not done - if (getTaskStatus().getState().isDone()) { - return; - } - - boolean needsUpdate = false; - for (Entry> entry : splitsBySource.asMap().entrySet()) { - PlanNodeId sourceId = entry.getKey(); - Collection splits = entry.getValue(); - boolean isTableScanSource = tableScanPlanNodeIds.contains(sourceId); - - checkState(!noMoreSplits.containsKey(sourceId), "noMoreSplits has already been set for %s", sourceId); - int added = 0; - long addedWeight = 0; - for (Split split : splits) { - if (pendingSplits.put(sourceId, new ScheduledSplit(nextSplitId.getAndIncrement(), sourceId, split))) { - if (isTableScanSource) { - added++; - addedWeight = addExact(addedWeight, split.getSplitWeight().getRawValue()); - } - } - } - if (isTableScanSource) { - pendingSourceSplitCount += added; - pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, addedWeight); - updateTaskStats(); - } - needsUpdate = true; - } - updateSplitQueueSpace(); - - if (needsUpdate) { - this.needsUpdate.set(true); - scheduleUpdate(); - } - } - - @Override - public synchronized void noMoreSplits(PlanNodeId sourceId) - { - if (noMoreSplits.containsKey(sourceId)) { - return; - } - - noMoreSplits.put(sourceId, true); - needsUpdate.set(true); - scheduleUpdate(); - } - - @Override - public synchronized void noMoreSplits(PlanNodeId sourceId, Lifespan lifespan) - { - if (pendingNoMoreSplitsForLifespan.put(sourceId, lifespan)) { - needsUpdate.set(true); - scheduleUpdate(); - } - } - - @Override - public synchronized void setOutputBuffers(OutputBuffers newOutputBuffers) - { - if (getTaskStatus().getState().isDone()) { - return; - } - - if (newOutputBuffers.getVersion() > outputBuffers.get().getVersion()) { - outputBuffers.set(newOutputBuffers); - needsUpdate.set(true); - scheduleUpdate(); - } - } - - @Override - public ListenableFuture removeRemoteSource(TaskId remoteSourceTaskId) - { - URI remoteSourceUri = uriBuilderFrom(taskLocation) - .appendPath("remote-source") - .appendPath(remoteSourceTaskId.toString()) - .build(); - - Request request = prepareDelete() - .setUri(remoteSourceUri) - .build(); - RequestErrorTracker errorTracker = taskRequestErrorTracker( - taskId, - remoteSourceUri, - maxErrorDuration, - errorScheduledExecutor, - "Remove exchange remote source"); - - SettableFuture future = SettableFuture.create(); - doRemoveRemoteSource(errorTracker, request, future); - return future; - } - - /// This method may call itself recursively when retrying for failures - private void doRemoveRemoteSource(RequestErrorTracker errorTracker, Request request, SettableFuture future) - { - errorTracker.startRequest(); - - FutureCallback callback = new FutureCallback() - { - @Override - public void onSuccess(@Nullable StatusResponse response) - { - if (response == null) { - throw new PrestoException(GENERIC_INTERNAL_ERROR, "Request failed with null response"); - } - if (response.getStatusCode() != OK.code() && response.getStatusCode() != NO_CONTENT.code()) { - throw new PrestoException(GENERIC_INTERNAL_ERROR, "Request failed with HTTP status " + response.getStatusCode()); - } - future.set(null); - } - - @Override - public void onFailure(Throwable failedReason) - { - if (failedReason instanceof RejectedExecutionException && httpClient.isClosed()) { - log.error("Unable to destroy exchange source at %s. HTTP client is closed", request.getUri()); - future.setException(failedReason); - return; - } - // record failure - try { - errorTracker.requestFailed(failedReason); - } - catch (PrestoException e) { - future.setException(e); - return; - } - // if throttled due to error, asynchronously wait for timeout and try again - ListenableFuture errorRateLimit = errorTracker.acquireRequestPermit(); - if (errorRateLimit.isDone()) { - doRemoveRemoteSource(errorTracker, request, future); - } - else { - errorRateLimit.addListener(() -> doRemoveRemoteSource(errorTracker, request, future), errorScheduledExecutor); - } - } - }; - - addCallback(httpClient.executeAsync(request, createStatusResponseHandler()), callback, directExecutor()); - } - - @Override - public PartitionedSplitsInfo getPartitionedSplitsInfo() - { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return PartitionedSplitsInfo.forZeroSplits(); - } - PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); - int count = unacknowledgedSplitsInfo.getCount() + taskStatus.getQueuedPartitionedDrivers() + taskStatus.getRunningPartitionedDrivers(); - long weight = unacknowledgedSplitsInfo.getWeightSum() + taskStatus.getQueuedPartitionedSplitsWeight() + taskStatus.getRunningPartitionedSplitsWeight(); - return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); - } - - @SuppressWarnings("FieldAccessNotGuarded") - public PartitionedSplitsInfo getUnacknowledgedPartitionedSplitsInfo() - { - int count = pendingSourceSplitCount; - long weight = pendingSourceSplitsWeight; - return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); - } - - @Override - public PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() - { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return PartitionedSplitsInfo.forZeroSplits(); - } - PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); - int count = unacknowledgedSplitsInfo.getCount() + taskStatus.getQueuedPartitionedDrivers(); - long weight = unacknowledgedSplitsInfo.getWeightSum() + taskStatus.getQueuedPartitionedSplitsWeight(); - return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); - } - - @Override - public int getUnacknowledgedPartitionedSplitCount() - { - return getPendingSourceSplitCount(); - } - - @SuppressWarnings("FieldAccessNotGuarded") - private int getPendingSourceSplitCount() - { - return pendingSourceSplitCount; - } - - private long getQueuedPartitionedSplitsWeight() - { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return 0; - } - return getPendingSourceSplitsWeight() + taskStatus.getQueuedPartitionedSplitsWeight(); - } - - @SuppressWarnings("FieldAccessNotGuarded") - private long getPendingSourceSplitsWeight() - { - return pendingSourceSplitsWeight; - } - - @Override - public void addStateChangeListener(StateChangeListener stateChangeListener) - { - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - taskStatusFetcher.addStateChangeListener(stateChangeListener); - } - } - - @Override - public void addFinalTaskInfoListener(StateChangeListener stateChangeListener) - { - taskInfoFetcher.addFinalTaskInfoListener(stateChangeListener); - } - - @Override - public synchronized ListenableFuture whenSplitQueueHasSpace(long weightThreshold) - { - if (whenSplitQueueHasSpaceThreshold.isPresent()) { - checkArgument(weightThreshold == whenSplitQueueHasSpaceThreshold.getAsLong(), "Multiple split queue space notification thresholds not supported"); - } - else { - whenSplitQueueHasSpaceThreshold = OptionalLong.of(weightThreshold); - updateSplitQueueSpace(); - } - if (splitQueueHasSpace) { - return immediateFuture(null); - } - return whenSplitQueueHasSpace.createNewListener(); - } - - private synchronized void updateSplitQueueSpace() - { - // Must check whether the unacknowledged split count threshold is reached even without listeners registered yet - splitQueueHasSpace = getUnacknowledgedPartitionedSplitCount() < maxUnacknowledgedSplits && - (!whenSplitQueueHasSpaceThreshold.isPresent() || getQueuedPartitionedSplitsWeight() < whenSplitQueueHasSpaceThreshold.getAsLong()); - // Only trigger notifications if a listener might be registered - if (splitQueueHasSpace && whenSplitQueueHasSpaceThreshold.isPresent()) { - whenSplitQueueHasSpace.complete(null, executor); - } - } - - private void updateTaskStats() - { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - nodeStatsTracker.setPartitionedSplits(PartitionedSplitsInfo.forZeroSplits()); - nodeStatsTracker.setMemoryUsage(0); - nodeStatsTracker.setCpuUsage(taskStatus.getTaskAgeInMillis(), 0); - } - else { - nodeStatsTracker.setPartitionedSplits(getPartitionedSplitsInfo()); - nodeStatsTracker.setMemoryUsage(taskStatus.getMemoryReservationInBytes() + taskStatus.getSystemMemoryReservationInBytes()); - nodeStatsTracker.setCpuUsage(taskStatus.getTaskAgeInMillis(), taskStatus.getTotalCpuTimeInNanos()); - } - } - - private synchronized void processTaskUpdate(TaskInfo newValue, List sources) - { - //Setting the flag as false since TaskUpdateRequest is not on thrift yet. - //Once it is converted to thrift we can use the isThrift enabled flag here. - updateTaskInfo(newValue, false); - - // remove acknowledged splits, which frees memory - for (TaskSource source : sources) { - PlanNodeId planNodeId = source.getPlanNodeId(); - boolean isTableScanSource = tableScanPlanNodeIds.contains(planNodeId); - int removed = 0; - long removedWeight = 0; - for (ScheduledSplit split : source.getSplits()) { - if (pendingSplits.remove(planNodeId, split)) { - if (isTableScanSource) { - removed++; - removedWeight = addExact(removedWeight, split.getSplit().getSplitWeight().getRawValue()); - } - } - } - if (source.isNoMoreSplits()) { - noMoreSplits.put(planNodeId, false); - } - for (Lifespan lifespan : source.getNoMoreSplitsForLifespan()) { - pendingNoMoreSplitsForLifespan.remove(planNodeId, lifespan); - } - if (isTableScanSource) { - pendingSourceSplitCount -= removed; - pendingSourceSplitsWeight -= removedWeight; - } - } - // Update stats before split queue space to ensure node stats are up to date before waking up the scheduler - updateTaskStats(); - updateSplitQueueSpace(); - } - - private void onSuccessTaskInfo(TaskInfo result) - { - try { - updateTaskInfo(result, taskInfoThriftTransportEnabled); - } - finally { - if (!getTaskInfo().getTaskStatus().getState().isDone()) { - cleanUpLocally(); - } - } - } - - private void updateTaskInfo(TaskInfo taskInfo, boolean isTaskInfoThriftTransportEnabled) - { - taskStatusFetcher.updateTaskStatus(taskInfo.getTaskStatus()); - if (isTaskInfoThriftTransportEnabled) { - taskInfo = convertFromThriftTaskInfo(taskInfo, connectorTypeSerdeManager, handleResolver); - } - taskInfoFetcher.updateTaskInfo(taskInfo); - } - - private void cleanUpLocally() - { - // Update the taskInfo with the new taskStatus. - - // Generally, we send a cleanup request to the worker, and update the TaskInfo on - // the coordinator based on what we fetched from the worker. If we somehow cannot - // get the cleanup request to the worker, the TaskInfo that we fetch for the worker - // likely will not say the task is done however many times we try. In this case, - // we have to set the local query info directly so that we stop trying to fetch - // updated TaskInfo from the worker. This way, the task on the worker eventually - // expires due to lack of activity. - - // This is required because the query state machine depends on TaskInfo (instead of task status) - // to transition its own state. - // TODO: Update the query state machine and stage state machine to depend on TaskStatus instead - - // Since this TaskInfo is updated in the client the "complete" flag will not be set, - // indicating that the stats may not reflect the final stats on the worker. - updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus()), taskInfoThriftTransportEnabled); - } - - private void onFailureTaskInfo( - Throwable t, - String action, - Request request, - Backoff cleanupBackoff) - { - if (t instanceof RejectedExecutionException && httpClient.isClosed()) { - logError(t, "Unable to %s task at %s. HTTP client is closed.", action, request.getUri()); - cleanUpLocally(); - return; - } - - // record failure - if (cleanupBackoff.failure()) { - logError(t, "Unable to %s task at %s. Back off depleted.", action, request.getUri()); - cleanUpLocally(); - return; - } - - // reschedule - long delayNanos = cleanupBackoff.getBackoffDelayNanos(); - if (delayNanos == 0) { - doScheduleAsyncCleanupRequest(cleanupBackoff, request, action); - } - else { - errorScheduledExecutor.schedule(() -> doScheduleAsyncCleanupRequest(cleanupBackoff, request, action), delayNanos, NANOSECONDS); - } - } - - private synchronized void scheduleUpdate() - { - taskUpdateTimeline.add(System.nanoTime()); - executor.execute(this::sendUpdate); - } - - private synchronized void sendUpdate() - { - TaskStatus taskStatus = getTaskStatus(); - // don't update if the task hasn't been started yet or if it is already finished - if (!started.get() || !needsUpdate.get() || taskStatus.getState().isDone()) { - return; - } - - // if there is a request already running, wait for it to complete - if (this.currentRequest != null && !this.currentRequest.isDone()) { - return; - } - - // if throttled due to error, asynchronously wait for timeout and try again - ListenableFuture errorRateLimit = updateErrorTracker.acquireRequestPermit(); - if (!errorRateLimit.isDone()) { - errorRateLimit.addListener(this::sendUpdate, executor); - return; - } - - List sources = getSources(); - - Optional fragment = Optional.empty(); - if (sendPlan.get()) { - long start = THREAD_MX_BEAN.getCurrentThreadCpuTime(); - fragment = Optional.of(planFragment.bytesForTaskSerialization(planFragmentCodec)); - schedulerStatsTracker.recordTaskPlanSerializedCpuTime(THREAD_MX_BEAN.getCurrentThreadCpuTime() - start); - } - Optional writeInfo = sendPlan.get() ? Optional.of(tableWriteInfo) : Optional.empty(); - TaskUpdateRequest updateRequest = new TaskUpdateRequest( - session.toSessionRepresentation(), - session.getIdentity().getExtraCredentials(), - fragment, - sources, - outputBuffers.get(), - writeInfo); - long serializeStartCpuTimeNanos = THREAD_MX_BEAN.getCurrentThreadCpuTime(); - - Request.Builder requestBuilder; - HttpUriBuilder uriBuilder = getHttpUriBuilder(taskStatus); - - byte[] taskUpdateRequestBytes = taskUpdateRequestCodec.toBytes(updateRequest); - schedulerStatsTracker.recordTaskUpdateSerializedCpuTime(THREAD_MX_BEAN.getCurrentThreadCpuTime() - serializeStartCpuTimeNanos); - - if (taskUpdateRequestBytes.length > maxTaskUpdateSizeInBytes) { - failTask(new PrestoException(EXCEEDED_TASK_UPDATE_SIZE_LIMIT, getExceededTaskUpdateSizeMessage(taskUpdateRequestBytes))); - } - - if (taskUpdateSizeTrackingEnabled) { - taskUpdateRequestSize.add(taskUpdateRequestBytes.length); - - if (fragment.isPresent()) { - stats.updateWithPlanSize(taskUpdateRequestBytes.length); - } - else { - if (ThreadLocalRandom.current().nextDouble() < UPDATE_WITHOUT_PLAN_STATS_SAMPLE_RATE) { - // This is to keep track of the task update size even when the plan fragment is NOT present - stats.updateWithoutPlanSize(taskUpdateRequestBytes.length); - } - } - } - requestBuilder = setTaskUpdateRequestContentTypeHeaders(taskUpdateRequestThriftSerdeEnabled, binaryTransportEnabled, preparePost()); - requestBuilder = setTaskInfoAcceptTypeHeaders(taskInfoResponseThriftSerdeEnabled, binaryTransportEnabled, requestBuilder); - Request request = requestBuilder - .setUri(uriBuilder.build()) - .setBodyGenerator(createStaticBodyGenerator(taskUpdateRequestBytes)) - .build(); - - ResponseHandler responseHandler; - if (taskInfoResponseThriftSerdeEnabled) { - responseHandler = new ThriftResponseHandler(unwrapThriftCodec(taskInfoResponseCodec)); - } - else if (binaryTransportEnabled) { - responseHandler = createFullSmileResponseHandler((SmileCodec) taskInfoResponseCodec); - } - else { - responseHandler = createAdaptingJsonResponseHandler((JsonCodec) taskInfoResponseCodec); - } - - updateErrorTracker.startRequest(); - - ListenableFuture future = httpClient.executeAsync(request, responseHandler); - currentRequest = future; - currentRequestStartNanos = System.nanoTime(); - if (!taskUpdateTimeline.isEmpty()) { - currentRequestLastTaskUpdate = taskUpdateTimeline.getLong(taskUpdateTimeline.size() - 1); - } - - // The needsUpdate flag needs to be set to false BEFORE adding the Future callback since callback might change the flag value - // and does so without grabbing the instance lock. - needsUpdate.set(false); - - if (taskInfoResponseThriftSerdeEnabled) { - Futures.addCallback( - (ListenableFuture>) future, - new ThriftHttpResponseHandler<>(new UpdateResponseHandler(sources), request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR), - executor); - } - else { - Futures.addCallback( - (ListenableFuture>) future, - new SimpleHttpResponseHandler<>(new UpdateResponseHandler(sources), request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR), - executor); - } - } - - private String getExceededTaskUpdateSizeMessage(byte[] taskUpdateRequestJson) - { - return format("TaskUpdate size of %s bytes has exceeded the limit of %s bytes", taskUpdateRequestJson.length, this.maxTaskUpdateSizeInBytes); - } - - private synchronized List getSources() - { - return Stream.concat(tableScanPlanNodeIds.stream(), remoteSourcePlanNodeIds.stream()) - .map(this::getSource) - .filter(Objects::nonNull) - .collect(toImmutableList()); - } - - private synchronized TaskSource getSource(PlanNodeId planNodeId) - { - Set splits = pendingSplits.get(planNodeId); - boolean pendingNoMoreSplits = Boolean.TRUE.equals(this.noMoreSplits.get(planNodeId)); - boolean noMoreSplits = this.noMoreSplits.containsKey(planNodeId); - Set noMoreSplitsForLifespan = pendingNoMoreSplitsForLifespan.get(planNodeId); - - TaskSource element = null; - if (!splits.isEmpty() || !noMoreSplitsForLifespan.isEmpty() || pendingNoMoreSplits) { - element = new TaskSource(planNodeId, splits, noMoreSplitsForLifespan, noMoreSplits); - } - return element; - } - - @Override - public synchronized void cancel() - { - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return; - } - - // send cancel to task and ignore response - HttpUriBuilder uriBuilder = getHttpUriBuilder(taskStatus).addParameter("abort", "false"); - Request.Builder builder = setContentTypeHeaders(binaryTransportEnabled, prepareDelete()); - if (taskInfoThriftTransportEnabled) { - builder = ThriftRequestUtils.prepareThriftDelete(thriftProtocol); - } - Request request = builder.setUri(uriBuilder.build()) - .build(); - scheduleAsyncCleanupRequest(createCleanupBackoff(), request, "cancel"); - } - } - - private synchronized void cleanUpTask() - { - checkState(getTaskStatus().getState().isDone(), "attempt to clean up a task that is not done yet"); - - // clear pending splits to free memory - pendingSplits.clear(); - pendingSourceSplitCount = 0; - pendingSourceSplitsWeight = 0; - updateTaskStats(); - splitQueueHasSpace = true; - whenSplitQueueHasSpace.complete(null, executor); - - // cancel pending request - if (currentRequest != null) { - // do not terminate if the request is already running to avoid closing pooled connections - currentRequest.cancel(false); - currentRequest = null; - currentRequestStartNanos = 0; - } - - taskStatusFetcher.stop(); - - // The remote task is likely to get a delete from the PageBufferClient first. - // We send an additional delete anyway to get the final TaskInfo - HttpUriBuilder uriBuilder = getHttpUriBuilder(getTaskStatus()); - Request.Builder requestBuilder = setContentTypeHeaders(binaryTransportEnabled, prepareDelete()); - if (taskInfoThriftTransportEnabled) { - requestBuilder = ThriftRequestUtils.prepareThriftDelete(Protocol.BINARY); - } - Request request = requestBuilder - .setUri(uriBuilder.build()) - .build(); - - scheduleAsyncCleanupRequest(createCleanupBackoff(), request, "cleanup"); - } - - @Override - public synchronized void abort() - { - if (getTaskStatus().getState().isDone()) { - return; - } - - abort(failWith(getTaskStatus(), ABORTED, ImmutableList.of())); - } - - private synchronized void abort(TaskStatus status) - { - checkState(status.getState().isDone(), "cannot abort task with an incomplete status"); - - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - taskStatusFetcher.updateTaskStatus(status); - - // send abort to task - HttpUriBuilder uriBuilder = getHttpUriBuilder(getTaskStatus()); - Request.Builder builder = setContentTypeHeaders(binaryTransportEnabled, prepareDelete()); - if (taskInfoThriftTransportEnabled) { - builder = ThriftRequestUtils.prepareThriftDelete(thriftProtocol); - } - - Request request = builder.setUri(uriBuilder.build()) - .build(); - scheduleAsyncCleanupRequest(createCleanupBackoff(), request, "abort"); - } - } - - private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, Request request, String action) - { - if (!aborting.compareAndSet(false, true)) { - // Do not initiate another round of cleanup requests if one had been initiated. - // Otherwise, we can get into an asynchronous recursion here. For example, when aborting a task after REMOTE_TASK_MISMATCH. - return; - } - doScheduleAsyncCleanupRequest(cleanupBackoff, request, action); - } - - private void doScheduleAsyncCleanupRequest(Backoff cleanupBackoff, Request request, String action) - { - ResponseHandler responseHandler; - if (taskInfoThriftTransportEnabled) { - responseHandler = new ThriftResponseHandler(unwrapThriftCodec(taskInfoCodec)); - Futures.addCallback(httpClient.executeAsync(request, responseHandler), - new ThriftResponseFutureCallback(action, request, cleanupBackoff), - executor); - } - else if (binaryTransportEnabled) { - responseHandler = createFullSmileResponseHandler((SmileCodec) taskInfoCodec); - Futures.addCallback(httpClient.executeAsync(request, responseHandler), - new BaseResponseFutureCallback(action, request, cleanupBackoff), - executor); - } - else { - responseHandler = createAdaptingJsonResponseHandler((JsonCodec) taskInfoCodec); - Futures.addCallback(httpClient.executeAsync(request, responseHandler), - new BaseResponseFutureCallback(action, request, cleanupBackoff), - executor); - } - } - - /** - * Move the task directly to the failed state if there was a failure in this task - */ - private void failTask(Throwable cause) - { - TaskStatus taskStatus = getTaskStatus(); - if (!taskStatus.getState().isDone()) { - log.debug(cause, "Remote task %s failed with %s", taskStatus.getSelf(), cause); - } - - TaskStatus failedTaskStatus = failWith(getTaskStatus(), FAILED, ImmutableList.of(toFailure(cause))); - // Transition task to failed state without waiting for the final task info returned by the abort request. - // The abort request is very likely not to succeed, leaving the task and the stage in the limbo state for - // the entire duration of abort retries. If the task is failed, it is not that important to actually - // record the final statistics and the final information about a failed task. - taskInfoFetcher.updateTaskInfo(getTaskInfo().withTaskStatus(failedTaskStatus)); - - // Initiate abort request - abort(failedTaskStatus); - } - - private HttpUriBuilder getHttpUriBuilder(TaskStatus taskStatus) - { - HttpUriBuilder uriBuilder = uriBuilderFrom(taskStatus.getSelf()); - if (summarizeTaskInfo) { - uriBuilder.addParameter("summarize"); - } - return uriBuilder; - } - - private static Backoff createCleanupBackoff() - { - return new Backoff(10, new Duration(10, TimeUnit.MINUTES), Ticker.systemTicker(), ImmutableList.builder() - .add(new Duration(0, MILLISECONDS)) - .add(new Duration(100, MILLISECONDS)) - .add(new Duration(500, MILLISECONDS)) - .add(new Duration(1, SECONDS)) - .add(new Duration(10, SECONDS)) - .build()); - } - - @Override - public String toString() - { - return toStringHelper(this) - .addValue(getTaskInfo()) - .toString(); - } - - private class UpdateResponseHandler - implements SimpleHttpResponseCallback - { - private final List sources; - - private UpdateResponseHandler(List sources) - { - this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null")); - } - - @Override - public void success(TaskInfo value) - { - try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - try { - long oldestTaskUpdateTime = 0; - long currentRequestStartNanos; - synchronized (HttpRemoteTask.this) { - currentRequest = null; - sendPlan.set(value.isNeedsPlan()); - currentRequestStartNanos = HttpRemoteTask.this.currentRequestStartNanos; - if (!taskUpdateTimeline.isEmpty()) { - oldestTaskUpdateTime = taskUpdateTimeline.getLong(0); - } - int deliveredUpdates = taskUpdateTimeline.size(); - while (deliveredUpdates > 0 && taskUpdateTimeline.getLong(deliveredUpdates - 1) > currentRequestLastTaskUpdate) { - deliveredUpdates--; - } - taskUpdateTimeline.removeElements(0, deliveredUpdates); - } - updateStats(currentRequestStartNanos); - processTaskUpdate(value, sources); - updateErrorTracker.requestSucceeded(); - if (oldestTaskUpdateTime != 0) { - schedulerStatsTracker.recordTaskUpdateDeliveredTime(System.nanoTime() - oldestTaskUpdateTime); - } - } - finally { - sendUpdate(); - } - } - } - - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - try { - long currentRequestStartNanos; - synchronized (HttpRemoteTask.this) { - currentRequest = null; - currentRequestStartNanos = HttpRemoteTask.this.currentRequestStartNanos; - } - updateStats(currentRequestStartNanos); - - // on failure assume we need to update again - needsUpdate.set(true); - - // if task not already done, record error - TaskStatus taskStatus = getTaskStatus(); - if (!taskStatus.getState().isDone()) { - updateErrorTracker.requestFailed(cause); - } - } - catch (Error e) { - failTask(e); - throw e; - } - catch (RuntimeException e) { - failTask(e); - } - finally { - sendUpdate(); - } - } - } - - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - failTask(cause); - } - } - - private void updateStats(long currentRequestStartNanos) - { - Duration requestRoundTrip = Duration.nanosSince(currentRequestStartNanos); - stats.updateRoundTripMillis(requestRoundTrip.toMillis()); - } - } - - private static void logError(Throwable t, String format, Object... args) - { - if (isExpectedError(t)) { - log.error(format + ": %s", ObjectArrays.concat(args, t)); - } - else { - log.error(t, format, args); - } - } - - private class ThriftResponseFutureCallback - implements FutureCallback> - { - private final String action; - private final Request request; - private final Backoff cleanupBackoff; - - public ThriftResponseFutureCallback(String action, Request request, Backoff cleanupBackoff) - { - this.action = action; - this.request = request; - this.cleanupBackoff = cleanupBackoff; - } - - @Override - public void onSuccess(ThriftResponse result) - { - onSuccessTaskInfo(result.getValue()); - } - - @Override - public void onFailure(Throwable throwable) - { - onFailureTaskInfo(throwable, this.action, this.request, this.cleanupBackoff); - } - } - - private class BaseResponseFutureCallback - implements FutureCallback> - { - private final String action; - private final Request request; - private final Backoff cleanupBackoff; - - public BaseResponseFutureCallback(String action, Request request, Backoff cleanupBackoff) - { - this.action = action; - this.request = request; - this.cleanupBackoff = cleanupBackoff; - } - - @Override - public void onSuccess(BaseResponse result) - { - onSuccessTaskInfo(result.getValue()); - } - - @Override - public void onFailure(Throwable throwable) - { - onFailureTaskInfo(throwable, this.action, this.request, this.cleanupBackoff); - } - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskFactory.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskFactory.java index 63b6627062aa2..294f8d27cf068 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskFactory.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskFactory.java @@ -21,10 +21,10 @@ import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.stats.DecayCounter; import com.facebook.airlift.stats.ExponentialDecay; +import com.facebook.airlift.units.Duration; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.execution.LocationFactory; import com.facebook.presto.execution.NodeTaskMap; import com.facebook.presto.execution.QueryManager; @@ -42,7 +42,6 @@ import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.metadata.Split; import com.facebook.presto.operator.ForScheduler; import com.facebook.presto.server.InternalCommunicationConfig; @@ -51,15 +50,13 @@ import com.facebook.presto.sql.planner.PlanFragment; import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.airlift.units.Duration; import io.netty.channel.EventLoop; import io.netty.util.concurrent.AbstractEventExecutorGroup; +import jakarta.annotation.PreDestroy; +import jakarta.inject.Inject; import org.weakref.jmx.Managed; import org.weakref.jmx.Nested; -import javax.annotation.PreDestroy; -import javax.inject.Inject; - import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; @@ -86,12 +83,10 @@ public class HttpRemoteTaskFactory private final Codec taskUpdateRequestCodec; private final Codec taskInfoResponseCodec; private final Codec planFragmentCodec; - private final Codec metadataUpdatesCodec; private final Duration maxErrorDuration; private final Duration taskStatusRefreshMaxWait; private final Duration taskInfoRefreshMaxWait; private final HandleResolver handleResolver; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; private final Duration taskInfoUpdateInterval; private final ExecutorService coreExecutor; @@ -130,14 +125,11 @@ public HttpRemoteTaskFactory( ThriftCodec taskUpdateRequestThriftCodec, JsonCodec planFragmentJsonCodec, SmileCodec planFragmentSmileCodec, - JsonCodec metadataUpdatesJsonCodec, - SmileCodec metadataUpdatesSmileCodec, RemoteTaskStats stats, InternalCommunicationConfig communicationConfig, MetadataManager metadataManager, QueryManager queryManager, - HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager) + HandleResolver handleResolver) { this.httpClient = httpClient; this.locationFactory = locationFactory; @@ -146,7 +138,6 @@ public HttpRemoteTaskFactory( this.taskInfoUpdateInterval = taskConfig.getInfoUpdateInterval(); this.taskInfoRefreshMaxWait = taskConfig.getInfoRefreshMaxWait(); this.handleResolver = handleResolver; - this.connectorTypeSerdeManager = connectorTypeSerdeManager; this.coreExecutor = newCachedThreadPool(daemonThreadsNamed("remote-task-callback-%s")); this.executor = new BoundedExecutor(coreExecutor, config.getRemoteTaskMaxCallbackThreads()); @@ -203,12 +194,6 @@ else if (binaryTransportEnabled) { } this.taskInfoJsonCodec = taskInfoJsonCodec; - if (binaryTransportEnabled) { - this.metadataUpdatesCodec = metadataUpdatesSmileCodec; - } - else { - this.metadataUpdatesCodec = metadataUpdatesJsonCodec; - } this.planFragmentCodec = planFragmentJsonCodec; this.metadataManager = metadataManager; @@ -219,7 +204,7 @@ else if (binaryTransportEnabled) { this.taskUpdateRequestSize = new DecayCounter(ExponentialDecay.oneMinute()); this.taskUpdateSizeTrackingEnabled = taskConfig.isTaskUpdateSizeTrackingEnabled(); - this.eventLoopGroup = taskConfig.isEventLoopEnabled() ? Optional.of(new SafeEventLoopGroup(config.getRemoteTaskMaxCallbackThreads(), + this.eventLoopGroup = Optional.of(new SafeEventLoopGroup(config.getRemoteTaskMaxCallbackThreads(), new ThreadFactoryBuilder().setNameFormat("task-event-loop-%s").setDaemon(true).build(), taskConfig.getSlowMethodThresholdOnEventLoop()) { @Override @@ -227,7 +212,7 @@ protected EventLoop newChild(Executor executor, Object... args) { return new SafeEventLoop(this, executor); } - }) : Optional.empty(); + }); } @Managed @@ -266,51 +251,7 @@ public RemoteTask createRemoteTask( TableWriteInfo tableWriteInfo, SchedulerStatsTracker schedulerStatsTracker) { - if (eventLoopGroup.isPresent()) { - // Use event loop based HttpRemoteTask - return createHttpRemoteTaskWithEventLoop( - session, - taskId, - node.getNodeIdentifier(), - locationFactory.createLegacyTaskLocation(node, taskId), - locationFactory.createTaskLocation(node, taskId), - fragment, - initialSplits, - outputBuffers, - httpClient, - maxErrorDuration, - taskStatusRefreshMaxWait, - taskInfoRefreshMaxWait, - taskInfoUpdateInterval, - summarizeTaskInfo, - taskStatusCodec, - taskInfoCodec, - taskInfoJsonCodec, - taskUpdateRequestCodec, - taskInfoResponseCodec, - planFragmentCodec, - metadataUpdatesCodec, - nodeStatsTracker, - stats, - binaryTransportEnabled, - thriftTransportEnabled, - taskInfoThriftTransportEnabled, - taskUpdateRequestThriftSerdeEnabled, - taskInfoResponseThriftSerdeEnabled, - thriftProtocol, - tableWriteInfo, - maxTaskUpdateSizeInBytes, - metadataManager, - queryManager, - taskUpdateRequestSize, - taskUpdateSizeTrackingEnabled, - handleResolver, - connectorTypeSerdeManager, - schedulerStatsTracker, - (SafeEventLoopGroup.SafeEventLoop) eventLoopGroup.get().next()); - } - // Use default executor based HttpRemoteTask - return new HttpRemoteTask( + return createHttpRemoteTaskWithEventLoop( session, taskId, node.getNodeIdentifier(), @@ -320,9 +261,6 @@ public RemoteTask createRemoteTask( initialSplits, outputBuffers, httpClient, - executor, - updateScheduledExecutor, - errorScheduledExecutor, maxErrorDuration, taskStatusRefreshMaxWait, taskInfoRefreshMaxWait, @@ -334,7 +272,6 @@ public RemoteTask createRemoteTask( taskUpdateRequestCodec, taskInfoResponseCodec, planFragmentCodec, - metadataUpdatesCodec, nodeStatsTracker, stats, binaryTransportEnabled, @@ -350,7 +287,7 @@ public RemoteTask createRemoteTask( taskUpdateRequestSize, taskUpdateSizeTrackingEnabled, handleResolver, - connectorTypeSerdeManager, - schedulerStatsTracker); + schedulerStatsTracker, + (SafeEventLoopGroup.SafeEventLoop) eventLoopGroup.get().next()); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskWithEventLoop.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskWithEventLoop.java index 713ab1302023f..638d6cf3ac60f 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskWithEventLoop.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/HttpRemoteTaskWithEventLoop.java @@ -26,9 +26,10 @@ import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.log.Logger; import com.facebook.airlift.stats.DecayCounter; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.execution.FutureStateChange; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.execution.NodeTaskMap.NodeStatsTracker; @@ -50,7 +51,6 @@ import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.metadata.Split; import com.facebook.presto.operator.TaskStats; import com.facebook.presto.server.RequestErrorTracker; @@ -76,11 +76,8 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.sun.management.ThreadMXBean; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import it.unimi.dsi.fastutil.longs.LongArrayList; - -import javax.annotation.Nullable; +import jakarta.annotation.Nullable; import java.lang.management.ManagementFactory; import java.net.URI; @@ -91,12 +88,13 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; -import java.util.OptionalLong; import java.util.Set; import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Stream; import static com.facebook.airlift.http.client.HttpStatus.NO_CONTENT; @@ -116,7 +114,6 @@ import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders; import static com.facebook.presto.server.RequestHelpers.setTaskInfoAcceptTypeHeaders; import static com.facebook.presto.server.RequestHelpers.setTaskUpdateRequestContentTypeHeaders; -import static com.facebook.presto.server.TaskResourceUtils.convertFromThriftTaskInfo; import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec; @@ -187,16 +184,15 @@ public final class HttpRemoteTaskWithEventLoop private long currentRequestLastTaskUpdate; private final SetMultimap pendingSplits = HashMultimap.create(); - private volatile int pendingSourceSplitCount; - private volatile long pendingSourceSplitsWeight; + private final AtomicInteger pendingSourceSplitCount = new AtomicInteger(); + private final AtomicLong pendingSourceSplitsWeight = new AtomicLong(); private final SetMultimap pendingNoMoreSplitsForLifespan = HashMultimap.create(); // The keys of this map represent all plan nodes that have "no more splits". // The boolean value of each entry represents whether the "no more splits" notification is pending delivery to workers. private final Map noMoreSplits = new HashMap<>(); private OutputBuffers outputBuffers; private final FutureStateChange whenSplitQueueHasSpace = new FutureStateChange<>(); - private volatile boolean splitQueueHasSpace; - private OptionalLong whenSplitQueueHasSpaceThreshold = OptionalLong.empty(); + private volatile long whenSplitQueueWeightThreshold = Long.MAX_VALUE; private final boolean summarizeTaskInfo; @@ -225,7 +221,6 @@ public final class HttpRemoteTaskWithEventLoop private final boolean taskUpdateRequestThriftSerdeEnabled; private final boolean taskInfoResponseThriftSerdeEnabled; private final Protocol thriftProtocol; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; private final HandleResolver handleResolver; private final int maxTaskUpdateSizeInBytes; private final int maxUnacknowledgedSplits; @@ -240,6 +235,9 @@ public final class HttpRemoteTaskWithEventLoop private final SafeEventLoopGroup.SafeEventLoop taskEventLoop; private final String loggingPrefix; + private long startTime; + private long startedTime; + public static HttpRemoteTaskWithEventLoop createHttpRemoteTaskWithEventLoop( Session session, TaskId taskId, @@ -261,7 +259,6 @@ public static HttpRemoteTaskWithEventLoop createHttpRemoteTaskWithEventLoop( Codec taskUpdateRequestCodec, Codec taskInfoResponseCodec, Codec planFragmentCodec, - Codec metadataUpdatesCodec, NodeStatsTracker nodeStatsTracker, RemoteTaskStats stats, boolean binaryTransportEnabled, @@ -277,7 +274,6 @@ public static HttpRemoteTaskWithEventLoop createHttpRemoteTaskWithEventLoop( DecayCounter taskUpdateRequestSize, boolean taskUpdateSizeTrackingEnabled, HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager, SchedulerStatsTracker schedulerStatsTracker, SafeEventLoopGroup.SafeEventLoop taskEventLoop) { @@ -301,7 +297,6 @@ public static HttpRemoteTaskWithEventLoop createHttpRemoteTaskWithEventLoop( taskUpdateRequestCodec, taskInfoResponseCodec, planFragmentCodec, - metadataUpdatesCodec, nodeStatsTracker, stats, binaryTransportEnabled, @@ -317,7 +312,6 @@ public static HttpRemoteTaskWithEventLoop createHttpRemoteTaskWithEventLoop( taskUpdateRequestSize, taskUpdateSizeTrackingEnabled, handleResolver, - connectorTypeSerdeManager, schedulerStatsTracker, taskEventLoop); task.initialize(); @@ -344,7 +338,6 @@ private HttpRemoteTaskWithEventLoop(Session session, Codec taskUpdateRequestCodec, Codec taskInfoResponseCodec, Codec planFragmentCodec, - Codec metadataUpdatesCodec, NodeStatsTracker nodeStatsTracker, RemoteTaskStats stats, boolean binaryTransportEnabled, @@ -360,7 +353,6 @@ private HttpRemoteTaskWithEventLoop(Session session, DecayCounter taskUpdateRequestSize, boolean taskUpdateSizeTrackingEnabled, HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager, SchedulerStatsTracker schedulerStatsTracker, SafeEventLoopGroup.SafeEventLoop taskEventLoop) { @@ -385,7 +377,6 @@ private HttpRemoteTaskWithEventLoop(Session session, requireNonNull(queryManager, "queryManager is null"); requireNonNull(thriftProtocol, "thriftProtocol is null"); requireNonNull(handleResolver, "handleResolver is null"); - requireNonNull(connectorTypeSerdeManager, "connectorTypeSerdeManager is null"); requireNonNull(taskUpdateRequestSize, "taskUpdateRequestSize cannot be null"); requireNonNull(schedulerStatsTracker, "schedulerStatsTracker is null"); requireNonNull(taskEventLoop, "taskEventLoop is null"); @@ -415,7 +406,6 @@ private HttpRemoteTaskWithEventLoop(Session session, this.taskUpdateRequestThriftSerdeEnabled = taskUpdateRequestThriftSerdeEnabled; this.taskInfoResponseThriftSerdeEnabled = taskInfoResponseThriftSerdeEnabled; this.thriftProtocol = thriftProtocol; - this.connectorTypeSerdeManager = connectorTypeSerdeManager; this.handleResolver = handleResolver; this.tableWriteInfo = tableWriteInfo; this.maxTaskUpdateSizeInBytes = maxTaskUpdateSizeInBytes; @@ -444,8 +434,8 @@ private HttpRemoteTaskWithEventLoop(Session session, pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, SplitWeight.rawValueSum(tableScanSplits, Split::getSplitWeight)); } } - this.pendingSourceSplitCount = pendingSourceSplitCount; - this.pendingSourceSplitsWeight = pendingSourceSplitsWeight; + this.pendingSourceSplitCount.set(pendingSourceSplitCount); + this.pendingSourceSplitsWeight.set(pendingSourceSplitsWeight); List bufferStates = outputBuffers.getBuffers() .keySet().stream() @@ -475,7 +465,6 @@ private HttpRemoteTaskWithEventLoop(Session session, taskInfoUpdateInterval, taskInfoRefreshMaxWait, taskInfoCodec, - metadataUpdatesCodec, maxErrorDuration, summarizeTaskInfo, taskEventLoop, @@ -486,7 +475,6 @@ private HttpRemoteTaskWithEventLoop(Session session, metadataManager, queryManager, handleResolver, - connectorTypeSerdeManager, thriftProtocol); this.loggingPrefix = format("Query: %s, Task: %s", session.getQueryId(), taskId); } @@ -550,9 +538,12 @@ public URI getRemoteTaskLocation() @Override public void start() { + startTime = System.nanoTime(); safeExecuteOnEventLoop(() -> { // to start we just need to trigger an update started = true; + startedTime = System.nanoTime(); + schedulerStatsTracker.recordStartWaitForEventLoop(startedTime - startTime); scheduleUpdate(); taskStatusFetcher.start(); @@ -570,32 +561,37 @@ public void addSplits(Multimap splitsBySource) return; } + int count = 0; + long weight = 0; + for (Entry> entry : splitsBySource.asMap().entrySet()) { + PlanNodeId sourceId = entry.getKey(); + Collection splits = entry.getValue(); + + if (tableScanPlanNodeIds.contains(sourceId)) { + count += splits.size(); + weight += splits.stream().map(Split::getSplitWeight) + .mapToLong(SplitWeight::getRawValue) + .sum(); + } + } + if (count != 0) { + pendingSourceSplitCount.addAndGet(count); + pendingSourceSplitsWeight.addAndGet(weight); + updateTaskStats(); + } + safeExecuteOnEventLoop(() -> { boolean updateNeeded = false; for (Entry> entry : splitsBySource.asMap().entrySet()) { PlanNodeId sourceId = entry.getKey(); Collection splits = entry.getValue(); - boolean isTableScanSource = tableScanPlanNodeIds.contains(sourceId); checkState(!noMoreSplits.containsKey(sourceId), "noMoreSplits has already been set for %s", sourceId); - int added = 0; - long addedWeight = 0; for (Split split : splits) { - if (pendingSplits.put(sourceId, new ScheduledSplit(nextSplitId++, sourceId, split))) { - if (isTableScanSource) { - added++; - addedWeight = addExact(addedWeight, split.getSplitWeight().getRawValue()); - } - } - } - if (isTableScanSource) { - pendingSourceSplitCount += added; - pendingSourceSplitsWeight = addExact(pendingSourceSplitsWeight, addedWeight); - updateTaskStats(); + pendingSplits.put(sourceId, new ScheduledSplit(nextSplitId++, sourceId, split)); } updateNeeded = true; } - updateSplitQueueSpace(); if (updateNeeded) { needsUpdate = true; @@ -737,9 +733,7 @@ public PartitionedSplitsInfo getPartitionedSplitsInfo() @SuppressWarnings("FieldAccessNotGuarded") public PartitionedSplitsInfo getUnacknowledgedPartitionedSplitsInfo() { - int count = pendingSourceSplitCount; - long weight = pendingSourceSplitsWeight; - return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); + return PartitionedSplitsInfo.forSplitCountAndWeightSum(pendingSourceSplitCount.get(), pendingSourceSplitsWeight.get()); } @Override @@ -764,7 +758,7 @@ public int getUnacknowledgedPartitionedSplitCount() @SuppressWarnings("FieldAccessNotGuarded") private int getPendingSourceSplitCount() { - return pendingSourceSplitCount; + return pendingSourceSplitCount.get(); } private long getQueuedPartitionedSplitsWeight() @@ -779,7 +773,7 @@ private long getQueuedPartitionedSplitsWeight() @SuppressWarnings("FieldAccessNotGuarded") private long getPendingSourceSplitsWeight() { - return pendingSourceSplitsWeight; + return pendingSourceSplitsWeight.get(); } @Override @@ -797,35 +791,45 @@ public void addFinalTaskInfoListener(StateChangeListener stateChangeLi @Override public ListenableFuture whenSplitQueueHasSpace(long weightThreshold) { - if (splitQueueHasSpace) { + setSplitQueueWeightThreshold(weightThreshold); + + if (splitQueueHasSpace()) { return immediateFuture(null); } SettableFuture future = SettableFuture.create(); safeExecuteOnEventLoop(() -> { - if (whenSplitQueueHasSpaceThreshold.isPresent()) { - checkArgument(weightThreshold == whenSplitQueueHasSpaceThreshold.getAsLong(), "Multiple split queue space notification thresholds not supported"); + if (splitQueueHasSpace()) { + future.set(null); } else { - whenSplitQueueHasSpaceThreshold = OptionalLong.of(weightThreshold); - updateSplitQueueSpace(); - } - if (splitQueueHasSpace) { - future.set(null); + whenSplitQueueHasSpace.createNewListener().addListener(() -> future.set(null), taskEventLoop); } - whenSplitQueueHasSpace.createNewListener().addListener(() -> future.set(null), taskEventLoop); }, "whenSplitQueueHasSpace"); return future; } + private void setSplitQueueWeightThreshold(long weightThreshold) + { + long currentValue = whenSplitQueueWeightThreshold; + if (currentValue != Long.MAX_VALUE) { + checkArgument(weightThreshold == currentValue, "Multiple split queue space notification thresholds not supported"); + } + else { + whenSplitQueueWeightThreshold = weightThreshold; + } + } + + private boolean splitQueueHasSpace() + { + return getUnacknowledgedPartitionedSplitCount() < maxUnacknowledgedSplits && + getQueuedPartitionedSplitsWeight() < whenSplitQueueWeightThreshold; + } + private void updateSplitQueueSpace() { verify(taskEventLoop.inEventLoop()); - - // Must check whether the unacknowledged split count threshold is reached even without listeners registered yet - splitQueueHasSpace = getUnacknowledgedPartitionedSplitCount() < maxUnacknowledgedSplits && - (!whenSplitQueueHasSpaceThreshold.isPresent() || getQueuedPartitionedSplitsWeight() < whenSplitQueueHasSpaceThreshold.getAsLong()); // Only trigger notifications if a listener might be registered - if (splitQueueHasSpace && whenSplitQueueHasSpaceThreshold.isPresent()) { + if (splitQueueHasSpace()) { whenSplitQueueHasSpace.complete(null, taskEventLoop); } } @@ -851,14 +855,15 @@ private void processTaskUpdate(TaskInfo newValue, List sources) //Setting the flag as false since TaskUpdateRequest is not on thrift yet. //Once it is converted to thrift we can use the isThrift enabled flag here. - updateTaskInfo(newValue, false); + updateTaskInfo(newValue); + + int removed = 0; + long removedWeight = 0; // remove acknowledged splits, which frees memory for (TaskSource source : sources) { PlanNodeId planNodeId = source.getPlanNodeId(); boolean isTableScanSource = tableScanPlanNodeIds.contains(planNodeId); - int removed = 0; - long removedWeight = 0; for (ScheduledSplit split : source.getSplits()) { if (pendingSplits.remove(planNodeId, split)) { if (isTableScanSource) { @@ -873,14 +878,14 @@ private void processTaskUpdate(TaskInfo newValue, List sources) for (Lifespan lifespan : source.getNoMoreSplitsForLifespan()) { pendingNoMoreSplitsForLifespan.remove(planNodeId, lifespan); } - if (isTableScanSource) { - pendingSourceSplitCount -= removed; - pendingSourceSplitsWeight -= removedWeight; - } } // Update stats before split queue space to ensure node stats are up to date before waking up the scheduler - updateTaskStats(); - updateSplitQueueSpace(); + if (removed != 0) { + pendingSourceSplitCount.addAndGet(-removed); + pendingSourceSplitsWeight.addAndGet(-removedWeight); + updateTaskStats(); + updateSplitQueueSpace(); + } } private void onSuccessTaskInfo(TaskInfo result) @@ -888,7 +893,7 @@ private void onSuccessTaskInfo(TaskInfo result) verify(taskEventLoop.inEventLoop()); try { - updateTaskInfo(result, taskInfoThriftTransportEnabled); + updateTaskInfo(result); } finally { if (!getTaskInfo().getTaskStatus().getState().isDone()) { @@ -897,14 +902,11 @@ private void onSuccessTaskInfo(TaskInfo result) } } - private void updateTaskInfo(TaskInfo taskInfo, boolean isTaskInfoThriftTransportEnabled) + private void updateTaskInfo(TaskInfo taskInfo) { verify(taskEventLoop.inEventLoop()); taskStatusFetcher.updateTaskStatus(taskInfo.getTaskStatus()); - if (isTaskInfoThriftTransportEnabled) { - taskInfo = convertFromThriftTaskInfo(taskInfo, connectorTypeSerdeManager, handleResolver); - } taskInfoFetcher.updateTaskInfo(taskInfo); } @@ -927,7 +929,7 @@ private void cleanUpLocally() // Since this TaskInfo is updated in the client the "complete" flag will not be set, // indicating that the stats may not reflect the final stats on the worker. - updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus()), taskInfoThriftTransportEnabled); + updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus())); } private void onFailureTaskInfo( @@ -1133,10 +1135,9 @@ private void cleanUpTask() // clear pending splits to free memory pendingSplits.clear(); - pendingSourceSplitCount = 0; - pendingSourceSplitsWeight = 0; + pendingSourceSplitCount.set(0); + pendingSourceSplitsWeight.set(0); updateTaskStats(); - splitQueueHasSpace = true; whenSplitQueueHasSpace.complete(null, taskEventLoop); // cancel pending request @@ -1154,7 +1155,7 @@ private void cleanUpTask() HttpUriBuilder uriBuilder = getHttpUriBuilder(getTaskStatus()); Request.Builder requestBuilder = setContentTypeHeaders(binaryTransportEnabled, prepareDelete()); if (taskInfoThriftTransportEnabled) { - requestBuilder = ThriftRequestUtils.prepareThriftDelete(Protocol.BINARY); + requestBuilder = ThriftRequestUtils.prepareThriftDelete(thriftProtocol); } Request request = requestBuilder .setUri(uriBuilder.build()) @@ -1315,6 +1316,7 @@ public void success(TaskInfo value) processTaskUpdate(value, sources); updateErrorTracker.requestSucceeded(); if (oldestTaskUpdateTime != 0) { + schedulerStatsTracker.recordDeliveredUpdates(deliveredUpdates); schedulerStatsTracker.recordTaskUpdateDeliveredTime(System.nanoTime() - oldestTaskUpdateTime); } } @@ -1368,6 +1370,7 @@ private void updateStats(long currentRequestStartNanos) verify(taskEventLoop.inEventLoop()); Duration requestRoundTrip = Duration.nanosSince(currentRequestStartNanos); stats.updateRoundTripMillis(requestRoundTrip.toMillis()); + schedulerStatsTracker.recordRoundTripTime(requestRoundTrip.toMillis() * 1000000); } } @@ -1399,7 +1402,17 @@ public ThriftResponseFutureCallback(String action, Request request, Backoff clea public void onSuccess(ThriftResponse result) { verify(taskEventLoop.inEventLoop()); - onSuccessTaskInfo(result.getValue()); + if (result.getException() != null) { + onFailure(result.getException()); + return; + } + + TaskInfo taskInfo = result.getValue(); + if (taskInfo == null) { + onFailure(new RuntimeException("TaskInfo is null")); + return; + } + onSuccessTaskInfo(taskInfo); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClient.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClient.java new file mode 100644 index 0000000000000..c9ec4ef2bfa93 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClient.java @@ -0,0 +1,489 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.http.client.HeaderName; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.http.client.RequestStats; +import com.facebook.airlift.http.client.Response; +import com.facebook.airlift.http.client.ResponseHandler; +import com.facebook.airlift.http.client.StaticBodyGenerator; +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; +import com.github.luben.zstd.ZstdInputStream; +import com.github.luben.zstd.ZstdOutputStreamNoFinalizer; +import com.google.common.base.Splitter; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.util.concurrent.SettableFuture; +import com.google.inject.Inject; +import io.netty.channel.ChannelOption; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.epoll.Epoll; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.netty.ByteBufFlux; +import reactor.netty.http.HttpProtocol; +import reactor.netty.http.client.Http2AllocationStrategy; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.HttpClientResponse; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.resources.LoopResources; + +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.file.Files; +import java.security.GeneralSecurityException; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import java.util.zip.GZIPInputStream; + +import static com.facebook.airlift.security.pem.PemReader.loadPrivateKey; +import static com.facebook.airlift.security.pem.PemReader.readCertificateChain; +import static io.netty.handler.ssl.ApplicationProtocolConfig.Protocol.ALPN; +import static io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT; +import static io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE; +import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_1_1; +import static io.netty.handler.ssl.ApplicationProtocolNames.HTTP_2; +import static io.netty.handler.ssl.SslProtocols.TLS_v1_2; +import static io.netty.handler.ssl.SslProtocols.TLS_v1_3; +import static io.netty.handler.ssl.SslProvider.JDK; +import static io.netty.handler.ssl.SslProvider.OPENSSL; +import static io.netty.handler.ssl.SslProvider.isAlpnSupported; +import static java.lang.String.format; +import static java.time.temporal.ChronoUnit.MILLIS; + +public class ReactorNettyHttpClient + implements com.facebook.airlift.http.client.HttpClient, Closeable +{ + private static final Logger log = Logger.get(ReactorNettyHttpClient.class); + private static final HeaderName CONTENT_TYPE_HEADER_NAME = HeaderName.of("Content-Type"); + private static final HeaderName CONTENT_LENGTH_HEADER_NAME = HeaderName.of("Content-Length"); + private static final HeaderName CONTENT_ENCODING_HEADER_NAME = HeaderName.of("Content-Encoding"); + private static final HeaderName ACCEPT_ENCODING_HEADER_NAME = HeaderName.of("Accept-Encoding"); + + private final Duration requestTimeout; + private HttpClient httpClient; + private final HttpClientConnectionPoolStats connectionPoolStats; + private final HttpClientStats httpClientStats; + private final boolean isHttp2CompressionEnabled; + private final int payloadSizeThreshold; + private final double compressionSavingThreshold; + + @Inject + public ReactorNettyHttpClient(ReactorNettyHttpClientConfig config, HttpClientConnectionPoolStats connectionPoolStats, HttpClientStats httpClientStats) + { + this.connectionPoolStats = connectionPoolStats; + this.httpClientStats = httpClientStats; + this.isHttp2CompressionEnabled = config.isHttp2CompressionEnabled(); + this.payloadSizeThreshold = config.getPayloadSizeThreshold(); + this.compressionSavingThreshold = config.getCompressionSavingThreshold(); + SslContext sslContext = null; + if (config.isHttpsEnabled()) { + try { + File keyFile = new File(config.getKeyStorePath()); + File trustCertificateFile = new File(config.getTrustStorePath()); + if (!Files.exists(keyFile.toPath()) || !Files.isReadable(keyFile.toPath())) { + throw new IllegalArgumentException("KeyStore file path is unreadable or doesn't exist"); + } + if (!Files.exists(trustCertificateFile.toPath()) || !Files.isReadable(trustCertificateFile.toPath())) { + throw new IllegalArgumentException("TrustStore file path is unreadable or doesn't exist"); + } + PrivateKey privateKey = loadPrivateKey(keyFile, Optional.of(config.getKeyStorePassword())); + X509Certificate[] certificateChain = readCertificateChain(keyFile).toArray(new X509Certificate[0]); + X509Certificate[] trustChain = readCertificateChain(trustCertificateFile).toArray(new X509Certificate[0]); + + String os = System.getProperty("os.name"); + if (os.toLowerCase(Locale.ENGLISH).contains("linux")) { + // Make sure Open ssl is available for linux deployments + if (!OpenSsl.isAvailable()) { + throw new UnsupportedOperationException(format("OpenSsl is not available. Stacktrace: %s", Arrays.toString(OpenSsl.unavailabilityCause().getStackTrace()).replace(',', '\n'))); + } + // Make sure epoll threads are used for linux deployments + if (!Epoll.isAvailable()) { + throw new UnsupportedOperationException(format("Epoll is not available. Stacktrace: %s", Arrays.toString(Epoll.unavailabilityCause().getStackTrace()).replace(',', '\n'))); + } + } + + SslProvider provider = isAlpnSupported(OPENSSL) ? OPENSSL : JDK; + SslContextBuilder sslContextBuilder = SslContextBuilder.forClient() + .sslProvider(provider) + .protocols(TLS_v1_3, TLS_v1_2) + .keyManager(privateKey, certificateChain) + .trustManager(trustChain) + .applicationProtocolConfig(new ApplicationProtocolConfig(ALPN, NO_ADVERTISE, ACCEPT, HTTP_2, HTTP_1_1)); + if (config.getCipherSuites().isPresent()) { + sslContextBuilder.ciphers(Splitter + .on(',') + .trimResults() + .omitEmptyStrings() + .splitToList(config.getCipherSuites().get())); + } + + sslContext = sslContextBuilder.build(); + } + catch (IOException | GeneralSecurityException e) { + throw new RuntimeException("Failed to configure SSL context", e); + } + } + + /* + * This is like wrapper and underlying there is a separate pool of connections for http1 and http2 protocols. Basically different pools for different protocols. + * Reactor Netty's HttpConnectionProvider will wrap this connection provider and handle protocol routing in the acquire() call. It examines + * the configured protocols and routes requests appropriately. So the http2 allocation strategy defined here will only be used for http2 connections. + */ + ConnectionProvider.Builder poolBuilder = ConnectionProvider.builder("shared-pool") + .maxConnections(config.getMaxConnections()) + .fifo(); + + if (config.getMaxIdleTime().toMillis() > 0) { + poolBuilder.maxIdleTime(java.time.Duration.of(config.getMaxIdleTime().toMillis(), MILLIS)); + } + if (config.getEvictBackgroundTime().toMillis() > 0) { + poolBuilder.evictInBackground(java.time.Duration.of(config.getEvictBackgroundTime().toMillis(), MILLIS)); + } + if (config.getPendingAcquireTimeout().toMillis() > 0) { + poolBuilder.pendingAcquireTimeout(java.time.Duration.of(config.getPendingAcquireTimeout().toMillis(), MILLIS)); + } + + poolBuilder.metrics(config.isHttp2ConnectionPoolStatsTrackingEnabled(), () -> connectionPoolStats) + .allocationStrategy((Http2AllocationStrategy.builder() + .maxConnections(config.getMaxConnections()) + .maxConcurrentStreams(config.getMaxStreamPerChannel()) + .minConnections(config.getMinConnections()).build())) + .build(); + + LoopResources loopResources = LoopResources.create("event-loop", config.getSelectorThreadCount(), config.getEventLoopThreadCount(), true, false); + + // Create HTTP/2 client + SslContext finalSslContext = sslContext; + + this.httpClient = HttpClient + .create(poolBuilder.build()) // The custom pool is wrapped with a HttpConnectionProvider over here + .compress(false) // we will enable response compression manually + .protocol(HttpProtocol.H2, HttpProtocol.HTTP11) + .runOn(loopResources, true) + .http2Settings(settings -> { + settings.maxConcurrentStreams(config.getMaxStreamPerChannel()); + if (config.getMaxInitialWindowSize().toBytes() > 0) { + settings.initialWindowSize((int) (config.getMaxInitialWindowSize().toBytes())); + } + if (config.getMaxFrameSize().toBytes() > 0) { + settings.maxFrameSize((int) (config.getMaxFrameSize().toBytes())); + } + }) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) config.getConnectTimeout().toMillis()) + .option(ChannelOption.SO_SNDBUF, config.getTcpBufferSize()) + .option(ChannelOption.SO_RCVBUF, config.getTcpBufferSize()) + .option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(config.getWriteBufferWaterMarkLow(), config.getWriteBufferWaterMarkHigh())); + if (config.isChannelOptionSoKeepAliveEnabled()) { + httpClient = httpClient.option(ChannelOption.SO_KEEPALIVE, config.isChannelOptionSoKeepAliveEnabled()); + } + if (config.isChannelOptionTcpNoDelayEnabled()) { + httpClient = httpClient.option(ChannelOption.TCP_NODELAY, config.isChannelOptionTcpNoDelayEnabled()); + } + if (config.isHttp2ClientStatsTrackingEnabled()) { + httpClient = httpClient.metrics(config.isHttp2ClientStatsTrackingEnabled(), () -> httpClientStats, Function.identity()); + } + + if (config.isHttpsEnabled()) { + if (finalSslContext == null) { + throw new IllegalStateException("SSL context must be configured for HTTPS"); + } + httpClient = httpClient.secure(spec -> spec.sslContext(finalSslContext)); + } + + this.requestTimeout = config.getRequestTimeout(); + } + + @Override + public T execute(Request request, ResponseHandler responseHandler) + throws E + { + throw new UnsupportedOperationException(); + } + + public HttpResponseFuture executeAsync(Request airliftRequest, ResponseHandler responseHandler) + { + SettableFuture listenableFuture = SettableFuture.create(); + + // Set the request headers + HttpClient client = this.httpClient.headers(hdr -> { + for (Map.Entry entry : airliftRequest.getHeaders().entries()) { + hdr.set(entry.getKey(), entry.getValue()); + } + + if (isHttp2CompressionEnabled) { + hdr.set(ACCEPT_ENCODING_HEADER_NAME.toString(), "zstd, gzip"); + } + }); + + URI uri = airliftRequest.getUri(); + Disposable disposable; + switch (airliftRequest.getMethod()) { + case "GET": + disposable = client.get() + .uri(uri) + .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response))) + // Request timeout + .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS)) + .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture)); + break; + case "POST": + byte[] postBytes = ((StaticBodyGenerator) airliftRequest.getBodyGenerator()).getBody(); + byte[] bodyToSend = postBytes; + HttpClient postClient = client; + // We manually do compression for request, use zstd + if (isHttp2CompressionEnabled && postBytes.length >= payloadSizeThreshold) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(postBytes.length / 2); + try (ZstdOutputStreamNoFinalizer zstdOutput = new ZstdOutputStreamNoFinalizer(baos)) { + zstdOutput.write(postBytes); + } + + byte[] compressedBytes = baos.toByteArray(); + double compressionRatio = (double) (postBytes.length - compressedBytes.length) / postBytes.length; + if (compressionRatio >= compressionSavingThreshold) { + bodyToSend = compressedBytes; + postClient = client.headers(h -> h.set(CONTENT_ENCODING_HEADER_NAME.toString(), "zstd")); + } + } + catch (IOException e) { + onError(listenableFuture, e); + disposable = () -> {}; + break; + } + } + + disposable = postClient.post() + .uri(uri) + .send(ByteBufFlux.fromInbound(Mono.just(bodyToSend))) + .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response))) + // Request timeout + .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS)) + .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture)); + break; + case "DELETE": + disposable = client.delete() + .uri(uri) + .responseSingle((response, bytes) -> bytes.asInputStream().zipWith(Mono.just(response))) + // Request timeout + .timeout(java.time.Duration.of(requestTimeout.toMillis(), MILLIS)) + .subscribe(t -> onSuccess(responseHandler, t.getT1(), t.getT2(), listenableFuture), e -> onError(listenableFuture, e), () -> onComplete(listenableFuture)); + break; + default: + throw new UnsupportedOperationException("Unexpected request: " + airliftRequest); + } + + return new HttpResponseFuture() + { + @Override + public boolean cancel(boolean mayInterruptIfRunning) + { + disposable.dispose(); + return listenableFuture.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() + { + return listenableFuture.isCancelled(); + } + + @Override + public boolean isDone() + { + return listenableFuture.isDone(); + } + + @Override + public Object get() + throws InterruptedException, ExecutionException + { + return listenableFuture.get(); + } + + @Override + public Object get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException + { + return listenableFuture.get(timeout, unit); + } + + @Override + public void addListener(Runnable listener, Executor executor) + { + listenableFuture.addListener(listener, executor); + } + + @Override + public String getState() + { + return ""; + } + }; + } + + public void onSuccess(ResponseHandler responseHandler, InputStream inputStream, HttpClientResponse response, SettableFuture listenableFuture) + { + ListMultimap responseHeaders = ArrayListMultimap.create(); + HttpHeaders headers = response.responseHeaders(); + int status = response.status().code(); + if (status != 200 && status != 204) { + listenableFuture.setException(new RuntimeException("Invalid response status: " + status)); + return; + } + + long contentLength = 0; + String contentEncoding = null; + // Iterate over the headers + for (String name : headers.names()) { + if (name.equalsIgnoreCase(CONTENT_LENGTH_HEADER_NAME.toString())) { + String val = headers.get(name); + contentLength = Integer.parseInt(val); + responseHeaders.put(CONTENT_LENGTH_HEADER_NAME, val); + } + else if (name.equalsIgnoreCase(CONTENT_TYPE_HEADER_NAME.toString())) { + responseHeaders.put(CONTENT_TYPE_HEADER_NAME, headers.get(name)); + } + else if (name.equalsIgnoreCase(CONTENT_ENCODING_HEADER_NAME.toString())) { + contentEncoding = headers.get(name); + } + else { + responseHeaders.put(HeaderName.of(name), headers.get(name)); + } + } + + if (!responseHeaders.containsKey(CONTENT_TYPE_HEADER_NAME) || responseHeaders.get(CONTENT_TYPE_HEADER_NAME).size() != 1) { + listenableFuture.setException(new RuntimeException("Expected ContentType header: " + responseHeaders)); + return; + } + + final InputStream[] streamHolder = new InputStream[1]; + streamHolder[0] = inputStream; + try { + if (contentEncoding != null && !contentEncoding.equalsIgnoreCase("identity")) { + if (contentEncoding.equalsIgnoreCase("zstd")) { + streamHolder[0] = new ZstdInputStream(inputStream); + } + else if (contentEncoding.equalsIgnoreCase("gzip")) { + streamHolder[0] = new GZIPInputStream(inputStream); + } + else { + throw new RuntimeException(format("Unsupported Content-Encoding: %s. Supported: zstd, gzip.", contentEncoding)); + } + } + + long finalContentLength = contentLength; + Object a = responseHandler.handle(null, new Response() + { + @Override + public int getStatusCode() + { + return status; + } + + @Override + public ListMultimap getHeaders() + { + return responseHeaders; + } + + @Override + public long getBytesRead() + { + return finalContentLength; + } + + @Override + public InputStream getInputStream() + throws IOException + { + return streamHolder[0]; + } + }); + // closing it here to prevent memory leak of bytebuf + if (streamHolder[0] != null) { + streamHolder[0].close(); + } + listenableFuture.set(a); + } + catch (Exception e) { + listenableFuture.setException(e); + } + finally { + try { + streamHolder[0].close(); + } + catch (IOException e) { + log.warn(e, "Failed to close input stream"); + } + } + } + + public void onError(SettableFuture listenableFuture, Throwable t) + { + listenableFuture.setException(t); + } + + public void onComplete(SettableFuture listenableFuture) + { + if (!listenableFuture.isDone()) { + listenableFuture.setException(new RuntimeException("completed without success or failure")); + } + } + + @Override + public RequestStats getStats() + { + return null; + } + + @Override + public long getMaxContentLength() + { + return 0; + } + + @Override + public void close() + { + // void + } + + @Override + public boolean isClosed() + { + return false; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClientConfig.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClientConfig.java new file mode 100644 index 0000000000000..bbea1c537414a --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/ReactorNettyHttpClientConfig.java @@ -0,0 +1,414 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; + +import java.util.Optional; + +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class ReactorNettyHttpClientConfig +{ + private boolean reactorNettyHttpClientEnabled; + private boolean httpsEnabled; + private int minConnections = 50; + private int maxConnections = 100; + private int maxStreamPerChannel = 100; + private int selectorThreadCount = Runtime.getRuntime().availableProcessors(); + private int eventLoopThreadCount = Runtime.getRuntime().availableProcessors(); + private Duration connectTimeout = new Duration(10, SECONDS); + private Duration requestTimeout = new Duration(10, SECONDS); + private Duration maxIdleTime = new Duration(0, SECONDS); + private Duration evictBackgroundTime = new Duration(0, SECONDS); + private Duration pendingAcquireTimeout = new Duration(0, SECONDS); + private DataSize maxInitialWindowSize = new DataSize(0, MEGABYTE); + private DataSize maxFrameSize = new DataSize(0, MEGABYTE); + private String keyStorePath; + private String keyStorePassword; + private String trustStorePath; + private Optional cipherSuites = Optional.empty(); + + private boolean http2CompressionEnabled; + private DataSize payloadSizeThreshold = new DataSize(50, KILOBYTE); + private double compressionSavingThreshold = 0.1; + private DataSize tcpBufferSize = new DataSize(512, KILOBYTE); + private DataSize writeBufferWaterMarkLow = new DataSize(256, KILOBYTE); + private DataSize writeBufferWaterMarkHigh = new DataSize(512, KILOBYTE); + + private boolean isHttp2ConnectionPoolStatsTrackingEnabled; + private boolean isHttp2ClientStatsTrackingEnabled; + private boolean isChannelOptionSoKeepAliveEnabled = true; + private boolean isChannelOptionTcpNoDelayEnabled = true; + + public boolean isHttp2ClientStatsTrackingEnabled() + { + return isHttp2ClientStatsTrackingEnabled; + } + + @Config("reactor.enable-http2-client-stats-tracking") + public ReactorNettyHttpClientConfig setHttp2ClientStatsTrackingEnabled(boolean isHttp2ClientStatsTrackingEnabled) + { + this.isHttp2ClientStatsTrackingEnabled = isHttp2ClientStatsTrackingEnabled; + return this; + } + + public boolean isHttp2ConnectionPoolStatsTrackingEnabled() + { + return isHttp2ConnectionPoolStatsTrackingEnabled; + } + + @Config("reactor.enable-http2-connection-pool-stats-tracking") + public ReactorNettyHttpClientConfig setHttp2ConnectionPoolStatsTrackingEnabled(boolean isHttp2ConnectionPoolStatsTrackingEnabled) + { + this.isHttp2ConnectionPoolStatsTrackingEnabled = isHttp2ConnectionPoolStatsTrackingEnabled; + return this; + } + + public boolean isChannelOptionSoKeepAliveEnabled() + { + return isChannelOptionSoKeepAliveEnabled; + } + + @Config("reactor.channel-option-so-keep-alive") + public ReactorNettyHttpClientConfig setChannelOptionSoKeepAliveEnabled(boolean isChannelOptionSoKeepAliveEnabled) + { + this.isChannelOptionSoKeepAliveEnabled = isChannelOptionSoKeepAliveEnabled; + return this; + } + + public boolean isChannelOptionTcpNoDelayEnabled() + { + return isChannelOptionTcpNoDelayEnabled; + } + + @Config("reactor.channel-option-tcp-no-delay") + public ReactorNettyHttpClientConfig setChannelOptionTcpNoDelayEnabled(boolean isChannelOptionTcpNoDelayEnabled) + { + this.isChannelOptionTcpNoDelayEnabled = isChannelOptionTcpNoDelayEnabled; + return this; + } + + @Config("reactor.enable-http2-compression") + public ReactorNettyHttpClientConfig setHttp2CompressionEnabled(boolean http2CompressionEnabled) + { + this.http2CompressionEnabled = http2CompressionEnabled; + return this; + } + + public boolean isHttp2CompressionEnabled() + { + return http2CompressionEnabled; + } + + public double getCompressionSavingThreshold() + { + return compressionSavingThreshold; + } + + @Config("reactor.compression-ratio-threshold") + @ConfigDescription("Use compressed data if the compression ratio is above the threshold") + public ReactorNettyHttpClientConfig setCompressionSavingThreshold(double compressionSavingThreshold) + { + this.compressionSavingThreshold = compressionSavingThreshold; + return this; + } + + @Min(1024) + @Max(1024 * 1024) + public int getTcpBufferSize() + { + return (int) tcpBufferSize.toBytes(); + } + + @Config("reactor.tcp-buffer-size") + public ReactorNettyHttpClientConfig setTcpBufferSize(DataSize tcpBufferSize) + { + this.tcpBufferSize = tcpBufferSize; + return this; + } + + @Min(1024) + @Max(1024 * 1024) + public int getWriteBufferWaterMarkLow() + { + return (int) writeBufferWaterMarkLow.toBytes(); + } + + @Config("reactor.tcp-write-buffer-water-mark-low") + public ReactorNettyHttpClientConfig setWriteBufferWaterMarkLow(DataSize writeBufferWaterMarkLow) + { + this.writeBufferWaterMarkLow = writeBufferWaterMarkLow; + return this; + } + + @Min(1024) + @Max(1024 * 1024) + public int getWriteBufferWaterMarkHigh() + { + return (int) writeBufferWaterMarkHigh.toBytes(); + } + + @Config("reactor.tcp-write-buffer-water-mark-high") + public ReactorNettyHttpClientConfig setWriteBufferWaterMarkHigh(DataSize writeBufferWaterMarkHigh) + { + this.writeBufferWaterMarkHigh = writeBufferWaterMarkHigh; + return this; + } + + @Min(1024) + @Max(512 * 1024) + public int getPayloadSizeThreshold() + { + return (int) payloadSizeThreshold.toBytes(); + } + + @Config("reactor.payload-compression-threshold") + public ReactorNettyHttpClientConfig setPayloadSizeThreshold(DataSize payloadSizeThreshold) + { + this.payloadSizeThreshold = payloadSizeThreshold; + return this; + } + + public boolean isReactorNettyHttpClientEnabled() + { + return reactorNettyHttpClientEnabled; + } + + @Config("reactor.netty-http-client-enabled") + @ConfigDescription("Enable reactor netty client for http communication between coordinator and worker") + public ReactorNettyHttpClientConfig setReactorNettyHttpClientEnabled(boolean reactorNettyHttpClientEnabled) + { + this.reactorNettyHttpClientEnabled = reactorNettyHttpClientEnabled; + return this; + } + + public boolean isHttpsEnabled() + { + return httpsEnabled; + } + + @Config("reactor.https-enabled") + public ReactorNettyHttpClientConfig setHttpsEnabled(boolean httpsEnabled) + { + this.httpsEnabled = httpsEnabled; + return this; + } + + public int getMinConnections() + { + return minConnections; + } + + @Min(10) + @Config("reactor.min-connections") + @ConfigDescription("Min number of connections in the pool used by the netty http2 client to talk to the workers") + public ReactorNettyHttpClientConfig setMinConnections(int minConnections) + { + this.minConnections = minConnections; + return this; + } + + public int getMaxConnections() + { + return maxConnections; + } + + @Min(10) + @Config("reactor.max-connections") + @ConfigDescription("Max total number of connections in the pool used by the netty client to talk to the workers") + public ReactorNettyHttpClientConfig setMaxConnections(int maxConnections) + { + this.maxConnections = maxConnections; + return this; + } + + public int getMaxStreamPerChannel() + { + return maxStreamPerChannel; + } + + @Config("reactor.max-stream-per-channel") + @ConfigDescription("Max number of streams per single tcp connection between coordinator and worker") + public ReactorNettyHttpClientConfig setMaxStreamPerChannel(int maxStreamPerChannel) + { + this.maxStreamPerChannel = maxStreamPerChannel; + return this; + } + + public int getSelectorThreadCount() + { + return selectorThreadCount; + } + + @Config("reactor.selector-thread-count") + @ConfigDescription("Number of select threads used by netty to handle the http messages") + public ReactorNettyHttpClientConfig setSelectorThreadCount(int selectorThreadCount) + { + this.selectorThreadCount = selectorThreadCount; + return this; + } + + public int getEventLoopThreadCount() + { + return eventLoopThreadCount; + } + + @Config("reactor.event-loop-thread-count") + @ConfigDescription("Number of event loop threads used by netty to handle the http messages") + public ReactorNettyHttpClientConfig setEventLoopThreadCount(int eventLoopThreadCount) + { + this.eventLoopThreadCount = eventLoopThreadCount; + return this; + } + + public Duration getConnectTimeout() + { + return connectTimeout; + } + + @Config("reactor.connect-timeout") + public ReactorNettyHttpClientConfig setConnectTimeout(Duration connectTimeout) + { + this.connectTimeout = connectTimeout; + return this; + } + + public Duration getRequestTimeout() + { + return requestTimeout; + } + + @Config("reactor.request-timeout") + public ReactorNettyHttpClientConfig setRequestTimeout(Duration requestTimeout) + { + this.requestTimeout = requestTimeout; + return this; + } + + public Duration getMaxIdleTime() + { + return maxIdleTime; + } + + @Config("reactor.max-idle-time") + public ReactorNettyHttpClientConfig setMaxIdleTime(Duration maxIdleTime) + { + this.maxIdleTime = maxIdleTime; + return this; + } + + public Duration getEvictBackgroundTime() + { + return evictBackgroundTime; + } + + @Config("reactor.evict-background-time") + public ReactorNettyHttpClientConfig setEvictBackgroundTime(Duration evictBackgroundTime) + { + this.evictBackgroundTime = evictBackgroundTime; + return this; + } + + public Duration getPendingAcquireTimeout() + { + return pendingAcquireTimeout; + } + + @Config("reactor.pending-acquire-timeout") + public ReactorNettyHttpClientConfig setPendingAcquireTimeout(Duration pendingAcquireTimeout) + { + this.pendingAcquireTimeout = pendingAcquireTimeout; + return this; + } + + public DataSize getMaxInitialWindowSize() + { + return maxInitialWindowSize; + } + + @Config("reactor.max-initial-window-size") + public ReactorNettyHttpClientConfig setMaxInitialWindowSize(DataSize maxInitialWindowSize) + { + this.maxInitialWindowSize = maxInitialWindowSize; + return this; + } + + public DataSize getMaxFrameSize() + { + return maxFrameSize; + } + + @Config("reactor.max-frame-size") + public ReactorNettyHttpClientConfig setMaxFrameSize(DataSize maxFrameSize) + { + this.maxFrameSize = maxFrameSize; + return this; + } + + public String getKeyStorePath() + { + return keyStorePath; + } + + @Config("reactor.keystore-path") + public ReactorNettyHttpClientConfig setKeyStorePath(String keyStorePath) + { + this.keyStorePath = keyStorePath; + return this; + } + + public String getKeyStorePassword() + { + return keyStorePassword; + } + + @Config("reactor.keystore-password") + public ReactorNettyHttpClientConfig setKeyStorePassword(String keyStorePassword) + { + this.keyStorePassword = keyStorePassword; + return this; + } + + public String getTrustStorePath() + { + return trustStorePath; + } + + @Config("reactor.truststore-path") + public ReactorNettyHttpClientConfig setTrustStorePath(String trustStorePath) + { + this.trustStorePath = trustStorePath; + return this; + } + + public Optional getCipherSuites() + { + return cipherSuites; + } + + @Config("reactor.cipher-suites") + public ReactorNettyHttpClientConfig setCipherSuites(String cipherSuites) + { + this.cipherSuites = Optional.ofNullable(cipherSuites); + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcher.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcher.java deleted file mode 100644 index ed5e5b21cfa1c..0000000000000 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcher.java +++ /dev/null @@ -1,441 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.remotetask; - -import com.facebook.airlift.concurrent.SetThreadName; -import com.facebook.airlift.http.client.HttpClient; -import com.facebook.airlift.http.client.HttpUriBuilder; -import com.facebook.airlift.http.client.Request; -import com.facebook.airlift.http.client.Response; -import com.facebook.airlift.http.client.ResponseHandler; -import com.facebook.airlift.http.client.thrift.ThriftRequestUtils; -import com.facebook.airlift.http.client.thrift.ThriftResponseHandler; -import com.facebook.airlift.json.Codec; -import com.facebook.airlift.json.JsonCodec; -import com.facebook.airlift.json.smile.SmileCodec; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; -import com.facebook.presto.execution.QueryManager; -import com.facebook.presto.execution.StateMachine; -import com.facebook.presto.execution.StateMachine.StateChangeListener; -import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskInfo; -import com.facebook.presto.execution.TaskStatus; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; -import com.facebook.presto.server.RequestErrorTracker; -import com.facebook.presto.server.SimpleHttpResponseCallback; -import com.facebook.presto.server.SimpleHttpResponseHandler; -import com.facebook.presto.server.smile.BaseResponse; -import com.facebook.presto.server.thrift.ThriftHttpResponseHandler; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; - -import javax.annotation.concurrent.GuardedBy; - -import java.net.URI; -import java.util.Optional; -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Consumer; - -import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; -import static com.facebook.airlift.http.client.Request.Builder.prepareGet; -import static com.facebook.airlift.http.client.Request.Builder.preparePost; -import static com.facebook.airlift.http.client.ResponseHandlerUtils.propagate; -import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; -import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker; -import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders; -import static com.facebook.presto.server.TaskResourceUtils.convertFromThriftTaskInfo; -import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; -import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; -import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; -import static io.airlift.units.Duration.nanosSince; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; - -public class TaskInfoFetcher - implements SimpleHttpResponseCallback -{ - private final TaskId taskId; - private final Consumer onFail; - private final StateMachine taskInfo; - private final StateMachine> finalTaskInfo; - private final Codec taskInfoCodec; - private final Codec metadataUpdatesCodec; - - private final long updateIntervalMillis; - private final Duration taskInfoRefreshMaxWait; - private final AtomicLong lastUpdateNanos = new AtomicLong(); - - private final ScheduledExecutorService updateScheduledExecutor; - - private final Executor executor; - private final HttpClient httpClient; - private final RequestErrorTracker errorTracker; - - private final boolean summarizeTaskInfo; - - @GuardedBy("this") - private final AtomicLong currentRequestStartNanos = new AtomicLong(); - - private final RemoteTaskStats stats; - - @GuardedBy("this") - private boolean running; - - @GuardedBy("this") - private ScheduledFuture scheduledFuture; - - @GuardedBy("this") - private ListenableFuture> future; - - @GuardedBy("this") - private ListenableFuture metadataUpdateFuture; - - private final boolean isBinaryTransportEnabled; - private final boolean isThriftTransportEnabled; - private final Session session; - private final MetadataManager metadataManager; - private final QueryManager queryManager; - private final HandleResolver handleResolver; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; - private final Protocol thriftProtocol; - - public TaskInfoFetcher( - Consumer onFail, - TaskInfo initialTask, - HttpClient httpClient, - Duration updateInterval, - Duration taskInfoRefreshMaxWait, - Codec taskInfoCodec, - Codec metadataUpdatesCodec, - Duration maxErrorDuration, - boolean summarizeTaskInfo, - Executor executor, - ScheduledExecutorService updateScheduledExecutor, - ScheduledExecutorService errorScheduledExecutor, - RemoteTaskStats stats, - boolean isBinaryTransportEnabled, - boolean isThriftTransportEnabled, - Session session, - MetadataManager metadataManager, - QueryManager queryManager, - HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager, - Protocol thriftProtocol) - { - requireNonNull(initialTask, "initialTask is null"); - requireNonNull(errorScheduledExecutor, "errorScheduledExecutor is null"); - - this.taskId = initialTask.getTaskId(); - this.onFail = requireNonNull(onFail, "onFail is null"); - this.taskInfo = new StateMachine<>("task " + taskId, executor, initialTask); - this.finalTaskInfo = new StateMachine<>("task-" + taskId, executor, Optional.empty()); - this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null"); - - this.metadataUpdatesCodec = requireNonNull(metadataUpdatesCodec, "metadataUpdatesCodec is null"); - - this.updateIntervalMillis = requireNonNull(updateInterval, "updateInterval is null").toMillis(); - this.taskInfoRefreshMaxWait = requireNonNull(taskInfoRefreshMaxWait, "taskInfoRefreshMaxWait is null"); - this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null"); - this.errorTracker = taskRequestErrorTracker(taskId, initialTask.getTaskStatus().getSelf(), maxErrorDuration, errorScheduledExecutor, "getting info for task"); - - this.summarizeTaskInfo = summarizeTaskInfo; - - this.executor = requireNonNull(executor, "executor is null"); - this.httpClient = requireNonNull(httpClient, "httpClient is null"); - this.stats = requireNonNull(stats, "stats is null"); - this.isBinaryTransportEnabled = isBinaryTransportEnabled; - this.isThriftTransportEnabled = isThriftTransportEnabled; - this.session = requireNonNull(session, "session is null"); - this.metadataManager = requireNonNull(metadataManager, "metadataManager is null"); - this.queryManager = requireNonNull(queryManager, "queryManager is null"); - this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); - this.connectorTypeSerdeManager = requireNonNull(connectorTypeSerdeManager, "connectorTypeSerdeManager is null"); - this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null"); - } - - public TaskInfo getTaskInfo() - { - return taskInfo.get(); - } - - public synchronized void start() - { - if (running) { - // already running - return; - } - running = true; - scheduleUpdate(); - } - - private synchronized void stop() - { - running = false; - if (future != null) { - // do not terminate if the request is already running to avoid closing pooled connections - future.cancel(false); - future = null; - } - if (scheduledFuture != null) { - scheduledFuture.cancel(true); - } - } - - /** - * Add a listener for the final task info. This notification is guaranteed to be fired only once. - * Listener is always notified asynchronously using a dedicated notification thread pool so, care should - * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is - * possible notifications are observed out of order due to the asynchronous execution. - */ - public void addFinalTaskInfoListener(StateChangeListener stateChangeListener) - { - AtomicBoolean done = new AtomicBoolean(); - StateChangeListener> fireOnceStateChangeListener = finalTaskInfo -> { - if (finalTaskInfo.isPresent() && done.compareAndSet(false, true)) { - stateChangeListener.stateChanged(finalTaskInfo.get()); - } - }; - finalTaskInfo.addStateChangeListener(fireOnceStateChangeListener); - fireOnceStateChangeListener.stateChanged(finalTaskInfo.get()); - } - - private synchronized void scheduleUpdate() - { - scheduledFuture = updateScheduledExecutor.scheduleWithFixedDelay(() -> { - try { - synchronized (this) { - // if the previous request still running, don't schedule a new request - if (future != null && !future.isDone()) { - return; - } - } - if (nanosSince(lastUpdateNanos.get()).toMillis() >= updateIntervalMillis) { - sendNextRequest(); - } - } - catch (Throwable t) { - fatal(t); - throw t; - } - }, 0, 100, MILLISECONDS); - } - - private synchronized void sendNextRequest() - { - TaskInfo taskInfo = getTaskInfo(); - TaskStatus taskStatus = taskInfo.getTaskStatus(); - - if (!running) { - return; - } - - // we already have the final task info - if (isDone(getTaskInfo())) { - stop(); - return; - } - - // if we have an outstanding request - if (future != null && !future.isDone()) { - return; - } - - // if throttled due to error, asynchronously wait for timeout and try again - ListenableFuture errorRateLimit = errorTracker.acquireRequestPermit(); - if (!errorRateLimit.isDone()) { - errorRateLimit.addListener(this::sendNextRequest, executor); - return; - } - - MetadataUpdates metadataUpdateRequests = taskInfo.getMetadataUpdates(); - if (!metadataUpdateRequests.getMetadataUpdates().isEmpty()) { - scheduleMetadataUpdates(metadataUpdateRequests); - } - - HttpUriBuilder httpUriBuilder = uriBuilderFrom(taskStatus.getSelf()); - URI uri = summarizeTaskInfo ? httpUriBuilder.addParameter("summarize").build() : httpUriBuilder.build(); - Request.Builder requestBuilder = setContentTypeHeaders(isBinaryTransportEnabled, prepareGet()); - - ResponseHandler responseHandler; - if (isThriftTransportEnabled) { - requestBuilder = ThriftRequestUtils.prepareThriftGet(thriftProtocol); - responseHandler = new ThriftResponseHandler(unwrapThriftCodec(taskInfoCodec)); - } - else if (isBinaryTransportEnabled) { - responseHandler = createFullSmileResponseHandler((SmileCodec) taskInfoCodec); - } - else { - responseHandler = createAdaptingJsonResponseHandler((JsonCodec) taskInfoCodec); - } - - if (taskInfoRefreshMaxWait.toMillis() != 0L) { - requestBuilder.setHeader(PRESTO_CURRENT_STATE, taskStatus.getState().toString()) - .setHeader(PRESTO_MAX_WAIT, taskInfoRefreshMaxWait.toString()); - } - - Request request = requestBuilder.setUri(uri).build(); - errorTracker.startRequest(); - future = httpClient.executeAsync(request, responseHandler); - currentRequestStartNanos.set(System.nanoTime()); - FutureCallback callback; - if (isThriftTransportEnabled) { - callback = new ThriftHttpResponseHandler(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR); - } - else { - callback = new SimpleHttpResponseHandler<>(this, request.getUri(), stats.getHttpResponseStats(), REMOTE_TASK_ERROR); - } - - Futures.addCallback( - future, - callback, - executor); - } - - synchronized void updateTaskInfo(TaskInfo newValue) - { - boolean updated = taskInfo.setIf(newValue, oldValue -> { - TaskStatus oldTaskStatus = oldValue.getTaskStatus(); - TaskStatus newTaskStatus = newValue.getTaskStatus(); - if (oldTaskStatus.getState().isDone()) { - // never update if the task has reached a terminal state - return false; - } - // don't update to an older version (same version is ok) - return newTaskStatus.getVersion() >= oldTaskStatus.getVersion(); - }); - - if (updated && newValue.getTaskStatus().getState().isDone()) { - finalTaskInfo.compareAndSet(Optional.empty(), Optional.of(newValue)); - stop(); - } - } - - @Override - public void success(TaskInfo newValue) - { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - lastUpdateNanos.set(System.nanoTime()); - - long startNanos; - synchronized (this) { - startNanos = this.currentRequestStartNanos.get(); - } - updateStats(startNanos); - errorTracker.requestSucceeded(); - if (isThriftTransportEnabled) { - newValue = convertFromThriftTaskInfo(newValue, connectorTypeSerdeManager, handleResolver); - } - updateTaskInfo(newValue); - } - } - - @Override - public void failed(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - lastUpdateNanos.set(System.nanoTime()); - - try { - // if task not already done, record error - if (!isDone(getTaskInfo())) { - errorTracker.requestFailed(cause); - } - } - catch (Error e) { - onFail.accept(e); - throw e; - } - catch (RuntimeException e) { - onFail.accept(e); - } - } - } - - @Override - public void fatal(Throwable cause) - { - try (SetThreadName ignored = new SetThreadName("TaskInfoFetcher-%s", taskId)) { - onFail.accept(cause); - } - } - - private void updateStats(long currentRequestStartNanos) - { - stats.infoRoundTripMillis(nanosSince(currentRequestStartNanos).toMillis()); - } - - private static boolean isDone(TaskInfo taskInfo) - { - return taskInfo.getTaskStatus().getState().isDone(); - } - - private void scheduleMetadataUpdates(MetadataUpdates metadataUpdateRequests) - { - MetadataUpdates results = metadataManager.getMetadataUpdateResults(session, queryManager, metadataUpdateRequests, taskId.getQueryId()); - executor.execute(() -> sendMetadataUpdates(results)); - } - - private synchronized void sendMetadataUpdates(MetadataUpdates results) - { - TaskStatus taskStatus = getTaskInfo().getTaskStatus(); - - // we already have the final task info - if (isDone(getTaskInfo())) { - stop(); - return; - } - - // outstanding request? - if (metadataUpdateFuture != null && !metadataUpdateFuture.isDone()) { - // this should never happen - return; - } - - byte[] metadataUpdatesJson = metadataUpdatesCodec.toBytes(results); - Request request = setContentTypeHeaders(isBinaryTransportEnabled, preparePost()) - .setUri(uriBuilderFrom(taskStatus.getSelf()).appendPath("metadataresults").build()) - .setBodyGenerator(createStaticBodyGenerator(metadataUpdatesJson)) - .build(); - - errorTracker.startRequest(); - metadataUpdateFuture = httpClient.executeAsync(request, new ResponseHandler() - { - @Override - public Response handleException(Request request, Exception exception) - { - throw propagate(request, exception); - } - - @Override - public Response handle(Request request, Response response) - { - return response; - } - }); - currentRequestStartNanos.set(System.nanoTime()); - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcherWithEventLoop.java b/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcherWithEventLoop.java index c0d3d426f1064..df56d783976eb 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcherWithEventLoop.java +++ b/presto-main/src/main/java/com/facebook/presto/server/remotetask/TaskInfoFetcherWithEventLoop.java @@ -16,16 +16,15 @@ import com.facebook.airlift.http.client.HttpClient; import com.facebook.airlift.http.client.HttpUriBuilder; import com.facebook.airlift.http.client.Request; -import com.facebook.airlift.http.client.Response; import com.facebook.airlift.http.client.ResponseHandler; import com.facebook.airlift.http.client.thrift.ThriftRequestUtils; import com.facebook.airlift.http.client.thrift.ThriftResponseHandler; import com.facebook.airlift.json.Codec; import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; +import com.facebook.airlift.units.Duration; import com.facebook.drift.transport.netty.codec.Protocol; import com.facebook.presto.Session; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; import com.facebook.presto.execution.QueryManager; import com.facebook.presto.execution.StateMachine; import com.facebook.presto.execution.StateMachine.StateChangeListener; @@ -34,7 +33,6 @@ import com.facebook.presto.execution.TaskStatus; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.MetadataManager; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.server.RequestErrorTracker; import com.facebook.presto.server.SimpleHttpResponseCallback; import com.facebook.presto.server.SimpleHttpResponseHandler; @@ -43,7 +41,6 @@ import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.Duration; import io.netty.channel.EventLoop; import java.net.URI; @@ -54,20 +51,16 @@ import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static com.facebook.airlift.http.client.Request.Builder.prepareGet; -import static com.facebook.airlift.http.client.Request.Builder.preparePost; -import static com.facebook.airlift.http.client.ResponseHandlerUtils.propagate; -import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static com.facebook.airlift.units.Duration.nanosSince; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; import static com.facebook.presto.server.RequestErrorTracker.taskRequestErrorTracker; import static com.facebook.presto.server.RequestHelpers.setContentTypeHeaders; -import static com.facebook.presto.server.TaskResourceUtils.convertFromThriftTaskInfo; import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; import static com.facebook.presto.server.smile.FullSmileResponseHandler.createFullSmileResponseHandler; import static com.facebook.presto.server.thrift.ThriftCodecWrapper.unwrapThriftCodec; import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; import static com.google.common.base.Verify.verify; -import static io.airlift.units.Duration.nanosSince; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -79,7 +72,6 @@ public class TaskInfoFetcherWithEventLoop private final StateMachine taskInfo; private final StateMachine> finalTaskInfo; private final Codec taskInfoCodec; - private final Codec metadataUpdatesCodec; private final long updateIntervalMillis; private final Duration taskInfoRefreshMaxWait; @@ -97,7 +89,6 @@ public class TaskInfoFetcherWithEventLoop private ScheduledFuture scheduledFuture; private ListenableFuture> future; - private ListenableFuture metadataUpdateFuture; private final boolean isBinaryTransportEnabled; private final boolean isThriftTransportEnabled; @@ -105,7 +96,6 @@ public class TaskInfoFetcherWithEventLoop private final MetadataManager metadataManager; private final QueryManager queryManager; private final HandleResolver handleResolver; - private final ConnectorTypeSerdeManager connectorTypeSerdeManager; private final Protocol thriftProtocol; public TaskInfoFetcherWithEventLoop( @@ -115,7 +105,6 @@ public TaskInfoFetcherWithEventLoop( Duration updateInterval, Duration taskInfoRefreshMaxWait, Codec taskInfoCodec, - Codec metadataUpdatesCodec, Duration maxErrorDuration, boolean summarizeTaskInfo, EventLoop taskEventLoop, @@ -126,7 +115,6 @@ public TaskInfoFetcherWithEventLoop( MetadataManager metadataManager, QueryManager queryManager, HandleResolver handleResolver, - ConnectorTypeSerdeManager connectorTypeSerdeManager, Protocol thriftProtocol) { requireNonNull(initialTask, "initialTask is null"); @@ -137,8 +125,6 @@ public TaskInfoFetcherWithEventLoop( this.finalTaskInfo = new StateMachine<>("task-" + taskId, taskEventLoop, Optional.empty()); this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null"); - this.metadataUpdatesCodec = requireNonNull(metadataUpdatesCodec, "metadataUpdatesCodec is null"); - this.updateIntervalMillis = requireNonNull(updateInterval, "updateInterval is null").toMillis(); this.taskInfoRefreshMaxWait = requireNonNull(taskInfoRefreshMaxWait, "taskInfoRefreshMaxWait is null"); this.errorTracker = taskRequestErrorTracker(taskId, initialTask.getTaskStatus().getSelf(), maxErrorDuration, taskEventLoop, "getting info for task"); @@ -154,7 +140,6 @@ public TaskInfoFetcherWithEventLoop( this.metadataManager = requireNonNull(metadataManager, "metadataManager is null"); this.queryManager = requireNonNull(queryManager, "queryManager is null"); this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); - this.connectorTypeSerdeManager = requireNonNull(connectorTypeSerdeManager, "connectorTypeSerdeManager is null"); this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null"); } @@ -259,11 +244,6 @@ private void sendNextRequest() return; } - MetadataUpdates metadataUpdateRequests = taskInfo.getMetadataUpdates(); - if (!metadataUpdateRequests.getMetadataUpdates().isEmpty()) { - scheduleMetadataUpdates(metadataUpdateRequests); - } - HttpUriBuilder httpUriBuilder = uriBuilderFrom(taskStatus.getSelf()); URI uri = summarizeTaskInfo ? httpUriBuilder.addParameter("summarize").build() : httpUriBuilder.build(); Request.Builder requestBuilder = setContentTypeHeaders(isBinaryTransportEnabled, prepareGet()); @@ -335,9 +315,6 @@ public void success(TaskInfo newValue) startNanos = this.currentRequestStartNanos; updateStats(startNanos); errorTracker.requestSucceeded(); - if (isThriftTransportEnabled) { - newValue = convertFromThriftTaskInfo(newValue, connectorTypeSerdeManager, handleResolver); - } updateTaskInfo(newValue); } @@ -380,51 +357,4 @@ private static boolean isDone(TaskInfo taskInfo) { return taskInfo.getTaskStatus().getState().isDone(); } - - private void scheduleMetadataUpdates(MetadataUpdates metadataUpdateRequests) - { - MetadataUpdates results = metadataManager.getMetadataUpdateResults(session, queryManager, metadataUpdateRequests, taskId.getQueryId()); - taskEventLoop.execute(() -> sendMetadataUpdates(results)); - } - - private void sendMetadataUpdates(MetadataUpdates results) - { - verify(taskEventLoop.inEventLoop()); - TaskStatus taskStatus = getTaskInfo().getTaskStatus(); - - // we already have the final task info - if (isDone(getTaskInfo())) { - stop(); - return; - } - - // outstanding request? - if (metadataUpdateFuture != null && !metadataUpdateFuture.isDone()) { - // this should never happen - return; - } - - byte[] metadataUpdatesJson = metadataUpdatesCodec.toBytes(results); - Request request = setContentTypeHeaders(isBinaryTransportEnabled, preparePost()) - .setUri(uriBuilderFrom(taskStatus.getSelf()).appendPath("metadataresults").build()) - .setBodyGenerator(createStaticBodyGenerator(metadataUpdatesJson)) - .build(); - - errorTracker.startRequest(); - metadataUpdateFuture = httpClient.executeAsync(request, new ResponseHandler() - { - @Override - public Response handleException(Request request, Exception exception) - { - throw propagate(request, exception); - } - - @Override - public Response handle(Request request, Response response) - { - return response; - } - }); - currentRequestStartNanos = System.nanoTime(); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 6914f8d7b03f5..17d462cb1b822 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -16,6 +16,7 @@ import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.airlift.http.server.Authenticator; import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.server.security.oauth2.OAuth2Authenticator; import com.facebook.presto.spi.ClientRequestFilter; import com.facebook.presto.spi.PrestoException; import com.google.common.base.Joiner; @@ -24,17 +25,16 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.net.HttpHeaders; - -import javax.inject.Inject; -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletRequestWrapper; -import javax.servlet.http.HttpServletResponse; +import jakarta.inject.Inject; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.FilterConfig; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.InputStream; @@ -46,16 +46,20 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import static com.facebook.presto.server.WebUiResource.UI_ENDPOINT; +import static com.facebook.presto.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT; +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchangeResource.TOKEN_ENDPOINT; import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8; +import static jakarta.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static java.util.Collections.enumeration; import static java.util.Collections.list; import static java.util.Objects.requireNonNull; -import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; public class AuthenticationFilter implements Filter @@ -65,15 +69,39 @@ public class AuthenticationFilter private final boolean allowForwardedHttps; private final ClientRequestFilterManager clientRequestFilterManager; private final List headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); + private final WebUiAuthenticationManager webUiAuthenticationManager; + private final boolean isOauth2Enabled; @Inject - public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) + public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager, WebUiAuthenticationManager webUiAuthenticationManager) { this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); + this.webUiAuthenticationManager = requireNonNull(webUiAuthenticationManager, "webUiAuthenticationManager is null"); + this.isOauth2Enabled = this.authenticators.stream() + .anyMatch(a -> a.getClass().equals(OAuth2Authenticator.class)); this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps(); this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); } + public static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) + { + requireNonNull(principal, "principal is null"); + return new HttpServletRequestWrapper(request) + { + @Override + public Principal getUserPrincipal() + { + return principal; + } + }; + } + + private static boolean isRequestToOAuthEndpoint(HttpServletRequest request) + { + return request.getPathInfo().startsWith(TOKEN_ENDPOINT) + || request.getPathInfo().startsWith(CALLBACK_ENDPOINT); + } + @Override public void init(FilterConfig filterConfig) {} @@ -87,6 +115,13 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; + // Check if it's a request going to the web UI side. + if (isWebUiRequest(request) && isOauth2Enabled) { + // call web authenticator + this.webUiAuthenticationManager.handleRequest(request, response, nextFilter); + return; + } + // skip authentication if non-secure or not configured if (!doesRequestSupportAuthentication(request)) { nextFilter.doFilter(request, response); @@ -109,6 +144,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo e.getAuthenticateHeader().ifPresent(authenticateHeaders::add); continue; } + // authentication succeeded HttpServletRequest wrappedRequest = mergeExtraHeaders(request, principal); nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response); @@ -118,6 +154,11 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo // authentication failed skipRequestBody(request); + // Browsers have special handling for the BASIC challenge authenticate header so we need to filter them out if the WebUI Oauth Token is present. + if (isOauth2Enabled && OAuth2Authenticator.extractTokenFromCookie(request).isPresent()) { + authenticateHeaders = authenticateHeaders.stream().filter(value -> value.contains("x_token_server")).collect(Collectors.toSet()); + } + for (String value : authenticateHeaders) { response.addHeader(WWW_AUTHENTICATE, value); } @@ -133,7 +174,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo // Clients should use the response body rather than the HTTP status // message (which does not exist with HTTP/2), but the status message // still needs to be sent for compatibility with existing clients. - response.setStatus(SC_UNAUTHORIZED, error); + response.setStatus(SC_UNAUTHORIZED); response.setContentType(PLAIN_TEXT_UTF_8.toString()); try (PrintWriter writer = response.getWriter()) { writer.write(error); @@ -183,6 +224,9 @@ public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principa private boolean doesRequestSupportAuthentication(HttpServletRequest request) { + if (isRequestToOAuthEndpoint(request)) { + return false; + } if (authenticators.isEmpty()) { return false; } @@ -195,19 +239,6 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request) return false; } - private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) - { - requireNonNull(principal, "principal is null"); - return new HttpServletRequestWrapper(request) - { - @Override - public Principal getUserPrincipal() - { - return principal; - } - }; - } - private static void skipRequestBody(HttpServletRequest request) throws IOException { @@ -222,6 +253,12 @@ private static void skipRequestBody(HttpServletRequest request) } } + private boolean isWebUiRequest(HttpServletRequest request) + { + String pathInfo = request.getPathInfo(); + return pathInfo == null || pathInfo.equals(UI_ENDPOINT) || pathInfo.startsWith("/ui"); + } + public static class ModifiedHttpServletRequest extends HttpServletRequestWrapper { diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/CustomPrestoAuthenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/CustomPrestoAuthenticator.java index ddaad5459974b..6a6b192520e6d 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/CustomPrestoAuthenticator.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/CustomPrestoAuthenticator.java @@ -15,10 +15,11 @@ import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.log.Logger; import com.facebook.presto.spi.security.AccessDeniedException; - -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; +import com.facebook.presto.spi.security.AuthenticatorNotApplicableException; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; import java.security.Principal; import java.util.List; @@ -31,6 +32,8 @@ public class CustomPrestoAuthenticator implements Authenticator { + private static final Logger log = Logger.get(CustomPrestoAuthenticator.class); + private PrestoAuthenticatorManager authenticatorManager; @Inject @@ -50,11 +53,21 @@ public Principal authenticate(HttpServletRequest request) // Passing the header map to the authenticator (instead of HttpServletRequest) return authenticatorManager.getAuthenticator().createAuthenticatedPrincipal(headers); } + catch (AuthenticatorNotApplicableException e) { + // Presto will gracefully handle this exception and will not propagate it back to the client + log.debug(e, e.getMessage()); + throw needAuthentication(); + } catch (AccessDeniedException e) { throw new AuthenticationException(e.getMessage()); } } + private static AuthenticationException needAuthentication() + { + return new AuthenticationException(null); + } + // Utility method to extract headers from HttpServletRequest private Map> getHeadersMap(HttpServletRequest request) { diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/DefaultWebUiAuthenticationManager.java b/presto-main/src/main/java/com/facebook/presto/server/security/DefaultWebUiAuthenticationManager.java new file mode 100644 index 0000000000000..75ff4d072d711 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/DefaultWebUiAuthenticationManager.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public class DefaultWebUiAuthenticationManager + implements WebUiAuthenticationManager +{ + @Override + public void handleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain nextFilter) + throws IOException, ServletException + { + nextFilter.doFilter(request, response); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/InternalAuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/InternalAuthenticationFilter.java index fadb1b405a0fb..fa5d218f0e3cd 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/InternalAuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/InternalAuthenticationFilter.java @@ -14,15 +14,14 @@ package com.facebook.presto.server.security; import com.facebook.presto.server.InternalAuthenticationManager; - -import javax.annotation.security.RolesAllowed; -import javax.inject.Inject; -import javax.ws.rs.container.ContainerRequestContext; -import javax.ws.rs.container.ContainerRequestFilter; -import javax.ws.rs.container.ResourceInfo; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.Response; -import javax.ws.rs.ext.Provider; +import jakarta.annotation.security.RolesAllowed; +import jakarta.inject.Inject; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.container.ResourceInfo; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.ext.Provider; import java.lang.reflect.Method; import java.security.Principal; @@ -30,9 +29,9 @@ import java.util.Optional; import static com.facebook.presto.server.security.RoleType.INTERNAL; +import static jakarta.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; +import static jakarta.ws.rs.core.Response.ResponseBuilder; import static java.util.Objects.requireNonNull; -import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; -import static javax.ws.rs.core.Response.ResponseBuilder; @Provider public class InternalAuthenticationFilter diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java index a8a209762f81d..310e9f76fa684 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenAuthenticator.java @@ -31,10 +31,10 @@ import io.jsonwebtoken.SigningKeyResolver; import io.jsonwebtoken.UnsupportedJwtException; import io.jsonwebtoken.jackson.io.JacksonDeserializer; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; import javax.crypto.spec.SecretKeySpec; -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; import java.io.File; import java.io.IOException; diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java index 1eaddacf1a1cb..1c6694e8a67e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/JsonWebTokenConfig.java @@ -14,8 +14,7 @@ package com.facebook.presto.server.security; import com.facebook.airlift.configuration.Config; - -import javax.validation.constraints.NotNull; +import jakarta.validation.constraints.NotNull; public class JsonWebTokenConfig { diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/PasswordAuthenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/PasswordAuthenticator.java index eef467b893e2f..db680542a1408 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/PasswordAuthenticator.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/PasswordAuthenticator.java @@ -17,9 +17,8 @@ import com.facebook.airlift.http.server.Authenticator; import com.facebook.presto.spi.security.AccessDeniedException; import com.google.common.base.Splitter; - -import javax.inject.Inject; -import javax.servlet.http.HttpServletRequest; +import jakarta.inject.Inject; +import jakarta.servlet.http.HttpServletRequest; import java.security.Principal; import java.util.Base64; diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java index cd7943e40b313..6fe27fa7b7e96 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/ServerSecurityModule.java @@ -15,13 +15,22 @@ import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.http.server.Authorizer; import com.facebook.airlift.http.server.CertificateAuthenticator; +import com.facebook.airlift.http.server.ConfigurationBasedAuthorizer; +import com.facebook.airlift.http.server.ConfigurationBasedAuthorizerConfig; import com.facebook.airlift.http.server.KerberosAuthenticator; import com.facebook.airlift.http.server.KerberosConfig; +import com.facebook.airlift.http.server.TheServlet; import com.facebook.presto.server.security.SecurityConfig.AuthenticationType; +import com.facebook.presto.server.security.oauth2.OAuth2AuthenticationSupportModule; +import com.facebook.presto.server.security.oauth2.OAuth2Authenticator; +import com.facebook.presto.server.security.oauth2.OAuth2Config; +import com.facebook.presto.server.security.oauth2.OAuth2WebUiAuthenticationManager; import com.google.inject.Binder; import com.google.inject.Scopes; import com.google.inject.multibindings.Multibinder; +import jakarta.servlet.Filter; import java.util.List; @@ -30,8 +39,11 @@ import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.CUSTOM; import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.JWT; import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.KERBEROS; +import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.OAUTH2; import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.PASSWORD; +import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.TEST_EXTERNAL; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; public class ServerSecurityModule extends AbstractConfigurationAwareModule @@ -39,6 +51,10 @@ public class ServerSecurityModule @Override protected void setup(Binder binder) { + newOptionalBinder(binder, WebUiAuthenticationManager.class).setDefault().to(DefaultWebUiAuthenticationManager.class).in(Scopes.SINGLETON); + newSetBinder(binder, Filter.class, TheServlet.class).addBinding() + .to(AuthenticationFilter.class).in(Scopes.SINGLETON); + binder.bind(PasswordAuthenticatorManager.class).in(Scopes.SINGLETON); binder.bind(PrestoAuthenticatorManager.class).in(Scopes.SINGLETON); @@ -60,11 +76,25 @@ else if (authType == JWT) { configBinder(binder).bindConfig(JsonWebTokenConfig.class); authBinder.addBinding().to(JsonWebTokenAuthenticator.class).in(Scopes.SINGLETON); } + else if (authType == OAUTH2) { + newOptionalBinder(binder, WebUiAuthenticationManager.class).setBinding().to(OAuth2WebUiAuthenticationManager.class).in(Scopes.SINGLETON); + install(new OAuth2AuthenticationSupportModule()); + binder.bind(OAuth2Authenticator.class).in(Scopes.SINGLETON); + configBinder(binder).bindConfig(OAuth2Config.class); + authBinder.addBinding().to(OAuth2Authenticator.class).in(Scopes.SINGLETON); + + configBinder(binder).bindConfig(ConfigurationBasedAuthorizerConfig.class); + binder.bind(Authorizer.class).to(ConfigurationBasedAuthorizer.class).in(Scopes.SINGLETON); + } else if (authType == CUSTOM) { authBinder.addBinding().to(CustomPrestoAuthenticator.class).in(Scopes.SINGLETON); } else { - throw new AssertionError("Unhandled auth type: " + authType); + // TEST_EXTERNAL is an authentication type used for testing the external auth flow for the JDBC driver. + // This is here as a guard since it's not a real authenticator but if I exclude it from the checks then teh error is thrown. + if (authType != TEST_EXTERNAL) { + throw new AssertionError("Unhandled auth type: " + authType); + } } } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java b/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java index 950f5a19d1ccc..f6c8ed6af3ae5 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/ServletSecurityUtils.java @@ -14,8 +14,7 @@ package com.facebook.presto.server.security; import com.facebook.presto.spi.security.AuthorizedIdentity; - -import javax.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequest; import java.util.Optional; diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/WebUiAuthenticationManager.java b/presto-main/src/main/java/com/facebook/presto/server/security/WebUiAuthenticationManager.java new file mode 100644 index 0000000000000..322d416168942 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/WebUiAuthenticationManager.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public interface WebUiAuthenticationManager +{ + void handleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain nextFilter) + throws IOException, ServletException; +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ChallengeFailedException.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ChallengeFailedException.java new file mode 100644 index 0000000000000..f72a7c7c537ea --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ChallengeFailedException.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +public class ChallengeFailedException + extends Exception +{ + public ChallengeFailedException(String message) + { + super(message); + } + + public ChallengeFailedException(String message, Throwable cause) + { + super(message, cause); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/ForJsonMetadataUpdateHandle.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForOAuth2.java similarity index 86% rename from presto-main-base/src/main/java/com/facebook/presto/server/ForJsonMetadataUpdateHandle.java rename to presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForOAuth2.java index 4306bc7f94198..ee66543301512 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/ForJsonMetadataUpdateHandle.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForOAuth2.java @@ -11,9 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.server; +package com.facebook.presto.server.security.oauth2; -import javax.inject.Qualifier; +import com.google.inject.BindingAnnotation; import java.lang.annotation.Retention; import java.lang.annotation.Target; @@ -25,7 +25,7 @@ @Retention(RUNTIME) @Target({FIELD, PARAMETER, METHOD}) -@Qualifier -public @interface ForJsonMetadataUpdateHandle +@BindingAnnotation +public @interface ForOAuth2 { } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForRefreshTokens.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForRefreshTokens.java new file mode 100644 index 0000000000000..d4d3939c61320 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ForRefreshTokens.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({FIELD, PARAMETER, METHOD}) +@BindingAnnotation +public @interface ForRefreshTokens +{ +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java new file mode 100644 index 0000000000000..45ded8e0ed84c --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java @@ -0,0 +1,195 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.units.Duration; +import com.nimbusds.jose.EncryptionMethod; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWEDecrypter; +import com.nimbusds.jose.JWEEncrypter; +import com.nimbusds.jose.JWEHeader; +import com.nimbusds.jose.JWEObject; +import com.nimbusds.jose.KeyLengthException; +import com.nimbusds.jose.Payload; +import com.nimbusds.jose.crypto.AESDecrypter; +import com.nimbusds.jose.crypto.AESEncrypter; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.CompressionCodec; +import io.jsonwebtoken.CompressionException; +import io.jsonwebtoken.Header; +import io.jsonwebtoken.JwtBuilder; +import io.jsonwebtoken.JwtParser; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; + +import java.security.NoSuchAlgorithmException; +import java.text.ParseException; +import java.time.Clock; +import java.util.Date; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.server.security.oauth2.JwtUtil.newJwtBuilder; +import static com.facebook.presto.server.security.oauth2.JwtUtil.newJwtParserBuilder; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class JweTokenSerializer + implements TokenPairSerializer +{ + private final JWEHeader encryptionHeader; + + private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec(); + private static final String ACCESS_TOKEN_KEY = "access_token"; + private static final String EXPIRATION_TIME_KEY = "expiration_time"; + private static final String REFRESH_TOKEN_KEY = "refresh_token"; + private final OAuth2Client client; + private final Clock clock; + private final String issuer; + private final String audience; + private final Duration tokenExpiration; + private final JwtParser parser; + private final JWEEncrypter jweEncrypter; + private final JWEDecrypter jweDecrypter; + private final String principalField; + + public JweTokenSerializer( + RefreshTokensConfig config, + OAuth2Client client, + String issuer, + String audience, + String principalField, + Clock clock, + Duration tokenExpiration) + throws KeyLengthException, NoSuchAlgorithmException + { + SecretKey secretKey = createKey(requireNonNull(config, "config is null")); + this.jweEncrypter = new AESEncrypter(secretKey); + this.jweDecrypter = new AESDecrypter(secretKey); + this.client = requireNonNull(client, "client is null"); + this.issuer = requireNonNull(issuer, "issuer is null"); + this.principalField = requireNonNull(principalField, "principalField is null"); + this.audience = requireNonNull(audience, "issuer is null"); + this.clock = requireNonNull(clock, "clock is null"); + this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null"); + this.encryptionHeader = createEncryptionHeader(secretKey); + + this.parser = newJwtParserBuilder() + .setClock(() -> Date.from(clock.instant())) + .requireIssuer(this.issuer) + .requireAudience(this.audience) + .setCompressionCodecResolver(JweTokenSerializer::resolveCompressionCodec) + .build(); + } + + private JWEHeader createEncryptionHeader(SecretKey key) + { + int keyLength = key.getEncoded().length; + switch (keyLength) { + case 16: + return new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM); + case 24: + return new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM); + case 32: + return new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM); + default: + throw new IllegalArgumentException( + String.format("Secret key size must be either 16, 24 or 32 bytes but was %d", keyLength)); + } + } + + @Override + public TokenPair deserialize(String token) + { + requireNonNull(token, "token is null"); + try { + JWEObject jwe = JWEObject.parse(token); + jwe.decrypt(jweDecrypter); + Claims claims = parser.parseClaimsJwt(jwe.getPayload().toString()).getBody(); + return TokenPair.accessAndRefreshTokens( + claims.get(ACCESS_TOKEN_KEY, String.class), + claims.get(EXPIRATION_TIME_KEY, Date.class), + claims.get(REFRESH_TOKEN_KEY, String.class)); + } + catch (ParseException ex) { + throw new IllegalArgumentException("Malformed jwt token", ex); + } + catch (JOSEException ex) { + throw new IllegalArgumentException("Decryption failed", ex); + } + } + + @Override + public String serialize(TokenPair tokenPair) + { + requireNonNull(tokenPair, "tokenPair is null"); + + Optional> accessTokenClaims = client.getClaims(tokenPair.getAccessToken()); + if (!accessTokenClaims.isPresent()) { + throw new IllegalArgumentException("Claims are missing"); + } + Map claims = accessTokenClaims.get(); + if (!claims.containsKey(principalField)) { + throw new IllegalArgumentException(format("%s field is missing", principalField)); + } + JwtBuilder jwt = newJwtBuilder() + .setExpiration(Date.from(clock.instant().plusMillis(tokenExpiration.toMillis()))) + .claim(principalField, claims.get(principalField).toString()) + .setAudience(audience) + .setIssuer(issuer) + .claim(ACCESS_TOKEN_KEY, tokenPair.getAccessToken()) + .claim(EXPIRATION_TIME_KEY, tokenPair.getExpiration()) + .claim(REFRESH_TOKEN_KEY, tokenPair.getRefreshToken().orElseThrow(JweTokenSerializer::throwExceptionForNonExistingRefreshToken)) + .compressWith(COMPRESSION_CODEC); + + try { + JWEObject jwe = new JWEObject(encryptionHeader, new Payload(jwt.compact())); + jwe.encrypt(jweEncrypter); + return jwe.serialize(); + } + catch (JOSEException ex) { + throw new IllegalStateException("Encryption failed", ex); + } + } + + private static SecretKey createKey(RefreshTokensConfig config) + throws NoSuchAlgorithmException + { + SecretKey signingKey = config.getSecretKey(); + if (signingKey == null) { + KeyGenerator generator = KeyGenerator.getInstance("AES"); + generator.init(256); + return generator.generateKey(); + } + return signingKey; + } + + private static RuntimeException throwExceptionForNonExistingRefreshToken() + { + throw new IllegalStateException("Expected refresh token to be present. Please check your identity provider setup, or disable refresh tokens"); + } + + private static CompressionCodec resolveCompressionCodec(Header header) + throws CompressionException + { + if (header.getCompressionAlgorithm() != null) { + checkState(header.getCompressionAlgorithm().equals(ZstdCodec.CODEC_NAME), "Unknown codec '%s' used for token compression", header.getCompressionAlgorithm()); + return COMPRESSION_CODEC; + } + return null; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializerModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializerModule.java new file mode 100644 index 0000000000000..1819f79a33137 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializerModule.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.presto.client.NodeVersion; +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Key; +import com.google.inject.Provides; +import com.google.inject.Singleton; +import com.nimbusds.jose.KeyLengthException; + +import java.security.NoSuchAlgorithmException; +import java.time.Clock; +import java.time.Duration; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; + +public class JweTokenSerializerModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(RefreshTokensConfig.class); + RefreshTokensConfig config = buildConfigObject(RefreshTokensConfig.class); + newOptionalBinder(binder, Key.get(Duration.class, ForRefreshTokens.class)).setBinding().toInstance(Duration.ofMillis(config.getTokenExpiration().toMillis())); + } + + @Provides + @Singleton + @Inject + public TokenPairSerializer getTokenPairSerializer( + OAuth2Client client, + NodeVersion nodeVersion, + RefreshTokensConfig config, + OAuth2Config oAuth2Config) + throws KeyLengthException, NoSuchAlgorithmException + { + return new JweTokenSerializer( + config, + client, + config.getIssuer() + "_" + nodeVersion.getVersion(), + config.getAudience(), + oAuth2Config.getPrincipalField(), + Clock.systemUTC(), + config.getTokenExpiration()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JwtUtil.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JwtUtil.java new file mode 100644 index 0000000000000..c2ea7114fd322 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JwtUtil.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import io.jsonwebtoken.JwtBuilder; +import io.jsonwebtoken.JwtParserBuilder; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.io.Deserializer; +import io.jsonwebtoken.io.Serializer; +import io.jsonwebtoken.jackson.io.JacksonDeserializer; +import io.jsonwebtoken.jackson.io.JacksonSerializer; + +import java.util.Map; + +// avoid reflection and services lookup +public final class JwtUtil +{ + private static final Serializer> JWT_SERIALIZER = new JacksonSerializer<>(); + private static final Deserializer> JWT_DESERIALIZER = new JacksonDeserializer<>(); + + private JwtUtil() {} + + public static JwtBuilder newJwtBuilder() + { + return Jwts.builder() + .serializeToJsonWith(JWT_SERIALIZER); + } + + public static JwtParserBuilder newJwtParserBuilder() + { + return Jwts.parserBuilder() + .deserializeJsonWith(JWT_DESERIALIZER); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusAirliftHttpClient.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusAirliftHttpClient.java new file mode 100644 index 0000000000000..3aeeb4046f216 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusAirliftHttpClient.java @@ -0,0 +1,137 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.client.HttpClient; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.http.client.Response; +import com.facebook.airlift.http.client.ResponseHandler; +import com.facebook.airlift.http.client.ResponseHandlerUtils; +import com.facebook.airlift.http.client.StringResponseHandler; +import com.google.common.collect.ImmutableMultimap; +import com.nimbusds.jose.util.Resource; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.http.HTTPRequest; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import jakarta.ws.rs.core.UriBuilder; + +import javax.inject.Inject; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; + +import static com.facebook.airlift.http.client.Request.Builder.prepareGet; +import static com.facebook.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static com.facebook.airlift.http.client.StringResponseHandler.createStringResponseHandler; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.DELETE; +import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.GET; +import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.POST; +import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.PUT; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +public class NimbusAirliftHttpClient + implements NimbusHttpClient +{ + private final HttpClient httpClient; + + @Inject + public NimbusAirliftHttpClient(@ForOAuth2 HttpClient httpClient) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + } + + @Override + public Resource retrieveResource(URL url) + throws IOException + { + try { + StringResponseHandler.StringResponse response = httpClient.execute( + prepareGet().setUri(url.toURI()).build(), + createStringResponseHandler()); + return new Resource(response.getBody(), response.getHeader(CONTENT_TYPE)); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public T execute(com.nimbusds.oauth2.sdk.Request nimbusRequest, Parser parser) + { + HTTPRequest httpRequest = nimbusRequest.toHTTPRequest(); + HTTPRequest.Method method = httpRequest.getMethod(); + + Request.Builder request = new Request.Builder() + .setMethod(method.name()) + .setFollowRedirects(httpRequest.getFollowRedirects()); + + UriBuilder url = UriBuilder.fromUri(httpRequest.getURI()); + if (method.equals(GET) || method.equals(DELETE)) { + httpRequest.getQueryParameters().forEach((key, value) -> url.queryParam(key, value.toArray())); + } + + url.fragment(httpRequest.getFragment()); + + request.setUri(url.build()); + + ImmutableMultimap.Builder headers = ImmutableMultimap.builder(); + httpRequest.getHeaderMap().forEach(headers::putAll); + request.addHeaders(headers.build()); + + if (method.equals(POST) || method.equals(PUT)) { + String query = httpRequest.getQuery(); + if (query != null) { + request.setBodyGenerator(createStaticBodyGenerator(httpRequest.getQuery(), UTF_8)); + } + } + return httpClient.execute(request.build(), new NimbusResponseHandler<>(parser)); + } + + public static class NimbusResponseHandler + implements ResponseHandler + { + private final StringResponseHandler handler = createStringResponseHandler(); + private final Parser parser; + + public NimbusResponseHandler(Parser parser) + { + this.parser = requireNonNull(parser, "parser is null"); + } + + @Override + public T handleException(Request request, Exception exception) + { + throw ResponseHandlerUtils.propagate(request, exception); + } + + @Override + public T handle(Request request, Response response) + { + StringResponseHandler.StringResponse stringResponse = handler.handle(request, response); + HTTPResponse nimbusResponse = new HTTPResponse(response.getStatusCode()); + response.getHeaders().asMap().forEach((name, values) -> nimbusResponse.setHeader(name.toString(), values.toArray(new String[0]))); + nimbusResponse.setContent(stringResponse.getBody()); + try { + return parser.parse(nimbusResponse); + } + catch (ParseException e) { + throw new RuntimeException(format("Unable to parse response status=[%d], body=[%s]", stringResponse.getStatusCode(), stringResponse.getBody()), e); + } + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusHttpClient.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusHttpClient.java new file mode 100644 index 0000000000000..b8c0e491dec4b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusHttpClient.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.nimbusds.jose.util.ResourceRetriever; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.Request; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; + +public interface NimbusHttpClient + extends ResourceRetriever +{ + T execute(Request request, Parser parser); + + interface Parser + { + T parse(HTTPResponse response) + throws ParseException; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java new file mode 100644 index 0000000000000..86b2fb96b7d8f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NimbusOAuth2Client.java @@ -0,0 +1,489 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.server.security.oauth2.OAuth2ServerConfigProvider.OAuth2ServerConfig; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.JWSHeader; +import com.nimbusds.jose.jwk.source.RemoteJWKSet; +import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; +import com.nimbusds.jwt.proc.JWTProcessor; +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.AuthorizationRequest; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; +import com.nimbusds.oauth2.sdk.token.RefreshToken; +import com.nimbusds.oauth2.sdk.token.Tokens; +import com.nimbusds.openid.connect.sdk.AuthenticationRequest; +import com.nimbusds.openid.connect.sdk.Nonce; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.UserInfoErrorResponse; +import com.nimbusds.openid.connect.sdk.UserInfoRequest; +import com.nimbusds.openid.connect.sdk.UserInfoResponse; +import com.nimbusds.openid.connect.sdk.UserInfoSuccessResponse; +import com.nimbusds.openid.connect.sdk.claims.AccessTokenHash; +import com.nimbusds.openid.connect.sdk.claims.IDTokenClaimsSet; +import com.nimbusds.openid.connect.sdk.token.OIDCTokens; +import com.nimbusds.openid.connect.sdk.validators.AccessTokenValidator; +import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator; +import com.nimbusds.openid.connect.sdk.validators.InvalidHashException; +import net.minidev.json.JSONObject; + +import javax.inject.Inject; + +import java.net.MalformedURLException; +import java.net.URI; +import java.time.Instant; +import java.util.Collections; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.hash.Hashing.sha256; +import static com.nimbusds.oauth2.sdk.ResponseType.CODE; +import static com.nimbusds.openid.connect.sdk.OIDCScopeValue.OPENID; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class NimbusOAuth2Client + implements OAuth2Client +{ + private static final Logger LOG = Logger.get(NimbusAirliftHttpClient.class); + + private final Issuer issuer; + private final ClientID clientId; + private final ClientSecretBasic clientAuth; + private final Scope scope; + private final String principalField; + private final Set accessTokenAudiences; + private final Duration maxClockSkew; + private final NimbusHttpClient httpClient; + private final OAuth2ServerConfigProvider serverConfigurationProvider; + private volatile boolean loaded; + private URI authUrl; + private URI tokenUrl; + private Optional userinfoUrl; + private JWSKeySelector jwsKeySelector; + private JWTProcessor accessTokenProcessor; + private AuthorizationCodeFlow flow; + + @Inject + public NimbusOAuth2Client(OAuth2Config oauthConfig, OAuth2ServerConfigProvider serverConfigurationProvider, NimbusHttpClient httpClient) + { + requireNonNull(oauthConfig, "oauthConfig is null"); + issuer = new Issuer(oauthConfig.getIssuer()); + clientId = new ClientID(oauthConfig.getClientId()); + clientAuth = new ClientSecretBasic(clientId, new Secret(oauthConfig.getClientSecret())); + scope = Scope.parse(oauthConfig.getScopes()); + principalField = oauthConfig.getPrincipalField(); + maxClockSkew = oauthConfig.getMaxClockSkew(); + + accessTokenAudiences = new HashSet<>(oauthConfig.getAdditionalAudiences()); + accessTokenAudiences.add(clientId.getValue()); + accessTokenAudiences.add(null); // A null value in the set allows JWTs with no audience + + this.serverConfigurationProvider = requireNonNull(serverConfigurationProvider, "serverConfigurationProvider is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + } + + @Override + public void load() + { + OAuth2ServerConfig config = serverConfigurationProvider.get(); + this.authUrl = config.getAuthUrl(); + this.tokenUrl = config.getTokenUrl(); + this.userinfoUrl = config.getUserinfoUrl(); + try { + jwsKeySelector = new JWSVerificationKeySelector<>( + Stream.concat(JWSAlgorithm.Family.RSA.stream(), JWSAlgorithm.Family.EC.stream()).collect(toImmutableSet()), + new RemoteJWKSet<>(config.getJwksUrl().toURL(), httpClient)); + } + catch (MalformedURLException e) { + throw new RuntimeException(e); + } + + DefaultJWTProcessor processor = new DefaultJWTProcessor<>(); + processor.setJWSKeySelector(jwsKeySelector); + DefaultJWTClaimsVerifier accessTokenVerifier = new DefaultJWTClaimsVerifier<>( + accessTokenAudiences, + new JWTClaimsSet.Builder() + .issuer(config.getAccessTokenIssuer().orElse(issuer.getValue())) + .build(), + ImmutableSet.of(principalField), + ImmutableSet.of()); + accessTokenVerifier.setMaxClockSkew((int) maxClockSkew.roundTo(SECONDS)); + processor.setJWTClaimsSetVerifier(accessTokenVerifier); + accessTokenProcessor = processor; + flow = scope.contains(OPENID) ? new OAuth2WithOidcExtensionsCodeFlow() : new OAuth2AuthorizationCodeFlow(); + loaded = true; + } + + @Override + public Request createAuthorizationRequest(String state, URI callbackUri) + { + checkState(loaded, "OAuth2 client not initialized"); + return flow.createAuthorizationRequest(state, callbackUri); + } + + @Override + public Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + throws ChallengeFailedException + { + checkState(loaded, "OAuth2 client not initialized"); + return flow.getOAuth2Response(code, callbackUri, nonce); + } + + @Override + public Optional> getClaims(String accessToken) + { + checkState(loaded, "OAuth2 client not initialized"); + return getJWTClaimsSet(accessToken).map(JWTClaimsSet::getClaims); + } + + @Override + public Response refreshTokens(String refreshToken) + throws ChallengeFailedException + { + checkState(loaded, "OAuth2 client not initialized"); + return flow.refreshTokens(refreshToken); + } + + private interface AuthorizationCodeFlow + { + Request createAuthorizationRequest(String state, URI callbackUri); + + Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + throws ChallengeFailedException; + + Response refreshTokens(String refreshToken) + throws ChallengeFailedException; + } + + private class OAuth2AuthorizationCodeFlow + implements AuthorizationCodeFlow + { + @Override + public Request createAuthorizationRequest(String state, URI callbackUri) + { + return new Request( + new AuthorizationRequest.Builder(CODE, clientId) + .redirectionURI(callbackUri) + .scope(scope) + .endpointURI(authUrl) + .state(new State(state)) + .build() + .toURI(), + Optional.empty()); + } + + @Override + public Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + throws ChallengeFailedException + { + checkArgument(!nonce.isPresent(), "Unexpected nonce provided"); + AccessTokenResponse tokenResponse = getTokenResponse(code, callbackUri, AccessTokenResponse::parse); + Tokens tokens = tokenResponse.toSuccessResponse().getTokens(); + return toResponse(tokens, Optional.empty()); + } + + @Override + public Response refreshTokens(String refreshToken) + throws ChallengeFailedException + { + requireNonNull(refreshToken, "refreshToken is null"); + AccessTokenResponse tokenResponse = getTokenResponse(refreshToken, AccessTokenResponse::parse); + return toResponse(tokenResponse.toSuccessResponse().getTokens(), Optional.of(refreshToken)); + } + + private Response toResponse(Tokens tokens, Optional existingRefreshToken) + throws ChallengeFailedException + { + AccessToken accessToken = tokens.getAccessToken(); + RefreshToken refreshToken = tokens.getRefreshToken(); + JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token")); + return new Response( + accessToken.getValue(), + determineExpiration(getExpiration(accessToken), claims.getExpirationTime()), + buildRefreshToken(refreshToken, existingRefreshToken)); + } + } + + private class OAuth2WithOidcExtensionsCodeFlow + implements AuthorizationCodeFlow + { + private final IDTokenValidator idTokenValidator; + + public OAuth2WithOidcExtensionsCodeFlow() + { + idTokenValidator = new IDTokenValidator(issuer, clientId, jwsKeySelector, null); + idTokenValidator.setMaxClockSkew((int) maxClockSkew.roundTo(SECONDS)); + } + + @Override + public Request createAuthorizationRequest(String state, URI callbackUri) + { + String nonce = new Nonce().getValue(); + return new Request( + new AuthenticationRequest.Builder(CODE, scope, clientId, callbackUri) + .endpointURI(authUrl) + .state(new State(state)) + .nonce(new Nonce(hashNonce(nonce))) + .build() + .toURI(), + Optional.of(nonce)); + } + + @Override + public Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + throws ChallengeFailedException + { + if (!nonce.isPresent()) { + throw new ChallengeFailedException("Missing nonce"); + } + + OIDCTokenResponse tokenResponse = getTokenResponse(code, callbackUri, OIDCTokenResponse::parse); + OIDCTokens tokens = tokenResponse.getOIDCTokens(); + validateTokens(tokens, nonce); + return toResponse(tokens, Optional.empty()); + } + + @Override + public Response refreshTokens(String refreshToken) + throws ChallengeFailedException + { + OIDCTokenResponse tokenResponse = getTokenResponse(refreshToken, OIDCTokenResponse::parse); + OIDCTokens tokens = tokenResponse.getOIDCTokens(); + validateTokens(tokens); + return toResponse(tokens, Optional.of(refreshToken)); + } + + private Response toResponse(OIDCTokens tokens, Optional existingRefreshToken) + throws ChallengeFailedException + { + AccessToken accessToken = tokens.getAccessToken(); + RefreshToken refreshToken = tokens.getRefreshToken(); + JWTClaimsSet claims = getJWTClaimsSet(accessToken.getValue()).orElseThrow(() -> new ChallengeFailedException("invalid access token")); + return new Response( + accessToken.getValue(), + determineExpiration(getExpiration(accessToken), claims.getExpirationTime()), + buildRefreshToken(refreshToken, existingRefreshToken)); + } + + private void validateTokens(OIDCTokens tokens, Optional nonce) + throws ChallengeFailedException + { + try { + IDTokenClaimsSet idToken = idTokenValidator.validate( + tokens.getIDToken(), + nonce.map(this::hashNonce) + .map(Nonce::new) + .orElse(null)); + AccessTokenHash accessTokenHash = idToken.getAccessTokenHash(); + if (accessTokenHash != null) { + AccessTokenValidator.validate(tokens.getAccessToken(), ((JWSHeader) tokens.getIDToken().getHeader()).getAlgorithm(), accessTokenHash); + } + } + catch (BadJOSEException | JOSEException | InvalidHashException e) { + throw new ChallengeFailedException("Cannot validate tokens", e); + } + } + + private void validateTokens(OIDCTokens tokens) + throws ChallengeFailedException + { + validateTokens(tokens, Optional.empty()); + } + + private String hashNonce(String nonce) + { + return sha256() + .hashString(nonce, UTF_8) + .toString(); + } + } + + private T getTokenResponse(String code, URI callbackUri, NimbusAirliftHttpClient.Parser parser) + throws ChallengeFailedException + { + return getTokenResponse(new AuthorizationCodeGrant(new AuthorizationCode(code), callbackUri), parser); + } + + private T getTokenResponse(String refreshToken, NimbusAirliftHttpClient.Parser parser) + throws ChallengeFailedException + { + return getTokenResponse(new RefreshTokenGrant(new RefreshToken(refreshToken)), parser); + } + + private T getTokenResponse(AuthorizationGrant authorizationGrant, NimbusAirliftHttpClient.Parser parser) + throws ChallengeFailedException + { + T tokenResponse = httpClient.execute(new TokenRequest(tokenUrl, clientAuth, authorizationGrant, scope), parser); + if (!tokenResponse.indicatesSuccess()) { + throw new ChallengeFailedException("Error while fetching access token: " + tokenResponse.toErrorResponse().toJSONObject()); + } + return tokenResponse; + } + + private Optional getJWTClaimsSet(String accessToken) + { + if (userinfoUrl.isPresent()) { + return queryUserInfo(accessToken); + } + return parseAccessToken(accessToken); + } + + private Optional queryUserInfo(String accessToken) + { + try { + UserInfoResponse response = httpClient.execute(new UserInfoRequest(userinfoUrl.get(), new BearerAccessToken(accessToken)), this::parse); + if (!response.indicatesSuccess()) { + LOG.error("Received bad response from userinfo endpoint: " + response.toErrorResponse().getErrorObject()); + return Optional.empty(); + } + return Optional.of(response.toSuccessResponse().getUserInfo().toJWTClaimsSet()); + } + catch (ParseException | RuntimeException e) { + LOG.error(e, "Received bad response from userinfo endpoint"); + return Optional.empty(); + } + } + + // Using this parsing method for our /userinfo response from the IdP in order to allow for different principal + // fields as defined, and in the absence of the `sub` claim. This is a "hack" solution to alter the claims + // present in the response before calling the parser provided by the oidc sdk, which fails hard if the + // `sub` claim is missing. Note we also have to offload audience verification to this method since it + // is not handled in the library + public UserInfoResponse parse(HTTPResponse httpResponse) + throws ParseException + { + JSONObject body = httpResponse.getContentAsJSONObject(); + + String principal = (String) body.get(principalField); + if (principal == null) { + throw new ParseException(String.format("/userinfo response missing principal field %s", principalField)); + } + + if (!principalField.equals("sub") && body.get("sub") == null) { + body.put("sub", principal); + httpResponse.setBody(body.toJSONString()); + } + + Object audClaim = body.get("aud"); + List audiences; + + if (audClaim instanceof String) { + audiences = List.of((String) audClaim); + } + else if (audClaim instanceof List) { + audiences = ((List) audClaim).stream() + .filter(String.class::isInstance) + .map(String.class::cast) + .collect(toImmutableList()); + } + else { + throw new ParseException("Unsupported or missing 'aud' claim type in /userinfo response"); + } + + if (!(audiences.contains(clientId.getValue()) || !Collections.disjoint(audiences, accessTokenAudiences))) { + throw new ParseException("Invalid audience in /userinfo response"); + } + + return (httpResponse.getStatusCode() == 200) + ? UserInfoSuccessResponse.parse(httpResponse) + : UserInfoErrorResponse.parse(httpResponse); + } + + private Optional parseAccessToken(String accessToken) + { + try { + return Optional.of(accessTokenProcessor.process(accessToken, null)); + } + catch (java.text.ParseException | BadJOSEException | JOSEException e) { + LOG.error(e, "Failed to parse JWT access token"); + return Optional.empty(); + } + } + + private static Instant determineExpiration(Optional validUntil, Date expiration) + throws ChallengeFailedException + { + if (validUntil.isPresent()) { + if (expiration != null) { + return Ordering.natural().min(validUntil.get(), expiration.toInstant()); + } + + return validUntil.get(); + } + + if (expiration != null) { + return expiration.toInstant(); + } + + throw new ChallengeFailedException("no valid expiration date"); + } + + private Optional buildRefreshToken(RefreshToken refreshToken, Optional existingRefreshToken) + { + Optional firstOption = Optional.ofNullable(refreshToken) + .map(RefreshToken::getValue); + + if (firstOption.isPresent()) { + return firstOption; + } + else if (existingRefreshToken.isPresent()) { + return existingRefreshToken; + } + else { + return Optional.empty(); + } + } + + private static Optional getExpiration(AccessToken accessToken) + { + return accessToken.getLifetime() != 0 ? Optional.of(Instant.now().plusSeconds(accessToken.getLifetime())) : Optional.empty(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NonceCookie.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NonceCookie.java new file mode 100644 index 0000000000000..d80e2cacc6941 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/NonceCookie.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import jakarta.ws.rs.core.Cookie; +import jakarta.ws.rs.core.NewCookie; +import org.apache.commons.lang3.StringUtils; + +import java.time.Instant; +import java.util.Date; +import java.util.Optional; + +import static com.facebook.presto.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT; +import static com.google.common.base.Predicates.not; +import static jakarta.ws.rs.core.Cookie.DEFAULT_VERSION; +import static jakarta.ws.rs.core.NewCookie.DEFAULT_MAX_AGE; + +public final class NonceCookie +{ + // prefix according to: https://tools.ietf.org/html/draft-ietf-httpbis-rfc6265bis-05#section-4.1.3.1 + public static final String NONCE_COOKIE = "__Secure-Presto-Nonce"; + + private NonceCookie() {} + + public static NewCookie create(String nonce, Instant tokenExpiration) + { + return new NewCookie( + NONCE_COOKIE, + nonce, + CALLBACK_ENDPOINT, + null, + DEFAULT_VERSION, + null, + DEFAULT_MAX_AGE, + Date.from(tokenExpiration), + true, + true); + } + + public static jakarta.servlet.http.Cookie createServletCookie(String nonce, Instant tokenExpiration) + { + return toServletCookie(create(nonce, tokenExpiration)); + } + + public static jakarta.servlet.http.Cookie toServletCookie(NewCookie cookie) + { + jakarta.servlet.http.Cookie servletCookie = new jakarta.servlet.http.Cookie(cookie.getName(), cookie.getValue()); + servletCookie.setPath(cookie.getPath()); + servletCookie.setMaxAge(cookie.getMaxAge()); + servletCookie.setSecure(cookie.isSecure()); + servletCookie.setHttpOnly(cookie.isHttpOnly()); + + return servletCookie; + } + + public static Optional read(Cookie cookie) + { + return Optional.ofNullable(cookie) + .map(Cookie::getValue) + .filter(not(StringUtils::isBlank)); + } + + public static NewCookie delete() + { + return new NewCookie( + NONCE_COOKIE, + "delete", + CALLBACK_ENDPOINT, + null, + DEFAULT_VERSION, + null, + 0, + null, + true, + true); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2AuthenticationSupportModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2AuthenticationSupportModule.java new file mode 100644 index 0000000000000..88cd2b5e0f23d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2AuthenticationSupportModule.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.google.inject.Binder; +import com.google.inject.Scopes; + +import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; + +public class OAuth2AuthenticationSupportModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + binder.bind(OAuth2TokenExchange.class).in(Scopes.SINGLETON); + binder.bind(OAuth2TokenHandler.class).to(OAuth2TokenExchange.class).in(Scopes.SINGLETON); + jaxrsBinder(binder).bind(OAuth2TokenExchangeResource.class); + install(new OAuth2ServiceModule()); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java new file mode 100644 index 0000000000000..83c2de302c287 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Authenticator.java @@ -0,0 +1,146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.http.server.BasicPrincipal; +import com.facebook.airlift.log.Logger; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.commons.lang3.StringUtils; + +import javax.inject.Inject; + +import java.net.URI; +import java.security.Principal; +import java.sql.Date; +import java.time.Instant; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchangeResource.getInitiateUri; +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchangeResource.getTokenUri; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static com.facebook.presto.server.security.oauth2.OAuthWebUiCookie.OAUTH2_COOKIE; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class OAuth2Authenticator + implements Authenticator +{ + private static final Logger Log = Logger.get(OAuth2Authenticator.class); + private final String principalField; + private final OAuth2Client client; + private final TokenPairSerializer tokenPairSerializer; + private final TokenRefresher tokenRefresher; + + @Inject + public OAuth2Authenticator(OAuth2Client client, OAuth2Config config, TokenRefresher tokenRefresher, TokenPairSerializer tokenPairSerializer) + { + this.client = requireNonNull(client, "service is null"); + this.principalField = config.getPrincipalField(); + requireNonNull(config, "oauth2Config is null"); + this.tokenRefresher = requireNonNull(tokenRefresher, "tokenRefresher is null"); + this.tokenPairSerializer = requireNonNull(tokenPairSerializer, "tokenPairSerializer is null"); + } + + public Principal authenticate(HttpServletRequest request) throws AuthenticationException + { + String token = extractToken(request); + TokenPairSerializer.TokenPair tokenPair; + try { + tokenPair = tokenPairSerializer.deserialize(token); + } + catch (IllegalArgumentException e) { + Log.error(e, "Failed to deserialize the OAuth token"); + throw needAuthentication(request, Optional.empty(), "Invalid Credentials"); + } + + if (tokenPair.getExpiration().before(Date.from(Instant.now()))) { + throw needAuthentication(request, Optional.of(token), "Invalid Credentials"); + } + Optional> claims = client.getClaims(tokenPair.getAccessToken()); + + if (!claims.isPresent()) { + throw needAuthentication(request, Optional.ofNullable(token), "Invalid Credentials"); + } + String principal = (String) claims.get().get(principalField); + if (StringUtils.isEmpty(principal)) { + Log.warn("The subject is not present we need to authenticate"); + needAuthentication(request, Optional.empty(), "Invalid Credentials"); + } + + return new BasicPrincipal(principal); + } + + public String extractToken(HttpServletRequest request) throws AuthenticationException + { + Optional cookieToken = this.extractTokenFromCookie(request); + Optional headerToken = this.extractTokenFromHeader(request); + + if (!cookieToken.isPresent() && !headerToken.isPresent()) { + throw needAuthentication(request, Optional.empty(), "Invalid Credentials"); + } + + return cookieToken.orElseGet(() -> headerToken.get()); + } + + public Optional extractTokenFromHeader(HttpServletRequest request) + { + String authHeader = nullToEmpty(request.getHeader(AUTHORIZATION)); + int space = authHeader.indexOf(' '); + if ((space < 0) || !authHeader.substring(0, space).equalsIgnoreCase("bearer")) { + return Optional.empty(); + } + + return Optional.ofNullable(authHeader.substring(space + 1).trim()) + .filter(t -> !t.isEmpty()); + } + + public static Optional extractTokenFromCookie(HttpServletRequest request) + { + Cookie[] cookies = Optional.ofNullable(request.getCookies()).orElse(new Cookie[0]); + return Optional.ofNullable(Arrays.stream(cookies) + .filter(cookie -> cookie.getName().equals(OAUTH2_COOKIE)) + .findFirst() + .map(c -> c.getValue()) + .orElse(null)); + } + + private AuthenticationException needAuthentication(HttpServletRequest request, Optional currentToken, String message) + { + URI baseUri = getSchemeUriBuilder(request).build(); + return currentToken + .map(tokenPairSerializer::deserialize) + .flatMap(tokenRefresher::refreshToken) + .map(refreshId -> baseUri.resolve(getTokenUri(refreshId))) + .map(tokenUri -> new AuthenticationException(message, format("Bearer x_token_server=\"%s\"", tokenUri))) + .orElseGet(() -> buildNeedAuthentication(request, message)); + } + + private AuthenticationException buildNeedAuthentication(HttpServletRequest request, String message) + { + UUID authId = UUID.randomUUID(); + URI baseUri = getSchemeUriBuilder(request).build(); + URI initiateUri = baseUri.resolve(getInitiateUri(authId)); + URI tokenUri = baseUri.resolve(getTokenUri(authId)); + + return new AuthenticationException(message, format("Bearer x_redirect_server=\"%s\", x_token_server=\"%s\"", initiateUri, tokenUri)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2CallbackResource.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2CallbackResource.java new file mode 100644 index 0000000000000..42d52450837d2 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2CallbackResource.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.log.Logger; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.CookieParam; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Cookie; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; + +import javax.inject.Inject; + +import static com.facebook.presto.server.security.oauth2.NonceCookie.NONCE_COOKIE; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static jakarta.ws.rs.core.MediaType.TEXT_HTML; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static java.util.Objects.requireNonNull; + +@Path(OAuth2CallbackResource.CALLBACK_ENDPOINT) +public class OAuth2CallbackResource +{ + private static final Logger LOG = Logger.get(OAuth2CallbackResource.class); + + public static final String CALLBACK_ENDPOINT = "/oauth2/callback"; + + private final OAuth2Service service; + + @Inject + public OAuth2CallbackResource(OAuth2Service service) + { + this.service = requireNonNull(service, "service is null"); + } + + @GET + @Produces(TEXT_HTML) + public Response callback( + @QueryParam("state") String state, + @QueryParam("code") String code, + @QueryParam("error") String error, + @QueryParam("error_description") String errorDescription, + @QueryParam("error_uri") String errorUri, + @CookieParam(NONCE_COOKIE) Cookie nonce, + @Context HttpServletRequest request) + { + if (error != null) { + return service.handleOAuth2Error(state, error, errorDescription, errorUri); + } + + try { + requireNonNull(state, "state is null"); + requireNonNull(code, "code is null"); + UriBuilder builder = getSchemeUriBuilder(request); + return service.finishOAuth2Challenge(state, code, builder.build().resolve(CALLBACK_ENDPOINT), NonceCookie.read(nonce), request); + } + catch (RuntimeException e) { + LOG.error(e, "Authentication response could not be verified: state=%s", state); + return Response.status(BAD_REQUEST) + .cookie(NonceCookie.delete()) + .entity(service.getInternalFailureHtml("Authentication response could not be verified")) + .build(); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java new file mode 100644 index 0000000000000..fd68a2a5a2e06 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Client.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import java.net.URI; +import java.time.Instant; +import java.util.Map; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public interface OAuth2Client +{ + void load(); + + Request createAuthorizationRequest(String state, URI callbackUri); + + Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + throws ChallengeFailedException; + + Optional> getClaims(String accessToken); + + Response refreshTokens(String refreshToken) + throws ChallengeFailedException; + + class Request + { + private final URI authorizationUri; + private final Optional nonce; + + public Request(URI authorizationUri, Optional nonce) + { + this.authorizationUri = requireNonNull(authorizationUri, "authorizationUri is null"); + this.nonce = requireNonNull(nonce, "nonce is null"); + } + + public URI getAuthorizationUri() + { + return authorizationUri; + } + + public Optional getNonce() + { + return nonce; + } + } + + class Response + { + private final String accessToken; + private final Instant expiration; + + private final Optional refreshToken; + + public Response(String accessToken, Instant expiration, Optional refreshToken) + { + this.accessToken = requireNonNull(accessToken, "accessToken is null"); + this.expiration = requireNonNull(expiration, "expiration is null"); + this.refreshToken = requireNonNull(refreshToken, "refreshToken is null"); + } + + public String getAccessToken() + { + return accessToken; + } + + public Instant getExpiration() + { + return expiration; + } + + public Optional getRefreshToken() + { + return refreshToken; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java new file mode 100644 index 0000000000000..b1b0f9513b2f7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Config.java @@ -0,0 +1,253 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.units.Duration; +import com.facebook.airlift.units.MinDuration; +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import jakarta.validation.constraints.NotNull; + +import java.io.File; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.server.security.oauth2.OAuth2Service.OPENID_SCOPE; +import static com.google.common.base.Strings.nullToEmpty; + +public class OAuth2Config +{ + private Optional stateKey = Optional.empty(); + private String issuer; + private String clientId; + private String clientSecret; + private Set scopes = ImmutableSet.of(OPENID_SCOPE); + private String principalField = "sub"; + private Optional groupsField = Optional.empty(); + private List additionalAudiences = Collections.emptyList(); + private Duration challengeTimeout = new Duration(15, TimeUnit.MINUTES); + private Duration maxClockSkew = new Duration(1, TimeUnit.MINUTES); + private Optional userMappingPattern = Optional.empty(); + private Optional userMappingFile = Optional.empty(); + private boolean enableRefreshTokens; + private boolean enableDiscovery = true; + + public Optional getStateKey() + { + return stateKey; + } + + @Config("http-server.authentication.oauth2.state-key") + @ConfigDescription("A secret key used by HMAC algorithm to sign the state parameter") + public OAuth2Config setStateKey(String stateKey) + { + this.stateKey = Optional.ofNullable(stateKey); + return this; + } + + @NotNull + public String getIssuer() + { + return issuer; + } + + @Config("http-server.authentication.oauth2.issuer") + @ConfigDescription("The required issuer of a token") + public OAuth2Config setIssuer(String issuer) + { + this.issuer = issuer; + return this; + } + + @NotNull + public String getClientId() + { + return clientId; + } + + @Config("http-server.authentication.oauth2.client-id") + @ConfigDescription("Client ID") + public OAuth2Config setClientId(String clientId) + { + this.clientId = clientId; + return this; + } + + @NotNull + public String getClientSecret() + { + return clientSecret; + } + + @Config("http-server.authentication.oauth2.client-secret") + @ConfigSecuritySensitive + @ConfigDescription("Client secret") + public OAuth2Config setClientSecret(String clientSecret) + { + this.clientSecret = clientSecret; + return this; + } + + @NotNull + public List getAdditionalAudiences() + { + return additionalAudiences; + } + + public OAuth2Config setAdditionalAudiences(List additionalAudiences) + { + this.additionalAudiences = ImmutableList.copyOf(additionalAudiences); + return this; + } + + @Config("http-server.authentication.oauth2.additional-audiences") + @ConfigDescription("Additional audiences to trust in addition to the Client ID") + public OAuth2Config setAdditionalAudiences(String additionalAudiences) + { + Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings(); + this.additionalAudiences = ImmutableList.copyOf(splitter.split(nullToEmpty(additionalAudiences))); + return this; + } + + @NotNull + public Set getScopes() + { + return scopes; + } + + @Config("http-server.authentication.oauth2.scopes") + @ConfigDescription("Scopes requested by the server during OAuth2 authorization challenge") + public OAuth2Config setScopes(String scopes) + { + Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings(); + this.scopes = ImmutableSet.copyOf(splitter.split(nullToEmpty(scopes))); + return this; + } + + @NotNull + public String getPrincipalField() + { + return principalField; + } + + @Config("http-server.authentication.oauth2.principal-field") + @ConfigDescription("The claim to use as the principal") + public OAuth2Config setPrincipalField(String principalField) + { + this.principalField = principalField; + return this; + } + + public Optional getGroupsField() + { + return groupsField; + } + + @Config("http-server.authentication.oauth2.groups-field") + @ConfigDescription("Groups field in the claim") + public OAuth2Config setGroupsField(String groupsField) + { + this.groupsField = Optional.ofNullable(groupsField); + return this; + } + + @MinDuration("1ms") + @NotNull + public Duration getChallengeTimeout() + { + return challengeTimeout; + } + + @Config("http-server.authentication.oauth2.challenge-timeout") + @ConfigDescription("Maximum duration of OAuth2 authorization challenge") + public OAuth2Config setChallengeTimeout(Duration challengeTimeout) + { + this.challengeTimeout = challengeTimeout; + return this; + } + + @MinDuration("0s") + @NotNull + public Duration getMaxClockSkew() + { + return maxClockSkew; + } + + @Config("http-server.authentication.oauth2.max-clock-skew") + @ConfigDescription("Max clock skew between the Authorization Server and the coordinator") + public OAuth2Config setMaxClockSkew(Duration maxClockSkew) + { + this.maxClockSkew = maxClockSkew; + return this; + } + + public Optional getUserMappingPattern() + { + return userMappingPattern; + } + + @Config("http-server.authentication.oauth2.user-mapping.pattern") + @ConfigDescription("Regex to match against user name") + public OAuth2Config setUserMappingPattern(String userMappingPattern) + { + this.userMappingPattern = Optional.ofNullable(userMappingPattern); + return this; + } + + public Optional getUserMappingFile() + { + return userMappingFile; + } + + @Config("http-server.authentication.oauth2.user-mapping.file") + @ConfigDescription("File containing rules for mapping user") + public OAuth2Config setUserMappingFile(File userMappingFile) + { + this.userMappingFile = Optional.ofNullable(userMappingFile); + return this; + } + + public boolean isEnableRefreshTokens() + { + return enableRefreshTokens; + } + + @Config("http-server.authentication.oauth2.refresh-tokens") + @ConfigDescription("Enables OpenID refresh tokens usage") + public OAuth2Config setEnableRefreshTokens(boolean enableRefreshTokens) + { + this.enableRefreshTokens = enableRefreshTokens; + return this; + } + + public boolean isEnableDiscovery() + { + return enableDiscovery; + } + + @Config("http-server.authentication.oauth2.oidc.discovery") + @ConfigDescription("Enable OpenID Provider Issuer discovery") + public OAuth2Config setEnableDiscovery(boolean enableDiscovery) + { + this.enableDiscovery = enableDiscovery; + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ErrorCode.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ErrorCode.java new file mode 100644 index 0000000000000..058ce8be92b1e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ErrorCode.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import java.util.Arrays; + +public enum OAuth2ErrorCode +{ + ACCESS_DENIED("access_denied", "OAuth2 server denied the login"), + UNAUTHORIZED_CLIENT("unauthorized_client", "OAuth2 server does not allow request from this Presto server"), + SERVER_ERROR("server_error", "OAuth2 server had a failure"), + TEMPORARILY_UNAVAILABLE("temporarily_unavailable", "OAuth2 server is temporarily unavailable"); + + private final String code; + private final String message; + + OAuth2ErrorCode(String code, String message) + { + this.code = code; + this.message = message; + } + + public static OAuth2ErrorCode fromString(String codeStr) + { + return Arrays.stream(OAuth2ErrorCode.values()) + .filter(value -> codeStr.equalsIgnoreCase(value.code)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("No enum constant " + OAuth2ErrorCode.class.getCanonicalName() + "." + codeStr)); + } + + public String getMessage() + { + return this.message; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServerConfigProvider.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServerConfigProvider.java new file mode 100644 index 0000000000000..9455e5622605d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServerConfigProvider.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import java.net.URI; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public interface OAuth2ServerConfigProvider +{ + OAuth2ServerConfig get(); + + class OAuth2ServerConfig + { + private final Optional accessTokenIssuer; + private final URI authUrl; + private final URI tokenUrl; + private final URI jwksUrl; + private final Optional userinfoUrl; + + public OAuth2ServerConfig(Optional accessTokenIssuer, URI authUrl, URI tokenUrl, URI jwksUrl, Optional userinfoUrl) + { + this.accessTokenIssuer = requireNonNull(accessTokenIssuer, "accessTokenIssuer is null"); + this.authUrl = requireNonNull(authUrl, "authUrl is null"); + this.tokenUrl = requireNonNull(tokenUrl, "tokenUrl is null"); + this.jwksUrl = requireNonNull(jwksUrl, "jwksUrl is null"); + this.userinfoUrl = requireNonNull(userinfoUrl, "userinfoUrl is null"); + } + + public Optional getAccessTokenIssuer() + { + return accessTokenIssuer; + } + + public URI getAuthUrl() + { + return authUrl; + } + + public URI getTokenUrl() + { + return tokenUrl; + } + + public URI getJwksUrl() + { + return jwksUrl; + } + + public Optional getUserinfoUrl() + { + return userinfoUrl; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Service.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Service.java new file mode 100644 index 0000000000000..78616a7d8cad5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Service.java @@ -0,0 +1,285 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.log.Logger; +import com.google.common.io.Resources; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.JwtParser; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; + +import javax.inject.Inject; + +import java.io.IOException; +import java.net.URI; +import java.security.Key; +import java.security.SecureRandom; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.TemporalAmount; +import java.util.Date; +import java.util.Optional; +import java.util.Random; + +import static com.facebook.presto.server.security.oauth2.JwtUtil.newJwtBuilder; +import static com.facebook.presto.server.security.oauth2.JwtUtil.newJwtParserBuilder; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.fromOAuth2Response; +import static com.google.common.base.Strings.nullToEmpty; +import static com.google.common.base.Verify.verify; +import static com.google.common.hash.Hashing.sha256; +import static io.jsonwebtoken.security.Keys.hmacShaKeyFor; +import static jakarta.ws.rs.core.Response.Status.BAD_REQUEST; +import static jakarta.ws.rs.core.Response.Status.FORBIDDEN; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.Instant.now; +import static java.util.Objects.requireNonNull; + +public class OAuth2Service +{ + private static final Logger logger = Logger.get(OAuth2Service.class); + + public static final String OPENID_SCOPE = "openid"; + + private static final String STATE_AUDIENCE_UI = "presto_oauth_ui"; + private static final String FAILURE_REPLACEMENT_TEXT = ""; + private static final Random SECURE_RANDOM = new SecureRandom(); + public static final String HANDLER_STATE_CLAIM = "handler_state"; + + private final OAuth2Client client; + private final Optional tokenExpiration; + private final TokenPairSerializer tokenPairSerializer; + + private final String successHtml; + private final String failureHtml; + + private final TemporalAmount challengeTimeout; + private final Key stateHmac; + private final JwtParser jwtParser; + + private final OAuth2TokenHandler tokenHandler; + + @Inject + public OAuth2Service( + OAuth2Client client, + OAuth2Config oauth2Config, + OAuth2TokenHandler tokenHandler, + TokenPairSerializer tokenPairSerializer, + @ForRefreshTokens Optional tokenExpiration) + throws IOException + { + this.client = requireNonNull(client, "client is null"); + requireNonNull(oauth2Config, "oauth2Config is null"); + + this.successHtml = Resources.toString(Resources.getResource(getClass(), "/webapp/oauth2/success.html"), UTF_8); + this.failureHtml = Resources.toString(Resources.getResource(getClass(), "/webapp/oauth2/failure.html"), UTF_8); + verify(failureHtml.contains(FAILURE_REPLACEMENT_TEXT), "failure.html does not contain the replacement text"); + + this.challengeTimeout = Duration.ofMillis(oauth2Config.getChallengeTimeout().toMillis()); + this.stateHmac = hmacShaKeyFor(oauth2Config.getStateKey() + .map(key -> sha256().hashString(key, UTF_8).asBytes()) + .orElseGet(() -> secureRandomBytes(32))); + this.jwtParser = newJwtParserBuilder() + .setSigningKey(stateHmac) + .requireAudience(STATE_AUDIENCE_UI) + .build(); + + this.tokenHandler = requireNonNull(tokenHandler, "tokenHandler is null"); + this.tokenPairSerializer = requireNonNull(tokenPairSerializer, "tokenPairSerializer is null"); + + this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null"); + } + + public Response startOAuth2Challenge(URI callbackUri, Optional handlerState) + { + Instant challengeExpiration = now().plus(challengeTimeout); + String state = newJwtBuilder() + .signWith(stateHmac) + .setAudience(STATE_AUDIENCE_UI) + .claim(HANDLER_STATE_CLAIM, handlerState.orElse(null)) + .setExpiration(Date.from(challengeExpiration)) + .compact(); + + OAuth2Client.Request request = client.createAuthorizationRequest(state, callbackUri); + Response.ResponseBuilder response = Response.seeOther(request.getAuthorizationUri()); + request.getNonce().ifPresent(nce -> response.cookie(NonceCookie.create(nce, challengeExpiration))); + return response.build(); + } + + public void startOAuth2Challenge(URI callbackUri, Optional handlerState, HttpServletResponse servletResponse) + throws IOException + { + Instant challengeExpiration = now().plus(challengeTimeout); + + OAuth2Client.Request challengeRequest = this.startChallenge(callbackUri, handlerState); + challengeRequest.getNonce().ifPresent(nce -> servletResponse.addCookie(NonceCookie.createServletCookie(nce, challengeExpiration))); + servletResponseSeeOther(challengeRequest.getAuthorizationUri().toString(), servletResponse); + } + + public void servletResponseSeeOther(String location, HttpServletResponse servletResponse) + throws IOException + { + // 303 is preferred over a 302 when this response is received by a POST/PUT/DELETE and the redirect should be done via a GET instead of original method + servletResponse.addHeader(HttpHeaders.LOCATION, location); + servletResponse.sendError(HttpServletResponse.SC_SEE_OTHER); + } + + private OAuth2Client.Request startChallenge(URI callbackUri, Optional handlerState) + { + Instant challengeExpiration = now().plus(challengeTimeout); + String state = newJwtBuilder() + .signWith(stateHmac) + .setAudience(STATE_AUDIENCE_UI) + .claim(HANDLER_STATE_CLAIM, handlerState.orElse(null)) + .setExpiration(Date.from(challengeExpiration)) + .compact(); + + return client.createAuthorizationRequest(state, callbackUri); + } + + public Response handleOAuth2Error(String state, String error, String errorDescription, String errorUri) + { + try { + Claims stateClaims = parseState(state); + Optional.ofNullable(stateClaims.get(HANDLER_STATE_CLAIM, String.class)) + .ifPresent(value -> + tokenHandler.setTokenExchangeError(value, + format("Authentication response could not be verified: error=%s, errorDescription=%s, errorUri=%s", + error, errorDescription, errorDescription))); + } + catch (ChallengeFailedException | RuntimeException e) { + logger.error(e, "Authentication response could not be verified invalid state: state=%s", state); + return Response.status(FORBIDDEN) + .entity(getInternalFailureHtml("Authentication response could not be verified")) + .cookie(NonceCookie.delete()) + .build(); + } + + logger.error("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", error, errorDescription, errorUri, state); + return Response.ok() + .entity(getCallbackErrorHtml(error)) + .cookie(NonceCookie.delete()) + .build(); + } + + public Response finishOAuth2Challenge(String state, String code, URI callbackUri, Optional nonce, HttpServletRequest request) + { + Optional handlerState; + try { + Claims stateClaims = parseState(state); + handlerState = Optional.ofNullable(stateClaims.get(HANDLER_STATE_CLAIM, String.class)); + } + catch (ChallengeFailedException | RuntimeException e) { + logger.error(e, "Authentication response could not be verified invalid state: state=%s", state); + return Response.status(BAD_REQUEST) + .entity(getInternalFailureHtml("Authentication response could not be verified")) + .cookie(NonceCookie.delete()) + .build(); + } + + // Note: the Web UI may be disabled, so REST requests can not redirect to a success or error page inside the Web UI + try { + // fetch access token + OAuth2Client.Response oauth2Response = client.getOAuth2Response(code, callbackUri, nonce); + + if (!handlerState.isPresent()) { + UriBuilder uriBuilder = getSchemeUriBuilder(request); + return Response + .seeOther(uriBuilder.build().resolve("/ui/")) + .cookie( + OAuthWebUiCookie.create( + tokenPairSerializer.serialize( + fromOAuth2Response(oauth2Response)), + tokenExpiration + .map(expiration -> Instant.now().plus(expiration)) + .orElse(oauth2Response.getExpiration())), + NonceCookie.delete()) + .build(); + } + + tokenHandler.setAccessToken(handlerState.get(), tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response))); + + Response.ResponseBuilder builder = Response.ok(getSuccessHtml()); + builder.cookie( + OAuthWebUiCookie.create( + tokenPairSerializer.serialize(fromOAuth2Response(oauth2Response)), + tokenExpiration.map(expiration -> Instant.now().plus(expiration)) + .orElse(oauth2Response.getExpiration()))); + + return builder.cookie(NonceCookie.delete()).build(); + } + catch (ChallengeFailedException | RuntimeException e) { + logger.error(e, "Authentication response could not be verified: state=%s", state); + + handlerState.ifPresent(value -> + tokenHandler.setTokenExchangeError(value, format("Authentication response could not be verified: state=%s", value))); + return Response.status(BAD_REQUEST) + .cookie(NonceCookie.delete()) + .entity(getInternalFailureHtml("Authentication response could not be verified")) + .build(); + } + } + + private Claims parseState(String state) + throws ChallengeFailedException + { + try { + return jwtParser + .parseClaimsJws(state) + .getBody(); + } + catch (RuntimeException e) { + throw new ChallengeFailedException("State validation failed", e); + } + } + + public String getSuccessHtml() + { + return successHtml; + } + + public String getCallbackErrorHtml(String errorCode) + { + return failureHtml.replace(FAILURE_REPLACEMENT_TEXT, getOAuth2ErrorMessage(errorCode)); + } + + public String getInternalFailureHtml(String errorMessage) + { + return failureHtml.replace(FAILURE_REPLACEMENT_TEXT, nullToEmpty(errorMessage)); + } + + private static byte[] secureRandomBytes(int count) + { + byte[] bytes = new byte[count]; + SECURE_RANDOM.nextBytes(bytes); + return bytes; + } + + private static String getOAuth2ErrorMessage(String errorCode) + { + try { + OAuth2ErrorCode code = OAuth2ErrorCode.fromString(errorCode); + return code.getMessage(); + } + catch (IllegalArgumentException e) { + logger.error(e, "Unknown error code received code=%s", errorCode); + return "OAuth2 unknown error code: " + errorCode; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java new file mode 100644 index 0000000000000..31b8740edbaf4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2ServiceModule.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Key; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.Singleton; + +import java.time.Duration; + +import static com.facebook.airlift.configuration.ConditionalModule.installModuleIf; +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder; +import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.ACCESS_TOKEN_ONLY_SERIALIZER; +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; + +public class OAuth2ServiceModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + jaxrsBinder(binder).bind(OAuth2CallbackResource.class); + configBinder(binder).bindConfig(OAuth2Config.class); + binder.bind(OAuth2Service.class).in(Scopes.SINGLETON); + binder.bind(OAuth2TokenHandler.class).to(OAuth2TokenExchange.class).in(Scopes.SINGLETON); + binder.bind(NimbusHttpClient.class).to(NimbusAirliftHttpClient.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, OAuth2Client.class) + .setDefault() + .to(NimbusOAuth2Client.class) + .in(Scopes.SINGLETON); + install(installModuleIf(OAuth2Config.class, OAuth2Config::isEnableDiscovery, this::bindOidcDiscovery, this::bindStaticConfiguration)); + install(installModuleIf(OAuth2Config.class, OAuth2Config::isEnableRefreshTokens, this::enableRefreshTokens, this::disableRefreshTokens)); + httpClientBinder(binder) + .bindHttpClient("oauth2-jwk", ForOAuth2.class) + // Reset to defaults to override InternalCommunicationModule changes to this client default configuration. + // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration + // for all clients in this guice context. This does not make sense for this client which will very rarely + // use the same SSL configuration, so using the system default truststore makes more sense. + .withConfigDefaults(config -> config + .setKeyStorePath(null) + .setKeyStorePassword(null) + .setTrustStorePath(null) + .setTrustStorePassword(null)); + } + + private void enableRefreshTokens(Binder binder) + { + install(new JweTokenSerializerModule()); + } + + private void disableRefreshTokens(Binder binder) + { + binder.bind(TokenPairSerializer.class).toInstance(ACCESS_TOKEN_ONLY_SERIALIZER); + newOptionalBinder(binder, Key.get(Duration.class, ForRefreshTokens.class)); + } + + @Singleton + @Provides + @Inject + public TokenRefresher getTokenRefresher(TokenPairSerializer tokenAssembler, OAuth2TokenHandler tokenHandler, OAuth2Client oAuth2Client) + { + return new TokenRefresher(tokenAssembler, tokenHandler, oAuth2Client); + } + + private void bindStaticConfiguration(Binder binder) + { + configBinder(binder).bindConfig(StaticOAuth2ServerConfiguration.class); + binder.bind(OAuth2ServerConfigProvider.class).to(StaticConfigurationProvider.class).in(Scopes.SINGLETON); + } + + private void bindOidcDiscovery(Binder binder) + { + configBinder(binder).bindConfig(OidcDiscoveryConfig.class); + binder.bind(OAuth2ServerConfigProvider.class).to(OidcDiscovery.class).in(Scopes.SINGLETON); + } + + @Override + public int hashCode() + { + return OAuth2ServiceModule.class.hashCode(); + } + + @Override + public boolean equals(Object obj) + { + return obj instanceof OAuth2ServiceModule; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchange.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchange.java new file mode 100644 index 0000000000000..95e0bdc25eb01 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchange.java @@ -0,0 +1,146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.units.Duration; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.hash.Hashing; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import jakarta.annotation.PreDestroy; + +import javax.inject.Inject; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class OAuth2TokenExchange + implements OAuth2TokenHandler +{ + public static final Duration MAX_POLL_TIME = new Duration(10, SECONDS); + private static final TokenPoll TOKEN_POLL_TIMED_OUT = TokenPoll.error("Authentication has timed out"); + private static final TokenPoll TOKEN_POLL_DROPPED = TokenPoll.error("Authentication has been finished by the client"); + + private final LoadingCache> cache; + private final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(daemonThreadsNamed("oauth2-token-exchange")); + + @Inject + public OAuth2TokenExchange(OAuth2Config config) + { + long challengeTimeoutMillis = config.getChallengeTimeout().toMillis(); + this.cache = buildUnsafeCache( + CacheBuilder.newBuilder() + .expireAfterWrite(challengeTimeoutMillis + (MAX_POLL_TIME.toMillis() * 10), MILLISECONDS) + .removalListener(notification -> notification.getValue().set(TOKEN_POLL_TIMED_OUT)), + new CacheLoader>() + { + @Override + public SettableFuture load(String authIdHash) + { + SettableFuture future = SettableFuture.create(); + Future timeout = executor.schedule(() -> future.set(TOKEN_POLL_TIMED_OUT), challengeTimeoutMillis, MILLISECONDS); + future.addListener(() -> timeout.cancel(true), executor); + return future; + } + }); + } + + private static LoadingCache buildUnsafeCache(CacheBuilder cacheBuilder, CacheLoader cacheLoader) + { + return cacheBuilder.build(cacheLoader); + } + + @PreDestroy + public void stop() + { + executor.shutdownNow(); + } + + @Override + public void setAccessToken(String authIdHash, String accessToken) + { + cache.getUnchecked(authIdHash).set(TokenPoll.token(accessToken)); + } + + @Override + public void setTokenExchangeError(String authIdHash, String message) + { + cache.getUnchecked(authIdHash).set(TokenPoll.error(message)); + } + + public ListenableFuture getTokenPoll(UUID authId) + { + return nonCancellationPropagating(cache.getUnchecked(hashAuthId(authId))); + } + + public void dropToken(UUID authId) + { + cache.getUnchecked(hashAuthId(authId)).set(TOKEN_POLL_DROPPED); + } + + public static String hashAuthId(UUID authId) + { + return Hashing.sha256() + .hashString(authId.toString(), StandardCharsets.UTF_8) + .toString(); + } + + public static class TokenPoll + { + private final Optional token; + private final Optional error; + + private TokenPoll(String token, String error) + { + this.token = Optional.ofNullable(token); + this.error = Optional.ofNullable(error); + } + + static TokenPoll token(String token) + { + requireNonNull(token, "token is null"); + + return new TokenPoll(token, null); + } + + static TokenPoll error(String error) + { + requireNonNull(error, "error is null"); + + return new TokenPoll(null, error); + } + + public Optional getToken() + { + return token; + } + + public Optional getError() + { + return error; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchangeResource.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchangeResource.java new file mode 100644 index 0000000000000..5335712abe75e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenExchangeResource.java @@ -0,0 +1,141 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.presto.dispatcher.DispatchExecutor; +import com.facebook.presto.server.security.oauth2.OAuth2TokenExchange.TokenPoll; +import com.google.common.base.VerifyException; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.BadRequestException; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.container.AsyncResponse; +import jakarta.ws.rs.container.Suspended; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static com.facebook.airlift.http.server.AsyncResponseHandler.bindAsyncResponse; +import static com.facebook.presto.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT; +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchange.MAX_POLL_TIME; +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchange.hashAuthId; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON_TYPE; +import static java.util.Objects.requireNonNull; + +@Path(OAuth2TokenExchangeResource.TOKEN_ENDPOINT) +public class OAuth2TokenExchangeResource +{ + public static final String TOKEN_ENDPOINT = "/oauth2/token/"; + private static final JsonCodec> MAP_CODEC = new JsonCodecFactory().mapJsonCodec(String.class, Object.class); + private final OAuth2TokenExchange tokenExchange; + private final OAuth2Service service; + private final ListeningExecutorService responseExecutor; + + @Inject + public OAuth2TokenExchangeResource(OAuth2TokenExchange tokenExchange, OAuth2Service service, DispatchExecutor executor) + { + this.tokenExchange = requireNonNull(tokenExchange, "tokenExchange is null"); + this.service = requireNonNull(service, "service is null"); + this.responseExecutor = requireNonNull(executor, "executor is null").getExecutor(); + } + + @Path("initiate/{authIdHash}") + @GET + @Produces(MediaType.APPLICATION_JSON) + public Response initiateTokenExchange(@PathParam("authIdHash") String authIdHash, @Context HttpServletRequest request) + { + UriBuilder builder = getSchemeUriBuilder(request); + return service.startOAuth2Challenge(builder.build().resolve(CALLBACK_ENDPOINT), Optional.ofNullable(authIdHash)); + } + + @Path("{authId}") + @GET + @Produces(MediaType.APPLICATION_JSON) + public void getAuthenticationToken(@PathParam("authId") UUID authId, @Suspended AsyncResponse asyncResponse, @Context HttpServletRequest request) + { + if (authId == null) { + throw new BadRequestException(); + } + + // Do not drop the response from the cache on failure, as this would result in a + // hang if the client retries the request. The response will timeout eventually. + ListenableFuture tokenFuture = tokenExchange.getTokenPoll(authId); + ListenableFuture responseFuture = Futures.transform(tokenFuture, OAuth2TokenExchangeResource::toResponse, responseExecutor); + bindAsyncResponse(asyncResponse, responseFuture, responseExecutor) + .withTimeout(MAX_POLL_TIME, pendingResponse(request)); + } + + private static Response toResponse(TokenPoll poll) + { + if (poll.getError().isPresent()) { + return Response.ok(jsonMap("error", poll.getError().get()), APPLICATION_JSON_TYPE).build(); + } + if (poll.getToken().isPresent()) { + return Response.ok(jsonMap("token", poll.getToken().get()), APPLICATION_JSON_TYPE).build(); + } + throw new VerifyException("invalid TokenPoll state"); + } + + private static Response pendingResponse(HttpServletRequest request) + { + UriBuilder builder = getSchemeUriBuilder(request); + return Response.ok(jsonMap("nextUri", builder.build()), APPLICATION_JSON_TYPE).build(); + } + + @DELETE + @Path("{authId}") + public Response deleteAuthenticationToken(@PathParam("authId") UUID authId) + { + if (authId == null) { + throw new BadRequestException(); + } + + tokenExchange.dropToken(authId); + return Response + .ok() + .build(); + } + + public static String getTokenUri(UUID authId) + { + return TOKEN_ENDPOINT + authId; + } + + public static String getInitiateUri(UUID authId) + { + return TOKEN_ENDPOINT + "initiate/" + hashAuthId(authId); + } + + private static String jsonMap(String key, Object value) + { + return MAP_CODEC.toJson(ImmutableMap.of(key, value)); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenHandler.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenHandler.java new file mode 100644 index 0000000000000..027d3c27d90f4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2TokenHandler.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +public interface OAuth2TokenHandler +{ + void setAccessToken(String hashedState, String accessToken); + + void setTokenExchangeError(String hashedState, String errorMessage); +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Utils.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Utils.java new file mode 100644 index 0000000000000..0b5f8f62db545 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2Utils.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.core.UriBuilder; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; + +public final class OAuth2Utils +{ + private OAuth2Utils() {} + + /** + * Returns a UriBuilder with the scheme set based upon the X_FORWARDED_PROTO header. + * If the header exists on the request we set the scheme to what is in that header. i.e. https. + * If the header is not set then we use the scheme on the request. + * + * Ex: If you are using a load balancer to handle ssl forwarding for Presto. You must set the + * X_FORWARDED_PROTO header in the load balancer to 'https'. For any callback or redirect url's + * for the OAUTH2 Login flow must use the scheme of https. + * + * @param request HttpServletRequest + * @return a new instance of UriBuilder with the scheme set. + */ + public static UriBuilder getSchemeUriBuilder(HttpServletRequest request) + { + Optional forwardedProto = Optional.ofNullable(request.getHeader(X_FORWARDED_PROTO)); + + UriBuilder builder = UriBuilder.fromUri(getFullRequestURL(request)); + if (forwardedProto.isPresent()) { + builder.scheme(forwardedProto.get()); + } + else { + builder.scheme(request.getScheme()); + } + + return builder; + } + + /** + * Finds the lastURL query parameter in the request. + * + * @return Optional the value of the lastURL parameter + */ + public static Optional getLastURLParameter(MultivaluedMap queryParams) + { + Optional>> lastUrl = queryParams.entrySet().stream().filter(qp -> qp.getKey().equals("lastURL")).findFirst(); + if (lastUrl.isPresent() && lastUrl.get().getValue().size() > 0) { + return Optional.ofNullable(lastUrl.get().getValue().get(0)); + } + + return Optional.empty(); + } + + public static String getFullRequestURL(HttpServletRequest request) + { + StringBuilder requestURL = new StringBuilder(request.getRequestURL()); + String queryString = request.getQueryString(); + + return queryString == null ? requestURL.toString() : requestURL.append("?").append(queryString).toString(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2WebUiAuthenticationManager.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2WebUiAuthenticationManager.java new file mode 100644 index 0000000000000..925019f5ac442 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuth2WebUiAuthenticationManager.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.airlift.log.Logger; +import com.facebook.presto.server.security.WebUiAuthenticationManager; +import com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.ws.rs.core.UriBuilder; + +import javax.inject.Inject; + +import java.io.IOException; +import java.security.Principal; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; + +import static com.facebook.presto.server.security.AuthenticationFilter.withPrincipal; +import static com.facebook.presto.server.security.oauth2.OAuth2Authenticator.extractTokenFromCookie; +import static com.facebook.presto.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static java.util.Objects.requireNonNull; + +public class OAuth2WebUiAuthenticationManager + implements WebUiAuthenticationManager +{ + private static final Logger logger = Logger.get(OAuth2WebUiAuthenticationManager.class); + private final OAuth2Service oAuth2Service; + private final OAuth2Authenticator oAuth2Authenticator; + private final TokenPairSerializer tokenPairSerializer; + private final OAuth2Client client; + private final Optional tokenExpiration; + + @Inject + public OAuth2WebUiAuthenticationManager(OAuth2Service oAuth2Service, OAuth2Authenticator oAuth2Authenticator, TokenPairSerializer tokenPairSerializer, OAuth2Client client, @ForRefreshTokens Optional tokenExpiration) + { + this.oAuth2Service = requireNonNull(oAuth2Service, "oauth2Service is null"); + this.oAuth2Authenticator = requireNonNull(oAuth2Authenticator, "oauth2Authenticator is null"); + this.tokenPairSerializer = requireNonNull(tokenPairSerializer, "tokenPairSerializer is null"); + this.client = requireNonNull(client, "oauth2Client is null"); + this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null"); + } + + public void handleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain nextFilter) + throws IOException, ServletException + { + try { + Principal principal = this.oAuth2Authenticator.authenticate(request); + nextFilter.doFilter(withPrincipal(request, principal), response); + } + catch (AuthenticationException e) { + needAuthentication(request, response); + } + } + + private Optional getTokenPair(HttpServletRequest request) + { + try { + Optional token = extractTokenFromCookie(request); + if (token.isPresent()) { + return Optional.ofNullable(tokenPairSerializer.deserialize(token.get())); + } + else { + return Optional.empty(); + } + } + catch (Exception e) { + logger.error(e, "Exception occurred during token pair deserialization"); + return Optional.empty(); + } + } + + private void needAuthentication(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + Optional tokenPair = getTokenPair(request); + Optional refreshToken = tokenPair.flatMap(TokenPair::getRefreshToken); + if (refreshToken.isPresent()) { + try { + OAuth2Client.Response refreshRes = client.refreshTokens(refreshToken.get()); + String serializeToken = tokenPairSerializer.serialize(TokenPair.fromOAuth2Response(refreshRes)); + UriBuilder builder = getSchemeUriBuilder(request); + Cookie newCookie = NonceCookie.toServletCookie(OAuthWebUiCookie.create(serializeToken, tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(refreshRes.getExpiration()))); + response.addCookie(newCookie); + response.sendRedirect(builder.build().toString()); + } + catch (ChallengeFailedException e) { + logger.error(e, "Token refresh challenge has failed"); + this.startOauth2Challenge(request, response); + } + } + else { + this.startOauth2Challenge(request, response); + } + } + + private void startOauth2Challenge(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + UriBuilder builder = getSchemeUriBuilder(request); + this.oAuth2Service.startOAuth2Challenge(builder.build().resolve(CALLBACK_ENDPOINT), Optional.empty(), response); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuthWebUiCookie.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuthWebUiCookie.java new file mode 100644 index 0000000000000..867e7b60c3dd8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OAuthWebUiCookie.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import jakarta.ws.rs.core.NewCookie; + +import java.time.Instant; +import java.util.Date; + +import static jakarta.ws.rs.core.Cookie.DEFAULT_VERSION; +import static jakarta.ws.rs.core.NewCookie.DEFAULT_MAX_AGE; + +public final class OAuthWebUiCookie +{ + // prefix according to: https://tools.ietf.org/html/draft-ietf-httpbis-rfc6265bis-05#section-4.1.3.1 + public static final String OAUTH2_COOKIE = "__Secure-Presto-OAuth2-Token"; + + public static final String API_PATH = "/"; + + private OAuthWebUiCookie() {} + + public static NewCookie create(String token, Instant tokenExpiration) + { + return new NewCookie( + OAUTH2_COOKIE, + token, + API_PATH, + null, + DEFAULT_VERSION, + null, + DEFAULT_MAX_AGE, + Date.from(tokenExpiration), + true, + true); + } + public static NewCookie delete() + { + return new NewCookie( + OAUTH2_COOKIE, + "delete", + API_PATH, + null, + DEFAULT_VERSION, + null, + 0, + null, + true, + true); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscovery.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscovery.java new file mode 100644 index 0000000000000..1b6bae30cc8c5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscovery.java @@ -0,0 +1,159 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.airlift.log.Logger; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderConfigurationRequest; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; + +import javax.inject.Inject; + +import java.net.URI; +import java.time.Duration; +import java.util.Optional; + +import static com.facebook.airlift.http.client.HttpStatus.OK; +import static com.facebook.airlift.http.client.HttpStatus.REQUEST_TIMEOUT; +import static com.facebook.airlift.http.client.HttpStatus.TOO_MANY_REQUESTS; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.ACCESS_TOKEN_ISSUER; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.AUTH_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.JWKS_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.TOKEN_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.USERINFO_URL; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class OidcDiscovery + implements OAuth2ServerConfigProvider +{ + private static final Logger LOG = Logger.get(OidcDiscovery.class); + + private static final ObjectMapper OBJECT_MAPPER = new JsonObjectMapperProvider().get(); + private final Issuer issuer; + private final Duration discoveryTimeout; + private final boolean userinfoEndpointEnabled; + private final Optional accessTokenIssuer; + private final Optional authUrl; + private final Optional tokenUrl; + private final Optional jwksUrl; + private final Optional userinfoUrl; + private final NimbusHttpClient httpClient; + + @Inject + public OidcDiscovery(OAuth2Config oauthConfig, OidcDiscoveryConfig oidcConfig, NimbusHttpClient httpClient) + { + requireNonNull(oauthConfig, "oauthConfig is null"); + issuer = new Issuer(requireNonNull(oauthConfig.getIssuer(), "issuer is null")); + requireNonNull(oidcConfig, "oidcConfig is null"); + userinfoEndpointEnabled = oidcConfig.isUserinfoEndpointEnabled(); + discoveryTimeout = Duration.ofMillis(requireNonNull(oidcConfig.getDiscoveryTimeout(), "discoveryTimeout is null").toMillis()); + accessTokenIssuer = requireNonNull(oidcConfig.getAccessTokenIssuer(), "accessTokenIssuer is null"); + authUrl = requireNonNull(oidcConfig.getAuthUrl(), "authUrl is null"); + tokenUrl = requireNonNull(oidcConfig.getTokenUrl(), "tokenUrl is null"); + jwksUrl = requireNonNull(oidcConfig.getJwksUrl(), "jwksUrl is null"); + userinfoUrl = requireNonNull(oidcConfig.getUserinfoUrl(), "userinfoUrl is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + } + + @Override + public OAuth2ServerConfig get() + { + return Failsafe.with(new RetryPolicy<>() + .withMaxAttempts(-1) + .withMaxDuration(discoveryTimeout) + .withDelay(Duration.ofSeconds(1)) + .abortOn(IllegalStateException.class) + .onFailedAttempt(attempt -> LOG.debug("OpenID Connect Metadata read failed: %s", attempt.getLastFailure()))) + .get(() -> httpClient.execute(new OIDCProviderConfigurationRequest(issuer), this::parseConfigurationResponse)); + } + + private OAuth2ServerConfig parseConfigurationResponse(HTTPResponse response) + throws ParseException + { + int statusCode = response.getStatusCode(); + if (statusCode != OK.code()) { + // stop on any client errors other than REQUEST_TIMEOUT and TOO_MANY_REQUESTS + if (statusCode < 400 || statusCode >= 500 || statusCode == REQUEST_TIMEOUT.code() || statusCode == TOO_MANY_REQUESTS.code()) { + throw new RuntimeException("Invalid response from OpenID Metadata endpoint: " + statusCode); + } + else { + throw new IllegalStateException(format("Invalid response from OpenID Metadata endpoint. Expected response code to be %s, but was %s", OK.code(), statusCode)); + } + } + return readConfiguration(response.getContent()); + } + + private OAuth2ServerConfig readConfiguration(String body) + throws ParseException + { + OIDCProviderMetadata metadata = OIDCProviderMetadata.parse(body); + checkMetadataState(issuer.equals(metadata.getIssuer()), "The value of the \"issuer\" claim in Metadata document different than the Issuer URL used for the Configuration Request."); + try { + JsonNode metadataJson = OBJECT_MAPPER.readTree(body); + Optional userinfoEndpoint; + if (userinfoEndpointEnabled) { + userinfoEndpoint = getOptionalField("userinfo_endpoint", Optional.ofNullable(metadata.getUserInfoEndpointURI()).map(URI::toString), USERINFO_URL, userinfoUrl); + } + else { + userinfoEndpoint = Optional.empty(); + } + return new OAuth2ServerConfig( + // AD FS server can include "access_token_issuer" field in OpenID Provider Metadata. + // It's not a part of the OIDC standard thus have to be handled separately. + // see: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-oidce/f629647a-4825-465b-80bb-32c7e9cec2c8 + getOptionalField("access_token_issuer", Optional.ofNullable(metadataJson.get("access_token_issuer")).map(JsonNode::textValue), ACCESS_TOKEN_ISSUER, accessTokenIssuer), + getRequiredField("authorization_endpoint", metadata.getAuthorizationEndpointURI(), AUTH_URL, authUrl), + getRequiredField("token_endpoint", metadata.getTokenEndpointURI(), TOKEN_URL, tokenUrl), + getRequiredField("jwks_uri", metadata.getJWKSetURI(), JWKS_URL, jwksUrl), + userinfoEndpoint.map(URI::create)); + } + catch (JsonProcessingException e) { + throw new ParseException("Invalid JSON value", e); + } + } + + private static URI getRequiredField(String metadataField, URI metadataValue, String configurationField, Optional configurationValue) + { + Optional uri = getOptionalField(metadataField, Optional.ofNullable(metadataValue).map(URI::toString), configurationField, configurationValue); + checkMetadataState(uri.isPresent(), "Missing required \"%s\" property.", metadataField); + return URI.create(uri.get()); + } + + private static Optional getOptionalField(String metadataField, Optional metadataValue, String configurationField, Optional configurationValue) + { + if (configurationValue.isPresent()) { + if (!configurationValue.equals(metadataValue)) { + LOG.warn("Overriding \"%s=%s\" from OpenID metadata document with value \"%s=%s\" defined in configuration", + metadataField, metadataValue.orElse(""), configurationField, configurationValue.orElse("")); + } + return configurationValue; + } + return metadataValue; + } + + private static void checkMetadataState(boolean expression, String additionalMessage, String... additionalMessageArgs) + { + checkState(expression, "Invalid response from OpenID Metadata endpoint. " + additionalMessage, (Object[]) additionalMessageArgs); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscoveryConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscoveryConfig.java new file mode 100644 index 0000000000000..d5d29589fc5a4 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/OidcDiscoveryConfig.java @@ -0,0 +1,148 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.units.Duration; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.ACCESS_TOKEN_ISSUER; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.AUTH_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.JWKS_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.TOKEN_URL; +import static com.facebook.presto.server.security.oauth2.StaticOAuth2ServerConfiguration.USERINFO_URL; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class OidcDiscoveryConfig +{ + private Duration discoveryTimeout = new Duration(30, SECONDS); + private boolean userinfoEndpointEnabled = true; + + //TODO Left for backward compatibility, remove after the next release/a couple of releases + private Optional accessTokenIssuer = Optional.empty(); + private Optional authUrl = Optional.empty(); + private Optional tokenUrl = Optional.empty(); + private Optional jwksUrl = Optional.empty(); + private Optional userinfoUrl = Optional.empty(); + + @NotNull + public Duration getDiscoveryTimeout() + { + return discoveryTimeout; + } + + @Config("http-server.authentication.oauth2.oidc.discovery.timeout") + @ConfigDescription("OpenID Connect discovery timeout") + public OidcDiscoveryConfig setDiscoveryTimeout(Duration discoveryTimeout) + { + this.discoveryTimeout = discoveryTimeout; + return this; + } + + public boolean isUserinfoEndpointEnabled() + { + return userinfoEndpointEnabled; + } + + @Config("http-server.authentication.oauth2.oidc.use-userinfo-endpoint") + @ConfigDescription("Use userinfo endpoint from OpenID connect metadata document") + public OidcDiscoveryConfig setUserinfoEndpointEnabled(boolean userinfoEndpointEnabled) + { + this.userinfoEndpointEnabled = userinfoEndpointEnabled; + return this; + } + + @NotNull + @Deprecated + public Optional getAccessTokenIssuer() + { + return accessTokenIssuer; + } + + @Config(ACCESS_TOKEN_ISSUER) + @ConfigDescription("The required issuer for access tokens") + @Deprecated + public OidcDiscoveryConfig setAccessTokenIssuer(String accessTokenIssuer) + { + this.accessTokenIssuer = Optional.ofNullable(accessTokenIssuer); + return this; + } + + @NotNull + @Deprecated + public Optional getAuthUrl() + { + return authUrl; + } + + @Config(AUTH_URL) + @ConfigDescription("URL of the authorization server's authorization endpoint") + @Deprecated + public OidcDiscoveryConfig setAuthUrl(String authUrl) + { + this.authUrl = Optional.ofNullable(authUrl); + return this; + } + + @NotNull + @Deprecated + public Optional getTokenUrl() + { + return tokenUrl; + } + + @Config(TOKEN_URL) + @ConfigDescription("URL of the authorization server's token endpoint") + @Deprecated + public OidcDiscoveryConfig setTokenUrl(String tokenUrl) + { + this.tokenUrl = Optional.ofNullable(tokenUrl); + return this; + } + + @NotNull + @Deprecated + public Optional getJwksUrl() + { + return jwksUrl; + } + + @Config(JWKS_URL) + @ConfigDescription("URL of the authorization server's JWKS (JSON Web Key Set) endpoint") + @Deprecated + public OidcDiscoveryConfig setJwksUrl(String jwksUrl) + { + this.jwksUrl = Optional.ofNullable(jwksUrl); + return this; + } + + @NotNull + @Deprecated + public Optional getUserinfoUrl() + { + return userinfoUrl; + } + + @Config(USERINFO_URL) + @ConfigDescription("URL of the userinfo endpoint") + @Deprecated + public OidcDiscoveryConfig setUserinfoUrl(String userinfoUrl) + { + this.userinfoUrl = Optional.ofNullable(userinfoUrl); + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java new file mode 100644 index 0000000000000..d4e1783c89128 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; +import com.facebook.airlift.units.Duration; +import io.jsonwebtoken.io.Decoders; +import jakarta.validation.constraints.NotEmpty; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.concurrent.TimeUnit.HOURS; + +public class RefreshTokensConfig +{ + private Duration tokenExpiration = Duration.succinctDuration(1, HOURS); + private static final String coordinator = "Presto_coordinator"; + private String issuer = coordinator; + private String audience = coordinator; + + private SecretKey secretKey; + + public Duration getTokenExpiration() + { + return tokenExpiration; + } + + @Config("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout") + @ConfigDescription("Expiration time for issued token. It needs to be equal or lower than duration of refresh token issued by IdP") + public RefreshTokensConfig setTokenExpiration(Duration tokenExpiration) + { + this.tokenExpiration = tokenExpiration; + return this; + } + + @NotEmpty + public String getIssuer() + { + return issuer; + } + + @Config("http-server.authentication.oauth2.refresh-tokens.issued-token.issuer") + @ConfigDescription("Issuer representing this coordinator instance, that will be used in issued token. In addition current Version will be added to it") + public RefreshTokensConfig setIssuer(String issuer) + { + this.issuer = issuer; + return this; + } + + @NotEmpty + public String getAudience() + { + return audience; + } + + @Config("http-server.authentication.oauth2.refresh-tokens.issued-token.audience") + @ConfigDescription("Audience representing this coordinator instance, that will be used in issued token") + public RefreshTokensConfig setAudience(String audience) + { + this.audience = audience; + return this; + } + + @Config("http-server.authentication.oauth2.refresh-tokens.secret-key") + @ConfigDescription("Base64 encoded secret key used to encrypt generated token") + @ConfigSecuritySensitive + public RefreshTokensConfig setSecretKey(String key) + { + if (isNullOrEmpty(key)) { + return this; + } + + secretKey = new SecretKeySpec(Decoders.BASE64.decode(key), "AES"); + return this; + } + + public SecretKey getSecretKey() + { + return secretKey; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticConfigurationProvider.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticConfigurationProvider.java new file mode 100644 index 0000000000000..627c42b3a5814 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticConfigurationProvider.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import javax.inject.Inject; + +import java.net.URI; + +import static java.util.Objects.requireNonNull; + +public class StaticConfigurationProvider + implements OAuth2ServerConfigProvider +{ + private final OAuth2ServerConfig config; + + @Inject + StaticConfigurationProvider(StaticOAuth2ServerConfiguration config) + { + requireNonNull(config, "config is null"); + this.config = new OAuth2ServerConfig( + config.getAccessTokenIssuer(), + URI.create(config.getAuthUrl()), + URI.create(config.getTokenUrl()), + URI.create(config.getJwksUrl()), + config.getUserinfoUrl().map(URI::create)); + } + + @Override + public OAuth2ServerConfig get() + { + return config; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticOAuth2ServerConfiguration.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticOAuth2ServerConfiguration.java new file mode 100644 index 0000000000000..51f116abe09b5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/StaticOAuth2ServerConfiguration.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +public class StaticOAuth2ServerConfiguration +{ + public static final String ACCESS_TOKEN_ISSUER = "http-server.authentication.oauth2.access-token-issuer"; + public static final String AUTH_URL = "http-server.authentication.oauth2.auth-url"; + public static final String TOKEN_URL = "http-server.authentication.oauth2.token-url"; + public static final String JWKS_URL = "http-server.authentication.oauth2.jwks-url"; + public static final String USERINFO_URL = "http-server.authentication.oauth2.userinfo-url"; + + private Optional accessTokenIssuer = Optional.empty(); + private String authUrl; + private String tokenUrl; + private String jwksUrl; + private Optional userinfoUrl = Optional.empty(); + + @NotNull + public Optional getAccessTokenIssuer() + { + return accessTokenIssuer; + } + + @Config(ACCESS_TOKEN_ISSUER) + @ConfigDescription("The required issuer for access tokens") + public StaticOAuth2ServerConfiguration setAccessTokenIssuer(String accessTokenIssuer) + { + this.accessTokenIssuer = Optional.ofNullable(accessTokenIssuer); + return this; + } + + @NotNull + public String getAuthUrl() + { + return authUrl; + } + + @Config(AUTH_URL) + @ConfigDescription("URL of the authorization server's authorization endpoint") + public StaticOAuth2ServerConfiguration setAuthUrl(String authUrl) + { + this.authUrl = authUrl; + return this; + } + + @NotNull + public String getTokenUrl() + { + return tokenUrl; + } + + @Config(TOKEN_URL) + @ConfigDescription("URL of the authorization server's token endpoint") + public StaticOAuth2ServerConfiguration setTokenUrl(String tokenUrl) + { + this.tokenUrl = tokenUrl; + return this; + } + + @NotNull + public String getJwksUrl() + { + return jwksUrl; + } + + @Config(JWKS_URL) + @ConfigDescription("URL of the authorization server's JWKS (JSON Web Key Set) endpoint") + public StaticOAuth2ServerConfiguration setJwksUrl(String jwksUrl) + { + this.jwksUrl = jwksUrl; + return this; + } + + public Optional getUserinfoUrl() + { + return userinfoUrl; + } + + @Config(USERINFO_URL) + @ConfigDescription("URL of the userinfo endpoint") + public StaticOAuth2ServerConfiguration setUserinfoUrl(String userinfoUrl) + { + this.userinfoUrl = Optional.ofNullable(userinfoUrl); + return this; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java new file mode 100644 index 0000000000000..a92e7fdf00e06 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java @@ -0,0 +1,96 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.server.security.oauth2; + +import com.facebook.presto.server.security.oauth2.OAuth2Client.Response; +import jakarta.annotation.Nullable; + +import java.util.Date; +import java.util.Optional; + +import static java.lang.Long.MAX_VALUE; +import static java.util.Objects.requireNonNull; + +public interface TokenPairSerializer +{ + TokenPairSerializer ACCESS_TOKEN_ONLY_SERIALIZER = new TokenPairSerializer() + { + @Override + public TokenPair deserialize(String token) + { + return TokenPair.accessToken(token); + } + + @Override + public String serialize(TokenPair tokenPair) + { + return tokenPair.getAccessToken(); + } + }; + + TokenPair deserialize(String token); + + String serialize(TokenPair tokenPair); + + class TokenPair + { + private final String accessToken; + private final Date expiration; + private final Optional refreshToken; + + private TokenPair(String accessToken, Date expiration, Optional refreshToken) + { + this.accessToken = requireNonNull(accessToken, "accessToken is nul"); + this.expiration = requireNonNull(expiration, "expiration is null"); + this.refreshToken = requireNonNull(refreshToken, "refreshToken is null"); + } + + public static TokenPair accessToken(String accessToken) + { + return new TokenPair(accessToken, new Date(MAX_VALUE), Optional.empty()); + } + + public static TokenPair fromOAuth2Response(Response tokens) + { + requireNonNull(tokens, "tokens is null"); + return new TokenPair(tokens.getAccessToken(), Date.from(tokens.getExpiration()), tokens.getRefreshToken()); + } + + public static TokenPair accessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken) + { + return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken)); + } + + public String getAccessToken() + { + return accessToken; + } + + public Date getExpiration() + { + return expiration; + } + + public Optional getRefreshToken() + { + return refreshToken; + } + + public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken) + { + return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken)); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenRefresher.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenRefresher.java new file mode 100644 index 0000000000000..411103f6f8724 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenRefresher.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.presto.server.security.oauth2.OAuth2Client.Response; +import com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair; + +import java.util.Optional; +import java.util.UUID; + +import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchange.hashAuthId; +import static java.util.Objects.requireNonNull; + +public class TokenRefresher +{ + private final TokenPairSerializer tokenAssembler; + private final OAuth2TokenHandler tokenHandler; + private final OAuth2Client client; + + public TokenRefresher(TokenPairSerializer tokenAssembler, OAuth2TokenHandler tokenHandler, OAuth2Client client) + { + this.tokenAssembler = requireNonNull(tokenAssembler, "tokenAssembler is null"); + this.tokenHandler = requireNonNull(tokenHandler, "tokenHandler is null"); + this.client = requireNonNull(client, "oAuth2Client is null"); + } + + public Optional refreshToken(TokenPair tokenPair) + { + requireNonNull(tokenPair, "tokenPair is null"); + + Optional refreshToken = tokenPair.getRefreshToken(); + if (refreshToken.isPresent()) { + UUID refreshingId = UUID.randomUUID(); + try { + refreshToken(refreshToken.get(), refreshingId); + return Optional.of(refreshingId); + } + // If Refresh token has expired then restart the flow + catch (RuntimeException exception) { + return Optional.empty(); + } + } + return Optional.empty(); + } + + private void refreshToken(String refreshToken, UUID refreshingId) + { + try { + Response response = client.refreshTokens(refreshToken); + String serializedToken = tokenAssembler.serialize(TokenPair.fromOAuth2Response(response)); + tokenHandler.setAccessToken(hashAuthId(refreshingId), serializedToken); + } + catch (ChallengeFailedException e) { + tokenHandler.setTokenExchangeError(hashAuthId(refreshingId), "Token refreshing has failed: " + e.getMessage()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ZstdCodec.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ZstdCodec.java new file mode 100644 index 0000000000000..1065d90416240 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/ZstdCodec.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import io.airlift.compress.zstd.ZstdCompressor; +import io.airlift.compress.zstd.ZstdDecompressor; +import io.jsonwebtoken.CompressionCodec; +import io.jsonwebtoken.CompressionException; + +import static java.lang.Math.toIntExact; +import static java.util.Arrays.copyOfRange; + +public class ZstdCodec + implements CompressionCodec +{ + public static final String CODEC_NAME = "ZSTD"; + + @Override + public String getAlgorithmName() + { + return CODEC_NAME; + } + + @Override + public byte[] compress(byte[] bytes) + throws CompressionException + { + ZstdCompressor compressor = new ZstdCompressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(bytes.length)]; + int outputSize = compressor.compress(bytes, 0, bytes.length, compressed, 0, compressed.length); + return copyOfRange(compressed, 0, outputSize); + } + + @Override + public byte[] decompress(byte[] bytes) + throws CompressionException + { + byte[] output = new byte[toIntExact(ZstdDecompressor.getDecompressedSize(bytes, 0, bytes.length))]; + new ZstdDecompressor().decompress(bytes, 0, bytes.length, output, 0, output.length); + return output; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 0765fbbfd3836..debae20a1aa3c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -34,11 +34,11 @@ import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.ClientRequestFilterModule; +import com.facebook.presto.builtin.tools.WorkerFunctionRegistryTool; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.dispatcher.DispatchManager; import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule; -import com.facebook.presto.eventlistener.EventListenerConfig; import com.facebook.presto.eventlistener.EventListenerManager; import com.facebook.presto.execution.QueryInfo; import com.facebook.presto.execution.QueryManager; @@ -51,6 +51,7 @@ import com.facebook.presto.memory.LocalMemoryManager; import com.facebook.presto.metadata.AllNodes; import com.facebook.presto.metadata.CatalogManager; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.metadata.InternalNode; import com.facebook.presto.metadata.InternalNodeManager; import com.facebook.presto.metadata.Metadata; @@ -70,8 +71,8 @@ import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; import com.facebook.presto.spi.eventlistener.EventListener; +import com.facebook.presto.spi.function.SqlFunction; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; -import com.facebook.presto.spi.security.AccessControl; import com.facebook.presto.split.PageSourceManager; import com.facebook.presto.split.SplitManager; import com.facebook.presto.sql.analyzer.FeaturesConfig; @@ -82,11 +83,10 @@ import com.facebook.presto.sql.planner.NodePartitioningManager; import com.facebook.presto.sql.planner.Plan; import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; -import com.facebook.presto.storage.TempStorageManager; import com.facebook.presto.testing.ProcedureTester; import com.facebook.presto.testing.TestingAccessControlManager; import com.facebook.presto.testing.TestingEventListenerManager; -import com.facebook.presto.testing.TestingTempStorageManager; +import com.facebook.presto.testing.TestingPrestoServerModule; import com.facebook.presto.testing.TestingWarningCollectorModule; import com.facebook.presto.transaction.TransactionManager; import com.facebook.presto.ttl.clusterttlprovidermanagers.ClusterTtlProviderManagerModule; @@ -97,20 +97,19 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.net.HostAndPort; +import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.Scopes; +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.FilterConfig; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; import org.weakref.jmx.guice.MBeanModule; -import javax.annotation.concurrent.GuardedBy; -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; - import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; @@ -148,6 +147,8 @@ public class TestingPrestoServer private final boolean preserveData; private final LifeCycleManager lifeCycleManager; private final PluginManager pluginManager; + private final FunctionAndTypeManager functionAndTypeManager; + private final WorkerFunctionRegistryTool workerFunctionRegistryTool; private final ConnectorManager connectorManager; private final TestingHttpServer server; private final CatalogManager catalogManager; @@ -225,6 +226,11 @@ public TestingPrestoServer(List additionalModules) this(true, ImmutableMap.of(), null, null, new SqlParserOptions(), additionalModules); } + public TestingPrestoServer(Map properties) throws Exception + { + this(true, properties, null, null, new SqlParserOptions(), ImmutableList.of()); + } + public TestingPrestoServer( boolean coordinator, Map properties, @@ -280,6 +286,42 @@ public TestingPrestoServer( List additionalModules, Optional dataDirectory) throws Exception + { + this( + resourceManager, + resourceManagerEnabled, + catalogServer, + catalogServerEnabled, + coordinatorSidecar, + coordinatorSidecarEnabled, + coordinator, + skipLoadingResourceGroupConfigurationManager, + true, + properties, + environment, + discoveryUri, + parserOptions, + additionalModules, + dataDirectory); + } + + public TestingPrestoServer( + boolean resourceManager, + boolean resourceManagerEnabled, + boolean catalogServer, + boolean catalogServerEnabled, + boolean coordinatorSidecar, + boolean coordinatorSidecarEnabled, + boolean coordinator, + boolean skipLoadingResourceGroupConfigurationManager, + boolean loadDefaultSystemAccessControl, + Map properties, + String environment, + URI discoveryUri, + SqlParserOptions parserOptions, + List additionalModules, + Optional dataDirectory) + throws Exception { this.resourceManager = resourceManager; this.catalogServer = catalogServer; @@ -319,18 +361,9 @@ public TestingPrestoServer( .add(new NodeTtlFetcherManagerModule()) .add(new ClusterTtlProviderManagerModule()) .add(new ClientRequestFilterModule()) + .add(new TestingPrestoServerModule(loadDefaultSystemAccessControl)) .add(binder -> { - binder.bind(TestingAccessControlManager.class).in(Scopes.SINGLETON); - binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON); - binder.bind(TestingTempStorageManager.class).in(Scopes.SINGLETON); - binder.bind(AccessControlManager.class).to(TestingAccessControlManager.class).in(Scopes.SINGLETON); - binder.bind(EventListenerManager.class).to(TestingEventListenerManager.class).in(Scopes.SINGLETON); - binder.bind(EventListenerConfig.class).in(Scopes.SINGLETON); - binder.bind(TempStorageManager.class).to(TestingTempStorageManager.class).in(Scopes.SINGLETON); - binder.bind(AccessControl.class).to(AccessControlManager.class).in(Scopes.SINGLETON); binder.bind(ShutdownAction.class).to(TestShutdownAction.class).in(Scopes.SINGLETON); - binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); - binder.bind(ProcedureTester.class).in(Scopes.SINGLETON); binder.bind(RequestBlocker.class).in(Scopes.SINGLETON); newSetBinder(binder, Filter.class, TheServlet.class).addBinding() .to(RequestBlocker.class).in(Scopes.SINGLETON); @@ -368,12 +401,15 @@ public TestingPrestoServer( connectorManager = injector.getInstance(ConnectorManager.class); + functionAndTypeManager = injector.getInstance(FunctionAndTypeManager.class); + workerFunctionRegistryTool = injector.getInstance(WorkerFunctionRegistryTool.class); + server = injector.getInstance(TestingHttpServer.class); catalogManager = injector.getInstance(CatalogManager.class); transactionManager = injector.getInstance(TransactionManager.class); sqlParser = injector.getInstance(SqlParser.class); metadata = injector.getInstance(Metadata.class); - accessControl = injector.getInstance(TestingAccessControlManager.class); + accessControl = (TestingAccessControlManager) injector.getInstance(AccessControlManager.class); procedureTester = injector.getInstance(ProcedureTester.class); splitManager = injector.getInstance(SplitManager.class); pageSourceManager = injector.getInstance(PageSourceManager.class); @@ -502,6 +538,11 @@ private Map getServerProperties( return ImmutableMap.copyOf(serverProperties); } + public void registerWorkerFunctions() + { + functionAndTypeManager.registerWorkerFunctions(workerFunctionRegistryTool.getWorkerFunctions()); + } + @Override public void close() throws IOException @@ -532,11 +573,22 @@ public void installPlugin(Plugin plugin) pluginManager.installPlugin(plugin); } + public void registerWorkerAggregateFunctions(List aggregateFunctions) + { + functionAndTypeManager.registerWorkerAggregateFunctions(aggregateFunctions); + } + public void installCoordinatorPlugin(CoordinatorPlugin plugin) { pluginManager.installCoordinatorPlugin(plugin); } + public void triggerConflictCheckWithBuiltInFunctions() + { + metadata.getFunctionAndTypeManager() + .getBuiltInPluginFunctionNamespaceManager().triggerConflictCheckWithBuiltInFunctions(); + } + public DispatchManager getDispatchManager() { return dispatchManager; @@ -602,6 +654,11 @@ public HostAndPort getHttpsAddress() return HostAndPort.fromParts(httpsUri.getHost(), httpsUri.getPort()); } + public URI getHttpBaseUrl() + { + return server.getHttpServerInfo().getHttpUri(); + } + public CatalogManager getCatalogManager() { return catalogManager; @@ -669,6 +726,11 @@ public NodeManager getPluginNodeManager() return pluginNodeManager; } + public FunctionAndTypeManager getFunctionAndTypeManager() + { + return functionAndTypeManager; + } + public NodePartitioningManager getNodePartitioningManager() { return nodePartitioningManager; diff --git a/presto-main/src/main/java/com/facebook/presto/server/thrift/ThriftTaskUpdateRequestBodyReader.java b/presto-main/src/main/java/com/facebook/presto/server/thrift/ThriftTaskUpdateRequestBodyReader.java index 9bfe517a4feef..142ef0349bc9e 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/thrift/ThriftTaskUpdateRequestBodyReader.java +++ b/presto-main/src/main/java/com/facebook/presto/server/thrift/ThriftTaskUpdateRequestBodyReader.java @@ -17,14 +17,13 @@ import com.facebook.drift.codec.ThriftCodec; import com.facebook.presto.server.TaskUpdateRequest; import com.google.common.io.ByteStreams; - -import javax.inject.Inject; -import javax.ws.rs.Consumes; -import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.MultivaluedMap; -import javax.ws.rs.ext.MessageBodyReader; -import javax.ws.rs.ext.Provider; +import jakarta.inject.Inject; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.ext.MessageBodyReader; +import jakarta.ws.rs.ext.Provider; import java.io.IOException; import java.io.InputStream; diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index 0f7fcf7501963..eb295a80fbade 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -16,6 +16,7 @@ import com.facebook.airlift.http.server.Authenticator; import com.facebook.presto.server.MockHttpServletRequest; import com.facebook.presto.server.security.AuthenticationFilter; +import com.facebook.presto.server.security.DefaultWebUiAuthenticationManager; import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.ClientRequestFilter; @@ -24,10 +25,9 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.servlet.http.HttpServletRequest; import org.testng.annotations.Test; -import javax.servlet.http.HttpServletRequest; - import java.security.Principal; import java.util.Collections; import java.util.List; @@ -112,7 +112,7 @@ private AuthenticationFilter setupAuthenticationFilter(List authenticators = createAuthenticators(); SecurityConfig securityConfig = createSecurityConfig(); - return new AuthenticationFilter(authenticators, securityConfig, clientRequestFilterManager); + return new AuthenticationFilter(authenticators, securityConfig, clientRequestFilterManager, new DefaultWebUiAuthenticationManager()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/dispatcher/TestLocalDispatchQuery.java b/presto-main/src/test/java/com/facebook/presto/dispatcher/TestLocalDispatchQuery.java index 1b54dc1dea4f9..70b85cbcf36e7 100644 --- a/presto-main/src/test/java/com/facebook/presto/dispatcher/TestLocalDispatchQuery.java +++ b/presto-main/src/test/java/com/facebook/presto/dispatcher/TestLocalDispatchQuery.java @@ -14,6 +14,7 @@ package com.facebook.presto.dispatcher; import com.facebook.airlift.node.NodeInfo; +import com.facebook.airlift.units.Duration; import com.facebook.presto.Session; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.cost.HistoryBasedOptimizationConfig; @@ -51,7 +52,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main/src/test/java/com/facebook/presto/eventlistener/TestEventListenerManager.java b/presto-main/src/test/java/com/facebook/presto/eventlistener/TestEventListenerManager.java index cccb271166fb7..f6431d1e4cb40 100644 --- a/presto-main/src/test/java/com/facebook/presto/eventlistener/TestEventListenerManager.java +++ b/presto-main/src/test/java/com/facebook/presto/eventlistener/TestEventListenerManager.java @@ -15,15 +15,20 @@ package com.facebook.presto.eventlistener; import com.facebook.airlift.log.Logger; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.RuntimeStats; import com.facebook.presto.common.plan.PlanCanonicalizationStrategy; import com.facebook.presto.common.resourceGroups.QueryType; import com.facebook.presto.spi.PrestoWarning; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.analyzer.UpdateInfo; +import com.facebook.presto.spi.connector.ConnectorCommitHandle; import com.facebook.presto.spi.eventlistener.CTEInformation; import com.facebook.presto.spi.eventlistener.Column; import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.spi.eventlistener.EventListenerFactory; import com.facebook.presto.spi.eventlistener.OperatorStatistics; +import com.facebook.presto.spi.eventlistener.OutputColumnMetadata; import com.facebook.presto.spi.eventlistener.PlanOptimizerInformation; import com.facebook.presto.spi.eventlistener.QueryCompletedEvent; import com.facebook.presto.spi.eventlistener.QueryContext; @@ -45,7 +50,6 @@ import com.facebook.presto.spi.session.ResourceEstimates; import com.facebook.presto.spi.statistics.PlanStatisticsWithSourceInfo; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.io.File; @@ -185,7 +189,7 @@ public static QueryCompletedEvent createDummyQueryCompletedEvent() Optional prestoSparkExecutionContext = Optional.empty(); Map hboPlanHash = new HashMap<>(); Optional> planIdNodeMap = Optional.ofNullable(new HashMap<>()); - + UpdateInfo updateInfo = new UpdateInfo("dummy-type", "dummy-object"); return new QueryCompletedEvent( metadata, statistics, @@ -213,7 +217,8 @@ public static QueryCompletedEvent createDummyQueryCompletedEvent() windowFunctions, prestoSparkExecutionContext, hboPlanHash, - planIdNodeMap); + planIdNodeMap, + Optional.of(updateInfo.getUpdateObject())); } public static QueryStatistics createDummyQueryStatistics() @@ -221,6 +226,7 @@ public static QueryStatistics createDummyQueryStatistics() Duration cpuTime = Duration.ofMillis(1000); Duration retriedCpuTime = Duration.ofMillis(500); Duration wallTime = Duration.ofMillis(2000); + Duration totalScheduledTime = Duration.ofMillis(2500); Duration waitingForPrerequisitesTime = Duration.ofMillis(300); Duration queuedTime = Duration.ofMillis(1500); Duration waitingForResourcesTime = Duration.ofMillis(600); @@ -256,6 +262,7 @@ public static QueryStatistics createDummyQueryStatistics() cpuTime, retriedCpuTime, wallTime, + totalScheduledTime, waitingForPrerequisitesTime, queuedTime, waitingForResourcesTime, @@ -346,10 +353,10 @@ private static QueryContext createDummyQueryContext() sessionProperties.put("property2", "value2"); ResourceEstimates resourceEstimates = new ResourceEstimates( - Optional.of(new io.airlift.units.Duration(1200, TimeUnit.SECONDS)), - Optional.of(new io.airlift.units.Duration(1200, TimeUnit.SECONDS)), - Optional.of(new io.airlift.units.DataSize(2, DataSize.Unit.GIGABYTE)), - Optional.of(new io.airlift.units.DataSize(2, DataSize.Unit.GIGABYTE))); + Optional.of(new com.facebook.airlift.units.Duration(1200, TimeUnit.SECONDS)), + Optional.of(new com.facebook.airlift.units.Duration(1200, TimeUnit.SECONDS)), + Optional.of(new com.facebook.airlift.units.DataSize(2, DataSize.Unit.GIGABYTE)), + Optional.of(new com.facebook.airlift.units.DataSize(2, DataSize.Unit.GIGABYTE))); return new QueryContext( user, principal, @@ -374,18 +381,21 @@ private static QueryIOMetadata createDummyQueryIoMetadata() List inputs = new ArrayList<>(); QueryInputMetadata queryInputMetadata = getQueryInputMetadata(); inputs.add(queryInputMetadata); - Column column1 = new Column("column1", "int"); - Column column2 = new Column("column2", "varchar"); - Column column3 = new Column("column3", "varchar"); - List columns = Arrays.asList(column1, column2, column3); + OutputColumnMetadata column1 = new OutputColumnMetadata("column1", "int", new HashSet<>()); + OutputColumnMetadata column2 = new OutputColumnMetadata("column2", "varchar", new HashSet<>()); + OutputColumnMetadata column3 = new OutputColumnMetadata("column3", "varchar", new HashSet<>()); + List columns = new ArrayList<>(); + columns.add(column1); + columns.add(column2); + columns.add(column3); QueryOutputMetadata outputMetadata = new QueryOutputMetadata( "dummyCatalog", "dummySchema", "dummyTable", Optional.of("dummyConnectorMetadata"), Optional.of(true), - "dummySerializedCommitOutput", - Optional.of(columns)); + Optional.of(columns), + Optional.of(new TestCommitHandle("", "dummySerializedCommitOutput"))); return new QueryIOMetadata(inputs, Optional.of(outputMetadata)); } @@ -407,7 +417,7 @@ private static QueryInputMetadata getQueryInputMetadata() columns, connectorInfo, Optional.empty(), - serializedCommitOutput); + Optional.of(new TestCommitHandle(serializedCommitOutput, ""))); } private static SplitCompletedEvent createDummySplitCompletedEvent() @@ -467,6 +477,31 @@ private static void tryDeleteFile(Path path) } } + private static class TestCommitHandle + implements ConnectorCommitHandle + { + private final String readOutput; + private final String writeOutput; + + public TestCommitHandle(String readOutput, String writeOutput) + { + this.readOutput = requireNonNull(readOutput, "readOutput is null"); + this.writeOutput = requireNonNull(writeOutput, "writeOutput is null"); + } + + @Override + public String getSerializedCommitOutputForRead(SchemaTableName table) + { + return readOutput; + } + + @Override + public String getSerializedCommitOutputForWrite(SchemaTableName table) + { + return writeOutput; + } + } + private static class TestEventListenerFactory implements EventListenerFactory { diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestSqlTask.java b/presto-main/src/test/java/com/facebook/presto/execution/TestSqlTask.java index 08c5992dd1b7a..63630b7df2d01 100644 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestSqlTask.java +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestSqlTask.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.execution.TestSqlTaskManager.MockExchangeClientSupplier; import com.facebook.presto.execution.buffer.BufferInfo; @@ -40,7 +41,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; @@ -52,6 +52,8 @@ import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.SqlTask.createSqlTask; import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.TASK_FAIR; @@ -65,8 +67,6 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertEquals; diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskWithConnectorTypeSerde.java b/presto-main/src/test/java/com/facebook/presto/execution/TestTaskWithConnectorTypeSerde.java deleted file mode 100644 index 335c0b6d1c7f5..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestTaskWithConnectorTypeSerde.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.execution; - -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.ThriftCodecManager; -import com.facebook.drift.codec.internal.compiler.CompilerThriftCodecFactory; -import com.facebook.drift.codec.internal.reflection.ReflectionThriftCodecFactory; -import com.facebook.drift.protocol.TBinaryProtocol; -import com.facebook.drift.protocol.TMemoryBuffer; -import com.facebook.drift.protocol.TProtocol; -import com.facebook.drift.protocol.TTransport; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.server.ConnectorMetadataUpdateHandleJsonSerde; -import com.facebook.presto.server.thrift.Any; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import com.facebook.presto.testing.TestingHandleResolver; -import com.facebook.presto.testing.TestingMetadataUpdateHandle; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import java.util.function.Function; - -import static org.testng.Assert.assertEquals; - -@Test(singleThreaded = true) -public class TestTaskWithConnectorTypeSerde -{ - private static final ThriftCodecManager COMPILER_CODEC_MANAGER = new ThriftCodecManager(new CompilerThriftCodecFactory(false)); - private static final ThriftCodecManager REFLECTION_CODEC_MANAGER = new ThriftCodecManager(new ReflectionThriftCodecFactory()); - private static final TMemoryBuffer transport = new TMemoryBuffer(100 * 1024); - private static HandleResolver handleResolver; - private ConnectorMetadataUpdateHandleJsonSerde connectorMetadataUpdateHandleJsonSerde; - - @BeforeMethod - public void setUp() - { - handleResolver = getHandleResolver(); - connectorMetadataUpdateHandleJsonSerde = new ConnectorMetadataUpdateHandleJsonSerde(); - } - - @DataProvider - public Object[][] codecCombinations() - { - return new Object[][] { - {COMPILER_CODEC_MANAGER, COMPILER_CODEC_MANAGER}, - {COMPILER_CODEC_MANAGER, REFLECTION_CODEC_MANAGER}, - {REFLECTION_CODEC_MANAGER, COMPILER_CODEC_MANAGER}, - {REFLECTION_CODEC_MANAGER, REFLECTION_CODEC_MANAGER} - }; - } - - @Test(dataProvider = "codecCombinations") - public void testRoundTripSerializeBinaryProtocol(ThriftCodecManager readCodecManager, ThriftCodecManager writeCodecManager) - throws Exception - { - TaskWithConnectorType taskDummy = getRoundTripSerialize(readCodecManager, writeCodecManager, TBinaryProtocol::new); - assertSerde(taskDummy, readCodecManager); - } - - @Test(dataProvider = "codecCombinations") - public void testRoundTripSerializeTCompactProtocol(ThriftCodecManager readCodecManager, ThriftCodecManager writeCodecManager) - throws Exception - { - TaskWithConnectorType taskDummy = getRoundTripSerialize(readCodecManager, writeCodecManager, TBinaryProtocol::new); - assertSerde(taskDummy, readCodecManager); - } - - @Test(dataProvider = "codecCombinations") - public void testRoundTripSerializeTFacebookCompactProtocol(ThriftCodecManager readCodecManager, ThriftCodecManager writeCodecManager) - throws Exception - { - TaskWithConnectorType taskDummy = getRoundTripSerialize(readCodecManager, writeCodecManager, TBinaryProtocol::new); - assertSerde(taskDummy, readCodecManager); - } - - @Test - public void testJsonSerdeRoundTrip() - { - TestingMetadataUpdateHandle metadataUpdateHandle = new TestingMetadataUpdateHandle(200); - byte[] serialized = connectorMetadataUpdateHandleJsonSerde.serialize(metadataUpdateHandle); - TestingMetadataUpdateHandle roundTripMetadataUpdateHandle = (TestingMetadataUpdateHandle) connectorMetadataUpdateHandleJsonSerde.deserialize(TestingMetadataUpdateHandle.class, serialized); - assertEquals(roundTripMetadataUpdateHandle.getValue(), metadataUpdateHandle.getValue()); - } - - private void assertSerde(TaskWithConnectorType taskWithConnectorType, ThriftCodecManager readCodecManager) - { - assertEquals(100, taskWithConnectorType.getValue()); - Any connectorMetadataUpdateHandleAny = taskWithConnectorType.getConnectorMetadataUpdateHandleAny(); - TestingMetadataUpdateHandle connectorMetadataUpdateHandle = (TestingMetadataUpdateHandle) getConnectorSerde(readCodecManager) - .deserialize(handleResolver.getMetadataUpdateHandleClass(connectorMetadataUpdateHandleAny.getId()), - connectorMetadataUpdateHandleAny.getBytes()); - assertEquals(200, connectorMetadataUpdateHandle.getValue()); - } - - private TaskWithConnectorType getRoundTripSerialize(ThriftCodecManager readCodecManager, ThriftCodecManager writeCodecManager, Function protocolFactory) - throws Exception - { - TProtocol protocol = protocolFactory.apply(transport); - ThriftCodec writeCodec = writeCodecManager.getCodec(TaskWithConnectorType.class); - writeCodec.write(getTaskDummy(writeCodecManager), protocol); - ThriftCodec readCodec = readCodecManager.getCodec(TaskWithConnectorType.class); - return readCodec.read(protocol); - } - - private TaskWithConnectorType getTaskDummy(ThriftCodecManager thriftCodecManager) - { - //Connector specific type - TestingMetadataUpdateHandle metadataUpdateHandle = new TestingMetadataUpdateHandle(200); - byte[] serialized = getConnectorSerde(thriftCodecManager).serialize(metadataUpdateHandle); - String id = handleResolver.getId(metadataUpdateHandle); - Any any = new Any(id, serialized); - return new TaskWithConnectorType(100, any); - } - - private static ConnectorTypeSerde getConnectorSerde(ThriftCodecManager thriftCodecManager) - { - return new TestingMetadataUpdateHandleSerde(thriftCodecManager, Protocol.BINARY, 128); - } - - private HandleResolver getHandleResolver() - { - HandleResolver handleResolver = new HandleResolver(); - //Register Connector - handleResolver.addConnectorName("test", new TestingHandleResolver()); - return handleResolver; - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestingMetadataUpdateHandleSerde.java b/presto-main/src/test/java/com/facebook/presto/execution/TestingMetadataUpdateHandleSerde.java deleted file mode 100644 index a43f9160bfc14..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/execution/TestingMetadataUpdateHandleSerde.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.execution; - -import com.facebook.airlift.http.client.thrift.ThriftProtocolException; -import com.facebook.airlift.http.client.thrift.ThriftProtocolUtils; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.ThriftCodecManager; -import com.facebook.drift.transport.netty.codec.Protocol; -import com.facebook.presto.spi.ConnectorMetadataUpdateHandle; -import com.facebook.presto.spi.ConnectorTypeSerde; -import io.airlift.slice.DynamicSliceOutput; -import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; - -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -public class TestingMetadataUpdateHandleSerde - implements ConnectorTypeSerde -{ - private final ThriftCodecManager thriftCodecManager; - private final Protocol thriftProtocol; - private final int bufferSize; - - public TestingMetadataUpdateHandleSerde( - ThriftCodecManager thriftCodecManager, - Protocol thriftProtocol, - int bufferSize) - { - this.thriftCodecManager = requireNonNull(thriftCodecManager, "thriftCodecManager is null"); - this.thriftProtocol = requireNonNull(thriftProtocol, "thriftProtocol is null"); - this.bufferSize = bufferSize; - } - - @Override - public byte[] serialize(ConnectorMetadataUpdateHandle value) - { - ThriftCodec codec = thriftCodecManager.getCodec(value.getClass()); - SliceOutput dynamicSliceOutput = new DynamicSliceOutput(bufferSize); - try { - ThriftProtocolUtils.write(value, codec, thriftProtocol, dynamicSliceOutput); - return dynamicSliceOutput.slice().getBytes(); - } - catch (ThriftProtocolException e) { - throw new IllegalArgumentException(format("%s could not be converted to Thrift", value.getClass().getName()), e); - } - } - - @Override - public ConnectorMetadataUpdateHandle deserialize(Class connectorTypeClass, byte[] bytes) - { - try { - ThriftCodec codec = thriftCodecManager.getCodec(connectorTypeClass); - return ThriftProtocolUtils.read(codec, thriftProtocol, Slices.wrappedBuffer(bytes).getInput()); - } - catch (ThriftProtocolException e) { - throw new IllegalArgumentException(format("Invalid Thrift bytes for %s", connectorTypeClass), e); - } - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java index 3c07fc83cd5c5..88814b2bc2363 100644 --- a/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java +++ b/presto-main/src/test/java/com/facebook/presto/failureDetector/TestHeartbeatFailureDetector.java @@ -32,11 +32,10 @@ import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; import org.testng.annotations.Test; -import javax.ws.rs.GET; -import javax.ws.rs.Path; - import java.net.SocketTimeoutException; import java.net.URI; diff --git a/presto-main/src/test/java/com/facebook/presto/memory/TestHighMemoryTaskKiller.java b/presto-main/src/test/java/com/facebook/presto/memory/TestHighMemoryTaskKiller.java index c809c5de26445..36fc388e00ace 100644 --- a/presto-main/src/test/java/com/facebook/presto/memory/TestHighMemoryTaskKiller.java +++ b/presto-main/src/test/java/com/facebook/presto/memory/TestHighMemoryTaskKiller.java @@ -15,6 +15,7 @@ import com.facebook.airlift.stats.CounterStat; import com.facebook.airlift.stats.TestingGcMonitor; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.execution.SqlTask; import com.facebook.presto.execution.SqlTaskExecutionFactory; @@ -40,7 +41,6 @@ import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.collect.ListMultimap; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.net.URI; @@ -52,6 +52,8 @@ import static com.facebook.airlift.concurrent.Threads.threadsNamed; import static com.facebook.airlift.json.JsonCodec.listJsonCodec; +import static com.facebook.airlift.units.DataSize.Unit.GIGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.execution.SqlTask.createSqlTask; import static com.facebook.presto.execution.TaskManagerConfig.TaskPriorityTracking.TASK_FAIR; @@ -62,8 +64,6 @@ import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; diff --git a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java index b2ae47bdba249..84c9b438e0e88 100644 --- a/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java +++ b/presto-main/src/test/java/com/facebook/presto/metadata/TestDiscoveryNodeManager.java @@ -28,11 +28,10 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import javax.annotation.concurrent.GuardedBy; - import java.net.URI; import java.util.List; import java.util.Optional; @@ -170,7 +169,7 @@ public void testGetAllNodesForWorkerNode() AllNodes allNodes = manager.getAllNodes(); Set activeNodes = allNodes.getActiveNodes(); - assertEqualsIgnoreOrder(activeNodes, ImmutableSet.of(resourceManager, catalogServer, coordinatorSidecar)); + assertEqualsIgnoreOrder(activeNodes, ImmutableSet.of(resourceManager, catalogServer)); for (InternalNode actual : activeNodes) { for (InternalNode expected : this.activeNodes) { @@ -181,7 +180,7 @@ public void testGetAllNodesForWorkerNode() assertEqualsIgnoreOrder(activeNodes, manager.getNodes(ACTIVE)); Set inactiveNodes = allNodes.getInactiveNodes(); - assertEqualsIgnoreOrder(inactiveNodes, ImmutableSet.of(inActiveResourceManager, inActiveCatalogServer, inActiveCoordinatorSidecar)); + assertEqualsIgnoreOrder(inactiveNodes, ImmutableSet.of(inActiveResourceManager, inActiveCatalogServer)); for (InternalNode actual : inactiveNodes) { for (InternalNode expected : this.inactiveNodes) { @@ -272,7 +271,7 @@ public void testGetAllNodesForCoordinatorSidecar() AllNodes allNodes = manager.getAllNodes(); Set activeNodes = allNodes.getActiveNodes(); - assertEqualsIgnoreOrder(activeNodes, this.activeNodes); + assertEqualsIgnoreOrder(activeNodes, ImmutableSet.of(resourceManager, catalogServer)); for (InternalNode actual : activeNodes) { for (InternalNode expected : this.activeNodes) { @@ -283,7 +282,7 @@ public void testGetAllNodesForCoordinatorSidecar() assertEqualsIgnoreOrder(activeNodes, manager.getNodes(ACTIVE)); Set inactiveNodes = allNodes.getInactiveNodes(); - assertEqualsIgnoreOrder(inactiveNodes, this.inactiveNodes); + assertEqualsIgnoreOrder(inactiveNodes, ImmutableSet.of(inActiveResourceManager, inActiveCatalogServer)); for (InternalNode actual : inactiveNodes) { for (InternalNode expected : this.inactiveNodes) { @@ -299,11 +298,15 @@ public void testGetAllNodesForCoordinatorSidecar() } @Test - public void testGetCurrentNode() + public void testNodesVisibleToWorkerNode() { DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, workerNodeInfo, new NoOpFailureDetector(), Optional.empty(), expectedVersion, testHttpClient, new TestingDriftClient<>(), internalCommunicationConfig); try { assertEquals(manager.getCurrentNode(), workerNode1); + assertEquals(manager.getCatalogServers(), ImmutableSet.of(catalogServer)); + assertEquals(manager.getResourceManagers(), ImmutableSet.of(resourceManager)); + assertEquals(manager.getCoordinatorSidecars(), ImmutableSet.of()); + assertEquals(manager.getCoordinators(), ImmutableSet.of()); } finally { manager.stop(); @@ -347,11 +350,14 @@ public void testGetCatalogServers() } @Test - public void testGetCoordinatorSidecar() + public void testNodesVisibleToCoordinatorSidecar() { DiscoveryNodeManager manager = new DiscoveryNodeManager(selector, coordinatorSidecarNodeInfo, new NoOpFailureDetector(), Optional.of(host -> false), expectedVersion, testHttpClient, new TestingDriftClient<>(), internalCommunicationConfig); try { - assertEquals(manager.getCoordinatorSidecars(), ImmutableSet.of(coordinatorSidecar)); + assertEquals(manager.getCatalogServers(), ImmutableSet.of(catalogServer)); + assertEquals(manager.getResourceManagers(), ImmutableSet.of(resourceManager)); + assertEquals(manager.getCoordinatorSidecars(), ImmutableSet.of()); + assertEquals(manager.getCoordinators(), ImmutableSet.of()); } finally { manager.stop(); diff --git a/presto-main/src/test/java/com/facebook/presto/operator/MockExchangeRequestProcessor.java b/presto-main/src/test/java/com/facebook/presto/operator/MockExchangeRequestProcessor.java index 2b1eb3daecb13..d9a42d0967f40 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/MockExchangeRequestProcessor.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/MockExchangeRequestProcessor.java @@ -18,6 +18,7 @@ import com.facebook.airlift.http.client.Response; import com.facebook.airlift.http.client.testing.TestingHttpClient; import com.facebook.airlift.http.client.testing.TestingResponse; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.client.PrestoHeaders; import com.facebook.presto.common.Page; import com.facebook.presto.execution.buffer.BufferResult; @@ -29,7 +30,6 @@ import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableListMultimap; import io.airlift.slice.DynamicSliceOutput; -import io.airlift.units.DataSize; import java.net.URI; import java.util.ArrayList; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeClient.java b/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeClient.java index ea201f2d9a451..c6d23fc85a323 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeClient.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeClient.java @@ -16,6 +16,8 @@ import com.facebook.airlift.http.client.Request; import com.facebook.airlift.http.client.Response; import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.block.BlockAssertions; import com.facebook.presto.common.Page; import com.facebook.presto.execution.TaskId; @@ -27,8 +29,6 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.UncheckedTimeoutException; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -48,6 +48,8 @@ import static com.facebook.airlift.concurrent.MoreFutures.tryGetFutureValue; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.DataSize.Unit.BYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.common.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static com.facebook.presto.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -55,8 +57,6 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static com.google.common.util.concurrent.Uninterruptibles.awaitUninterruptibly; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeOperator.java b/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeOperator.java index e17171967246f..458d2cb19311e 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeOperator.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestExchangeOperator.java @@ -15,6 +15,8 @@ import com.facebook.airlift.http.client.HttpClient; import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.CompressionCodec; import com.facebook.presto.common.Page; import com.facebook.presto.common.type.Type; @@ -32,8 +34,6 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; @@ -47,13 +47,13 @@ import java.util.concurrent.TimeUnit; import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static com.facebook.presto.operator.PageAssertions.assertPageEquals; import static com.facebook.presto.operator.TestingTaskBuffer.PAGE; import static com.facebook.presto.testing.TestingTaskContext.createTaskContext; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static java.util.concurrent.Executors.newScheduledThreadPool; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestPageBufferClient.java b/presto-main/src/test/java/com/facebook/presto/operator/TestPageBufferClient.java index 8a4dc445561c8..c0a3f829cc147 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestPageBufferClient.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestPageBufferClient.java @@ -19,15 +19,15 @@ import com.facebook.airlift.http.client.testing.TestingHttpClient; import com.facebook.airlift.http.client.testing.TestingResponse; import com.facebook.airlift.testing.TestingTicker; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.DataSize.Unit; +import com.facebook.airlift.units.Duration; import com.facebook.presto.common.Page; import com.facebook.presto.operator.PageBufferClient.ClientCallback; import com.facebook.presto.spi.HostAddress; import com.facebook.presto.spi.page.PagesSerde; import com.facebook.presto.spi.page.SerializedPage; import com.google.common.collect.ImmutableListMultimap; -import io.airlift.units.DataSize; -import io.airlift.units.DataSize.Unit; -import io.airlift.units.Duration; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/TestingExchangeHttpClientHandler.java b/presto-main/src/test/java/com/facebook/presto/operator/TestingExchangeHttpClientHandler.java index 3636a68f77b26..5e10d115a3cef 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/TestingExchangeHttpClientHandler.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/TestingExchangeHttpClientHandler.java @@ -33,8 +33,8 @@ import static com.facebook.presto.client.PrestoHeaders.PRESTO_PAGE_TOKEN; import static com.facebook.presto.client.PrestoHeaders.PRESTO_TASK_INSTANCE_ID; import static com.facebook.presto.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; import static java.util.Objects.requireNonNull; -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; import static org.testng.Assert.assertEquals; public class TestingExchangeHttpClientHandler diff --git a/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClient.java b/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClient.java new file mode 100644 index 0000000000000..061615cf37fcd --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClient.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.remotetask; + +import com.facebook.airlift.http.client.HttpClient.HttpResponseFuture; +import com.facebook.airlift.http.client.Request; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.execution.TaskStatus; +import com.facebook.presto.server.remotetask.HttpClientConnectionPoolStats; +import com.facebook.presto.server.remotetask.HttpClientStats; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClient; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig; +import com.facebook.presto.server.smile.AdaptingJsonResponseHandler; +import com.facebook.presto.server.smile.JsonResponseWrapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.HttpProtocol; +import reactor.netty.http.server.HttpServer; + +import java.net.URI; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static com.facebook.airlift.http.client.Request.Builder.prepareGet; +import static com.facebook.presto.execution.TaskState.PLANNED; +import static com.facebook.presto.server.RequestHelpers.getJsonTransportBuilder; +import static com.facebook.presto.server.smile.AdaptingJsonResponseHandler.createAdaptingJsonResponseHandler; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; + +public class TestReactorNettyHttpClient +{ + private static DisposableServer server; + private static ReactorNettyHttpClient reactorNettyHttpClient; + private static JsonCodec taskStatusCodec; + + private static final TaskStatus TEST_TASK_STATUS_HTTP11 = new TaskStatus( + 11111L, + 99999L, + 123, + PLANNED, + URI.create("http://localhost:8080/v1/task/1234"), + ImmutableSet.of(), + ImmutableList.of(), + 0, + 0, + 0.0, + false, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0L, + 0L); + + @BeforeClass + public static void setUp() + { + taskStatusCodec = new JsonCodecFactory(new JsonObjectMapperProvider()).jsonCodec(TaskStatus.class); + server = HttpServer.create() + .port(8080) + .protocol(HttpProtocol.HTTP11) + .route(routes -> + routes.get("/v1/task/1234/status", (request, response) -> + response.header("Content-Type", "application/json").sendString(Mono.just(taskStatusCodec.toJson(TEST_TASK_STATUS_HTTP11))))) + .bindNow(); + ReactorNettyHttpClientConfig reactorNettyHttpClientConfig = new ReactorNettyHttpClientConfig() + .setRequestTimeout(new Duration(30, TimeUnit.SECONDS)) + .setConnectTimeout(new Duration(30, TimeUnit.SECONDS)); + reactorNettyHttpClient = new ReactorNettyHttpClient(reactorNettyHttpClientConfig, new HttpClientConnectionPoolStats(), new HttpClientStats()); + } + + @AfterClass + public static void tearDown() + { + server.disposeNow(); + } + + @Test + public void testGetTaskStatus() + throws ExecutionException, InterruptedException + { + AdaptingJsonResponseHandler responseHandler = createAdaptingJsonResponseHandler(taskStatusCodec); + Request request = getJsonTransportBuilder(prepareGet()).setUri(uriBuilderFrom(TEST_TASK_STATUS_HTTP11.getSelf()).appendPath("status").build()).build(); + + HttpResponseFuture response = reactorNettyHttpClient.executeAsync(request, responseHandler); + TaskStatus a = (TaskStatus) ((JsonResponseWrapper) response.get()).getValue(); + + assertEquals(a.getState(), TEST_TASK_STATUS_HTTP11.getState()); + assertEquals(a.getTaskInstanceIdLeastSignificantBits(), TEST_TASK_STATUS_HTTP11.getTaskInstanceIdLeastSignificantBits()); + assertEquals(a.getTaskInstanceIdMostSignificantBits(), TEST_TASK_STATUS_HTTP11.getTaskInstanceIdMostSignificantBits()); + assertEquals(a.getVersion(), TEST_TASK_STATUS_HTTP11.getVersion()); + assertEquals(a.getSelf(), TEST_TASK_STATUS_HTTP11.getSelf()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClientConfig.java b/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClientConfig.java new file mode 100644 index 0000000000000..7ba1311d25c11 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/remotetask/TestReactorNettyHttpClientConfig.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.remotetask; + +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; +import com.facebook.presto.server.remotetask.ReactorNettyHttpClientConfig; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static com.facebook.airlift.units.DataSize.Unit.KILOBYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestReactorNettyHttpClientConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(ReactorNettyHttpClientConfig.class) + .setReactorNettyHttpClientEnabled(false) + .setHttpsEnabled(false) + .setMinConnections(50) + .setMaxConnections(100) + .setMaxStreamPerChannel(100) + .setSelectorThreadCount(Runtime.getRuntime().availableProcessors()) + .setEventLoopThreadCount(Runtime.getRuntime().availableProcessors()) + .setConnectTimeout(new Duration(10, SECONDS)) + .setRequestTimeout(new Duration(10, SECONDS)) + .setMaxIdleTime(new Duration(0, SECONDS)) + .setEvictBackgroundTime(new Duration(0, SECONDS)) + .setPendingAcquireTimeout(new Duration(0, SECONDS)) + .setMaxInitialWindowSize(new DataSize(0, MEGABYTE)) + .setMaxFrameSize(new DataSize(0, MEGABYTE)) + .setKeyStorePath(null) + .setKeyStorePassword(null) + .setTrustStorePath(null) + .setCipherSuites(null) + .setHttp2CompressionEnabled(false) + .setPayloadSizeThreshold(new DataSize(50, KILOBYTE)) + .setCompressionSavingThreshold(0.1) + .setTcpBufferSize(new DataSize(512, KILOBYTE)) + .setWriteBufferWaterMarkHigh(new DataSize(512, KILOBYTE)) + .setWriteBufferWaterMarkLow(new DataSize(256, KILOBYTE)) + .setHttp2ConnectionPoolStatsTrackingEnabled(false) + .setHttp2ClientStatsTrackingEnabled(false) + .setChannelOptionSoKeepAliveEnabled(true) + .setChannelOptionTcpNoDelayEnabled(true)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("reactor.netty-http-client-enabled", "true") + .put("reactor.https-enabled", "true") + .put("reactor.min-connections", "100") + .put("reactor.max-connections", "500") + .put("reactor.max-stream-per-channel", "300") + .put("reactor.selector-thread-count", "50") + .put("reactor.event-loop-thread-count", "150") + .put("reactor.connect-timeout", "2s") + .put("reactor.request-timeout", "1s") + .put("reactor.max-idle-time", "120s") + .put("reactor.evict-background-time", "120s") + .put("reactor.pending-acquire-timeout", "10s") + .put("reactor.max-initial-window-size", "10MB") + .put("reactor.max-frame-size", "4MB") + .put("reactor.keystore-path", "/var/abc/def/presto.jks") + .put("reactor.truststore-path", "/var/abc/def/presto.jks") + .put("reactor.keystore-password", "password") + .put("reactor.cipher-suites", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") + .put("reactor.enable-http2-compression", "true") + .put("reactor.payload-compression-threshold", "10kB") + .put("reactor.compression-ratio-threshold", "0.2") + .put("reactor.tcp-buffer-size", "256kB") + .put("reactor.tcp-write-buffer-water-mark-high", "256kB") + .put("reactor.tcp-write-buffer-water-mark-low", "128kB") + .put("reactor.enable-http2-connection-pool-stats-tracking", "true") + .put("reactor.enable-http2-client-stats-tracking", "true") + .put("reactor.channel-option-so-keep-alive", "false") + .put("reactor.channel-option-tcp-no-delay", "false") + .build(); + + ReactorNettyHttpClientConfig expected = new ReactorNettyHttpClientConfig() + .setReactorNettyHttpClientEnabled(true) + .setHttpsEnabled(true) + .setMinConnections(100) + .setMaxConnections(500) + .setMaxStreamPerChannel(300) + .setSelectorThreadCount(50) + .setEventLoopThreadCount(150) + .setConnectTimeout(new Duration(2, SECONDS)) + .setRequestTimeout(new Duration(1, SECONDS)) + .setMaxIdleTime(new Duration(120, SECONDS)) + .setEvictBackgroundTime(new Duration(120, SECONDS)) + .setPendingAcquireTimeout(new Duration(10, SECONDS)) + .setMaxInitialWindowSize(new DataSize(10, MEGABYTE)) // 10MB + .setMaxFrameSize(new DataSize(4, MEGABYTE)) // 4MB + .setKeyStorePath("/var/abc/def/presto.jks") + .setTrustStorePath("/var/abc/def/presto.jks") + .setKeyStorePassword("password") + .setCipherSuites("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") + .setHttp2CompressionEnabled(true) + .setPayloadSizeThreshold(new DataSize(10, KILOBYTE)) + .setCompressionSavingThreshold(0.2) + .setTcpBufferSize(new DataSize(256, KILOBYTE)) + .setWriteBufferWaterMarkHigh(new DataSize(256, KILOBYTE)) + .setWriteBufferWaterMarkLow(new DataSize(128, KILOBYTE)) + .setHttp2ConnectionPoolStatsTrackingEnabled(true) + .setHttp2ClientStatsTrackingEnabled(true) + .setChannelOptionSoKeepAliveEnabled(false) + .setChannelOptionTcpNoDelayEnabled(false); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java b/presto-main/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java similarity index 97% rename from presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java rename to presto-main/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java index ab28b1bcc895a..d0ba55a09151d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java +++ b/presto-main/src/test/java/com/facebook/presto/resourcemanager/TestResourceManagerClusterStatusSender.java @@ -13,6 +13,8 @@ */ package com.facebook.presto.resourcemanager; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.presto.client.NodeVersion; import com.facebook.presto.execution.MockManagedQueryExecution; import com.facebook.presto.execution.resourceGroups.NoOpResourceGroupManager; @@ -23,8 +25,6 @@ import com.facebook.presto.server.ServerConfig; import com.facebook.presto.spi.ConnectorId; import com.google.common.collect.ImmutableMap; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; @@ -32,7 +32,7 @@ import java.net.URI; import java.util.OptionalInt; -import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static java.lang.String.format; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.MILLISECONDS; diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java b/presto-main/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java similarity index 79% rename from presto-main-base/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java rename to presto-main/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java index 1d0fee4e5a8ea..600dfea77e704 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java +++ b/presto-main/src/test/java/com/facebook/presto/server/MockContainerRequestContext.java @@ -14,15 +14,14 @@ package com.facebook.presto.server; import com.google.common.collect.ListMultimap; - -import javax.ws.rs.container.ContainerRequestContext; -import javax.ws.rs.core.Cookie; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.MultivaluedMap; -import javax.ws.rs.core.Request; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.SecurityContext; -import javax.ws.rs.core.UriInfo; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.Cookie; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.MultivaluedMap; +import jakarta.ws.rs.core.Request; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.SecurityContext; +import jakarta.ws.rs.core.UriInfo; import java.io.InputStream; import java.net.URI; @@ -31,6 +30,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.Predicate; public class MockContainerRequestContext implements ContainerRequestContext @@ -50,6 +50,12 @@ public Object getProperty(String name) return null; } + @Override + public boolean hasProperty(String name) + { + return ContainerRequestContext.super.hasProperty(name); + } + @Override public Collection getPropertyNames() { @@ -105,6 +111,18 @@ public String getHeaderString(String name) return null; } + @Override + public boolean containsHeaderString(String s, String s1, Predicate predicate) + { + return headers.containsKey(s) && headers.get(s).stream().anyMatch(predicate); + } + + @Override + public boolean containsHeaderString(String name, Predicate valuePredicate) + { + return ContainerRequestContext.super.containsHeaderString(name, valuePredicate); + } + @Override public Date getDate() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java similarity index 84% rename from presto-main-base/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java rename to presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java index 000eda9b853e0..2033b62f88262 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java +++ b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java @@ -16,22 +16,24 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; - -import javax.servlet.AsyncContext; -import javax.servlet.DispatcherType; -import javax.servlet.RequestDispatcher; -import javax.servlet.ServletContext; -import javax.servlet.ServletInputStream; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; -import javax.servlet.http.HttpUpgradeHandler; -import javax.servlet.http.Part; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.RequestDispatcher; +import jakarta.servlet.ServletConnection; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpSession; +import jakarta.servlet.http.HttpUpgradeHandler; +import jakarta.servlet.http.Part; +import jakarta.ws.rs.core.UriBuilder; import java.io.BufferedReader; +import java.net.URI; import java.security.Principal; import java.util.Collection; import java.util.Enumeration; @@ -50,12 +52,14 @@ public class MockHttpServletRequest private final ListMultimap headers; private final String remoteAddress; private final Map attributes; + private final String requestUrl; public MockHttpServletRequest(ListMultimap headers, String remoteAddress, Map attributes) { this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null")); this.remoteAddress = requireNonNull(remoteAddress, "remoteAddress is null"); this.attributes = new HashMap<>(requireNonNull(attributes, "attributes is null")); + this.requestUrl = null; } public MockHttpServletRequest(ListMultimap headers) @@ -64,6 +68,14 @@ public MockHttpServletRequest(ListMultimap headers) this(headers, DEFAULT_ADDRESS, ImmutableMap.of()); } + public MockHttpServletRequest(ListMultimap headers, String remoteAddress, String requestUrl) + { + this.headers = ImmutableListMultimap.copyOf(requireNonNull(headers, "headers is null")); + this.remoteAddress = requireNonNull(remoteAddress, "remoteAddress is null"); + this.requestUrl = requireNonNull(requestUrl, "requestUrl is null"); + this.attributes = ImmutableMap.of(); + } + @Override public String getAuthType() { @@ -145,7 +157,12 @@ public String getContextPath() @Override public String getQueryString() { - throw new UnsupportedOperationException(); + if (this.requestUrl == null) { + throw new UnsupportedOperationException(); + } + URI uri = UriBuilder.fromUri(this.requestUrl).build(); + + return uri.getQuery(); } @Override @@ -181,7 +198,10 @@ public String getRequestURI() @Override public StringBuffer getRequestURL() { - throw new UnsupportedOperationException(); + if (this.requestUrl == null) { + throw new UnsupportedOperationException(); + } + return new StringBuffer(this.requestUrl); } @Override @@ -226,12 +246,6 @@ public boolean isRequestedSessionIdFromURL() throw new UnsupportedOperationException(); } - @Override - public boolean isRequestedSessionIdFromUrl() - { - throw new UnsupportedOperationException(); - } - @Override public boolean authenticate(HttpServletResponse response) { @@ -257,7 +271,7 @@ public Collection getParts() } @Override - public Part getPart(String name) + public Part getPart(String s) { throw new UnsupportedOperationException(); } @@ -343,7 +357,12 @@ public String getProtocol() @Override public String getScheme() { - throw new UnsupportedOperationException(); + if (this.requestUrl == null) { + throw new UnsupportedOperationException(); + } + URI uri = UriBuilder.fromUri(this.requestUrl).build(); + + return uri.getScheme(); } @Override @@ -412,12 +431,6 @@ public RequestDispatcher getRequestDispatcher(String path) throw new UnsupportedOperationException(); } - @Override - public String getRealPath(String path) - { - throw new UnsupportedOperationException(); - } - @Override public int getRemotePort() { @@ -485,4 +498,22 @@ public DispatcherType getDispatcherType() { throw new UnsupportedOperationException(); } + + @Override + public String getRequestId() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getProtocolRequestId() + { + throw new UnsupportedOperationException(); + } + + @Override + public ServletConnection getServletConnection() + { + throw new UnsupportedOperationException(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestFailureDetectorConfig.java b/presto-main/src/test/java/com/facebook/presto/server/TestFailureDetectorConfig.java index 1bc543a0f1cb2..eee32c75ad13b 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestFailureDetectorConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestFailureDetectorConfig.java @@ -14,9 +14,9 @@ package com.facebook.presto.server; import com.facebook.airlift.configuration.testing.ConfigAssertions; +import com.facebook.airlift.units.Duration; import com.facebook.presto.failureDetector.FailureDetectorConfig; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestGenerateTokenFilter.java b/presto-main/src/test/java/com/facebook/presto/server/TestGenerateTokenFilter.java index efe96e18b25f7..a5c1098221ee2 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestGenerateTokenFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestGenerateTokenFilter.java @@ -23,15 +23,14 @@ import com.google.inject.Binder; import com.google.inject.Key; import com.google.inject.Module; +import jakarta.inject.Qualifier; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.Path; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import javax.inject.Qualifier; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.Path; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.Target; @@ -44,8 +43,8 @@ import static com.facebook.airlift.jaxrs.JaxrsBinder.jaxrsBinder; import static com.facebook.airlift.testing.Assertions.assertInstanceOf; import static com.facebook.airlift.testing.Closeables.closeQuietly; +import static jakarta.servlet.http.HttpServletResponse.SC_OK; import static java.lang.annotation.RetentionPolicy.RUNTIME; -import static javax.servlet.http.HttpServletResponse.SC_OK; import static org.testng.Assert.assertEquals; @Test(singleThreaded = true) diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java index 221950a2de9d3..b771190025c8f 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionContext.java @@ -28,11 +28,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.WebApplicationException; import org.testng.annotations.Test; -import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.WebApplicationException; - import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.util.EnumSet; diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestParseToJsonData.java b/presto-main/src/test/java/com/facebook/presto/server/TestParseToJsonData.java similarity index 100% rename from presto-main-base/src/test/java/com/facebook/presto/server/TestParseToJsonData.java rename to presto-main/src/test/java/com/facebook/presto/server/TestParseToJsonData.java diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java b/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java index 48f61396de1a9..bcde6328970ff 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestQuerySessionSupplier.java @@ -29,10 +29,9 @@ import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import jakarta.servlet.http.HttpServletRequest; import org.testng.annotations.Test; -import javax.servlet.http.HttpServletRequest; - import java.util.Locale; import static com.facebook.airlift.json.JsonCodec.jsonCodec; diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java index c2ccda8fd46b5..b841fb3b8e09a 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java @@ -65,6 +65,7 @@ import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; import static com.facebook.presto.client.PrestoHeaders.PRESTO_PREPARED_STATEMENT; +import static com.facebook.presto.client.PrestoHeaders.PRESTO_RETRY_QUERY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_SCHEMA; import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION; import static com.facebook.presto.client.PrestoHeaders.PRESTO_SESSION_FUNCTION; @@ -78,15 +79,20 @@ import static com.facebook.presto.server.TestHttpRequestSessionContext.createSqlFunctionIdAdd; import static com.facebook.presto.server.TestHttpRequestSessionContext.urlEncode; import static com.facebook.presto.spi.StandardErrorCode.INCOMPATIBLE_CLIENT; +import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; import static com.facebook.presto.spi.page.PagesSerdeUtil.readSerializedPage; +import static jakarta.ws.rs.core.HttpHeaders.CACHE_CONTROL; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static jakarta.ws.rs.core.Response.Status.OK; +import static java.lang.Integer.parseInt; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; -import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; -import static javax.ws.rs.core.MediaType.APPLICATION_JSON; -import static javax.ws.rs.core.Response.Status.OK; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; @Test(singleThreaded = true) @@ -155,10 +161,7 @@ public void testServerStarts() public void testBinaryResults() { // start query - URI uri = HttpUriBuilder.uriBuilderFrom(server.getBaseUrl()) - .replacePath("/v1/statement") - .replaceParameter("binaryResults", "true") - .build(); + URI uri = buildStatementUri(true); Request request = preparePost() .setUri(uri) .setBodyGenerator(createStaticBodyGenerator("show catalogs", UTF_8)) @@ -425,6 +428,327 @@ public void testStatusPing() assertEquals(response.getHeader(CONTENT_TYPE), APPLICATION_JSON, "Content Type"); } + @Test + public void testCacheControlHeaderExists() + { + Request request = preparePost() + .setUri(uriFor("/v1/statement")) + .setBodyGenerator(createStaticBodyGenerator("show catalogs", UTF_8)) + .setHeader(PRESTO_USER, "user") + .build(); + + JsonResponse initResponse = client.execute(request, createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); + + String initHeader = initResponse.getHeader(CACHE_CONTROL); + assertNotNull(initHeader); + assertTrue(initHeader.contains("max-age")); + + int initAge = parseInt(initHeader.substring(initHeader.indexOf("=") + 1)); + assertTrue(initAge >= 0); + + JsonResponse queryResults = initResponse; + while (queryResults.getValue().getNextUri() != null) { + URI nextUri = queryResults.getValue().getNextUri(); + queryResults = client.execute(prepareGet().setUri(nextUri).build(), createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); + + String header = queryResults.getHeader(CACHE_CONTROL); + assertNotNull(header); + assertTrue(header.contains("max-age")); + + int maxAge = parseInt(header.substring(header.indexOf("=") + 1)); + assertTrue(maxAge >= 0); + } + } + + @Test + public void testQueryWithRetryUrl() + { + String retryUrl = format("%s/v1/statement/queued/retry/abc123", server.getBaseUrl()); + int expirationSeconds = 3600; + + // start query with retry URL + URI uri = buildStatementUri(retryUrl, expirationSeconds, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator("SELECT 123", UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .build(); + + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(queryResults); + assertNotNull(queryResults.getId()); + + // Verify query executes normally with retry parameters + while (queryResults.getNextUri() != null) { + queryResults = client.execute(prepareGet().setUri(queryResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + } + assertNull(queryResults.getError()); + } + + @Test + public void testQueryWithRetryFlag() + { + // submit a retry query (marked with retry flag) + URI uri = buildStatementUri(); + + String retryQuery = "-- retry query 20240115_120000_00000_xxxxx; attempt: 1\nSELECT 456"; + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator(retryQuery, UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .setHeader(PRESTO_RETRY_QUERY, "true") + .build(); + + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(queryResults); + assertNotNull(queryResults.getId()); + + // Verify the usual endpoint fails + try { + client.execute(prepareGet().setUri(queryResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + } + catch (UnexpectedResponseException e) { + // Expected failure, retry query should not be fetched from the usual endpoint + assertEquals(e.getStatusCode(), 409, "Expected 409 Conflict for retry query"); + } + + request = prepareGet() + .setUri(HttpUriBuilder.uriBuilderFrom(server.getBaseUrl()) + .replacePath(format("/v1/statement/queued/retry/%s", queryResults.getId())) + .build()) + .build(); + // Fetch the retry query results from the special retry endpoint + queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + + while (queryResults.getNextUri() != null) { + queryResults = client.execute(prepareGet().setUri(queryResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + } + + assertNull(queryResults.getError()); + } + + @Test + public void testQueryWithInvalidRetryParameters() + { + // Test with missing expiration when retry URL is provided + String retryUrl = format("%s/v1/statement/queued/retry", server.getBaseUrl()); + URI uri = buildStatementUri(retryUrl, null, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator("SELECT 1", UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .build(); + + JsonResponse response = client.execute(request, createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertEquals(response.getStatusCode(), 400); + } + + @Test + public void testQueryWithInvalidRetryUrl() + { + // Test with invalid remote retry URL + String invalidRetryUrl = "http://insecure-cluster.example.com/v1/statement"; + URI uri = buildStatementUri(invalidRetryUrl, 3600, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator("SELECT 1", UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .build(); + + JsonResponse response = client.execute(request, createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertEquals(response.getStatusCode(), 400); + } + + @Test + public void testQueryWithExpiredRetryTime() + { + String retryUrl = "https://backup-cluster.example.com/v1/statement"; + // Use 0 seconds expiration (immediately expired) + URI uri = buildStatementUri(retryUrl, 0, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator("SELECT 1", UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .build(); + + JsonResponse response = client.execute(request, createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertEquals(response.getStatusCode(), 400); + } + + @Test + public void testQueryWithRetryUrlAndSessionProperties() + { + String retryUrl = format("%s/v1/statement/queued/retry", server.getBaseUrl()); + URI uri = buildStatementUri(retryUrl, 3600, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator("SELECT 789", UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .setHeader(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") + .setHeader(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=BROADCAST") + .build(); + + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(queryResults); + assertNotNull(queryResults.getId()); + + // Verify query executes with both retry parameters and session properties + while (queryResults.getNextUri() != null) { + queryResults = client.execute(prepareGet().setUri(queryResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + } + assertNull(queryResults.getError()); + } + + @Test + public void testCrossClusterRetryWithRetryChainPrevention() + throws Exception + { + // Create two test servers to simulate two clusters + TestingPrestoServer backupServer = null; + HttpClient backupClient = null; + try { + backupServer = new TestingPrestoServer(); + backupClient = new JettyHttpClient(); + + // Step 1: Simulate router POSTing query to backup server with retry header + URI backupUri = HttpUriBuilder.uriBuilderFrom(backupServer.getBaseUrl()) + .replacePath("/v1/statement") + .build(); + + String failQuery = format("SELECT fail(%s, 'Simulated remote task error')", REMOTE_TASK_ERROR.toErrorCode().getCode()); + Request backupRequest = preparePost() + .setUri(backupUri) + .setBodyGenerator(createStaticBodyGenerator(failQuery, UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .setHeader(PRESTO_SESSION, "query_retry_limit=1") + .setHeader(PRESTO_RETRY_QUERY, "true") // Router marks this as a retry query + .build(); + + // Router POSTs to backup server and gets the query ID for the retry query + QueryResults backupResults = backupClient.execute(backupRequest, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(backupResults); + String retryQueryId = backupResults.getId(); + + // Step 2: Configure retry URL to point to the backup cluster's retry endpoint + String retryUrl = backupServer.getBaseUrl() + "/v1/statement/queued/retry/" + retryQueryId; + + // Step 3: Router redirects client to primary server with retry parameters + URI uri = buildStatementUri(retryUrl, 3600, false); + + Request request = preparePost() + .setUri(uri) + .setBodyGenerator(createStaticBodyGenerator(failQuery, UTF_8)) + .setHeader(PRESTO_USER, "user") + .setHeader(PRESTO_SOURCE, "source") + .setHeader(PRESTO_CATALOG, "catalog") + .setHeader(PRESTO_SCHEMA, "schema") + .setHeader(PRESTO_SESSION, "query_retry_limit=1") + .build(); + + // Execute the query on the primary cluster + QueryResults firstResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(firstResults); + + // Poll for results until we get the error or nextUri changes to backup cluster + int iterations = 0; + while (firstResults.getNextUri() != null && iterations < 20) { + // Check if nextUri has changed from primary to backup cluster + URI nextUri = firstResults.getNextUri(); + URI primaryUri = URI.create(server.getBaseUrl().toString()); + URI backupClusterUri = URI.create(backupServer.getBaseUrl().toString()); + + // Check both host and port to determine if we've switched clusters + boolean isPrimaryCluster = nextUri.getHost().equals(primaryUri.getHost()) && + nextUri.getPort() == primaryUri.getPort(); + boolean isBackupCluster = nextUri.getHost().equals(backupClusterUri.getHost()) && + nextUri.getPort() == backupClusterUri.getPort(); + + if (!isPrimaryCluster && isBackupCluster) { + // NextUri has changed to backup cluster - this is expected + break; + } + + // If we have an error, check if it includes the retry URL + if (firstResults.getError() != null) { + // Even with an error, there should be a nextUri pointing to backup cluster + assertNotNull(firstResults.getNextUri(), "Error response should include retry URL"); + assertTrue(firstResults.getNextUri().toString().contains(backupServer.getBaseUrl().toString()), + "Error response should have nextUri pointing to backup cluster"); + break; + } + + firstResults = client.execute(prepareGet().setUri(firstResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + iterations++; + } + + // Verify we didn't timeout + assertTrue(iterations < 20, "Query did not complete or redirect within 20 iterations"); + + // Verify the query failure is not exposed yet + assertNull(firstResults.getError()); + + // The nextUri should now point to the backup cluster + assertNotNull(firstResults.getNextUri()); + assertTrue(firstResults.getNextUri().toString().contains(backupServer.getBaseUrl().toString())); + + // Step 4: Client follows the redirect link from primary server to backup server + // The retry endpoint will return the results of the retry query that was already created + Request retryRequest = prepareGet() + .setUri(firstResults.getNextUri()) + .build(); + + QueryResults retryResults = backupClient.execute(retryRequest, createJsonResponseHandler(QUERY_RESULTS_CODEC)); + assertNotNull(retryResults); + + // Verify the retry query ID does not match what we created earlier (because the retry query is a placeholder, + // and when it is retried it generates a new ID) + assertNotEquals(retryResults.getId(), retryQueryId); + + // The retry query should also fail (since we're using the same fail() function) + while (retryResults.getNextUri() != null) { + retryResults = backupClient.execute(prepareGet().setUri(retryResults.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_CODEC)); + } + + // Verify the retry query also failed + assertNotNull(retryResults.getError()); + assertEquals(retryResults.getError().getErrorName(), "REMOTE_TASK_ERROR"); + + // IMPORTANT: The retry query should NOT have a nextUri for another retry + // This prevents retry chains (retry of a retry) + assertNull(retryResults.getNextUri(), "Retry query should not have nextUri to prevent retry chains"); + } + finally { + Closeables.closeQuietly(backupClient); + Closeables.closeQuietly(backupServer); + } + } + public URI uriFor(String path) { return HttpUriBuilder.uriBuilderFrom(server.getBaseUrl()).replacePath(path).build(); @@ -438,4 +762,34 @@ public URI uriFor(String path, QueryId queryId, String slug) .addParameter("slug", slug) .build(); } + + private URI buildStatementUri() + { + return buildStatementUri(null, null, false); + } + + private URI buildStatementUri(boolean binaryResults) + { + return buildStatementUri(null, null, binaryResults); + } + + private URI buildStatementUri(String retryUrl, Integer retryExpirationInSeconds, boolean binaryResults) + { + HttpUriBuilder builder = HttpUriBuilder.uriBuilderFrom(server.getBaseUrl()) + .replacePath("/v1/statement"); + + if (retryUrl != null) { + builder.addParameter("retryUrl", urlEncode(retryUrl)); + } + + if (retryExpirationInSeconds != null) { + builder.addParameter("retryExpirationInSeconds", String.valueOf(retryExpirationInSeconds)); + } + + if (binaryResults) { + builder.replaceParameter("binaryResults", "true"); + } + + return builder.build(); + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/TestThreadResource.java b/presto-main/src/test/java/com/facebook/presto/server/TestThreadResource.java similarity index 100% rename from presto-main-base/src/test/java/com/facebook/presto/server/TestThreadResource.java rename to presto-main/src/test/java/com/facebook/presto/server/TestThreadResource.java diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestThriftClusterStats.java b/presto-main/src/test/java/com/facebook/presto/server/TestThriftClusterStats.java index c396f083da00b..4a8d6d944067f 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestThriftClusterStats.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestThriftClusterStats.java @@ -46,6 +46,7 @@ public class TestThriftClusterStats public static final long TOTAL_INPUT_BYTES = 1003; public static final long TOTAL_CPU_TIME_SECS = 1004; public static final long ADJUSTED_QUEUE_SIZE = 1005; + public static final String CLUSTER_TAG = "test-cluster"; private static final ThriftCodecManager COMPILER_READ_CODEC_MANAGER = new ThriftCodecManager(new CompilerThriftCodecFactory(false)); private static final ThriftCodecManager COMPILER_WRITE_CODEC_MANAGER = new ThriftCodecManager(new CompilerThriftCodecFactory(false)); private static final ThriftCodec COMPILER_READ_CODEC = COMPILER_READ_CODEC_MANAGER.getCodec(ClusterStats.class); @@ -111,6 +112,7 @@ private void assertSerde(ClusterStats clusterStats) assertEquals(clusterStats.getTotalInputBytes(), TOTAL_INPUT_BYTES); assertEquals(clusterStats.getTotalCpuTimeSecs(), TOTAL_CPU_TIME_SECS); assertEquals(clusterStats.getAdjustedQueueSize(), ADJUSTED_QUEUE_SIZE); + assertEquals(clusterStats.getClusterTag(), CLUSTER_TAG); } private ClusterStats getRoundTripSerialize(ThriftCodec readCodec, ThriftCodec writeCodec, Function protocolFactory) @@ -134,6 +136,7 @@ private ClusterStats getClusterStats() TOTAL_INPUT_ROWS, TOTAL_INPUT_BYTES, TOTAL_CPU_TIME_SECS, - ADJUSTED_QUEUE_SIZE); + ADJUSTED_QUEUE_SIZE, + CLUSTER_TAG); } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestThriftServerInfoIntegration.java b/presto-main/src/test/java/com/facebook/presto/server/TestThriftServerInfoIntegration.java index bbeb70b1d9e3a..462667c59d8b6 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestThriftServerInfoIntegration.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestThriftServerInfoIntegration.java @@ -40,7 +40,6 @@ import com.facebook.presto.execution.buffer.OutputBuffers.OutputBufferId; import com.facebook.presto.execution.scheduler.TableWriteInfo; import com.facebook.presto.memory.MemoryPoolAssignmentsRequest; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.server.thrift.ThriftServerInfoClient; import com.facebook.presto.server.thrift.ThriftServerInfoService; @@ -55,12 +54,11 @@ import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; +import jakarta.inject.Singleton; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import javax.inject.Singleton; - import java.util.List; import java.util.Optional; @@ -69,6 +67,7 @@ import static com.facebook.drift.server.guice.DriftServerBinder.driftServerBinder; import static com.facebook.drift.transport.netty.client.DriftNettyMethodInvokerFactory.createStaticDriftNettyMethodInvokerFactory; import static com.facebook.presto.spi.NodeState.ACTIVE; +import static io.netty.buffer.ByteBufAllocator.DEFAULT; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; @@ -109,7 +108,7 @@ public void testServer() AddressSelector addressSelector = new SimpleAddressSelector( ImmutableSet.of(HostAndPort.fromParts("localhost", thriftServerPort)), true); - try (DriftNettyMethodInvokerFactory invokerFactory = createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig())) { + try (DriftNettyMethodInvokerFactory invokerFactory = createStaticDriftNettyMethodInvokerFactory(new DriftNettyClientConfig(), DEFAULT)) { DriftClientFactory clientFactory = new DriftClientFactory(new ThriftCodecManager(), invokerFactory, addressSelector, NORMAL_RESULT); ThriftServerInfoClient client = clientFactory.createDriftClient(ThriftServerInfoClient.class).get(); @@ -245,12 +244,6 @@ public void removeRemoteSource(TaskId taskId, TaskId remoteSourceTaskId) { throw new UnsupportedOperationException(); } - - @Override - public void updateMetadataResults(TaskId taskId, MetadataUpdates metadataUpdates) - { - throw new UnsupportedOperationException(); - } }; } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java deleted file mode 100644 index b703ec779f460..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTask.java +++ /dev/null @@ -1,677 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.server.remotetask; - -import com.facebook.airlift.bootstrap.Bootstrap; -import com.facebook.airlift.http.client.testing.TestingHttpClient; -import com.facebook.airlift.jaxrs.JsonMapper; -import com.facebook.airlift.jaxrs.testing.JaxrsTestingHttpProcessor; -import com.facebook.airlift.jaxrs.thrift.ThriftMapper; -import com.facebook.airlift.json.JsonCodec; -import com.facebook.airlift.json.JsonModule; -import com.facebook.airlift.json.smile.SmileCodec; -import com.facebook.airlift.json.smile.SmileModule; -import com.facebook.drift.codec.ThriftCodec; -import com.facebook.drift.codec.guice.ThriftCodecModule; -import com.facebook.drift.codec.utils.DataSizeToBytesThriftCodec; -import com.facebook.drift.codec.utils.DurationToMillisThriftCodec; -import com.facebook.drift.codec.utils.JodaDateTimeToEpochMillisThriftCodec; -import com.facebook.drift.codec.utils.LocaleToLanguageTagCodec; -import com.facebook.presto.SessionTestUtils; -import com.facebook.presto.client.NodeVersion; -import com.facebook.presto.common.ErrorCode; -import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.TypeManager; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; -import com.facebook.presto.execution.Lifespan; -import com.facebook.presto.execution.NodeTaskMap; -import com.facebook.presto.execution.QueryManagerConfig; -import com.facebook.presto.execution.RemoteTask; -import com.facebook.presto.execution.SchedulerStatsTracker; -import com.facebook.presto.execution.TaskId; -import com.facebook.presto.execution.TaskInfo; -import com.facebook.presto.execution.TaskManagerConfig; -import com.facebook.presto.execution.TaskSource; -import com.facebook.presto.execution.TaskState; -import com.facebook.presto.execution.TaskStatus; -import com.facebook.presto.execution.TaskTestUtils; -import com.facebook.presto.execution.TestQueryManager; -import com.facebook.presto.execution.TestSqlTaskManager; -import com.facebook.presto.execution.buffer.OutputBuffers; -import com.facebook.presto.execution.scheduler.TableWriteInfo; -import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.HandleJsonModule; -import com.facebook.presto.metadata.HandleResolver; -import com.facebook.presto.metadata.InternalNode; -import com.facebook.presto.metadata.MetadataUpdates; -import com.facebook.presto.metadata.Split; -import com.facebook.presto.server.ConnectorMetadataUpdateHandleJsonSerde; -import com.facebook.presto.server.InternalCommunicationConfig; -import com.facebook.presto.server.TaskUpdateRequest; -import com.facebook.presto.server.thrift.MetadataUpdatesCodec; -import com.facebook.presto.server.thrift.SplitCodec; -import com.facebook.presto.server.thrift.TableWriteInfoCodec; -import com.facebook.presto.spi.ConnectorId; -import com.facebook.presto.spi.plan.PlanNodeId; -import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.Serialization; -import com.facebook.presto.sql.analyzer.FeaturesConfig; -import com.facebook.presto.sql.planner.PlanFragment; -import com.facebook.presto.testing.TestingHandleResolver; -import com.facebook.presto.testing.TestingSplit; -import com.facebook.presto.testing.TestingTransactionHandle; -import com.facebook.presto.type.TypeDeserializer; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.inject.Binder; -import com.google.inject.Injector; -import com.google.inject.Module; -import com.google.inject.Provides; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import javax.ws.rs.Consumes; -import javax.ws.rs.DELETE; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.UriInfo; - -import java.lang.reflect.Method; -import java.net.URI; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BooleanSupplier; - -import static com.facebook.airlift.configuration.ConfigBinder.configBinder; -import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_BINARY; -import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_COMPACT; -import static com.facebook.airlift.http.client.thrift.ThriftRequestUtils.APPLICATION_THRIFT_FB_COMPACT; -import static com.facebook.airlift.json.JsonBinder.jsonBinder; -import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; -import static com.facebook.airlift.json.smile.SmileCodecBinder.smileCodecBinder; -import static com.facebook.drift.codec.guice.ThriftCodecBinder.thriftCodecBinder; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_CURRENT_STATE; -import static com.facebook.presto.client.PrestoHeaders.PRESTO_MAX_WAIT; -import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; -import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; -import static com.facebook.presto.spi.SplitContext.NON_CACHEABLE; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_ERROR; -import static com.facebook.presto.spi.StandardErrorCode.REMOTE_TASK_MISMATCH; -import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.inject.multibindings.Multibinder.newSetBinder; -import static java.lang.Math.min; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.NANOSECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.assertj.core.api.Assertions.assertThat; -import static org.testng.Assert.assertTrue; - -public class TestHttpRemoteTask -{ - // This 30 sec per-test timeout should never be reached because the test should fail and do proper cleanup after 20 sec. - private static final Duration POLL_TIMEOUT = new Duration(100, MILLISECONDS); - private static final Duration IDLE_TIMEOUT = new Duration(3, SECONDS); - private static final Duration FAIL_TIMEOUT = new Duration(40, SECONDS); - private static final TaskManagerConfig TASK_MANAGER_CONFIG = new TaskManagerConfig() - // Shorten status refresh wait and info update interval so that we can have a shorter test timeout - .setStatusRefreshMaxWait(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 100, MILLISECONDS)) - .setInfoUpdateInterval(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 10, MILLISECONDS)); - - private static final boolean TRACE_HTTP = false; - - @DataProvider - public Object[][] thriftEncodingToggle() - { - return new Object[][] {{true}, {false}}; - } - - @Test(timeOut = 50000, dataProvider = "thriftEncodingToggle") - public void testRemoteTaskMismatch(boolean useThriftEncoding) - throws Exception - { - runTest(FailureScenario.TASK_MISMATCH, useThriftEncoding); - } - - @Test(timeOut = 50000, dataProvider = "thriftEncodingToggle") - public void testRejectedExecutionWhenVersionIsHigh(boolean useThriftEncoding) - throws Exception - { - runTest(FailureScenario.TASK_MISMATCH_WHEN_VERSION_IS_HIGH, useThriftEncoding); - } - - @Test(timeOut = 40000, dataProvider = "thriftEncodingToggle") - public void testRejectedExecution(boolean useThriftEncoding) - throws Exception - { - runTest(FailureScenario.REJECTED_EXECUTION, useThriftEncoding); - } - - @Test(timeOut = 60000, dataProvider = "thriftEncodingToggle") - public void testRegular(boolean useThriftEncoding) - throws Exception - { - AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE); - - HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, useThriftEncoding); - - RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); - - testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); - remoteTask.start(); - - Lifespan lifespan = Lifespan.driverGroup(3); - remoteTask.addSplits(ImmutableMultimap.of(TaskTestUtils.TABLE_SCAN_NODE_ID, new Split(new ConnectorId("test"), TestingTransactionHandle.create(), TestingSplit.createLocalSplit(), lifespan, NON_CACHEABLE))); - poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID) != null); - poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID).getSplits().size() == 1); - - remoteTask.noMoreSplits(TaskTestUtils.TABLE_SCAN_NODE_ID, lifespan); - poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID).getNoMoreSplitsForLifespan().size() == 1); - - remoteTask.noMoreSplits(TaskTestUtils.TABLE_SCAN_NODE_ID); - poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID).isNoMoreSplits()); - - remoteTask.cancel(); - poll(() -> remoteTask.getTaskStatus().getState().isDone()); - poll(() -> remoteTask.getTaskInfo().getTaskStatus().getState().isDone()); - - httpRemoteTaskFactory.stop(); - } - - @Test(timeOut = 50000) - public void testHTTPRemoteTaskSize() - throws Exception - { - AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE); - - HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, false); - - RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); - - testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); - remoteTask.start(); - // just need to run a TaskUpdateRequest to increment the decay counter - remoteTask.cancel(); - httpRemoteTaskFactory.stop(); - - assertTrue(httpRemoteTaskFactory.getTaskUpdateRequestSize() > 0); - } - - @Test(timeOut = 50000) - public void testHTTPRemoteBadTaskSize() - throws Exception - { - AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE); - boolean useThriftEncoding = false; - DataSize maxDataSize = DataSize.succinctBytes(1024); - InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig() - .setThriftTransportEnabled(useThriftEncoding) - .setMaxTaskUpdateSize(maxDataSize); - - HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, useThriftEncoding, internalCommunicationConfig); - - RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); - testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); - remoteTask.start(); - waitUntilIdle(lastActivityNanos); - httpRemoteTaskFactory.stop(); - - assertTrue(remoteTask.getTaskStatus().getState().isDone(), format("TaskStatus is not in a done state: %s", remoteTask.getTaskStatus())); - assertThat(getOnlyElement(remoteTask.getTaskStatus().getFailures()).getMessage()) - .matches("TaskUpdate size of .+? has exceeded the limit of 1024 bytes"); - } - - @Test(dataProvider = "getUpdateSize") - public void testGetExceededTaskUpdateSizeListMessage(int updateSizeInBytes, int maxDataSizeInBytes, - String expectedMessage) - throws Exception - { - AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, FailureScenario.NO_FAILURE); - boolean useThriftEncoding = false; - DataSize maxDataSize = DataSize.succinctBytes(maxDataSizeInBytes); - InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig() - .setThriftTransportEnabled(useThriftEncoding) - .setMaxTaskUpdateSize(maxDataSize); - HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, useThriftEncoding, internalCommunicationConfig); - RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); - - Method targetMethod = HttpRemoteTask.class.getDeclaredMethod("getExceededTaskUpdateSizeMessage", new Class[] {byte[].class}); - targetMethod.setAccessible(true); - byte[] taskUpdateRequestJson = new byte[updateSizeInBytes]; - String message = (String) targetMethod.invoke(remoteTask, new Object[] {taskUpdateRequestJson}); - assertEquals(message, expectedMessage); - } - - @DataProvider(name = "getUpdateSize") - protected Object[][] getUpdateSize() - { - return new Object[][] { - {2000, 1000, "TaskUpdate size of 2000 bytes has exceeded the limit of 1000 bytes"}, - {2000, 1024, "TaskUpdate size of 2000 bytes has exceeded the limit of 1024 bytes"}, - {5000, 4 * 1024, "TaskUpdate size of 5000 bytes has exceeded the limit of 4096 bytes"}, - {2 * 1024, 1024, "TaskUpdate size of 2048 bytes has exceeded the limit of 1024 bytes"}, - {1024 * 1024, 512 * 1024, "TaskUpdate size of 1048576 bytes has exceeded the limit of 524288 bytes"}, - {16 * 1024 * 1024, 8 * 1024 * 1024, "TaskUpdate size of 16777216 bytes has exceeded the limit of 8388608 bytes"}, - {485 * 1000 * 1000, 1024 * 1024 * 512, "TaskUpdate size of 485000000 bytes has exceeded the limit of 536870912 bytes"}, - {1024 * 1024 * 1024, 1024 * 1024 * 512, "TaskUpdate size of 1073741824 bytes has exceeded the limit of 536870912 bytes"}, - {860492511, 524288000, "TaskUpdate size of 860492511 bytes has exceeded the limit of 524288000 bytes"}}; - } - - private void runTest(FailureScenario failureScenario, boolean useThriftEncoding) - throws Exception - { - AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); - TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, failureScenario); - - HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, useThriftEncoding); - RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); - - testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); - remoteTask.start(); - - waitUntilIdle(lastActivityNanos); - - httpRemoteTaskFactory.stop(); - assertTrue(remoteTask.getTaskStatus().getState().isDone(), format("TaskStatus is not in a done state: %s", remoteTask.getTaskStatus())); - - ErrorCode actualErrorCode = getOnlyElement(remoteTask.getTaskStatus().getFailures()).getErrorCode(); - switch (failureScenario) { - case TASK_MISMATCH: - case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: - assertTrue(remoteTask.getTaskInfo().getTaskStatus().getState().isDone(), format("TaskInfo is not in a done state: %s", remoteTask.getTaskInfo())); - assertEquals(actualErrorCode, REMOTE_TASK_MISMATCH.toErrorCode()); - break; - case REJECTED_EXECUTION: - // for a rejection to occur, the http client must be shutdown, which means we will not be able to ge the final task info - assertEquals(actualErrorCode, REMOTE_TASK_ERROR.toErrorCode()); - break; - default: - throw new UnsupportedOperationException(); - } - } - - private RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory) - { - return httpRemoteTaskFactory.createRemoteTask( - SessionTestUtils.TEST_SESSION, - new TaskId("test", 1, 0, 2, 0), - new InternalNode("node-id", URI.create("http://fake.invalid/"), new NodeVersion("version"), false), - TaskTestUtils.createPlanFragment(), - ImmutableMultimap.of(), - createInitialEmptyOutputBuffers(OutputBuffers.BufferType.BROADCAST), - new NodeTaskMap.NodeStatsTracker(i -> {}, i -> {}, (age, i) -> {}), - true, - new TableWriteInfo(Optional.empty(), Optional.empty()), - SchedulerStatsTracker.NOOP); - } - - private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource, boolean useThriftEncoding) - throws Exception - { - InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig().setThriftTransportEnabled(useThriftEncoding); - return createHttpRemoteTaskFactory(testingTaskResource, useThriftEncoding, internalCommunicationConfig); - } - - private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource, boolean useThriftEncoding, InternalCommunicationConfig internalCommunicationConfig) - throws Exception - { - Bootstrap app = new Bootstrap( - new JsonModule(), - new SmileModule(), - new ThriftCodecModule(), - new HandleJsonModule(), - new Module() - { - @Override - public void configure(Binder binder) - { - binder.bind(JsonMapper.class); - binder.bind(ThriftMapper.class); - configBinder(binder).bindConfig(FeaturesConfig.class); - FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); - binder.bind(TypeManager.class).toInstance(functionAndTypeManager); - jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); - newSetBinder(binder, Type.class); - smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); - smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); - smileCodecBinder(binder).bindSmileCodec(TaskUpdateRequest.class); - smileCodecBinder(binder).bindSmileCodec(PlanFragment.class); - smileCodecBinder(binder).bindSmileCodec(MetadataUpdates.class); - jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); - jsonCodecBinder(binder).bindJsonCodec(TaskInfo.class); - jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); - jsonCodecBinder(binder).bindJsonCodec(PlanFragment.class); - jsonCodecBinder(binder).bindJsonCodec(MetadataUpdates.class); - jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); - jsonCodecBinder(binder).bindJsonCodec(Split.class); - jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); - jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); - thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); - thriftCodecBinder(binder).bindThriftCodec(TaskInfo.class); - thriftCodecBinder(binder).bindThriftCodec(TaskUpdateRequest.class); - thriftCodecBinder(binder).bindCustomThriftCodec(MetadataUpdatesCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(SplitCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(TableWriteInfoCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(LocaleToLanguageTagCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(JodaDateTimeToEpochMillisThriftCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(DurationToMillisThriftCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(DataSizeToBytesThriftCodec.class); - } - - @Provides - private HttpRemoteTaskFactory createHttpRemoteTaskFactory( - JsonMapper jsonMapper, - ThriftMapper thriftMapper, - JsonCodec taskStatusJsonCodec, - SmileCodec taskStatusSmileCodec, - ThriftCodec taskStatusThriftCodec, - JsonCodec taskInfoJsonCodec, - ThriftCodec taskInfoThriftCodec, - SmileCodec taskInfoSmileCodec, - JsonCodec taskUpdateRequestJsonCodec, - SmileCodec taskUpdateRequestSmileCodec, - ThriftCodec taskUpdateRequestThriftCodec, - JsonCodec planFragmentJsonCodec, - SmileCodec planFragmentSmileCodec, - JsonCodec metadataUpdatesJsonCodec, - SmileCodec metadataUpdatesSmileCodec) - { - JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor(URI.create("http://fake.invalid/"), testingTaskResource, jsonMapper, thriftMapper); - TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor.setTrace(TRACE_HTTP)); - testingTaskResource.setHttpClient(testingHttpClient); - return new HttpRemoteTaskFactory( - new QueryManagerConfig(), - TASK_MANAGER_CONFIG, - testingHttpClient, - new TestSqlTaskManager.MockLocationFactory(), - taskStatusJsonCodec, - taskStatusSmileCodec, - taskStatusThriftCodec, - taskInfoJsonCodec, - taskInfoSmileCodec, - taskInfoThriftCodec, - taskUpdateRequestJsonCodec, - taskUpdateRequestSmileCodec, - taskUpdateRequestThriftCodec, - planFragmentJsonCodec, - planFragmentSmileCodec, - metadataUpdatesJsonCodec, - metadataUpdatesSmileCodec, - new RemoteTaskStats(), - internalCommunicationConfig, - createTestMetadataManager(), - new TestQueryManager(), - new HandleResolver(), - new ConnectorTypeSerdeManager(new ConnectorMetadataUpdateHandleJsonSerde())); - } - }); - Injector injector = app - .doNotInitializeLogging() - .quiet() - .initialize(); - HandleResolver handleResolver = injector.getInstance(HandleResolver.class); - handleResolver.addConnectorName("test", new TestingHandleResolver()); - return injector.getInstance(HttpRemoteTaskFactory.class); - } - - private static void poll(BooleanSupplier success) - throws InterruptedException - { - long failAt = System.nanoTime() + FAIL_TIMEOUT.roundTo(NANOSECONDS); - - while (!success.getAsBoolean()) { - long millisUntilFail = (failAt - System.nanoTime()) / 1_000_000; - if (millisUntilFail <= 0) { - throw new AssertionError(format("Timeout of %s reached", FAIL_TIMEOUT)); - } - Thread.sleep(min(POLL_TIMEOUT.toMillis(), millisUntilFail)); - } - } - - private static void waitUntilIdle(AtomicLong lastActivityNanos) - throws InterruptedException - { - long startTimeNanos = System.nanoTime(); - - while (true) { - long millisSinceLastActivity = (System.nanoTime() - lastActivityNanos.get()) / 1_000_000L; - long millisSinceStart = (System.nanoTime() - startTimeNanos) / 1_000_000L; - long millisToIdleTarget = IDLE_TIMEOUT.toMillis() - millisSinceLastActivity; - long millisToFailTarget = FAIL_TIMEOUT.toMillis() - millisSinceStart; - if (millisToFailTarget < millisToIdleTarget) { - throw new AssertionError(format("Activity doesn't stop after %s", FAIL_TIMEOUT)); - } - if (millisToIdleTarget < 0) { - return; - } - Thread.sleep(millisToIdleTarget); - } - } - - private enum FailureScenario - { - NO_FAILURE, - TASK_MISMATCH, - TASK_MISMATCH_WHEN_VERSION_IS_HIGH, - REJECTED_EXECUTION, - } - - @Path("/task/{nodeId}") - public static class TestingTaskResource - { - private static final UUID INITIAL_TASK_INSTANCE_ID = UUID.randomUUID(); - private static final UUID NEW_TASK_INSTANCE_ID = UUID.randomUUID(); - - private final AtomicLong lastActivityNanos; - private final FailureScenario failureScenario; - - private AtomicReference httpClient = new AtomicReference<>(); - - private TaskInfo initialTaskInfo; - private TaskStatus initialTaskStatus; - private long version; - private TaskState taskState; - private long taskInstanceIdLeastSignificantBits = INITIAL_TASK_INSTANCE_ID.getLeastSignificantBits(); - private long taskInstanceIdMostSignificantBits = INITIAL_TASK_INSTANCE_ID.getMostSignificantBits(); - - private long statusFetchCounter; - - public TestingTaskResource(AtomicLong lastActivityNanos, FailureScenario failureScenario) - { - this.lastActivityNanos = requireNonNull(lastActivityNanos, "lastActivityNanos is null"); - this.failureScenario = requireNonNull(failureScenario, "failureScenario is null"); - } - - public void setHttpClient(TestingHttpClient newValue) - { - httpClient.set(newValue); - } - - @GET - @Path("{taskId}") - @Produces(MediaType.APPLICATION_JSON) - public synchronized TaskInfo getTaskInfo( - @PathParam("taskId") final TaskId taskId, - @HeaderParam(PRESTO_CURRENT_STATE) TaskState currentState, - @HeaderParam(PRESTO_MAX_WAIT) Duration maxWait, - @Context UriInfo uriInfo) - { - lastActivityNanos.set(System.nanoTime()); - return buildTaskInfo(); - } - - Map taskSourceMap = new HashMap<>(); - - @POST - @Path("{taskId}") - @Consumes(MediaType.APPLICATION_JSON) - @Produces(MediaType.APPLICATION_JSON) - public synchronized TaskInfo createOrUpdateTask( - @PathParam("taskId") TaskId taskId, - TaskUpdateRequest taskUpdateRequest, - @Context UriInfo uriInfo) - { - for (TaskSource source : taskUpdateRequest.getSources()) { - taskSourceMap.compute(source.getPlanNodeId(), (planNodeId, taskSource) -> taskSource == null ? source : taskSource.update(source)); - } - lastActivityNanos.set(System.nanoTime()); - return buildTaskInfo(); - } - - public synchronized TaskSource getTaskSource(PlanNodeId planNodeId) - { - TaskSource source = taskSourceMap.get(planNodeId); - if (source == null) { - return null; - } - return new TaskSource(source.getPlanNodeId(), source.getSplits(), source.getNoMoreSplitsForLifespan(), source.isNoMoreSplits()); - } - - @GET - @Path("{taskId}/status") - @Produces({MediaType.APPLICATION_JSON, APPLICATION_THRIFT_BINARY, APPLICATION_THRIFT_COMPACT, APPLICATION_THRIFT_FB_COMPACT}) - public synchronized TaskStatus getTaskStatus( - @PathParam("taskId") TaskId taskId, - @HeaderParam(PRESTO_CURRENT_STATE) TaskState currentState, - @HeaderParam(PRESTO_MAX_WAIT) Duration maxWait, - @Context UriInfo uriInfo) - throws InterruptedException - { - lastActivityNanos.set(System.nanoTime()); - - wait(maxWait.roundTo(MILLISECONDS)); - return buildTaskStatus(); - } - - @DELETE - @Path("{taskId}") - @Produces(MediaType.APPLICATION_JSON) - public synchronized TaskInfo deleteTask( - @PathParam("taskId") TaskId taskId, - @QueryParam("abort") @DefaultValue("true") boolean abort, - @Context UriInfo uriInfo) - { - lastActivityNanos.set(System.nanoTime()); - - taskState = abort ? TaskState.ABORTED : TaskState.CANCELED; - return buildTaskInfo(); - } - - public void setInitialTaskInfo(TaskInfo initialTaskInfo) - { - this.initialTaskInfo = initialTaskInfo; - this.initialTaskStatus = initialTaskInfo.getTaskStatus(); - this.taskState = initialTaskStatus.getState(); - this.version = initialTaskStatus.getVersion(); - switch (failureScenario) { - case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: - // Make the initial version large enough. - // This way, the version number can't be reached if it is reset to 0. - version = 1_000_000; - break; - case TASK_MISMATCH: - case REJECTED_EXECUTION: - case NO_FAILURE: - break; // do nothing - default: - throw new UnsupportedOperationException(); - } - } - - private TaskInfo buildTaskInfo() - { - return new TaskInfo( - initialTaskInfo.getTaskId(), - buildTaskStatus(), - initialTaskInfo.getLastHeartbeatInMillis(), - initialTaskInfo.getOutputBuffers(), - initialTaskInfo.getNoMoreSplits(), - initialTaskInfo.getStats(), - initialTaskInfo.isNeedsPlan(), - initialTaskInfo.getMetadataUpdates(), - initialTaskInfo.getNodeId()); - } - - private TaskStatus buildTaskStatus() - { - statusFetchCounter++; - // Change the task instance id after 10th fetch to simulate worker restart - switch (failureScenario) { - case TASK_MISMATCH: - case TASK_MISMATCH_WHEN_VERSION_IS_HIGH: - if (statusFetchCounter == 10) { - taskInstanceIdLeastSignificantBits = NEW_TASK_INSTANCE_ID.getLeastSignificantBits(); - taskInstanceIdMostSignificantBits = NEW_TASK_INSTANCE_ID.getMostSignificantBits(); - version = 0; - } - break; - case REJECTED_EXECUTION: - if (statusFetchCounter >= 10) { - httpClient.get().close(); - throw new RejectedExecutionException(); - } - break; - case NO_FAILURE: - break; - default: - throw new UnsupportedOperationException(); - } - - return new TaskStatus( - taskInstanceIdLeastSignificantBits, - taskInstanceIdMostSignificantBits, - ++version, - taskState, - initialTaskStatus.getSelf(), - ImmutableSet.of(), - initialTaskStatus.getFailures(), - initialTaskStatus.getQueuedPartitionedDrivers(), - initialTaskStatus.getRunningPartitionedDrivers(), - initialTaskStatus.getOutputBufferUtilization(), - initialTaskStatus.isOutputBufferOverutilized(), - initialTaskStatus.getPhysicalWrittenDataSizeInBytes(), - initialTaskStatus.getMemoryReservationInBytes(), - initialTaskStatus.getSystemMemoryReservationInBytes(), - initialTaskStatus.getPeakNodeTotalMemoryReservationInBytes(), - initialTaskStatus.getFullGcCount(), - initialTaskStatus.getFullGcTimeInMillis(), - initialTaskStatus.getTotalCpuTimeInNanos(), - initialTaskStatus.getTaskAgeInMillis(), - initialTaskStatus.getQueuedPartitionedSplitsWeight(), - initialTaskStatus.getRunningPartitionedSplitsWeight()); - } - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskConnectorCodec.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskConnectorCodec.java new file mode 100644 index 0000000000000..66ab57b155e33 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskConnectorCodec.java @@ -0,0 +1,1489 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.remotetask; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.http.client.testing.TestingHttpClient; +import com.facebook.airlift.jaxrs.JsonMapper; +import com.facebook.airlift.jaxrs.testing.JaxrsTestingHttpProcessor; +import com.facebook.airlift.jaxrs.thrift.ThriftMapper; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonModule; +import com.facebook.airlift.json.smile.SmileCodec; +import com.facebook.airlift.json.smile.SmileModule; +import com.facebook.drift.codec.ThriftCodec; +import com.facebook.drift.codec.guice.ThriftCodecModule; +import com.facebook.drift.codec.utils.DataSizeToBytesThriftCodec; +import com.facebook.drift.codec.utils.DurationToMillisThriftCodec; +import com.facebook.drift.codec.utils.JodaDateTimeToEpochMillisThriftCodec; +import com.facebook.drift.codec.utils.LocaleToLanguageTagCodec; +import com.facebook.presto.SessionTestUtils; +import com.facebook.presto.client.NodeVersion; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.cost.StatsAndCosts; +import com.facebook.presto.execution.Lifespan; +import com.facebook.presto.execution.NodeTaskMap; +import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.RemoteTask; +import com.facebook.presto.execution.ScheduledSplit; +import com.facebook.presto.execution.SchedulerStatsTracker; +import com.facebook.presto.execution.TaskId; +import com.facebook.presto.execution.TaskInfo; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.execution.TaskSource; +import com.facebook.presto.execution.TaskStatus; +import com.facebook.presto.execution.TaskTestUtils; +import com.facebook.presto.execution.TestQueryManager; +import com.facebook.presto.execution.TestSqlTaskManager; +import com.facebook.presto.execution.buffer.OutputBuffers; +import com.facebook.presto.execution.scheduler.ExecutionWriterTarget; +import com.facebook.presto.execution.scheduler.TableWriteInfo; +import com.facebook.presto.metadata.ColumnHandleJacksonModule; +import com.facebook.presto.metadata.DeleteTableHandle; +import com.facebook.presto.metadata.DeleteTableHandleJacksonModule; +import com.facebook.presto.metadata.DistributedProcedureHandleJacksonModule; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.FunctionHandleJacksonModule; +import com.facebook.presto.metadata.HandleResolver; +import com.facebook.presto.metadata.InsertTableHandle; +import com.facebook.presto.metadata.InsertTableHandleJacksonModule; +import com.facebook.presto.metadata.InternalNode; +import com.facebook.presto.metadata.MergeTableHandleJacksonModule; +import com.facebook.presto.metadata.OutputTableHandle; +import com.facebook.presto.metadata.OutputTableHandleJacksonModule; +import com.facebook.presto.metadata.PartitioningHandleJacksonModule; +import com.facebook.presto.metadata.Split; +import com.facebook.presto.metadata.SplitJacksonModule; +import com.facebook.presto.metadata.TableHandleJacksonModule; +import com.facebook.presto.metadata.TableLayoutHandleJacksonModule; +import com.facebook.presto.metadata.TransactionHandleJacksonModule; +import com.facebook.presto.server.InternalCommunicationConfig; +import com.facebook.presto.server.TaskUpdateRequest; +import com.facebook.presto.server.thrift.ConnectorSplitThriftCodec; +import com.facebook.presto.server.thrift.DeleteTableHandleThriftCodec; +import com.facebook.presto.server.thrift.InsertTableHandleThriftCodec; +import com.facebook.presto.server.thrift.MergeTableHandleThriftCodec; +import com.facebook.presto.server.thrift.OutputTableHandleThriftCodec; +import com.facebook.presto.server.thrift.TableHandleThriftCodec; +import com.facebook.presto.server.thrift.TableLayoutHandleThriftCodec; +import com.facebook.presto.server.thrift.TransactionHandleThriftCodec; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorCodec; +import com.facebook.presto.spi.ConnectorDeleteTableHandle; +import com.facebook.presto.spi.ConnectorDistributedProcedureHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorIndexHandle; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.connector.ConnectorCodecProvider; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanFragmentId; +import com.facebook.presto.spi.plan.StageExecutionDescriptor; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.facebook.presto.sql.Serialization; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.PlanFragment; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.type.TypeDeserializer; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Binder; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.TypeLiteral; +import org.testng.annotations.Test; + +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.airlift.json.smile.SmileCodecBinder.smileCodecBinder; +import static com.facebook.drift.codec.guice.ThriftCodecBinder.thriftCodecBinder; +import static com.facebook.presto.execution.Lifespan.driverGroup; +import static com.facebook.presto.execution.TaskTestUtils.createPlanFragment; +import static com.facebook.presto.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.server.remotetask.TestHttpRemoteTaskWithEventLoop.TestingTaskResource; +import static com.facebook.presto.spi.SplitContext.NON_CACHEABLE; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestHttpRemoteTaskConnectorCodec +{ + private static final TaskManagerConfig TASK_MANAGER_CONFIG = new TaskManagerConfig(); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @Test(timeOut = 50000) + public void testConnectorSplitBinarySerialization() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorName = "test-codec-split"; + Injector injector = createInjectorWithCodec(connectorName, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec jsonCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); + try { + testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); + remoteTask.start(); + + Lifespan lifespan = driverGroup(1); + TestConnectorWithCodecSplit codecSplit = new TestConnectorWithCodecSplit("test-data", 42); + remoteTask.addSplits(ImmutableMultimap.of( + TaskTestUtils.TABLE_SCAN_NODE_ID, + new Split(new ConnectorId(connectorName), TestingTransactionHandle.create(), codecSplit, lifespan, NON_CACHEABLE))); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID) != null); + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID).getSplits().size() == 1); + + TaskUpdateRequest taskUpdateRequest = testingTaskResource.getLastTaskUpdateRequest(); + assertNotNull(taskUpdateRequest, "TaskUpdateRequest should not be null"); + + String json = jsonCodec.toJson(taskUpdateRequest); + JsonNode root = OBJECT_MAPPER.readTree(json); + JsonNode splitNode = root.at("/sources/0/splits/0/split/connectorSplit"); + assertTrue(splitNode.has("customSerializedValue"), + "Split should have customSerializedValue for binary serialization"); + assertFalse(splitNode.has("data"), + "Split should not have inline data field"); + + TaskUpdateRequest deserializedRequest = jsonCodec.fromJson(json); + TaskSource deserializedSource = deserializedRequest.getSources().get(0); + Split deserializedSplit = getOnlyElement(deserializedSource.getSplits()).getSplit(); + ConnectorSplit deserializedConnectorSplit = deserializedSplit.getConnectorSplit(); + assertEquals(deserializedConnectorSplit, codecSplit, "Expected deserialized split to match original"); + } + finally { + remoteTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + @Test(timeOut = 50000) + public void testOutputTableHandleBinarySerialization() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorName = "test-codec-output"; + Injector injector = createInjectorWithCodec(connectorName, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec taskUpdateRequestCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + ConnectorId connectorId = new ConnectorId(connectorName); + TestConnectorOutputTableHandle outputHandle = new TestConnectorOutputTableHandle("output_table"); + TableWriteInfo outputTableWriteInfo = new TableWriteInfo( + Optional.of(new ExecutionWriterTarget.CreateHandle( + new OutputTableHandle(connectorId, com.facebook.presto.testing.TestingTransactionHandle.create(), outputHandle), + new SchemaTableName("test_schema", "output_table"))), + Optional.empty()); + + RemoteTask outputTask = createRemoteTask(httpRemoteTaskFactory, createPlanFragment(), outputTableWriteInfo); + try { + testingTaskResource.setInitialTaskInfo(outputTask.getTaskInfo()); + outputTask.start(); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getLastTaskUpdateRequest() != null); + TaskUpdateRequest outputRequest = testingTaskResource.getLastTaskUpdateRequest(); + String outputJson = taskUpdateRequestCodec.toJson(outputRequest); + + JsonNode root = OBJECT_MAPPER.readTree(outputJson); + JsonNode outputTableHandleNode = root.at("/tableWriteInfo/writerTarget/handle/connectorHandle"); + assertTrue(outputTableHandleNode.has("customSerializedValue"), + "OutputTableHandle should have customSerializedValue for binary serialization"); + assertFalse(outputTableHandleNode.has("tableName"), + "OutputTableHandle should not have inline tableName field"); + + ExecutionWriterTarget.CreateHandle createHandle = (ExecutionWriterTarget.CreateHandle) outputRequest.getTableWriteInfo().get().getWriterTarget().get(); + TestConnectorOutputTableHandle receivedHandle = (TestConnectorOutputTableHandle) createHandle.getHandle().getConnectorHandle(); + assertEquals(receivedHandle.getTableName(), outputHandle.getTableName(), "OutputTableHandle should match after round-trip"); + } + finally { + outputTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + @Test(timeOut = 50000) + public void testInsertTableHandleBinarySerialization() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorName = "test-codec-insert"; + Injector injector = createInjectorWithCodec(connectorName, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec taskUpdateRequestCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + ConnectorId connectorId = new ConnectorId(connectorName); + TestConnectorInsertTableHandle insertHandle = new TestConnectorInsertTableHandle("insert_table"); + TableWriteInfo insertTableWriteInfo = new TableWriteInfo( + Optional.of(new ExecutionWriterTarget.InsertHandle( + new InsertTableHandle(connectorId, com.facebook.presto.testing.TestingTransactionHandle.create(), insertHandle), + new SchemaTableName("test_schema", "insert_table"))), + Optional.empty()); + + RemoteTask insertTask = createRemoteTask(httpRemoteTaskFactory, createPlanFragment(), insertTableWriteInfo); + try { + testingTaskResource.setInitialTaskInfo(insertTask.getTaskInfo()); + insertTask.start(); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getLastTaskUpdateRequest() != null); + TaskUpdateRequest insertRequest = testingTaskResource.getLastTaskUpdateRequest(); + String insertJson = taskUpdateRequestCodec.toJson(insertRequest); + + JsonNode root = OBJECT_MAPPER.readTree(insertJson); + JsonNode insertTableHandleNode = root.at("/tableWriteInfo/writerTarget/handle/connectorHandle"); + assertTrue(insertTableHandleNode.has("customSerializedValue"), + "InsertTableHandle should have customSerializedValue for binary serialization"); + assertFalse(insertTableHandleNode.has("tableName"), + "InsertTableHandle should not have inline tableName field"); + + ExecutionWriterTarget.InsertHandle deserializedInsertHandle = (ExecutionWriterTarget.InsertHandle) insertRequest.getTableWriteInfo().get().getWriterTarget().get(); + TestConnectorInsertTableHandle receivedHandle = (TestConnectorInsertTableHandle) deserializedInsertHandle.getHandle().getConnectorHandle(); + assertEquals(receivedHandle.getTableName(), insertHandle.getTableName(), "InsertTableHandle should match after round-trip"); + } + finally { + insertTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + @Test(timeOut = 50000) + public void testDeleteTableHandleBinarySerialization() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorName = "test-codec-delete"; + Injector injector = createInjectorWithCodec(connectorName, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec taskUpdateRequestCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + ConnectorId connectorId = new ConnectorId(connectorName); + TestConnectorDeleteTableHandle deleteHandle = new TestConnectorDeleteTableHandle("delete_table"); + TableWriteInfo deleteTableWriteInfo = new TableWriteInfo( + Optional.of(new ExecutionWriterTarget.DeleteHandle( + new DeleteTableHandle(connectorId, com.facebook.presto.testing.TestingTransactionHandle.create(), deleteHandle), + new SchemaTableName("test_schema", "delete_table"))), + Optional.empty()); + + RemoteTask deleteTask = createRemoteTask(httpRemoteTaskFactory, createPlanFragment(), deleteTableWriteInfo); + try { + testingTaskResource.setInitialTaskInfo(deleteTask.getTaskInfo()); + deleteTask.start(); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getLastTaskUpdateRequest() != null); + TaskUpdateRequest deleteRequest = testingTaskResource.getLastTaskUpdateRequest(); + String deleteJson = taskUpdateRequestCodec.toJson(deleteRequest); + + JsonNode root = OBJECT_MAPPER.readTree(deleteJson); + JsonNode deleteTableHandleNode = root.at("/tableWriteInfo/writerTarget/handle/connectorHandle"); + assertTrue(deleteTableHandleNode.has("customSerializedValue"), + "DeleteTableHandle should have customSerializedValue for binary serialization"); + assertFalse(deleteTableHandleNode.has("tableName"), + "DeleteTableHandle should not have inline tableName field"); + + ExecutionWriterTarget.DeleteHandle deserializedDeleteHandle = (ExecutionWriterTarget.DeleteHandle) deleteRequest.getTableWriteInfo().get().getWriterTarget().get(); + TestConnectorDeleteTableHandle receivedHandle = (TestConnectorDeleteTableHandle) deserializedDeleteHandle.getHandle().getConnectorHandle(); + assertEquals(receivedHandle.getTableName(), deleteHandle.getTableName(), "DeleteTableHandle should match after round-trip"); + } + finally { + deleteTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + @Test(timeOut = 50000) + public void testConnectorHandlesBinarySerialization() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorName = "test-codec-handles"; + Injector injector = createInjectorWithCodec(connectorName, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec planFragmentCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + PlanFragment planFragment = createPlanFragmentWithCodecHandles(connectorName); + RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory, planFragment); + + try { + testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); + remoteTask.start(); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getLastTaskUpdateRequest() != null); + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getLastTaskUpdateRequest().getFragment().isPresent()); + + TaskUpdateRequest taskUpdateRequest = testingTaskResource.getLastTaskUpdateRequest(); + byte[] fragmentBytes = taskUpdateRequest.getFragment().get(); + String json = new String(fragmentBytes, UTF_8); + JsonNode root = OBJECT_MAPPER.readTree(json); + + JsonNode tableHandleNode = root.at("/root/table/connectorHandle"); + assertTrue(tableHandleNode.has("customSerializedValue"), + "TableHandle should have customSerializedValue for binary serialization"); + assertFalse(tableHandleNode.has("tableName"), + "TableHandle should not have inline tableName field"); + + JsonNode layoutHandleNode = root.at("/root/table/connectorTableLayout"); + assertTrue(layoutHandleNode.has("customSerializedValue"), + "TableLayoutHandle should have customSerializedValue for binary serialization"); + assertFalse(layoutHandleNode.has("layoutName"), + "TableLayoutHandle should not have inline layoutName field"); + + JsonNode assignmentsNode = root.at("/root/assignments"); + assertTrue(assignmentsNode.isObject() && assignmentsNode.size() > 0, + "Should have at least one column assignment"); + assignmentsNode.fields().forEachRemaining(entry -> { + JsonNode columnHandleNode = entry.getValue(); + assertTrue(columnHandleNode.has("customSerializedValue"), + "ColumnHandle should have customSerializedValue for binary serialization"); + assertFalse(columnHandleNode.has("columnName"), + "ColumnHandle should not have inline columnName field"); + assertFalse(columnHandleNode.has("columnType"), + "ColumnHandle should not have inline columnType field"); + }); + + PlanFragment receivedFragment = planFragmentCodec.fromJson(json); + assertNotNull(receivedFragment, "Deserialized PlanFragment should not be null"); + assertNotNull(receivedFragment.getRoot(), "Deserialized PlanFragment should have a root node"); + + TableScanNode originalScan = (TableScanNode) planFragment.getRoot(); + TableScanNode receivedScan = (TableScanNode) receivedFragment.getRoot(); + + TestConnectorTableHandle originalTableHandle = (TestConnectorTableHandle) originalScan.getTable().getConnectorHandle(); + TestConnectorTableHandle receivedTableHandle = (TestConnectorTableHandle) receivedScan.getTable().getConnectorHandle(); + assertEquals(receivedTableHandle.getTableName(), originalTableHandle.getTableName(), "TableHandle should match after round-trip"); + + TestConnectorTableLayoutHandle originalLayoutHandle = (TestConnectorTableLayoutHandle) originalScan.getTable().getLayout().get(); + TestConnectorTableLayoutHandle receivedLayoutHandle = (TestConnectorTableLayoutHandle) receivedScan.getTable().getLayout().get(); + assertEquals(receivedLayoutHandle.getLayoutName(), originalLayoutHandle.getLayoutName(), "TableLayoutHandle should match after round-trip"); + + TestConnectorColumnHandle originalColumnHandle = (TestConnectorColumnHandle) originalScan.getAssignments().values().iterator().next(); + TestConnectorColumnHandle receivedColumnHandle = (TestConnectorColumnHandle) receivedScan.getAssignments().values().iterator().next(); + assertEquals(receivedColumnHandle.getColumnName(), originalColumnHandle.getColumnName(), "ColumnHandle name should match after round-trip"); + assertEquals(receivedColumnHandle.getColumnType(), originalColumnHandle.getColumnType(), "ColumnHandle type should match after round-trip"); + } + finally { + remoteTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + @Test(timeOut = 50000) + public void testMixedConnectorSerializationWithAndWithoutCodec() + throws Exception + { + AtomicLong lastActivityNanos = new AtomicLong(System.nanoTime()); + TestingTaskResource testingTaskResource = new TestingTaskResource(lastActivityNanos, TestHttpRemoteTaskWithEventLoop.FailureScenario.NO_FAILURE); + + String connectorWithCodec = "test-with-codec"; + String connectorWithoutCodec = "test-without-codec"; + Injector injector = createInjectorWithMixedConnectors(connectorWithCodec, connectorWithoutCodec, testingTaskResource); + HttpRemoteTaskFactory httpRemoteTaskFactory = injector.getInstance(HttpRemoteTaskFactory.class); + JsonCodec jsonCodec = injector.getInstance(Key.get(new TypeLiteral<>() {})); + + RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory); + try { + testingTaskResource.setInitialTaskInfo(remoteTask.getTaskInfo()); + remoteTask.start(); + + Lifespan lifespan = driverGroup(1); + + TestConnectorWithCodecSplit splitWithCodec = new TestConnectorWithCodecSplit("codec-data", 100); + TestConnectorWithoutCodecSplit splitWithoutCodec = new TestConnectorWithoutCodecSplit("json-data", 200); + + remoteTask.addSplits(ImmutableMultimap.of( + TaskTestUtils.TABLE_SCAN_NODE_ID, + new Split(new ConnectorId(connectorWithCodec), TestingTransactionHandle.create(), splitWithCodec, lifespan, NON_CACHEABLE), + TaskTestUtils.TABLE_SCAN_NODE_ID, + new Split(new ConnectorId(connectorWithoutCodec), TestingTransactionHandle.create(), splitWithoutCodec, lifespan, NON_CACHEABLE))); + + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID) != null); + TestHttpRemoteTaskWithEventLoop.poll(() -> testingTaskResource.getTaskSource(TaskTestUtils.TABLE_SCAN_NODE_ID).getSplits().size() == 2); + + TaskUpdateRequest taskUpdateRequest = testingTaskResource.getLastTaskUpdateRequest(); + assertNotNull(taskUpdateRequest, "TaskUpdateRequest should not be null"); + + String json = jsonCodec.toJson(taskUpdateRequest); + JsonNode root = OBJECT_MAPPER.readTree(json); + JsonNode splitsNode = root.at("/sources/0/splits"); + + assertTrue(splitsNode.isArray() && splitsNode.size() == 2, + "Should have exactly 2 splits"); + + JsonNode codecSplitNode = null; + JsonNode jsonSplitNode = null; + + for (JsonNode splitWrapper : splitsNode) { + JsonNode connectorIdNode = splitWrapper.at("/split/connectorId"); + String catalogName = connectorIdNode.asText(); + JsonNode connectorSplitNode = splitWrapper.at("/split/connectorSplit"); + + if (connectorWithCodec.equals(catalogName)) { + codecSplitNode = connectorSplitNode; + } + else if (connectorWithoutCodec.equals(catalogName)) { + jsonSplitNode = connectorSplitNode; + } + } + + assertNotNull(codecSplitNode, "Should find split from connector with codec"); + assertNotNull(jsonSplitNode, "Should find split from connector without codec"); + + assertTrue(codecSplitNode.has("customSerializedValue"), + "Split with codec should have customSerializedValue for binary serialization"); + assertFalse(codecSplitNode.has("data"), + "Split with codec should not have inline data field"); + + assertFalse(jsonSplitNode.has("customSerializedValue"), + "Split without codec should not have customSerializedValue"); + assertTrue(jsonSplitNode.has("data"), + "Split without codec should have inline data field for JSON serialization"); + assertTrue(jsonSplitNode.has("sequence"), + "Split without codec should have inline sequence field for JSON serialization"); + + TaskUpdateRequest deserializedRequest = jsonCodec.fromJson(json); + TaskSource deserializedSource = deserializedRequest.getSources().get(0); + List deserializedSplits = deserializedSource.getSplits().stream() + .map(ScheduledSplit::getSplit) + .collect(toImmutableList()); + + assertEquals(deserializedSplits.size(), 2, "Should have 2 deserialized splits"); + + boolean foundCodecSplit = false; + boolean foundJsonSplit = false; + + for (Split split : deserializedSplits) { + if (split.getConnectorSplit() instanceof TestConnectorWithCodecSplit) { + TestConnectorWithCodecSplit deserialized = (TestConnectorWithCodecSplit) split.getConnectorSplit(); + assertEquals(deserialized, splitWithCodec, "Codec split should match after round-trip"); + foundCodecSplit = true; + } + else if (split.getConnectorSplit() instanceof TestConnectorWithoutCodecSplit) { + TestConnectorWithoutCodecSplit deserialized = (TestConnectorWithoutCodecSplit) split.getConnectorSplit(); + assertEquals(deserialized, splitWithoutCodec, "JSON split should match after round-trip"); + foundJsonSplit = true; + } + } + + assertTrue(foundCodecSplit, "Should have found and verified the codec split"); + assertTrue(foundJsonSplit, "Should have found and verified the JSON split"); + } + finally { + remoteTask.cancel(); + httpRemoteTaskFactory.stop(); + } + } + + private static RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory) + { + return createRemoteTask(httpRemoteTaskFactory, createPlanFragment()); + } + + private static RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, PlanFragment planFragment) + { + return createRemoteTask(httpRemoteTaskFactory, planFragment, new TableWriteInfo(Optional.empty(), Optional.empty())); + } + + private static RemoteTask createRemoteTask(HttpRemoteTaskFactory httpRemoteTaskFactory, PlanFragment planFragment, TableWriteInfo tableWriteInfo) + { + return httpRemoteTaskFactory.createRemoteTask( + SessionTestUtils.TEST_SESSION, + new TaskId("test", 1, 0, 2, 0), + new InternalNode("node-id", URI.create("http://fake.invalid/"), new NodeVersion("version"), false), + planFragment, + ImmutableMultimap.of(), + createInitialEmptyOutputBuffers(OutputBuffers.BufferType.BROADCAST), + new NodeTaskMap.NodeStatsTracker(i -> {}, i -> {}, (age, i) -> {}), + true, + tableWriteInfo, + SchedulerStatsTracker.NOOP); + } + + private static PlanFragment createPlanFragmentWithCodecHandles(String connectorName) + { + ConnectorId connectorId = new ConnectorId(connectorName); + TestConnectorTableHandle tableHandle = new TestConnectorTableHandle("test_table"); + TestConnectorTableLayoutHandle layoutHandle = new TestConnectorTableLayoutHandle("test_layout"); + TestConnectorColumnHandle columnHandle = new TestConnectorColumnHandle("test_column", "VARCHAR"); + VariableReferenceExpression variable = new VariableReferenceExpression(Optional.empty(), "test_column", com.facebook.presto.common.type.VarcharType.VARCHAR); + + return new PlanFragment( + new PlanFragmentId(0), + new TableScanNode( + Optional.empty(), + TaskTestUtils.TABLE_SCAN_NODE_ID, + new TableHandle(connectorId, tableHandle, TestingTransactionHandle.create(), Optional.of(layoutHandle)), + ImmutableList.of(variable), + ImmutableMap.of(variable, columnHandle), + TupleDomain.all(), + TupleDomain.all(), + Optional.empty()), + ImmutableSet.of(variable), + SOURCE_DISTRIBUTION, + ImmutableList.of(TaskTestUtils.TABLE_SCAN_NODE_ID), + new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(variable)) + .withBucketToPartition(Optional.of(new int[1])), + Optional.empty(), + StageExecutionDescriptor.ungroupedExecution(), + false, + Optional.of(StatsAndCosts.empty()), + Optional.empty()); + } + + private static Injector createInjectorWithCodec(String connectorName, TestingTaskResource testingTaskResource) + throws Exception + { + return createInjectorWithMixedConnectors(connectorName, "unused-connector", testingTaskResource); + } + + private static Injector createInjectorWithMixedConnectors( + String connectorWithCodec, + String connectorWithoutCodec, + TestingTaskResource testingTaskResource) + throws Exception + { + InternalCommunicationConfig internalCommunicationConfig = new InternalCommunicationConfig().setThriftTransportEnabled(false); + Bootstrap app = new Bootstrap( + new JsonModule(), + new SmileModule(), + new ThriftCodecModule(), + new Module() + { + @Override + public void configure(Binder binder) + { + binder.bind(JsonMapper.class); + binder.bind(ThriftMapper.class); + + FeaturesConfig featuresConfig = new FeaturesConfig(); + featuresConfig.setUseConnectorProvidedSerializationCodecs(true); + binder.bind(FeaturesConfig.class).toInstance(featuresConfig); + + TestConnectorWithCodecProvider codecProvider = new TestConnectorWithCodecProvider(); + + Map> splitCodecMap = new ConcurrentHashMap<>(); + splitCodecMap.put(connectorWithCodec, codecProvider.getConnectorSplitCodec().get()); + + Map> tableHandleCodecMap = new ConcurrentHashMap<>(); + tableHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorTableHandleCodec().get()); + + Map> columnHandleCodecMap = new ConcurrentHashMap<>(); + columnHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorColumnHandleCodec().get()); + + Map> tableLayoutHandleCodecMap = new ConcurrentHashMap<>(); + tableLayoutHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorTableLayoutHandleCodec().get()); + + Map> outputTableHandleCodecMap = new ConcurrentHashMap<>(); + outputTableHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorOutputTableHandleCodec().get()); + + Map> insertTableHandleCodecMap = new ConcurrentHashMap<>(); + insertTableHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorInsertTableHandleCodec().get()); + + Map> deleteTableHandleCodecMap = new ConcurrentHashMap<>(); + deleteTableHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorDeleteTableHandleCodec().get()); + + Map> mergeTableHandleCodecMap = new ConcurrentHashMap<>(); + mergeTableHandleCodecMap.put(connectorWithCodec, codecProvider.getConnectorMergeTableHandleCodec().get()); + + HandleResolver handleResolver = new HandleResolver(); + handleResolver.addConnectorName(connectorWithCodec, new TestConnectorWithCodecHandleResolver()); + handleResolver.addConnectorName(connectorWithoutCodec, new TestConnectorWithoutCodecHandleResolver()); + binder.bind(HandleResolver.class).toInstance(handleResolver); + + Function>> tableHandleCodecExtractor = + connectorId -> Optional.ofNullable(tableHandleCodecMap.get(connectorId.getCatalogName())); + Function>> tableLayoutHandleCodecExtractor = + connectorId -> Optional.ofNullable(tableLayoutHandleCodecMap.get(connectorId.getCatalogName())); + Function>> columnHandleCodecExtractor = + connectorId -> Optional.ofNullable(columnHandleCodecMap.get(connectorId.getCatalogName())); + Function>> outputTableHandleCodecExtractor = + connectorId -> Optional.ofNullable(outputTableHandleCodecMap.get(connectorId.getCatalogName())); + Function>> insertTableHandleCodecExtractor = + connectorId -> Optional.ofNullable(insertTableHandleCodecMap.get(connectorId.getCatalogName())); + Function>> deleteTableHandleCodecExtractor = + connectorId -> Optional.ofNullable(deleteTableHandleCodecMap.get(connectorId.getCatalogName())); + Function>> mergeTableHandleCodecExtractor = + connectorId -> Optional.ofNullable(mergeTableHandleCodecMap.get(connectorId.getCatalogName())); + Function>> noOpIndexCodec = + connectorId -> Optional.empty(); + Function>> noOpTransactionCodec = + connectorId -> Optional.empty(); + Function>> noOpPartitioningCodec = + connectorId -> Optional.empty(); + Function>> noOpDistributedProcedureCodec = + connectorId -> Optional.empty(); + Function>> splitCodecExtractor = + connectorId -> Optional.ofNullable(splitCodecMap.get(connectorId.getCatalogName())); + + jsonBinder(binder).addModuleBinding().toInstance(new TableHandleJacksonModule(handleResolver, featuresConfig, tableHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new TableLayoutHandleJacksonModule(handleResolver, featuresConfig, tableLayoutHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new ColumnHandleJacksonModule(handleResolver, featuresConfig, columnHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new OutputTableHandleJacksonModule(handleResolver, featuresConfig, outputTableHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new InsertTableHandleJacksonModule(handleResolver, featuresConfig, insertTableHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new DeleteTableHandleJacksonModule(handleResolver, featuresConfig, deleteTableHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new MergeTableHandleJacksonModule(handleResolver, featuresConfig, mergeTableHandleCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new com.facebook.presto.index.IndexHandleJacksonModule(handleResolver, featuresConfig, noOpIndexCodec)); + jsonBinder(binder).addModuleBinding().toInstance(new TransactionHandleJacksonModule(handleResolver, featuresConfig, noOpTransactionCodec)); + jsonBinder(binder).addModuleBinding().toInstance(new PartitioningHandleJacksonModule(handleResolver, featuresConfig, noOpPartitioningCodec)); + jsonBinder(binder).addModuleBinding().toInstance(new FunctionHandleJacksonModule(handleResolver)); + jsonBinder(binder).addModuleBinding().toInstance(new SplitJacksonModule(handleResolver, featuresConfig, splitCodecExtractor)); + jsonBinder(binder).addModuleBinding().toInstance(new DistributedProcedureHandleJacksonModule(handleResolver, featuresConfig, noOpDistributedProcedureCodec)); + + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + smileCodecBinder(binder).bindSmileCodec(TaskStatus.class); + smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); + smileCodecBinder(binder).bindSmileCodec(TaskUpdateRequest.class); + smileCodecBinder(binder).bindSmileCodec(PlanFragment.class); + jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); + jsonCodecBinder(binder).bindJsonCodec(TaskInfo.class); + jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); + jsonCodecBinder(binder).bindJsonCodec(PlanFragment.class); + jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); + jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); + jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorSplit.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTransactionHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ColumnHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorOutputTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorDeleteTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorInsertTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableLayoutHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorMergeTableHandle.class); + + binder.bind(ConnectorCodecManager.class).in(Scopes.SINGLETON); + + thriftCodecBinder(binder).bindCustomThriftCodec(ConnectorSplitThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TransactionHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(OutputTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(InsertTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(DeleteTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(MergeTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableLayoutHandleThriftCodec.class); + thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); + thriftCodecBinder(binder).bindThriftCodec(TaskInfo.class); + thriftCodecBinder(binder).bindThriftCodec(TaskUpdateRequest.class); + thriftCodecBinder(binder).bindCustomThriftCodec(LocaleToLanguageTagCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(JodaDateTimeToEpochMillisThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(DurationToMillisThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(DataSizeToBytesThriftCodec.class); + } + + @Provides + private HttpRemoteTaskFactory createHttpRemoteTaskFactory( + JsonMapper jsonMapper, + ThriftMapper thriftMapper, + JsonCodec taskStatusJsonCodec, + SmileCodec taskStatusSmileCodec, + ThriftCodec taskStatusThriftCodec, + JsonCodec taskInfoJsonCodec, + ThriftCodec taskInfoThriftCodec, + SmileCodec taskInfoSmileCodec, + JsonCodec taskUpdateRequestJsonCodec, + SmileCodec taskUpdateRequestSmileCodec, + ThriftCodec taskUpdateRequestThriftCodec, + JsonCodec planFragmentJsonCodec, + SmileCodec planFragmentSmileCodec) + { + JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor(URI.create("http://fake.invalid/"), testingTaskResource, jsonMapper, thriftMapper); + TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor.setTrace(false)); + testingTaskResource.setHttpClient(testingHttpClient); + return new HttpRemoteTaskFactory( + new QueryManagerConfig(), + TASK_MANAGER_CONFIG, + testingHttpClient, + new TestSqlTaskManager.MockLocationFactory(), + taskStatusJsonCodec, + taskStatusSmileCodec, + taskStatusThriftCodec, + taskInfoJsonCodec, + taskInfoSmileCodec, + taskInfoThriftCodec, + taskUpdateRequestJsonCodec, + taskUpdateRequestSmileCodec, + taskUpdateRequestThriftCodec, + planFragmentJsonCodec, + planFragmentSmileCodec, + new RemoteTaskStats(), + internalCommunicationConfig, + createTestMetadataManager(), + new TestQueryManager(), + new HandleResolver()); + } + }); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + HandleResolver handleResolver = injector.getInstance(HandleResolver.class); + handleResolver.addConnectorName("test", new com.facebook.presto.testing.TestingHandleResolver()); + + ConnectorCodecManager codecManager = injector.getInstance(ConnectorCodecManager.class); + codecManager.addConnectorCodecProvider(new ConnectorId(connectorWithCodec), new TestConnectorWithCodecProvider()); + + return injector; + } + + /** + * Test connector split that supports binary serialization via codec + */ + public static class TestConnectorWithCodecSplit + implements ConnectorSplit + { + private final String data; + private final int sequence; + + @JsonCreator + public TestConnectorWithCodecSplit( + @JsonProperty("data") String data, + @JsonProperty("sequence") int sequence) + { + this.data = data; + this.sequence = sequence; + } + + @JsonProperty + public String getData() + { + return data; + } + + @JsonProperty + public int getSequence() + { + return sequence; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return ImmutableList.of(); + } + + @Override + public Object getInfo() + { + return ImmutableMap.of("data", data, "sequence", sequence); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + TestConnectorWithCodecSplit that = (TestConnectorWithCodecSplit) obj; + return sequence == that.sequence && Objects.equals(data, that.data); + } + + @Override + public int hashCode() + { + return Objects.hash(data, sequence); + } + } + + /** + * Test connector codec provider that provides binary serialization + */ + public static class TestConnectorWithCodecProvider + implements ConnectorCodecProvider + { + @Override + public Optional> getConnectorSplitCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorSplit split) + { + TestConnectorWithCodecSplit codecSplit = (TestConnectorWithCodecSplit) split; + return (codecSplit.getData() + "|" + codecSplit.getSequence()).getBytes(UTF_8); + } + + @Override + public ConnectorSplit deserialize(byte[] data) + { + String[] parts = new String(data, UTF_8).split("\\|"); + return new TestConnectorWithCodecSplit(parts[0], Integer.parseInt(parts[1])); + } + }); + } + + @Override + public Optional> getConnectorTableHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorTableHandle handle) + { + TestConnectorTableHandle tableHandle = (TestConnectorTableHandle) handle; + return tableHandle.getTableName().getBytes(UTF_8); + } + + @Override + public ConnectorTableHandle deserialize(byte[] data) + { + return new TestConnectorTableHandle(new String(data, UTF_8)); + } + }); + } + + public Optional> getConnectorColumnHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ColumnHandle handle) + { + TestConnectorColumnHandle columnHandle = (TestConnectorColumnHandle) handle; + return (columnHandle.getColumnName() + ":" + columnHandle.getColumnType()).getBytes(UTF_8); + } + + @Override + public ColumnHandle deserialize(byte[] data) + { + String[] parts = new String(data, UTF_8).split(":"); + return new TestConnectorColumnHandle(parts[0], parts[1]); + } + }); + } + + public Optional> getConnectorTableLayoutHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorTableLayoutHandle handle) + { + TestConnectorTableLayoutHandle layoutHandle = (TestConnectorTableLayoutHandle) handle; + return layoutHandle.getLayoutName().getBytes(UTF_8); + } + + @Override + public ConnectorTableLayoutHandle deserialize(byte[] data) + { + return new TestConnectorTableLayoutHandle(new String(data, UTF_8)); + } + }); + } + + public Optional> getConnectorOutputTableHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorOutputTableHandle handle) + { + TestConnectorOutputTableHandle outputHandle = (TestConnectorOutputTableHandle) handle; + return outputHandle.getTableName().getBytes(UTF_8); + } + + @Override + public ConnectorOutputTableHandle deserialize(byte[] data) + { + return new TestConnectorOutputTableHandle(new String(data, UTF_8)); + } + }); + } + + public Optional> getConnectorInsertTableHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorInsertTableHandle handle) + { + TestConnectorInsertTableHandle insertHandle = (TestConnectorInsertTableHandle) handle; + return insertHandle.getTableName().getBytes(UTF_8); + } + + @Override + public ConnectorInsertTableHandle deserialize(byte[] data) + { + return new TestConnectorInsertTableHandle(new String(data, UTF_8)); + } + }); + } + + public Optional> getConnectorDeleteTableHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorDeleteTableHandle handle) + { + TestConnectorDeleteTableHandle deleteHandle = (TestConnectorDeleteTableHandle) handle; + return deleteHandle.getTableName().getBytes(UTF_8); + } + + @Override + public ConnectorDeleteTableHandle deserialize(byte[] data) + { + return new TestConnectorDeleteTableHandle(new String(data, UTF_8)); + } + }); + } + + public Optional> getConnectorMergeTableHandleCodec() + { + return Optional.of(new ConnectorCodec<>() + { + @Override + public byte[] serialize(ConnectorMergeTableHandle handle) + { + TestConnectorMergeTableHandle mergeTableHandle = (TestConnectorMergeTableHandle) handle; + return mergeTableHandle.getTableName().getBytes(UTF_8); + } + + @Override + public ConnectorMergeTableHandle deserialize(byte[] data) + { + return new TestConnectorMergeTableHandle(new String(data, UTF_8)); + } + }); + } + } + + /** + * Test table handle with binary serialization support + */ + public static class TestConnectorTableHandle + implements ConnectorTableHandle + { + private final String tableName; + + @JsonCreator + public TestConnectorTableHandle(@JsonProperty("tableName") String tableName) + { + this.tableName = tableName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + TestConnectorTableHandle that = (TestConnectorTableHandle) obj; + return Objects.equals(tableName, that.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName); + } + } + + /** + * Test table layout handle with binary serialization support + */ + public static class TestConnectorTableLayoutHandle + implements ConnectorTableLayoutHandle + { + private final String layoutName; + + public TestConnectorTableLayoutHandle(String layoutName) + { + this.layoutName = layoutName; + } + + public String getLayoutName() + { + return layoutName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestConnectorTableLayoutHandle that = (TestConnectorTableLayoutHandle) o; + return Objects.equals(layoutName, that.layoutName); + } + + @Override + public int hashCode() + { + return layoutName.hashCode(); + } + } + + /** + * Test column handle with binary serialization support + */ + public static class TestConnectorColumnHandle + implements ColumnHandle + { + private final String columnName; + private final String columnType; + + @JsonCreator + public TestConnectorColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") String columnType) + { + this.columnName = columnName; + this.columnType = columnType; + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public String getColumnType() + { + return columnType; + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + TestConnectorColumnHandle that = (TestConnectorColumnHandle) obj; + return Objects.equals(columnName, that.columnName) && + Objects.equals(columnType, that.columnType); + } + + @Override + public int hashCode() + { + return Objects.hash(columnName, columnType); + } + } + + /** + * Test connector handle resolver for codec-enabled connector + */ + public static class TestConnectorWithCodecHandleResolver + implements ConnectorHandleResolver + { + @Override + public Class getTableHandleClass() + { + return TestConnectorTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return TestConnectorTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return TestConnectorColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return TestConnectorWithCodecSplit.class; + } + + @Override + public Class getOutputTableHandleClass() + { + return TestConnectorOutputTableHandle.class; + } + + @Override + public Class getInsertTableHandleClass() + { + return TestConnectorInsertTableHandle.class; + } + + @Override + public Class getDeleteTableHandleClass() + { + return TestConnectorDeleteTableHandle.class; + } + } + + /** + * Test output table handle with binary serialization support + */ + public static class TestConnectorOutputTableHandle + implements ConnectorOutputTableHandle + { + private final String tableName; + + @JsonCreator + public TestConnectorOutputTableHandle( + @JsonProperty("tableName") String tableName) + { + this.tableName = tableName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestConnectorOutputTableHandle that = (TestConnectorOutputTableHandle) o; + return tableName.equals(that.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName); + } + } + + /** + * Test insert table handle with binary serialization support + */ + public static class TestConnectorInsertTableHandle + implements ConnectorInsertTableHandle + { + private final String tableName; + + @JsonCreator + public TestConnectorInsertTableHandle( + @JsonProperty("tableName") String tableName) + { + this.tableName = tableName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestConnectorInsertTableHandle that = (TestConnectorInsertTableHandle) o; + return tableName.equals(that.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName); + } + } + + /** + * Test delete table handle with binary serialization support + */ + public static class TestConnectorDeleteTableHandle + implements ConnectorDeleteTableHandle + { + private final String tableName; + + @JsonCreator + public TestConnectorDeleteTableHandle( + @JsonProperty("tableName") String tableName) + { + this.tableName = tableName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestConnectorDeleteTableHandle that = (TestConnectorDeleteTableHandle) o; + return tableName.equals(that.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName); + } + } + + /** + * Test merge table handle with binary serialization support + */ + public static class TestConnectorMergeTableHandle + implements ConnectorMergeTableHandle + { + private final String tableName; + + @JsonCreator + public TestConnectorMergeTableHandle( + @JsonProperty("tableName") String tableName) + { + this.tableName = tableName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @Override + public ConnectorTableHandle getTableHandle() + { + throw new UnsupportedOperationException("Merge table handles not supported"); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TestConnectorMergeTableHandle that = (TestConnectorMergeTableHandle) o; + return tableName.equals(that.tableName); + } + + @Override + public int hashCode() + { + return Objects.hash(tableName); + } + } + + public static class TestConnectorWithoutCodecSplit + implements ConnectorSplit + { + private final String data; + private final int sequence; + + @JsonCreator + public TestConnectorWithoutCodecSplit( + @JsonProperty("data") String data, + @JsonProperty("sequence") int sequence) + { + this.data = data; + this.sequence = sequence; + } + + @JsonProperty + public String getData() + { + return data; + } + + @JsonProperty + public int getSequence() + { + return sequence; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return ImmutableList.of(); + } + + @Override + public Object getInfo() + { + return ImmutableMap.of("data", data, "sequence", sequence); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + TestConnectorWithoutCodecSplit that = (TestConnectorWithoutCodecSplit) obj; + return sequence == that.sequence && Objects.equals(data, that.data); + } + + @Override + public int hashCode() + { + return Objects.hash(data, sequence); + } + } + + public static class TestConnectorWithoutCodecHandleResolver + implements ConnectorHandleResolver + { + @Override + public Class getTableHandleClass() + { + throw new UnsupportedOperationException("Table handles not supported"); + } + + @Override + public Class getTableLayoutHandleClass() + { + throw new UnsupportedOperationException("Table layout handles not supported"); + } + + @Override + public Class getColumnHandleClass() + { + throw new UnsupportedOperationException("Column handles not supported"); + } + + @Override + public Class getSplitClass() + { + return TestConnectorWithoutCodecSplit.class; + } + + @Override + public Class getOutputTableHandleClass() + { + throw new UnsupportedOperationException("Output table handles not supported"); + } + + @Override + public Class getInsertTableHandleClass() + { + throw new UnsupportedOperationException("Insert table handles not supported"); + } + + @Override + public Class getDeleteTableHandleClass() + { + throw new UnsupportedOperationException("Delete table handles not supported"); + } + + @Override + public Class getMergeTableHandleClass() + { + throw new UnsupportedOperationException("Merge table handles not supported"); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java index ca08715f8c1b4..8ede62846822e 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java +++ b/presto-main/src/test/java/com/facebook/presto/server/remotetask/TestHttpRemoteTaskWithEventLoop.java @@ -22,6 +22,8 @@ import com.facebook.airlift.json.JsonModule; import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.json.smile.SmileModule; +import com.facebook.airlift.units.DataSize; +import com.facebook.airlift.units.Duration; import com.facebook.drift.codec.ThriftCodec; import com.facebook.drift.codec.guice.ThriftCodecModule; import com.facebook.drift.codec.utils.DataSizeToBytesThriftCodec; @@ -32,7 +34,8 @@ import com.facebook.presto.common.ErrorCode; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeManager; -import com.facebook.presto.connector.ConnectorTypeSerdeManager; +import com.facebook.presto.connector.ConnectorCodecManager; +import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.execution.Lifespan; import com.facebook.presto.execution.NodeTaskMap; import com.facebook.presto.execution.QueryManagerConfig; @@ -52,15 +55,26 @@ import com.facebook.presto.metadata.HandleJsonModule; import com.facebook.presto.metadata.HandleResolver; import com.facebook.presto.metadata.InternalNode; -import com.facebook.presto.metadata.MetadataUpdates; import com.facebook.presto.metadata.Split; -import com.facebook.presto.server.ConnectorMetadataUpdateHandleJsonSerde; import com.facebook.presto.server.InternalCommunicationConfig; import com.facebook.presto.server.TaskUpdateRequest; -import com.facebook.presto.server.thrift.MetadataUpdatesCodec; -import com.facebook.presto.server.thrift.SplitCodec; -import com.facebook.presto.server.thrift.TableWriteInfoCodec; +import com.facebook.presto.server.thrift.ConnectorSplitThriftCodec; +import com.facebook.presto.server.thrift.DeleteTableHandleThriftCodec; +import com.facebook.presto.server.thrift.InsertTableHandleThriftCodec; +import com.facebook.presto.server.thrift.MergeTableHandleThriftCodec; +import com.facebook.presto.server.thrift.OutputTableHandleThriftCodec; +import com.facebook.presto.server.thrift.TableHandleThriftCodec; +import com.facebook.presto.server.thrift.TableLayoutHandleThriftCodec; +import com.facebook.presto.server.thrift.TransactionHandleThriftCodec; +import com.facebook.presto.spi.ConnectorDeleteTableHandle; import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorMergeTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.Serialization; @@ -77,25 +91,23 @@ import com.google.inject.Injector; import com.google.inject.Module; import com.google.inject.Provides; -import io.airlift.units.DataSize; -import io.airlift.units.Duration; +import com.google.inject.Scopes; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.UriInfo; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import javax.ws.rs.Consumes; -import javax.ws.rs.DELETE; -import javax.ws.rs.DefaultValue; -import javax.ws.rs.GET; -import javax.ws.rs.HeaderParam; -import javax.ws.rs.POST; -import javax.ws.rs.Path; -import javax.ws.rs.PathParam; -import javax.ws.rs.Produces; -import javax.ws.rs.QueryParam; -import javax.ws.rs.core.Context; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.UriInfo; - import java.lang.reflect.Method; import java.net.URI; import java.util.HashMap; @@ -147,8 +159,7 @@ public class TestHttpRemoteTaskWithEventLoop private static final TaskManagerConfig TASK_MANAGER_CONFIG = new TaskManagerConfig() // Shorten status refresh wait and info update interval so that we can have a shorter test timeout .setStatusRefreshMaxWait(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 100, MILLISECONDS)) - .setInfoUpdateInterval(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 10, MILLISECONDS)) - .setEventLoopEnabled(true); + .setInfoUpdateInterval(new Duration(IDLE_TIMEOUT.roundTo(MILLISECONDS) / 10, MILLISECONDS)); private static final boolean TRACE_HTTP = false; @@ -366,6 +377,7 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso @Override public void configure(Binder binder) { + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); binder.bind(JsonMapper.class); binder.bind(ThriftMapper.class); configBinder(binder).bindConfig(FeaturesConfig.class); @@ -377,22 +389,35 @@ public void configure(Binder binder) smileCodecBinder(binder).bindSmileCodec(TaskInfo.class); smileCodecBinder(binder).bindSmileCodec(TaskUpdateRequest.class); smileCodecBinder(binder).bindSmileCodec(PlanFragment.class); - smileCodecBinder(binder).bindSmileCodec(MetadataUpdates.class); jsonCodecBinder(binder).bindJsonCodec(TaskStatus.class); jsonCodecBinder(binder).bindJsonCodec(TaskInfo.class); jsonCodecBinder(binder).bindJsonCodec(TaskUpdateRequest.class); jsonCodecBinder(binder).bindJsonCodec(PlanFragment.class); - jsonCodecBinder(binder).bindJsonCodec(MetadataUpdates.class); jsonCodecBinder(binder).bindJsonCodec(TableWriteInfo.class); - jsonCodecBinder(binder).bindJsonCodec(Split.class); jsonBinder(binder).addKeySerializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionSerializer.class); jsonBinder(binder).addKeyDeserializerBinding(VariableReferenceExpression.class).to(Serialization.VariableReferenceExpressionDeserializer.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorSplit.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTransactionHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorOutputTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorDeleteTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorInsertTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorMergeTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableHandle.class); + jsonCodecBinder(binder).bindJsonCodec(ConnectorTableLayoutHandle.class); + + binder.bind(ConnectorCodecManager.class).in(Scopes.SINGLETON); + + thriftCodecBinder(binder).bindCustomThriftCodec(ConnectorSplitThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TransactionHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(OutputTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(InsertTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(DeleteTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(MergeTableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableHandleThriftCodec.class); + thriftCodecBinder(binder).bindCustomThriftCodec(TableLayoutHandleThriftCodec.class); thriftCodecBinder(binder).bindThriftCodec(TaskStatus.class); thriftCodecBinder(binder).bindThriftCodec(TaskInfo.class); thriftCodecBinder(binder).bindThriftCodec(TaskUpdateRequest.class); - thriftCodecBinder(binder).bindCustomThriftCodec(MetadataUpdatesCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(SplitCodec.class); - thriftCodecBinder(binder).bindCustomThriftCodec(TableWriteInfoCodec.class); thriftCodecBinder(binder).bindCustomThriftCodec(LocaleToLanguageTagCodec.class); thriftCodecBinder(binder).bindCustomThriftCodec(JodaDateTimeToEpochMillisThriftCodec.class); thriftCodecBinder(binder).bindCustomThriftCodec(DurationToMillisThriftCodec.class); @@ -413,9 +438,7 @@ private HttpRemoteTaskFactory createHttpRemoteTaskFactory( SmileCodec taskUpdateRequestSmileCodec, ThriftCodec taskUpdateRequestThriftCodec, JsonCodec planFragmentJsonCodec, - SmileCodec planFragmentSmileCodec, - JsonCodec metadataUpdatesJsonCodec, - SmileCodec metadataUpdatesSmileCodec) + SmileCodec planFragmentSmileCodec) { JaxrsTestingHttpProcessor jaxrsTestingHttpProcessor = new JaxrsTestingHttpProcessor(URI.create("http://fake.invalid/"), testingTaskResource, jsonMapper, thriftMapper); TestingHttpClient testingHttpClient = new TestingHttpClient(jaxrsTestingHttpProcessor.setTrace(TRACE_HTTP)); @@ -436,14 +459,11 @@ private HttpRemoteTaskFactory createHttpRemoteTaskFactory( taskUpdateRequestThriftCodec, planFragmentJsonCodec, planFragmentSmileCodec, - metadataUpdatesJsonCodec, - metadataUpdatesSmileCodec, new RemoteTaskStats(), internalCommunicationConfig, createTestMetadataManager(), new TestQueryManager(), - new HandleResolver(), - new ConnectorTypeSerdeManager(new ConnectorMetadataUpdateHandleJsonSerde())); + new HandleResolver()); } }); Injector injector = app @@ -455,7 +475,7 @@ private HttpRemoteTaskFactory createHttpRemoteTaskFactory( return injector.getInstance(HttpRemoteTaskFactory.class); } - private static void poll(BooleanSupplier success) + static void poll(BooleanSupplier success) throws InterruptedException { long failAt = System.nanoTime() + FAIL_TIMEOUT.roundTo(NANOSECONDS); @@ -469,7 +489,7 @@ private static void poll(BooleanSupplier success) } } - private static void waitUntilTaskFinish(RemoteTask task) + static void waitUntilTaskFinish(RemoteTask task) throws Exception { SettableFuture taskFinished = SettableFuture.create(); @@ -482,7 +502,7 @@ private static void waitUntilTaskFinish(RemoteTask task) taskFinished.get(); } - private enum FailureScenario + enum FailureScenario { NO_FAILURE, TASK_MISMATCH, @@ -535,6 +555,7 @@ public synchronized TaskInfo getTaskInfo( } Map taskSourceMap = new HashMap<>(); + private TaskUpdateRequest lastTaskUpdateRequest; @POST @Path("{taskId}") @@ -545,6 +566,7 @@ public synchronized TaskInfo createOrUpdateTask( TaskUpdateRequest taskUpdateRequest, @Context UriInfo uriInfo) { + this.lastTaskUpdateRequest = taskUpdateRequest; for (TaskSource source : taskUpdateRequest.getSources()) { taskSourceMap.compute(source.getPlanNodeId(), (planNodeId, taskSource) -> taskSource == null ? source : taskSource.update(source)); } @@ -561,6 +583,11 @@ public synchronized TaskSource getTaskSource(PlanNodeId planNodeId) return new TaskSource(source.getPlanNodeId(), source.getSplits(), source.getNoMoreSplitsForLifespan(), source.isNoMoreSplits()); } + public synchronized TaskUpdateRequest getLastTaskUpdateRequest() + { + return lastTaskUpdateRequest; + } + @GET @Path("{taskId}/status") @Produces({MediaType.APPLICATION_JSON, APPLICATION_THRIFT_BINARY, APPLICATION_THRIFT_COMPACT, APPLICATION_THRIFT_FB_COMPACT}) @@ -622,7 +649,6 @@ private TaskInfo buildTaskInfo() initialTaskInfo.getNoMoreSplits(), initialTaskInfo.getStats(), initialTaskInfo.isNeedsPlan(), - initialTaskInfo.getMetadataUpdates(), initialTaskInfo.getNodeId()); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java b/presto-main/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java similarity index 56% rename from presto-main-base/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java rename to presto-main/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java index b48fb2ee539c8..55e6c8ee24dc8 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/TestCustomPrestoAuthenticator.java @@ -13,33 +13,30 @@ */ package com.facebook.presto.server.security; +import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.presto.security.BasicPrincipal; import com.facebook.presto.server.MockHttpServletRequest; import com.facebook.presto.spi.security.AccessDeniedException; +import com.facebook.presto.spi.security.AuthenticatorNotApplicableException; import com.facebook.presto.spi.security.PrestoAuthenticator; import com.facebook.presto.spi.security.PrestoAuthenticatorFactory; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import jakarta.servlet.http.HttpServletRequest; import org.testng.annotations.Test; -import javax.servlet.http.HttpServletRequest; - import java.security.Principal; -import java.util.List; import java.util.Map; -import java.util.Optional; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static java.util.Collections.list; import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; public class TestCustomPrestoAuthenticator { private static final String TEST_HEADER = "test_header"; + private static final String TEST_INVALID_HEADER = "test_invalid_header"; private static final String TEST_HEADER_VALID_VALUE = "VALID"; private static final String TEST_HEADER_INVALID_VALUE = "INVALID"; private static final String TEST_FACTORY = "test_factory"; @@ -48,15 +45,9 @@ public class TestCustomPrestoAuthenticator @Test public void testPrestoAuthenticator() + throws Exception { - SecurityConfig mockSecurityConfig = new SecurityConfig(); - mockSecurityConfig.setAuthenticationTypes(ImmutableList.of(SecurityConfig.AuthenticationType.CUSTOM)); - PrestoAuthenticatorManager prestoAuthenticatorManager = new PrestoAuthenticatorManager(mockSecurityConfig); - // Add Test Presto Authenticator Factory - prestoAuthenticatorManager.addPrestoAuthenticatorFactory( - new TestingPrestoAuthenticatorFactory( - TEST_FACTORY, - TEST_HEADER_VALID_VALUE)); + PrestoAuthenticatorManager prestoAuthenticatorManager = getPrestoAuthenticatorManager(); prestoAuthenticatorManager.loadAuthenticator(TEST_FACTORY); @@ -66,52 +57,77 @@ public void testPrestoAuthenticator() TEST_REMOTE_ADDRESS, ImmutableMap.of()); - Optional principal = checkAuthentication(prestoAuthenticatorManager.getAuthenticator(), request); - assertTrue(principal.isPresent()); - assertEquals(principal.get().getName(), TEST_USER); + CustomPrestoAuthenticator customPrestoAuthenticator = new CustomPrestoAuthenticator(prestoAuthenticatorManager); + Principal principal = customPrestoAuthenticator.authenticate(request); + + assertEquals(principal.getName(), TEST_USER); + } + + @Test(expectedExceptions = AuthenticationException.class, expectedExceptionsMessageRegExp = "Access Denied: Authentication Failed!") + public void testPrestoAuthenticatorFailedAuthentication() + throws AuthenticationException + { + PrestoAuthenticatorManager prestoAuthenticatorManager = getPrestoAuthenticatorManager(); + + prestoAuthenticatorManager.loadAuthenticator(TEST_FACTORY); // Test failed authentication - request = new MockHttpServletRequest( + HttpServletRequest request = new MockHttpServletRequest( ImmutableListMultimap.of(TEST_HEADER, TEST_HEADER_INVALID_VALUE + ":" + TEST_USER), TEST_REMOTE_ADDRESS, ImmutableMap.of()); - principal = checkAuthentication(prestoAuthenticatorManager.getAuthenticator(), request); - assertFalse(principal.isPresent()); + CustomPrestoAuthenticator customPrestoAuthenticator = new CustomPrestoAuthenticator(prestoAuthenticatorManager); + customPrestoAuthenticator.authenticate(request); } - private Optional checkAuthentication(PrestoAuthenticator authenticator, HttpServletRequest request) + @Test + public void testPrestoAuthenticatorNotApplicable() { - try { - // Converting HttpServletRequest to Map - Map> headers = getHeadersMap(request); + PrestoAuthenticatorManager prestoAuthenticatorManager = getPrestoAuthenticatorManager(); - // Passing the headers Map to the authenticator - return Optional.of(authenticator.createAuthenticatedPrincipal(headers)); - } - catch (AccessDeniedException e) { - return Optional.empty(); - } + prestoAuthenticatorManager.loadAuthenticator(TEST_FACTORY); + + // Test invalid authenticator + HttpServletRequest request = new MockHttpServletRequest( + ImmutableListMultimap.of(TEST_INVALID_HEADER, TEST_HEADER_VALID_VALUE + ":" + TEST_USER), + TEST_REMOTE_ADDRESS, + ImmutableMap.of()); + + CustomPrestoAuthenticator customPrestoAuthenticator = new CustomPrestoAuthenticator(prestoAuthenticatorManager); + + assertThatThrownBy(() -> customPrestoAuthenticator.authenticate(request)) + .isInstanceOf(AuthenticationException.class) + .hasMessage(null); } - private Map> getHeadersMap(HttpServletRequest request) + private static PrestoAuthenticatorManager getPrestoAuthenticatorManager() { - return list(request.getHeaderNames()) - .stream() - .collect(toImmutableMap( - headerName -> headerName, - headerName -> list(request.getHeaders(headerName)))); + SecurityConfig mockSecurityConfig = new SecurityConfig(); + mockSecurityConfig.setAuthenticationTypes(ImmutableList.of(SecurityConfig.AuthenticationType.CUSTOM)); + PrestoAuthenticatorManager prestoAuthenticatorManager = new PrestoAuthenticatorManager(mockSecurityConfig); + + // Add Test Presto Authenticator Factory + prestoAuthenticatorManager.addPrestoAuthenticatorFactory( + new TestingPrestoAuthenticatorFactory( + TEST_FACTORY, + TEST_HEADER, + TEST_HEADER_VALID_VALUE)); + + return prestoAuthenticatorManager; } private static class TestingPrestoAuthenticatorFactory implements PrestoAuthenticatorFactory { private final String name; + private final String validHeaderName; private final String validHeaderValue; - TestingPrestoAuthenticatorFactory(String name, String validHeaderValue) + TestingPrestoAuthenticatorFactory(String name, String validHeaderName, String validHeaderValue) { this.name = requireNonNull(name, "name is null"); + this.validHeaderName = requireNonNull(validHeaderName, "validHeaderName is null"); this.validHeaderValue = requireNonNull(validHeaderValue, "validHeaderValue is null"); } @@ -125,8 +141,12 @@ public String getName() public PrestoAuthenticator create(Map config) { return (headers) -> { - // TEST_HEADER will have value of the form PART1:PART2 - String[] header = headers.get(TEST_HEADER).get(0).split(":"); + if (!headers.containsKey(this.validHeaderName)) { + throw new AuthenticatorNotApplicableException("Invalid authenticator: required headers are missing"); + } + + // HEADER will have value of the form PART1:PART2 + String[] header = headers.get(this.validHeaderName).get(0).split(":"); if (header[0].equals(this.validHeaderValue)) { return new BasicPrincipal(header[1]); diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/TestInternalAuthenticationFilter.java b/presto-main/src/test/java/com/facebook/presto/server/security/TestInternalAuthenticationFilter.java index a4f976269685c..0102cce8a1f95 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/security/TestInternalAuthenticationFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/TestInternalAuthenticationFilter.java @@ -20,19 +20,18 @@ import com.google.common.hash.Hashing; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; +import jakarta.ws.rs.container.ResourceInfo; import org.testng.annotations.Test; -import javax.ws.rs.container.ResourceInfo; - import java.lang.reflect.Method; import java.time.ZonedDateTime; import java.util.Date; import java.util.Optional; import static com.facebook.presto.server.InternalAuthenticationManager.PRESTO_INTERNAL_BEARER; +import static jakarta.servlet.http.HttpServletResponse.SC_OK; +import static jakarta.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static java.nio.charset.StandardCharsets.UTF_8; -import static javax.servlet.http.HttpServletResponse.SC_OK; -import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java b/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java index 761ac37033c91..7f447c3a0ac66 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/TestJsonWebTokenAuthenticator.java @@ -20,12 +20,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.io.Files; import io.jsonwebtoken.Jwts; +import jakarta.servlet.http.HttpServletRequest; import org.testng.annotations.AfterTest; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; -import javax.servlet.http.HttpServletRequest; - import java.io.IOException; import java.nio.file.Path; import java.security.Principal; diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java new file mode 100644 index 0000000000000..4540c655cb655 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/BaseOAuth2AuthenticationFilterTest.java @@ -0,0 +1,403 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Level; +import com.facebook.airlift.log.Logging; +import com.facebook.airlift.testing.Closeables; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Key; +import io.jsonwebtoken.impl.DefaultClaims; +import okhttp3.Cookie; +import okhttp3.CookieJar; +import okhttp3.HttpUrl; +import okhttp3.JavaNetCookieJar; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.CookieManager; +import java.net.CookieStore; +import java.net.HttpCookie; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.Duration.nanosSince; +import static com.facebook.presto.client.OkHttpUtil.setupInsecureSsl; +import static com.facebook.presto.server.security.oauth2.JwtUtil.newJwtBuilder; +import static com.facebook.presto.server.security.oauth2.OAuthWebUiCookie.OAUTH2_COOKIE; +import static com.facebook.presto.server.security.oauth2.TokenEndpointAuthMethod.CLIENT_SECRET_BASIC; +import static jakarta.servlet.http.HttpServletResponse.SC_OK; +import static jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION; +import static jakarta.ws.rs.core.HttpHeaders.LOCATION; +import static jakarta.ws.rs.core.Response.Status.OK; +import static jakarta.ws.rs.core.Response.Status.SEE_OTHER; +import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.testng.Assert.assertEquals; + +public abstract class BaseOAuth2AuthenticationFilterTest +{ + protected static final Duration TTL_ACCESS_TOKEN_IN_SECONDS = Duration.ofSeconds(5); + + protected static final String PRESTO_CLIENT_ID = "presto-client"; + protected static final String PRESTO_CLIENT_SECRET = "presto-secret"; + private static final String PRESTO_AUDIENCE = PRESTO_CLIENT_ID; + private static final String ADDITIONAL_AUDIENCE = "https://external-service.com"; + protected static final String TRUSTED_CLIENT_ID = "trusted-client"; + protected static final String TRUSTED_CLIENT_SECRET = "trusted-secret"; + private static final String UNTRUSTED_CLIENT_ID = "untrusted-client"; + private static final String UNTRUSTED_CLIENT_SECRET = "untrusted-secret"; + private static final String UNTRUSTED_CLIENT_AUDIENCE = "https://untrusted.com"; + + private final Logging logging = Logging.initialize(); + protected final OkHttpClient httpClient; + protected TestingHydraIdentityProvider hydraIdP; + + private TestingPrestoServer server; + + private SimpleProxyServer simpleProxy; + private URI uiUri; + + private URI proxyURI; + + protected BaseOAuth2AuthenticationFilterTest() + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClientBuilder.followRedirects(false); + httpClient = httpClientBuilder.build(); + } + + static void waitForNodeRefresh(TestingPrestoServer server) + throws InterruptedException + { + long start = System.nanoTime(); + while (server.refreshNodes().getActiveNodes().size() < 1) { + assertLessThan(nanosSince(start), new com.facebook.airlift.units.Duration(10, SECONDS)); + MILLISECONDS.sleep(10); + } + } + + @BeforeClass + public void setup() + throws Exception + { + logging.setLevel(OAuth2Service.class.getName(), Level.DEBUG); + + hydraIdP = getHydraIdp(); + String idpUrl = "https://localhost:" + hydraIdP.getAuthPort(); + server = new TestingPrestoServer(getOAuth2Config(idpUrl)); + server.getInstance(Key.get(OAuth2Client.class)).load(); + waitForNodeRefresh(server); + // Due to problems with the Presto OSS project related to the AuthenticationFilter we have to run Presto behind a Proxy and terminate SSL at the proxy. + simpleProxy = new SimpleProxyServer(server.getHttpBaseUrl()); + MILLISECONDS.sleep(1000); + proxyURI = URI.create("https://127.0.0.1:" + simpleProxy.getHttpsBaseUrl().getPort()); + uiUri = proxyURI.resolve("/"); + + hydraIdP.createClient( + PRESTO_CLIENT_ID, + PRESTO_CLIENT_SECRET, + CLIENT_SECRET_BASIC, + ImmutableList.of(PRESTO_AUDIENCE, ADDITIONAL_AUDIENCE), + proxyURI + "/oauth2/callback"); + hydraIdP.createClient( + TRUSTED_CLIENT_ID, + TRUSTED_CLIENT_SECRET, + CLIENT_SECRET_BASIC, + ImmutableList.of(TRUSTED_CLIENT_ID), + proxyURI + "/oauth2/callback"); + hydraIdP.createClient( + UNTRUSTED_CLIENT_ID, + UNTRUSTED_CLIENT_SECRET, + CLIENT_SECRET_BASIC, + ImmutableList.of(UNTRUSTED_CLIENT_AUDIENCE), + "https://untrusted.com/callback"); + } + + protected abstract ImmutableMap getOAuth2Config(String idpUrl) + throws IOException; + + protected abstract TestingHydraIdentityProvider getHydraIdp() + throws Exception; + + @AfterClass(alwaysRun = true) + public void tearDown() + throws Exception + { + Closeables.closeAll(server, hydraIdP, simpleProxy); + } + + @Test + public void testUnauthorizedApiCall() + throws IOException + { + try (Response response = httpClient + .newCall(apiCall().build()) + .execute()) { + assertUnauthorizedResponse(response); + } + } + + @Test + public void testUnauthorizedUICall() + throws IOException + { + try (Response response = httpClient + .newCall(uiCall().build()) + .execute()) { + assertRedirectResponse(response); + } + } + + @Test + public void testUnsignedToken() + throws NoSuchAlgorithmException, IOException + { + KeyPairGenerator keyGenerator = KeyPairGenerator.getInstance("RSA"); + keyGenerator.initialize(4096); + long now = Instant.now().getEpochSecond(); + String token = newJwtBuilder() + .setHeaderParam("alg", "RS256") + .setHeaderParam("kid", "public:f467aa08-1c1b-4cde-ba45-84b0ef5d2ba8") + .setHeaderParam("typ", "JWT") + .setClaims( + new DefaultClaims( + ImmutableMap.builder() + .put("aud", ImmutableList.of()) + .put("client_id", PRESTO_CLIENT_ID) + .put("exp", now + 60L) + .put("iat", now) + .put("iss", "https://hydra:4444/") + .put("jti", UUID.randomUUID()) + .put("nbf", now) + .put("scp", ImmutableList.of("openid")) + .put("sub", "foo@bar.com") + .build())) + .signWith(keyGenerator.generateKeyPair().getPrivate()) + .compact(); + try (Response response = httpClientWithOAuth2Cookie(token, false) + .newCall(apiCall().build()) + .execute()) { + assertUnauthorizedResponse(response); + } + } + + @Test + public void testTokenWithInvalidAudience() + throws IOException + { + String token = hydraIdP.getToken(UNTRUSTED_CLIENT_ID, UNTRUSTED_CLIENT_SECRET, ImmutableList.of(UNTRUSTED_CLIENT_AUDIENCE)); + try (Response response = httpClientWithOAuth2Cookie(token, false) + .newCall(apiCall().build()) + .execute()) { + assertUnauthorizedResponse(response); + } + } + + @Test + public void testTokenFromTrustedClient() + throws IOException + { + String token = hydraIdP.getToken(TRUSTED_CLIENT_ID, TRUSTED_CLIENT_SECRET, ImmutableList.of(TRUSTED_CLIENT_ID)); + assertUICallWithCookie(token); + } + + @Test + public void testTokenWithMultipleAudiences() + throws IOException + { + String token = hydraIdP.getToken(PRESTO_CLIENT_ID, PRESTO_CLIENT_SECRET, ImmutableList.of(PRESTO_AUDIENCE, ADDITIONAL_AUDIENCE)); + assertUICallWithCookie(token); + } + + @Test + public void testSuccessfulFlow() + throws Exception + { + // create a new HttpClient which follows redirects and give access to cookies + CookieManager cookieManager = new CookieManager(); + CookieStore cookieStore = cookieManager.getCookieStore(); + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + OkHttpClient httpClient = httpClientBuilder + .followRedirects(true) + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + + assertThat(cookieStore.get(uiUri)).isEmpty(); + + // access UI and follow redirects in order to get OAuth2 cookie + Response response = httpClient.newCall( + new Request.Builder() + .url(uiUri.toURL()) + .get() + .build()) + .execute(); + + assertEquals(response.code(), SC_OK); + + Optional oauth2Cookie = cookieStore.get(uiUri) + .stream() + .filter(cookie -> cookie.getName().equals(OAUTH2_COOKIE)) + .findFirst(); + assertThat(oauth2Cookie).isNotEmpty(); + assertOAuth2Cookie(oauth2Cookie.get()); + assertUICallWithCookie(oauth2Cookie.get().getValue()); + } + + @Test + public void testExpiredAccessToken() + throws Exception + { + String token = hydraIdP.getToken(PRESTO_CLIENT_ID, PRESTO_CLIENT_SECRET, ImmutableList.of(PRESTO_AUDIENCE)); + assertUICallWithCookie(token); + Thread.sleep(TTL_ACCESS_TOKEN_IN_SECONDS.plusSeconds(1).toMillis()); // wait for the token expiration = ttl of access token + 1 sec + try (Response response = httpClientWithOAuth2Cookie(token, false).newCall(apiCall().build()).execute()) { + assertUnauthorizedResponse(response); + } + } + + private Request.Builder uiCall() + { + return new Request.Builder() + .url(proxyURI.resolve("/").toString()) + .get(); + } + + private Request.Builder apiCall() + { + return new Request.Builder() + .url(proxyURI.resolve("/v1/cluster").toString()) + .get(); + } + + private void assertOAuth2Cookie(HttpCookie cookie) + { + assertThat(cookie.getName()).isEqualTo(OAUTH2_COOKIE); + assertThat(cookie.getDomain()).isIn(proxyURI.getHost()); + assertThat(cookie.getPath()).isEqualTo("/"); + assertThat(cookie.getSecure()).isTrue(); + assertThat(cookie.isHttpOnly()).isTrue(); + assertThat(cookie.getMaxAge()).isLessThanOrEqualTo(TTL_ACCESS_TOKEN_IN_SECONDS.getSeconds()); + validateAccessToken(cookie.getValue()); + } + + protected void validateAccessToken(String cookieValue) + { + Request request = new Request.Builder().url("https://localhost:" + hydraIdP.getAuthPort() + "/userinfo").addHeader(AUTHORIZATION, "Bearer " + cookieValue).build(); + try (Response response = httpClient.newCall(request).execute()) { + assertThat(response.body()).isNotNull(); + DefaultClaims claims = new DefaultClaims(JsonCodec.mapJsonCodec(String.class, Object.class).fromJson(response.body().bytes())); + assertThat(claims.getSubject()).isEqualTo("foo@bar.com"); + } + catch (IOException e) { + fail("Exception while calling /userinfo", e); + } + } + + private void assertUICallWithCookie(String cookieValue) + throws IOException + { + OkHttpClient httpClient = httpClientWithOAuth2Cookie(cookieValue, true); + // pass access token in Presto UI cookie + try (Response response = httpClient.newCall(uiCall().build()) + .execute()) { + assertThat(response.code()).isEqualTo(OK.getStatusCode()); + } + } + + @SuppressWarnings("NullableProblems") + private OkHttpClient httpClientWithOAuth2Cookie(String cookieValue, boolean followRedirects) + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClientBuilder.followRedirects(followRedirects); + httpClientBuilder.cookieJar(new CookieJar() + { + @Override + public void saveFromResponse(HttpUrl url, List cookies) + { + } + + @Override + public List loadForRequest(HttpUrl url) + { + Cookie cookie = new Cookie.Builder() + .domain(proxyURI.getHost()) + .path("/") + .name(OAUTH2_COOKIE) + .value(cookieValue) + .secure() + .build(); + return ImmutableList.of(cookie); + } + }); + return httpClientBuilder.build(); + } + + private void assertRedirectResponse(Response response) + throws MalformedURLException + { + assertThat(response.code()).isEqualTo(SEE_OTHER.getStatusCode()); + assertRedirectUrl(response.header(LOCATION)); + } + + private void assertUnauthorizedResponse(Response response) + throws IOException + { + assertThat(response.code()).isEqualTo(UNAUTHORIZED.getStatusCode()); + assertThat(response.body()).isNotNull(); + // NOTE that our errors come in looking like an HTML page since we don't do anything special on the server side so it just is like that. + assertThat(response.body().string()).contains("Invalid Credentials"); + } + + private void assertRedirectUrl(String redirectUrl) + throws MalformedURLException + { + assertThat(redirectUrl).isNotNull(); + URL location = new URL(redirectUrl); + HttpUrl url = HttpUrl.parse(redirectUrl); + assertThat(url).isNotNull(); + assertThat(location.getProtocol()).isEqualTo("https"); + assertThat(location.getHost()).isEqualTo("localhost"); + assertThat(location.getPort()).isEqualTo(hydraIdP.getAuthPort()); + assertThat(location.getPath()).isEqualTo("/oauth2/auth"); + assertThat(url.queryParameterValues("response_type")).isEqualTo(ImmutableList.of("code")); + assertThat(url.queryParameterValues("scope")).isEqualTo(ImmutableList.of("openid")); + assertThat(url.queryParameterValues("redirect_uri")).isEqualTo(ImmutableList.of(proxyURI + "/oauth2/callback")); + assertThat(url.queryParameterValues("client_id")).isEqualTo(ImmutableList.of(PRESTO_CLIENT_ID)); + assertThat(url.queryParameterValues("state")).isNotNull(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/SimpleProxyServer.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/SimpleProxyServer.java new file mode 100644 index 0000000000000..99e8bb43c4b8d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/SimpleProxyServer.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.server.HttpServerConfig; +import com.facebook.airlift.http.server.HttpServerInfo; +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.node.NodeInfo; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.ws.rs.core.UriBuilder; +import okhttp3.Headers; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; + +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.client.OkHttpUtil.setupInsecureSsl; +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getFullRequestURL; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.common.net.HttpHeaders.HOST; +import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR; +import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class SimpleProxyServer + implements Closeable +{ + private final TestingHttpServer server; + + public SimpleProxyServer(URI forwardBaseURI) + throws Exception + { + server = createSimpleProxyServer(forwardBaseURI); + server.start(); + } + + @Override + public void close() + throws IOException + { + try { + server.stop(); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } + + public URI getHttpsBaseUrl() + { + return server.getHttpServerInfo().getHttpsUri(); + } + + private TestingHttpServer createSimpleProxyServer(URI forwardBaseURI) + throws IOException + { + NodeInfo nodeInfo = new NodeInfo("test"); + HttpServerConfig config = new HttpServerConfig() + .setHttpPort(0) + .setHttpsEnabled(true) + .setHttpsPort(0) + .setKeystorePath(Resources.getResource("cert/localhost.pem").getPath()); + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); + return new TestingHttpServer(httpServerInfo, nodeInfo, config, new SimpleProxy(forwardBaseURI), ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); + } + + private class SimpleProxy + extends HttpServlet + { + private final OkHttpClient httpClient; + private final URI forwardBaseURI; + + private final Logger logger = Logger.get(SimpleProxy.class); + + public SimpleProxy(URI forwardBaseURI) + { + this.forwardBaseURI = forwardBaseURI; + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClient = httpClientBuilder + .followRedirects(false) + .connectTimeout(100, SECONDS) + .writeTimeout(100, SECONDS) + .readTimeout(100, SECONDS) + .build(); + } + + @Override + protected void service(HttpServletRequest request, HttpServletResponse servletResponse) + throws ServletException, IOException + { + UriBuilder requestUriBuilder = UriBuilder.fromUri(getFullRequestURL(request)); + requestUriBuilder + .scheme("http") + .host(forwardBaseURI.getHost()) + .port(forwardBaseURI.getPort()); + + String hostHeader = new StringBuilder().append(request.getRemoteHost()).append(":").append(request.getLocalPort()).toString(); + Cookie[] cookies = Optional.ofNullable(request.getCookies()).orElse(new Cookie[0]); + String requestUri = requestUriBuilder.build().toString(); + Request.Builder reqBuilder = new Request.Builder() + .url(requestUri) + .addHeader(X_FORWARDED_PROTO, "https") + .addHeader(X_FORWARDED_FOR, request.getRemoteAddr()) + .addHeader(HOST, hostHeader) + .get(); + + if (cookies.length > 0) { + for (Cookie cookie : cookies) { + reqBuilder.addHeader("Cookie", cookie.getName() + "=" + cookie.getValue()); + } + } + Response response; + try { + response = httpClient.newCall(reqBuilder.build()).execute(); + servletResponse.setStatus(response.code()); + + Headers responseHeaders = response.headers(); + responseHeaders.names().stream().forEach(headerName -> { + // Headers can have multiple values + List headerValues = responseHeaders.values(headerName); + headerValues.forEach(headerValue -> { + servletResponse.addHeader(headerName, headerValue); + }); + }); + + //copy the response body to the servlet response. + InputStream is = response.body().byteStream(); + OutputStream os = servletResponse.getOutputStream(); + byte[] buffer = new byte[10 * 1024]; + int read; + while ((read = is.read(buffer)) != -1) { + os.write(buffer, 0, read); + } + } + catch (Exception e) { + logger.error(format("Encountered an error while proxying request to %s", requestUri), e); + servletResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + } + + // This is just to help iterate changes that might need to be made in the future to the Simple Proxy Server for test purposes. + // Waiting for the tests to run can be really slow so having this helper here is nice if you want quicker feedback. + private static void runTestServer() + throws Exception + { + SimpleProxyServer test = new SimpleProxyServer(new URI("")); + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + OkHttpClient client = httpClientBuilder.build(); + Request.Builder req = new Request.Builder().url(test.getHttpsBaseUrl().resolve("/v1/query").toString()).get(); + Logger logger = Logger.get("Run Test Server Debug Helper"); + try { + client.newCall(req.build()).execute(); + } + catch (Exception e) { + logger.error(e); + } + + test.close(); + } + + public static void main(String[] args) + throws Exception + { + runTestServer(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestDualAuthenticationFilterWithOAuth.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestDualAuthenticationFilterWithOAuth.java new file mode 100644 index 0000000000000..089c6b9c80aaf --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestDualAuthenticationFilterWithOAuth.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.log.Level; +import com.facebook.airlift.log.Logging; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.inject.Key; +import okhttp3.Cookie; +import okhttp3.CookieJar; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import java.util.Base64; +import java.util.List; + +import static com.facebook.airlift.testing.Assertions.assertLessThan; +import static com.facebook.airlift.units.Duration.nanosSince; +import static com.facebook.presto.client.OkHttpUtil.setupInsecureSsl; +import static com.facebook.presto.server.security.oauth2.OAuthWebUiCookie.OAUTH2_COOKIE; +import static com.facebook.presto.server.security.oauth2.TokenEndpointAuthMethod.CLIENT_SECRET_BASIC; +import static com.google.common.net.HttpHeaders.AUTHORIZATION; +import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; +import static jakarta.ws.rs.core.Response.Status.OK; +import static jakarta.ws.rs.core.Response.Status.UNAUTHORIZED; +import static java.io.File.createTempFile; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestDualAuthenticationFilterWithOAuth +{ + protected static final Duration TTL_ACCESS_TOKEN_IN_SECONDS = Duration.ofSeconds(5); + protected static final String PRESTO_CLIENT_ID = "presto-client"; + protected static final String PRESTO_CLIENT_SECRET = "presto-secret"; + private static final String PRESTO_AUDIENCE = PRESTO_CLIENT_ID; + private static final String ADDITIONAL_AUDIENCE = "https://external-service.com"; + protected static final String TRUSTED_CLIENT_ID = "trusted-client"; + + private final Logging logging = Logging.initialize(); + protected final OkHttpClient httpClient; + protected TestingHydraIdentityProvider hydraIdP; + + private TestingPrestoServer server; + + private SimpleProxyServer simpleProxy; + + private URI proxyURI; + + protected TestDualAuthenticationFilterWithOAuth() + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClientBuilder.followRedirects(false); + httpClient = httpClientBuilder.build(); + } + + static void waitForNodeRefresh(TestingPrestoServer server) + throws InterruptedException + { + long start = System.nanoTime(); + while (server.refreshNodes().getActiveNodes().size() < 1) { + assertLessThan(nanosSince(start), new com.facebook.airlift.units.Duration(10, SECONDS)); + MILLISECONDS.sleep(10); + } + } + + protected ImmutableMap getConfig(String idpUrl) + throws IOException + { + return ImmutableMap.builder() + .put("http-server.authentication.allow-forwarded-https", "true") + .put("http-server.authentication.type", "OAUTH2,PASSWORD") + .put("http-server.authentication.oauth2.issuer", "https://localhost:4444/") + .put("http-server.authentication.oauth2.auth-url", idpUrl + "/oauth2/auth") + .put("http-server.authentication.oauth2.token-url", idpUrl + "/oauth2/token") + .put("http-server.authentication.oauth2.jwks-url", idpUrl + "/.well-known/jwks.json") + .put("http-server.authentication.oauth2.client-id", PRESTO_CLIENT_ID) + .put("http-server.authentication.oauth2.client-secret", PRESTO_CLIENT_SECRET) + .put("http-server.authentication.oauth2.additional-audiences", TRUSTED_CLIENT_ID) + .put("http-server.authentication.oauth2.max-clock-skew", "0s") + .put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)(@.*)?") + .put("http-server.authentication.oauth2.oidc.discovery", "false") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .put("oauth2-jwk.http-client.trust-store-path", Resources.getResource("cert/localhost.pem").getPath()) + .build(); + } + + @BeforeClass + public void setup() + throws Exception + { + logging.setLevel(OAuth2Service.class.getName(), Level.DEBUG); + hydraIdP = new TestingHydraIdentityProvider(TTL_ACCESS_TOKEN_IN_SECONDS, true, false); + hydraIdP.start(); + String idpUrl = "https://localhost:" + hydraIdP.getAuthPort(); + server = new TestingPrestoServer(getConfig(idpUrl)); + server.getInstance(Key.get(OAuth2Client.class)).load(); + waitForNodeRefresh(server); + // Due to problems with the Presto OSS project related to the AuthenticationFilter we have to run Presto behind a Proxy and terminate SSL at the proxy. + simpleProxy = new SimpleProxyServer(server.getHttpBaseUrl()); + MILLISECONDS.sleep(1000); + proxyURI = URI.create("https://127.0.0.1:" + simpleProxy.getHttpsBaseUrl().getPort()); + + hydraIdP.createClient( + PRESTO_CLIENT_ID, + PRESTO_CLIENT_SECRET, + CLIENT_SECRET_BASIC, + ImmutableList.of(PRESTO_AUDIENCE, ADDITIONAL_AUDIENCE), + simpleProxy.getHttpsBaseUrl() + "/oauth2/callback"); + } + + @Test + public void testExpiredOAuthToken() + throws Exception + { + String token = hydraIdP.getToken(PRESTO_CLIENT_ID, PRESTO_CLIENT_SECRET, ImmutableList.of(PRESTO_AUDIENCE)); + assertUICallWithCookie(token); + Thread.sleep(TTL_ACCESS_TOKEN_IN_SECONDS.plusSeconds(1).toMillis()); // wait for the token expiration = ttl of access token + 1 sec + try (Response response = httpClientWithOAuth2Cookie(token, false).newCall(apiCall().build()).execute()) { + assertUnauthorizedOAuthOnlyHeaders(response); + } + } + + @Test + public void testNoAuth() + throws Exception + { + try (Response response = httpClient + .newCall(apiCall().build()) + .execute()) { + assertAllUnauthorizedHeaders(response); + } + } + + @Test + public void testInvalidBasicAuth() + throws Exception + { + String userPass = "test:password"; + String basicAuth = "Basic " + Base64.getEncoder().encodeToString(userPass.getBytes()); + try (Response response = httpClient + .newCall(apiCall().addHeader(AUTHORIZATION, basicAuth).build()) + .execute()) { + assertAllUnauthorizedHeaders(response); + } + } + + private Request.Builder apiCall() + { + return new Request.Builder() + .url(proxyURI.resolve("/v1/cluster").toString()) + .get(); + } + + private void assertUnauthorizedOAuthOnlyHeaders(Response response) + throws IOException + { + String redirectServer = "x_redirect_server=\"" + proxyURI.resolve("/oauth2/token/initiate/"); + String tokenServer = "x_token_server=\"" + proxyURI.resolve("/oauth2/token/"); + assertUnauthorizedResponse(response); + List headers = response.headers(WWW_AUTHENTICATE); + assertThat(headers.size()).isEqualTo(1); + assertThat(headers.get(0)).contains(tokenServer, redirectServer); + } + + private void assertAllUnauthorizedHeaders(Response response) + throws IOException + { + String redirectServer = "x_redirect_server=\"" + proxyURI.resolve("/oauth2/token/initiate/").toString(); + String tokenServer = "x_token_server=\"" + proxyURI.resolve("/oauth2/token/"); + assertUnauthorizedResponse(response); + List headers = response.headers(WWW_AUTHENTICATE); + assertThat(headers.size()).isEqualTo(2); + assertThat(headers.stream().allMatch(h -> + h.contains("Basic realm=\"Presto\"") || + (h.contains(redirectServer) && h.contains(tokenServer)) + )).isTrue(); + } + + private void assertUICallWithCookie(String cookieValue) + throws IOException + { + OkHttpClient httpClient = httpClientWithOAuth2Cookie(cookieValue, true); + // pass access token in Presto UI cookie + try (Response response = httpClient.newCall(apiCall().build()) + .execute()) { + assertThat(response.code()).isEqualTo(OK.getStatusCode()); + } + } + + private OkHttpClient httpClientWithOAuth2Cookie(String cookieValue, boolean followRedirects) + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClientBuilder.followRedirects(followRedirects); + httpClientBuilder.cookieJar(new CookieJar() + { + @Override + public void saveFromResponse(HttpUrl url, List cookies) + { + } + + @Override + public List loadForRequest(HttpUrl url) + { + return ImmutableList.of(new Cookie.Builder() + .domain(proxyURI.getHost()) + .path("/") + .name(OAUTH2_COOKIE) + .value(cookieValue) + .secure() + .build()); + } + }); + return httpClientBuilder.build(); + } + + private void assertUnauthorizedResponse(Response response) + throws IOException + { + assertThat(response.code()).isEqualTo(UNAUTHORIZED.getStatusCode()); + assertThat(response.body()).isNotNull(); + // NOTE that our errors come in looking like an HTML page since we don't do anything special on the server side so it just is like that. + assertThat(response.body().string()).contains("Invalid Credentials"); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java new file mode 100644 index 0000000000000..e29442b667523 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java @@ -0,0 +1,274 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.units.Duration; +import com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair; +import com.nimbusds.jose.KeyLengthException; +import io.jsonwebtoken.ExpiredJwtException; +import io.jsonwebtoken.Jwts; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.net.URI; +import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.util.Base64; +import java.util.Calendar; +import java.util.Date; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import static com.facebook.airlift.units.Duration.succinctDuration; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.accessAndRefreshTokens; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.withAccessAndRefreshTokens; +import static java.time.temporal.ChronoUnit.MILLIS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public class TestJweTokenSerializer +{ + @Test + public void testSerialization() + throws Exception + { + JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), randomEncodedSecret()); + + Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); + String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); + TokenPair deserializedTokenPair = serializer.deserialize(serializedTokenPair); + + assertThat(deserializedTokenPair.getAccessToken()).isEqualTo("access_token"); + assertThat(deserializedTokenPair.getExpiration()).isEqualTo(expiration); + assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token")); + } + + @Test(dataProvider = "wrongSecretsProvider") + public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret) + { + assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Decryption failed") + .hasStackTraceContaining("Tag mismatch!"); + } + + @DataProvider + public Object[][] wrongSecretsProvider() + { + return new Object[][] { + {randomEncodedSecret(), randomEncodedSecret()}, + {randomEncodedSecret(16), randomEncodedSecret(24)}, + {null, null}, // This will generate two different secret keys + {null, randomEncodedSecret()}, + {randomEncodedSecret(), null} + }; + } + + @Test + public void testSerializationDeserializationRoundTripWithDifferentKeyLengths() + throws Exception + { + for (int keySize : new int[] {16, 24, 32}) { + String secret = randomEncodedSecret(keySize); + assertRoundTrip(secret, secret); + } + } + + @Test + public void testSerializationFailsWithWrongKeySize() + { + for (int wrongKeySize : new int[] {8, 64, 128}) { + String tooShortSecret = randomEncodedSecret(wrongKeySize); + assertThatThrownBy(() -> assertRoundTrip(tooShortSecret, tooShortSecret)) + .hasStackTraceContaining("The Key Encryption Key length must be 128 bits (16 bytes), 192 bits (24 bytes) or 256 bits (32 bytes)"); + } + } + + private void assertRoundTrip(String serializerSecret, String deserializerSecret) + throws Exception + { + assertRoundTrip(Optional.of(serializerSecret), Optional.of(deserializerSecret)); + } + + private void assertRoundTrip(Optional serializerSecret, Optional deserializerSecret) + throws Exception + { + JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), serializerSecret); + JweTokenSerializer deserializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), deserializerSecret); + Date expiration = new Calendar.Builder().setDate(2023, 6, 22).build().getTime(); + TokenPair tokenPair = withAccessAndRefreshTokens(randomEncodedSecret(), expiration, randomEncodedSecret()); + TokenPair postSerPair = deserializer.deserialize(serializer.serialize(tokenPair)); + assertEquals(tokenPair.getAccessToken(), postSerPair.getAccessToken()); + assertEquals(tokenPair.getRefreshToken(), postSerPair.getRefreshToken()); + assertEquals(tokenPair.getExpiration(), postSerPair.getExpiration()); + } + + @Test + public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension() + throws Exception + { + TestingClock clock = new TestingClock(); + JweTokenSerializer serializer = tokenSerializer( + clock, + succinctDuration(12, MINUTES), + randomEncodedSecret()); + Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); + String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); + clock.advanceBy(succinctDuration(10, MINUTES)); + TokenPair deserializedTokenPair = serializer.deserialize(serializedTokenPair); + + assertThat(deserializedTokenPair.getAccessToken()).isEqualTo("access_token"); + assertThat(deserializedTokenPair.getExpiration()).isEqualTo(expiration); + assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token")); + } + + @Test + public void testTokenDeserializationAfterTimeoutAndExpirationExtension() + throws Exception + { + TestingClock clock = new TestingClock(); + + JweTokenSerializer serializer = tokenSerializer( + clock, + succinctDuration(12, MINUTES), + randomEncodedSecret()); + Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); + String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); + + clock.advanceBy(succinctDuration(20, MINUTES)); + assertThatThrownBy(() -> serializer.deserialize(serializedTokenPair)) + .isExactlyInstanceOf(ExpiredJwtException.class); + } + + private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration) + throws GeneralSecurityException, KeyLengthException + { + return new JweTokenSerializer( + new RefreshTokensConfig(), + new Oauth2ClientStub(), + "presto_coordinator_test_version", + "presto_coordinator", + "sub", + clock, + tokenExpiration); + } + + private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, String encodedSecretKey) + throws GeneralSecurityException, KeyLengthException + { + return tokenSerializer(clock, tokenExpiration, Optional.of(encodedSecretKey)); + } + + private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, Optional secretKey) + throws NoSuchAlgorithmException, KeyLengthException + { + RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig(); + secretKey.ifPresent(refreshTokensConfig::setSecretKey); + return new JweTokenSerializer( + refreshTokensConfig, + new Oauth2ClientStub(), + "presto_coordinator_test_version", + "presto_coordinator", + "sub", + clock, + tokenExpiration); + } + + private static String randomEncodedSecret() + { + return randomEncodedSecret(24); + } + + private static String randomEncodedSecret(int length) + { + Random random = new SecureRandom(); + final byte[] buffer = new byte[length]; + random.nextBytes(buffer); + return Base64.getEncoder().encodeToString(buffer); + } + + static class Oauth2ClientStub + implements OAuth2Client + { + private final Map claims = Jwts.claims() + .setSubject("user"); + + @Override + public void load() + { + } + + @Override + public Request createAuthorizationRequest(String state, URI callbackUri) + { + throw new UnsupportedOperationException("operation is not yet supported"); + } + + @Override + public Response getOAuth2Response(String code, URI callbackUri, Optional nonce) + { + throw new UnsupportedOperationException("operation is not yet supported"); + } + + @Override + public Optional> getClaims(String accessToken) + { + return Optional.of(claims); + } + + @Override + public Response refreshTokens(String refreshToken) + { + throw new UnsupportedOperationException("operation is not yet supported"); + } + } + + private static class TestingClock + extends Clock + { + private Instant currentTime = ZonedDateTime.of(2022, 5, 6, 10, 15, 0, 0, ZoneId.systemDefault()).toInstant(); + + @Override + public ZoneId getZone() + { + return ZoneId.systemDefault(); + } + + @Override + public Clock withZone(ZoneId zone) + { + return this; + } + + @Override + public Instant instant() + { + return currentTime; + } + + public void advanceBy(Duration currentTimeDelta) + { + this.currentTime = currentTime.plus(currentTimeDelta.toMillis(), MILLIS); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithJwt.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithJwt.java new file mode 100644 index 0000000000000..71806b590bc00 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithJwt.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; + +import java.io.IOException; + +import static java.io.File.createTempFile; + +public class TestOAuth2AuthenticationFilterWithJwt + extends BaseOAuth2AuthenticationFilterTest +{ + @Override + protected ImmutableMap getOAuth2Config(String idpUrl) + throws IOException + { + return ImmutableMap.builder() + .put("http-server.authentication.allow-forwarded-https", "true") + .put("http-server.authentication.type", "OAUTH2") + .put("http-server.authentication.oauth2.issuer", "https://localhost:4444/") + .put("http-server.authentication.oauth2.auth-url", idpUrl + "/oauth2/auth") + .put("http-server.authentication.oauth2.token-url", idpUrl + "/oauth2/token") + .put("http-server.authentication.oauth2.jwks-url", idpUrl + "/.well-known/jwks.json") + .put("http-server.authentication.oauth2.client-id", PRESTO_CLIENT_ID) + .put("http-server.authentication.oauth2.client-secret", PRESTO_CLIENT_SECRET) + .put("http-server.authentication.oauth2.additional-audiences", TRUSTED_CLIENT_ID) + .put("http-server.authentication.oauth2.max-clock-skew", "0s") + .put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)(@.*)?") + .put("http-server.authentication.oauth2.oidc.discovery", "false") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .put("oauth2-jwk.http-client.trust-store-path", Resources.getResource("cert/localhost.pem").getPath()) + .build(); + } + + @Override + protected TestingHydraIdentityProvider getHydraIdp() + throws Exception + { + TestingHydraIdentityProvider hydraIdP = new TestingHydraIdentityProvider(TTL_ACCESS_TOKEN_IN_SECONDS, true, false); + hydraIdP.start(); + + return hydraIdP; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithOpaque.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithOpaque.java new file mode 100644 index 0000000000000..114b9b0ed20d1 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2AuthenticationFilterWithOpaque.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; + +import java.io.IOException; + +import static java.io.File.createTempFile; + +public class TestOAuth2AuthenticationFilterWithOpaque + extends BaseOAuth2AuthenticationFilterTest +{ + @Override + protected ImmutableMap getOAuth2Config(String idpUrl) + throws IOException + { + return ImmutableMap.builder() + .put("http-server.authentication.allow-forwarded-https", "true") + .put("http-server.authentication.type", "OAUTH2") + .put("http-server.authentication.oauth2.issuer", "https://localhost:4444/") + .put("http-server.authentication.oauth2.auth-url", idpUrl + "/oauth2/auth") + .put("http-server.authentication.oauth2.token-url", idpUrl + "/oauth2/token") + .put("http-server.authentication.oauth2.jwks-url", idpUrl + "/.well-known/jwks.json") + .put("http-server.authentication.oauth2.userinfo-url", idpUrl + "/userinfo") + .put("http-server.authentication.oauth2.client-id", PRESTO_CLIENT_ID) + .put("http-server.authentication.oauth2.client-secret", PRESTO_CLIENT_SECRET) + // This is necessary as Hydra does not return `sub` from `/userinfo` for client credential grants. + .put("http-server.authentication.oauth2.principal-field", "iss") + .put("http-server.authentication.oauth2.additional-audiences", TRUSTED_CLIENT_ID) + .put("http-server.authentication.oauth2.max-clock-skew", "0s") + .put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)(@.*)?") + .put("http-server.authentication.oauth2.oidc.discovery", "false") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .put("oauth2-jwk.http-client.trust-store-path", Resources.getResource("cert/localhost.pem").getPath()) + .build(); + } + + @Override + protected TestingHydraIdentityProvider getHydraIdp() + throws Exception + { + TestingHydraIdentityProvider hydraIdP = new TestingHydraIdentityProvider(TTL_ACCESS_TOKEN_IN_SECONDS, false, false); + hydraIdP.start(); + + return hydraIdP; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java new file mode 100644 index 0000000000000..bb2735b95e7f6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Config.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.units.Duration; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestOAuth2Config +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(OAuth2Config.class) + .setStateKey(null) + .setIssuer(null) + .setClientId(null) + .setClientSecret(null) + .setScopes("openid") + .setChallengeTimeout(new Duration(15, MINUTES)) + .setPrincipalField("sub") + .setGroupsField(null) + .setAdditionalAudiences("") + .setMaxClockSkew(new Duration(1, MINUTES)) + .setUserMappingPattern(null) + .setUserMappingFile(null) + .setEnableRefreshTokens(false) + .setEnableDiscovery(true)); + } + + @Test + public void testExplicitPropertyMappings() + throws IOException + { + Path userMappingFile = Files.createTempFile(null, null); + Map properties = ImmutableMap.builder() + .put("http-server.authentication.oauth2.state-key", "key-secret") + .put("http-server.authentication.oauth2.issuer", "http://127.0.0.1:9000/oauth2") + .put("http-server.authentication.oauth2.client-id", "another-consumer") + .put("http-server.authentication.oauth2.client-secret", "consumer-secret") + .put("http-server.authentication.oauth2.scopes", "email,offline") + .put("http-server.authentication.oauth2.principal-field", "some-field") + .put("http-server.authentication.oauth2.groups-field", "groups") + .put("http-server.authentication.oauth2.additional-audiences", "test-aud1,test-aud2") + .put("http-server.authentication.oauth2.challenge-timeout", "90s") + .put("http-server.authentication.oauth2.max-clock-skew", "15s") + .put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)@something") + .put("http-server.authentication.oauth2.user-mapping.file", userMappingFile.toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.oidc.discovery", "false") + .build(); + + OAuth2Config expected = new OAuth2Config() + .setStateKey("key-secret") + .setIssuer("http://127.0.0.1:9000/oauth2") + .setClientId("another-consumer") + .setClientSecret("consumer-secret") + .setScopes("email, offline") + .setPrincipalField("some-field") + .setGroupsField("groups") + .setAdditionalAudiences("test-aud1,test-aud2") + .setChallengeTimeout(new Duration(90, SECONDS)) + .setMaxClockSkew(new Duration(15, SECONDS)) + .setUserMappingPattern("(.*)@something") + .setUserMappingFile(userMappingFile.toFile()) + .setEnableRefreshTokens(true) + .setEnableDiscovery(false); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Utils.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Utils.java new file mode 100644 index 0000000000000..119fd7213bd96 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOAuth2Utils.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.presto.server.MockHttpServletRequest; +import com.google.common.collect.ImmutableListMultimap; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.ws.rs.core.UriBuilder; +import org.testng.annotations.Test; + +import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getSchemeUriBuilder; +import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static org.testng.Assert.assertEquals; + +public class TestOAuth2Utils +{ + @Test + public void testGetSchemeUriBuilderNoProtoHeader() + { + HttpServletRequest request = new MockHttpServletRequest( + ImmutableListMultimap.builder() + .build(), + "testRemote", + "http://www.example.com"); + + UriBuilder builder = getSchemeUriBuilder(request); + assertEquals(builder.build().getScheme(), "http"); + } + + @Test + public void testGetSchemeUriBuilderProtoHeader() + { + HttpServletRequest request = new MockHttpServletRequest( + ImmutableListMultimap.builder() + .put(X_FORWARDED_PROTO, "https") + .build(), + "testRemote", + "http://www.example.com"); + + UriBuilder builder = getSchemeUriBuilder(request); + assertEquals(builder.build().getScheme(), "https"); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscovery.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscovery.java new file mode 100644 index 0000000000000..679095b9c7cd4 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscovery.java @@ -0,0 +1,373 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.airlift.http.server.HttpServerConfig; +import com.facebook.airlift.http.server.HttpServerInfo; +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.server.security.oauth2.OAuth2ServerConfigProvider.OAuth2ServerConfig; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.inject.Key; +import com.google.inject.TypeLiteral; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.URI; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.airlift.http.client.HttpStatus.TOO_MANY_REQUESTS; +import static com.facebook.presto.server.security.oauth2.BaseOAuth2AuthenticationFilterTest.PRESTO_CLIENT_ID; +import static com.facebook.presto.server.security.oauth2.BaseOAuth2AuthenticationFilterTest.PRESTO_CLIENT_SECRET; +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static java.io.File.createTempFile; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestOidcDiscovery +{ + @Test(dataProvider = "staticConfiguration") + public void testStaticConfiguration(Optional accessTokenPath, Optional userinfoPath) + throws Exception + { + try (MetadataServer metadataServer = new MetadataServer(ImmutableMap.of("/jwks.json", "jwk/jwk-public.json"))) { + URI issuer = metadataServer.getBaseUrl(); + Optional accessTokenIssuer = accessTokenPath.map(issuer::resolve); + Optional userinfoUrl = userinfoPath.map(issuer::resolve); + ImmutableMap.Builder properties = ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "false") + .put("http-server.authentication.oauth2.auth-url", issuer.resolve("/connect/authorize").toString()) + .put("http-server.authentication.oauth2.token-url", issuer.resolve("/connect/token").toString()) + .put("http-server.authentication.oauth2.jwks-url", issuer.resolve("/jwks.json").toString()) + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()); + accessTokenIssuer.map(URI::toString).ifPresent(uri -> properties.put("http-server.authentication.oauth2.access-token-issuer", uri)); + userinfoUrl.map(URI::toString).ifPresent(uri -> properties.put("http-server.authentication.oauth2.userinfo-url", uri)); + try (TestingPrestoServer server = createServer(properties.build())) { + assertConfiguration(server, issuer, accessTokenIssuer.map(issuer::resolve), userinfoUrl.map(issuer::resolve)); + } + } + } + + @DataProvider(name = "staticConfiguration") + public static Object[][] staticConfiguration() + { + return new Object[][] { + {Optional.empty(), Optional.empty()}, + {Optional.of("/access-token-issuer"), Optional.of("/userinfo")}, + }; + } + + @Test(dataProvider = "oidcDiscovery") + public void testOidcDiscovery(String configuration, Optional accessTokenIssuer, Optional userinfoUrl) + throws Exception + { + try (MetadataServer metadataServer = new MetadataServer( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/" + configuration) + .put("/jwks.json", "jwk/jwk-public.json") + .build()); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + URI issuer = metadataServer.getBaseUrl(); + assertConfiguration(server, issuer, accessTokenIssuer.map(issuer::resolve), userinfoUrl.map(issuer::resolve)); + } + } + + @DataProvider(name = "oidcDiscovery") + public static Object[][] oidcDiscovery() + { + return new Object[][] { + {"openid-configuration.json", Optional.empty(), Optional.of("/connect/userinfo")}, + {"openid-configuration-without-userinfo.json", Optional.empty(), Optional.empty()}, + {"openid-configuration-with-access-token-issuer.json", Optional.of("http://access-token-issuer.com/adfs/services/trust"), Optional.of("/connect/userinfo")}, + }; + } + + @Test + public void testIssuerCheck() + { + assertThatThrownBy(() -> { + try (MetadataServer metadataServer = new MetadataServer( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/openid-configuration-invalid-issuer.json") + .put("/jwks.json", "jwk/jwk-public.json") + .build()); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + // should throw an exception + server.getInstance(Key.get(OAuth2ServerConfigProvider.class)).get(); + } + }).hasMessageContaining( + "Invalid response from OpenID Metadata endpoint. " + + "The value of the \"issuer\" claim in Metadata document different than the Issuer URL used for the Configuration Request."); + } + + @Test + public void testStopOnClientError() + { + assertThatThrownBy(() -> { + try (MetadataServer metadataServer = new MetadataServer(ImmutableMap.of()); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + // should throw an exception + server.getInstance(Key.get(OAuth2ServerConfigProvider.class)).get(); + } + }).hasMessageContaining("Invalid response from OpenID Metadata endpoint. Expected response code to be 200, but was 404"); + } + + @Test + public void testOidcDiscoveryRetrying() + throws Exception + { + try (MetadataServer metadataServer = new MetadataServer(new MetadataServletWithStartup( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/openid-configuration.json") + .put("/jwks.json", "jwk/jwk-public.json") + .build(), 5)); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("http-server.authentication.oauth2.oidc.discovery.timeout", "10s") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + URI issuer = metadataServer.getBaseUrl(); + assertConfiguration(server, issuer, Optional.empty(), Optional.of(issuer.resolve("/connect/userinfo"))); + } + } + + @Test + public void testOidcDiscoveryTimesOut() + { + assertThatThrownBy(() -> { + try (MetadataServer metadataServer = new MetadataServer(new MetadataServletWithStartup( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/openid-configuration.json") + .put("/jwks.json", "jwk/jwk-public.json") + .build(), 10)); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("http-server.authentication.oauth2.oidc.discovery.timeout", "5s") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + // should throw an exception + server.getInstance(Key.get(OAuth2ServerConfigProvider.class)).get(); + } + }).hasMessageContaining("Invalid response from OpenID Metadata endpoint: 429"); + } + + @Test + public void testIgnoringUserinfoUrl() + throws Exception + { + try (MetadataServer metadataServer = new MetadataServer( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/openid-configuration.json") + .put("/jwks.json", "jwk/jwk-public.json") + .build()); + TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", metadataServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("http-server.authentication.oauth2.oidc.use-userinfo-endpoint", "false") + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + URI issuer = metadataServer.getBaseUrl(); + assertConfiguration(server, issuer, Optional.empty(), Optional.empty()); + } + } + + @Test + public void testBackwardCompatibility() + throws Exception + { + try (MetadataServer metadataServer = new MetadataServer( + ImmutableMap.builder() + .put("/.well-known/openid-configuration", "oidc/openid-configuration-with-access-token-issuer.json") + .put("/jwks.json", "jwk/jwk-public.json") + .build())) { + URI issuer = metadataServer.getBaseUrl(); + URI authUrl = issuer.resolve("/custom-authorize"); + URI tokenUrl = issuer.resolve("/custom-token"); + URI jwksUrl = issuer.resolve("/custom-jwks.json"); + String accessTokenIssuer = issuer.resolve("/custom-access-token-issuer").toString(); + URI userinfoUrl = issuer.resolve("/custom-userinfo-url"); + try (TestingPrestoServer server = createServer( + ImmutableMap.builder() + .put("http-server.authentication.oauth2.issuer", issuer.toString()) + .put("http-server.authentication.oauth2.oidc.discovery", "true") + .put("http-server.authentication.oauth2.auth-url", authUrl.toString()) + .put("http-server.authentication.oauth2.token-url", tokenUrl.toString()) + .put("http-server.authentication.oauth2.jwks-url", jwksUrl.toString()) + .put("http-server.authentication.oauth2.access-token-issuer", accessTokenIssuer) + .put("http-server.authentication.oauth2.userinfo-url", userinfoUrl.toString()) + .put("configuration-based-authorizer.role-regex-map.file-path", createTempFile("regex-map", null).getAbsolutePath().toString()) + .build())) { + assertComponents(server); + OAuth2ServerConfig config = server.getInstance(Key.get(OAuth2ServerConfigProvider.class)).get(); + assertThat(config.getAccessTokenIssuer()).isEqualTo(Optional.of(accessTokenIssuer)); + assertThat(config.getAuthUrl()).isEqualTo(authUrl); + assertThat(config.getTokenUrl()).isEqualTo(tokenUrl); + assertThat(config.getJwksUrl()).isEqualTo(jwksUrl); + assertThat(config.getUserinfoUrl()).isEqualTo(Optional.of(userinfoUrl)); + } + } + } + + private static void assertConfiguration(TestingPrestoServer server, URI issuer, Optional accessTokenIssuer, Optional userinfoUrl) + { + assertComponents(server); + OAuth2ServerConfig config = server.getInstance(Key.get(OAuth2ServerConfigProvider.class)).get(); + assertThat(config.getAccessTokenIssuer()).isEqualTo(accessTokenIssuer.map(URI::toString)); + assertThat(config.getAuthUrl()).isEqualTo(issuer.resolve("/connect/authorize")); + assertThat(config.getTokenUrl()).isEqualTo(issuer.resolve("/connect/token")); + assertThat(config.getJwksUrl()).isEqualTo(issuer.resolve("/jwks.json")); + assertThat(config.getUserinfoUrl()).isEqualTo(userinfoUrl); + } + + private static void assertComponents(TestingPrestoServer server) + { + List authenticators = server.getInstance(Key.get(new TypeLiteral>() {})); + assertThat(authenticators).hasSize(1); + assertThat(authenticators.get(0)).isInstanceOf(OAuth2Authenticator.class); +// assertThat(server.getInstance(Key.get(WebUiAuthenticationFilter.class))).isInstanceOf(OAuth2WebUiAuthenticationFilter.class); + // does not throw an exception + server.getInstance(Key.get(OAuth2Client.class)).load(); + } + + private static TestingPrestoServer createServer(Map configuration) + throws Exception + { + ImmutableMap config = ImmutableMap.builder() + .put("http-server.authentication.allow-forwarded-https", "true") + .put("http-server.authentication.type", "OAUTH2") + .put("http-server.authentication.oauth2.client-id", PRESTO_CLIENT_ID) + .put("http-server.authentication.oauth2.client-secret", PRESTO_CLIENT_SECRET) + .putAll(configuration) + .build(); + + return new TestingPrestoServer(config); + } + + public static class MetadataServer + implements AutoCloseable + { + private final TestingHttpServer httpServer; + + public MetadataServer(Map responseMapping) + throws Exception + { + this(new MetadataServlet(responseMapping)); + } + + public MetadataServer(HttpServlet servlet) + throws Exception + { + NodeInfo nodeInfo = new NodeInfo("test"); + HttpServerConfig config = new HttpServerConfig().setHttpPort(0); + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); + httpServer = new TestingHttpServer(httpServerInfo, nodeInfo, config, servlet, ImmutableMap.of(), ImmutableMap.of(), Optional.empty()); + httpServer.start(); + } + + public URI getBaseUrl() + { + return httpServer.getBaseUrl(); + } + + @Override + public void close() + throws Exception + { + httpServer.stop(); + } + } + + public static class MetadataServlet + extends HttpServlet + { + private final Map responseMapping; + + public MetadataServlet(Map responseMapping) + { + this.responseMapping = requireNonNull(responseMapping, "responseMapping is null"); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + String fileName = responseMapping.get(request.getPathInfo()); + if (fileName == null) { + response.setStatus(404); + return; + } + response.setHeader(CONTENT_TYPE, APPLICATION_JSON); + String body = Resources.toString(Resources.getResource(fileName), UTF_8); + body = body.replaceAll("https://issuer.com", request.getRequestURL().toString().replace("/.well-known/openid-configuration", "")); + response.getWriter().write(body); + } + } + + public static class MetadataServletWithStartup + extends MetadataServlet + { + private final Instant startTime; + + public MetadataServletWithStartup(Map responseMapping, int startupInSeconds) + { + super(responseMapping); + startTime = Instant.now().plusSeconds(startupInSeconds); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + if (Instant.now().isBefore(startTime)) { + response.setStatus(TOO_MANY_REQUESTS.code()); + return; + } + super.doGet(request, response); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscoveryConfig.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscoveryConfig.java new file mode 100644 index 0000000000000..37c4dd317924b --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestOidcDiscoveryConfig.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.units.Duration; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class TestOidcDiscoveryConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(OidcDiscoveryConfig.class) + .setDiscoveryTimeout(new Duration(30, SECONDS)) + .setUserinfoEndpointEnabled(true) + .setAccessTokenIssuer(null) + .setAuthUrl(null) + .setTokenUrl(null) + .setJwksUrl(null) + .setUserinfoUrl(null)); + } + + @Test + public void testExplicitPropertyMapping() + { + Map properties = ImmutableMap.builder() + .put("http-server.authentication.oauth2.oidc.discovery.timeout", "1m") + .put("http-server.authentication.oauth2.oidc.use-userinfo-endpoint", "false") + .put("http-server.authentication.oauth2.access-token-issuer", "https://issuer.com/at") + .put("http-server.authentication.oauth2.auth-url", "https://issuer.com/auth") + .put("http-server.authentication.oauth2.token-url", "https://issuer.com/token") + .put("http-server.authentication.oauth2.jwks-url", "https://issuer.com/jwks.json") + .put("http-server.authentication.oauth2.userinfo-url", "https://issuer.com/user") + .build(); + + OidcDiscoveryConfig expected = new OidcDiscoveryConfig() + .setDiscoveryTimeout(new Duration(1, MINUTES)) + .setUserinfoEndpointEnabled(false) + .setAccessTokenIssuer("https://issuer.com/at") + .setAuthUrl("https://issuer.com/auth") + .setTokenUrl("https://issuer.com/token") + .setJwksUrl("https://issuer.com/jwks.json") + .setUserinfoUrl("https://issuer.com/user"); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestRefreshTokensConfig.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestRefreshTokensConfig.java new file mode 100644 index 0000000000000..29c7d5b11c603 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestRefreshTokensConfig.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; + +import java.security.NoSuchAlgorithmException; +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; +import static com.facebook.airlift.units.Duration.succinctDuration; +import static io.jsonwebtoken.io.Encoders.BASE64; +import static java.util.concurrent.TimeUnit.HOURS; + +public class TestRefreshTokensConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(RefreshTokensConfig.class) + .setTokenExpiration(succinctDuration(1, HOURS)) + .setIssuer("Presto_coordinator") + .setAudience("Presto_coordinator") + .setSecretKey(null)); + } + + @Test + public void testExplicitPropertyMappings() + throws Exception + { + String encodedBase64SecretKey = BASE64.encode(generateKey()); + + Map properties = ImmutableMap.builder() + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", "24h") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.issuer", "issuer") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.audience", "audience") + .put("http-server.authentication.oauth2.refresh-tokens.secret-key", encodedBase64SecretKey) + .build(); + + RefreshTokensConfig expected = new RefreshTokensConfig() + .setTokenExpiration(succinctDuration(24, HOURS)) + .setIssuer("issuer") + .setAudience("audience") + .setSecretKey(encodedBase64SecretKey); + + assertFullMapping(properties, expected); + } + + private byte[] generateKey() + throws NoSuchAlgorithmException + { + KeyGenerator generator = KeyGenerator.getInstance("AES"); + generator.init(256); + SecretKey key = generator.generateKey(); + return key.getEncoded(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestingHydraIdentityProvider.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestingHydraIdentityProvider.java new file mode 100644 index 0000000000000..0daa570a29842 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestingHydraIdentityProvider.java @@ -0,0 +1,391 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import com.facebook.airlift.http.server.HttpServerConfig; +import com.facebook.airlift.http.server.HttpServerInfo; +import com.facebook.airlift.http.server.testing.TestingHttpServer; +import com.facebook.airlift.log.Level; +import com.facebook.airlift.log.Logging; +import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.util.AutoCloseableCloser; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.inject.Key; +import com.nimbusds.oauth2.sdk.GrantType; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.ws.rs.core.HttpHeaders; +import okhttp3.Credentials; +import okhttp3.FormBody; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import org.testcontainers.containers.FixedHostPortGenericContainer; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.containers.startupcheck.OneShotStartupCheckStrategy; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.containers.wait.strategy.WaitAllStrategy; +import org.testcontainers.utility.MountableFile; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.client.OkHttpUtil.setupInsecureSsl; +import static com.facebook.presto.server.security.oauth2.TokenEndpointAuthMethod.CLIENT_SECRET_BASIC; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static jakarta.servlet.http.HttpServletResponse.SC_NOT_FOUND; +import static jakarta.servlet.http.HttpServletResponse.SC_OK; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static java.util.Objects.requireNonNull; + +public class TestingHydraIdentityProvider + implements Closeable +{ + private static final String HYDRA_IMAGE = "oryd/hydra:v1.10.6"; + private static final String ISSUER = "https://localhost:4444/"; + private static final String DSN = "postgres://hydra:mysecretpassword@database:5432/hydra?sslmode=disable"; + + private final Network network = Network.newNetwork(); + + private final PostgreSQLContainer databaseContainer = new PostgreSQLContainer<>() + .withNetwork(network) + .withNetworkAliases("database") + .withUsername("hydra") + .withPassword("mysecretpassword") + .withDatabaseName("hydra"); + + private final GenericContainer migrationContainer = createHydraContainer() + .withCommand("migrate", "sql", "--yes", DSN) + .withStartupCheckStrategy(new OneShotStartupCheckStrategy().withTimeout(Duration.ofMinutes(5))); + + private final AutoCloseableCloser closer = AutoCloseableCloser.create(); + private final ObjectMapper mapper = new ObjectMapper(); + private final Duration ttlAccessToken; + private final boolean useJwt; + private final boolean exposeFixedPorts; + private final OkHttpClient httpClient; + private FixedHostPortGenericContainer hydraContainer; + + public TestingHydraIdentityProvider() + { + this(Duration.ofMinutes(30), true, false); + } + + public TestingHydraIdentityProvider(Duration ttlAccessToken, boolean useJwt, boolean exposeFixedPorts) + { + this.ttlAccessToken = requireNonNull(ttlAccessToken, "ttlAccessToken is null"); + this.useJwt = useJwt; + this.exposeFixedPorts = exposeFixedPorts; + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClientBuilder.followRedirects(false); + httpClient = httpClientBuilder.build(); + closer.register(network); + closer.register(databaseContainer); + closer.register(migrationContainer); + } + + public void start() + throws Exception + { + databaseContainer.start(); + migrationContainer.start(); + TestingHttpServer loginAndConsentServer = createTestingLoginAndConsentServer(); + closer.register(loginAndConsentServer::stop); + loginAndConsentServer.start(); + URI loginAndConsentBaseUrl = loginAndConsentServer.getBaseUrl(); + + hydraContainer = createHydraContainer() + .withNetworkAliases("hydra") + .withExposedPorts(4444, 4445) + .withEnv("DSN", DSN) + .withEnv("URLS_SELF_ISSUER", ISSUER) + .withEnv("URLS_CONSENT", loginAndConsentBaseUrl + "/consent") + .withEnv("URLS_LOGIN", loginAndConsentBaseUrl + "/login") + .withEnv("SERVE_TLS_KEY_PATH", "/tmp/certs/localhost.pem") + .withEnv("SERVE_TLS_CERT_PATH", "/tmp/certs/localhost.pem") + .withEnv("TTL_ACCESS_TOKEN", ttlAccessToken.getSeconds() + "s") + .withEnv("STRATEGIES_ACCESS_TOKEN", useJwt ? "jwt" : null) + .withEnv("LOG_LEAK_SENSITIVE_VALUES", "true") + .withCommand("serve", "all") + .withCopyFileToContainer(MountableFile.forClasspathResource("/cert"), "/tmp/certs") + .waitingFor(new WaitAllStrategy() + .withStrategy(Wait.forLogMessage(".*Setting up http server on :4444.*", 1)) + .withStrategy(Wait.forLogMessage(".*Setting up http server on :4445.*", 1))); + if (exposeFixedPorts) { + hydraContainer = hydraContainer + .withFixedExposedPort(4444, 4444) + .withFixedExposedPort(4445, 4445); + } + closer.register(hydraContainer); + hydraContainer.start(); + } + + public FixedHostPortGenericContainer createHydraContainer() + { + return new FixedHostPortGenericContainer<>(HYDRA_IMAGE).withNetwork(network); + } + + public void createClient( + String clientId, + String clientSecret, + TokenEndpointAuthMethod tokenEndpointAuthMethod, + List audiences, + String callbackUrl) + { + createHydraContainer() + .withCommand("clients", "create", + "--endpoint", "https://hydra:4445", + "--skip-tls-verify", + "--id", clientId, + "--secret", clientSecret, + "--audience", String.join(",", audiences), + "--grant-types", "authorization_code,refresh_token,client_credentials", + "--response-types", "token,code,id_token", + "--scope", "openid,offline", + "--token-endpoint-auth-method", tokenEndpointAuthMethod.getValue(), + "--callbacks", callbackUrl) + .withStartupCheckStrategy(new OneShotStartupCheckStrategy().withTimeout(Duration.ofSeconds(30))) + .start(); + } + + public String getToken(String clientId, String clientSecret, List audiences) + throws IOException + { + try (Response response = httpClient + .newCall( + new Request.Builder() + .url("https://localhost:" + getAuthPort() + "/oauth2/token") + .addHeader(HttpHeaders.AUTHORIZATION, Credentials.basic(clientId, clientSecret)) + .post(new FormBody.Builder() + .add("grant_type", GrantType.CLIENT_CREDENTIALS.getValue()) + .add("audience", String.join(" ", audiences)) + .build()) + .build()) + .execute()) { + checkState(response.code() == SC_OK); + requireNonNull(response.body()); + return mapper.readTree(response.body().byteStream()) + .get("access_token") + .textValue(); + } + } + + public int getAuthPort() + { + return hydraContainer.getMappedPort(4444); + } + + public int getAdminPort() + { + return hydraContainer.getMappedPort(4445); + } + + @Override + public void close() + throws IOException + { + try { + closer.close(); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } + + private TestingHttpServer createTestingLoginAndConsentServer() + throws IOException + { + NodeInfo nodeInfo = new NodeInfo("test"); + HttpServerConfig config = new HttpServerConfig().setHttpPort(0); + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); + return new TestingHttpServer( + httpServerInfo, + nodeInfo, + config, + new AcceptAllLoginsAndConsentsServlet(), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty()); + } + + private class AcceptAllLoginsAndConsentsServlet + extends HttpServlet + { + private final ObjectMapper mapper = new ObjectMapper(); + private final OkHttpClient httpClient; + + public AcceptAllLoginsAndConsentsServlet() + { + OkHttpClient.Builder httpClientBuilder = new OkHttpClient.Builder(); + setupInsecureSsl(httpClientBuilder); + httpClient = httpClientBuilder.build(); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + if (request.getPathInfo().equals("/login")) { + acceptLogin(request, response); + return; + } + if (request.getPathInfo().contains("/consent")) { + acceptConsent(request, response); + return; + } + response.setStatus(SC_NOT_FOUND); + } + + private void acceptLogin(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + String loginChallenge = request.getParameter("login_challenge"); + try (Response loginAcceptResponse = acceptLogin(loginChallenge)) { + sendRedirect(loginAcceptResponse, response); + } + } + + private void acceptConsent(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + String consentChallenge = request.getParameter("consent_challenge"); + JsonNode consentRequest = getConsentRequest(consentChallenge); + try (Response acceptConsentResponse = acceptConsent(consentChallenge, consentRequest)) { + sendRedirect(acceptConsentResponse, response); + } + } + + private Response acceptLogin(String loginChallenge) + throws IOException + { + return httpClient.newCall( + new Request.Builder() + .url("https://localhost:" + getAdminPort() + "/oauth2/auth/requests/login/accept?login_challenge=" + loginChallenge) + .put(RequestBody.create( + MediaType.parse(APPLICATION_JSON), + mapper.writeValueAsString(mapper.createObjectNode().put("subject", "foo@bar.com")))) + .build()) + .execute(); + } + + private JsonNode getConsentRequest(String consentChallenge) + throws IOException + { + try (Response response = httpClient.newCall( + new Request.Builder() + .url("https://localhost:" + getAdminPort() + "/oauth2/auth/requests/consent?consent_challenge=" + consentChallenge) + .get() + .build()) + .execute()) { + requireNonNull(response.body()); + return mapper.readTree(response.body().byteStream()); + } + } + + private Response acceptConsent(String consentChallenge, JsonNode consentRequest) + throws IOException + { + return httpClient.newCall( + new Request.Builder() + .url("https://localhost:" + getAdminPort() + "/oauth2/auth/requests/consent/accept?consent_challenge=" + consentChallenge) + .put(RequestBody.create( + MediaType.parse(APPLICATION_JSON), + mapper.writeValueAsString(mapper.createObjectNode() + .set("grant_scope", consentRequest.get("requested_scope")) + .set("grant_access_token_audience", consentRequest.get("requested_access_token_audience"))))) + .build()) + .execute(); + } + + private void sendRedirect(Response redirectResponse, HttpServletResponse response) + throws IOException + { + requireNonNull(redirectResponse.body()); + response.sendRedirect( + toHostUrl(mapper.readTree(redirectResponse.body().byteStream()) + .get("redirect_to") + .textValue())); + } + + private String toHostUrl(String url) + { + return HttpUrl.get(URI.create(url)) + .newBuilder() + .port(getAuthPort()) + .toString(); + } + } + + private static void runTestServer(boolean useJwt) + throws Exception + { + try (TestingHydraIdentityProvider service = new TestingHydraIdentityProvider(Duration.ofMinutes(30), useJwt, true)) { + service.start(); + service.createClient( + "presto-client", + "presto-secret", + CLIENT_SECRET_BASIC, + ImmutableList.of("https://localhost:8443/ui"), + "https://localhost:8443/oauth2/callback"); + ImmutableMap.Builder config = ImmutableMap.builder() + .put("http-server.https.port", "8443") + .put("http-server.https.enabled", "true") + .put("http-server.https.keystore.path", Resources.getResource("cert/localhost.pem").getPath()) + .put("http-server.https.keystore.key", "") + .put("http-server.authentication.type", "OAUTH2") + .put("http-server.authentication.oauth2.issuer", ISSUER) + .put("http-server.authentication.oauth2.client-id", "presto-client") + .put("http-server.authentication.oauth2.client-secret", "presto-secret") + .put("http-server.authentication.oauth2.user-mapping.pattern", "(.*)@.*") + .put("http-server.authentication.oauth2.oidc.use-userinfo-endpoint", String.valueOf(!useJwt)) + .put("oauth2-jwk.http-client.trust-store-path", Resources.getResource("cert/localhost.pem").getPath()); + try (TestingPrestoServer server = new TestingPrestoServer(config.build())) { + server.getInstance(Key.get(OAuth2Client.class)).load(); + Thread.sleep(Long.MAX_VALUE); + } + } + } + + public static void main(String[] args) + throws Exception + { + Logging logging = Logging.initialize(); + try { + logging.setLevel(OAuth2Service.class.getName(), Level.DEBUG); + runTestServer(false); + } + finally { + logging.setLevel(OAuth2Service.class.getName(), Level.INFO); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TokenEndpointAuthMethod.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TokenEndpointAuthMethod.java new file mode 100644 index 0000000000000..c05e636e9b11c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TokenEndpointAuthMethod.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.server.security.oauth2; + +import static java.util.Objects.requireNonNull; + +public enum TokenEndpointAuthMethod +{ + CLIENT_SECRET_BASIC("client_secret_basic"); + + private final String value; + + TokenEndpointAuthMethod(String value) + { + this.value = requireNonNull(value, "value is null"); + } + + public String getValue() + { + return value; + } +} diff --git a/presto-main/src/test/resources/cert/generate.sh b/presto-main/src/test/resources/cert/generate.sh new file mode 100755 index 0000000000000..a84749045a5e4 --- /dev/null +++ b/presto-main/src/test/resources/cert/generate.sh @@ -0,0 +1,7 @@ +#!/bin/sh + +set -eux + +openssl req -new -x509 -newkey rsa:4096 -sha256 -nodes -keyout localhost.key -days 3560 -out localhost.crt -config localhost.conf +cat localhost.crt localhost.key > localhost.pem + diff --git a/presto-main/src/test/resources/cert/localhost.conf b/presto-main/src/test/resources/cert/localhost.conf new file mode 100644 index 0000000000000..560e1a454c8d9 --- /dev/null +++ b/presto-main/src/test/resources/cert/localhost.conf @@ -0,0 +1,20 @@ +[req] +default_bits = 4096 +prompt = no +default_md = sha256 +x509_extensions = v3_req +distinguished_name = dn + +[dn] +C = US +ST = California +L = Palo Alto +O = PrestoTest +CN = PrestoTest + +[v3_req] +subjectAltName = @alt_names + +[alt_names] +IP.1 = 127.0.0.1 +DNS.1 = localhost diff --git a/presto-main/src/test/resources/cert/localhost.pem b/presto-main/src/test/resources/cert/localhost.pem new file mode 100644 index 0000000000000..2d8a8fb289876 --- /dev/null +++ b/presto-main/src/test/resources/cert/localhost.pem @@ -0,0 +1,83 @@ +-----BEGIN CERTIFICATE----- +MIIFYTCCA0mgAwIBAgIJAKUofzuCtcfnMA0GCSqGSIb3DQEBCwUAMGAxCzAJBgNV +BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRIwEAYDVQQHDAlQYWxvIEFsdG8x +EzARBgNVBAoMClByZXN0b1Rlc3QxEzARBgNVBAMMClByZXN0b1Rlc3QwHhcNMjIw +OTA2MTgyOTUxWhcNMzIwNjA1MTgyOTUxWjBgMQswCQYDVQQGEwJVUzETMBEGA1UE +CAwKQ2FsaWZvcm5pYTESMBAGA1UEBwwJUGFsbyBBbHRvMRMwEQYDVQQKDApQcmVz +dG9UZXN0MRMwEQYDVQQDDApQcmVzdG9UZXN0MIICIjANBgkqhkiG9w0BAQEFAAOC +Ag8AMIICCgKCAgEA28P/OFPTMWu5AUt2YF3IvfFtZ2FRioB02+FIE21KtYM1r8w2 +2GRyvLuT8vBaYoh8bNSGI2x2R1NtfOalaUCUfr9XRyZPcP4FEE5x0QRK2SYYOfzr +URx5gv3SSlhahmSjsFAojpG7lUBsKpopFcjZb0wSq3hFDVHQN57Xzmt1YHbTZrEt +5yFyqt2AYRVHz8XxJbsUOy514/YGQfLLZqukSLYk055qFIclzFqXU+/cg6UVpl7U +hlLTo0GApBQ2eLGCBDZqXhkCf2U1lMGGVLsJFNmGaumLV88yZmRYQC9MJWlCPCGG +ZcyKNxjZq70SbmjIA6s0FVcXYZ0z6xQqDpVBichLebrtR8ShKU29u1ITL2kaF1iK +gZi8FzEKwnzxLlTZACeBfGeELl9HKUmUvOwU9LHp0UX4fZLlcj5IYXk0IMExFEYa +qfKYxThmdo1Gmpl1orW3mnx+BK8/VtMn7RquHTgQof/dry3EAMo+7/N9v3tI4m+5 +99g+DoHyOdpWYOsTtGmkYcWLG8/ka2lzcaLx347VBRBgNa/afKquPK3ogi04wYYl +K8wBciyhg+J7MkC0k1Q5ek25qynotHLkIGC+LsmEOuw8kApc8/ZFvA5yavn4Dsm5 +x94CmVhJQqL5Dqr6IEmzxvWjFWCC4mD2os52Ff1RlEGuEB5B9CAeaWCSuGECAwEA +AaMeMBwwGgYDVR0RBBMwEYcEfwAAAYIJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUA +A4ICAQCnGqgmdRRIuEij5bAHHF+CkbtkHbOO0fTxmd5/Cd5vkiCUR0uhwBg9JUbC +5zl/wBTx8guc0b7Wjvr3gwBUWQhGR2LKQMfatFZSy8nPcfSkZUGjY4s7KYI8mjPY +1Lri/EE1gu2p+iB9Dw4EnHW3QSneyy4yrkpLcjywbkF93SsThgHQ27gEK3/HeGzY +dFQY17z6zRNnlkMU2JVh0VptE6xtdR3WAMhscVx4dpEjFz9FSKnAYE2svTAy0OTD +8+W795a4/eVCxdHw/3PqR4XSO8isdSimeSTtdpRqsDrW4jx4IIGVFMQPin5XABAl +Wbbs8VMZPB1OvLSpmPtV79o/EklWS9x0MtbXF6iT5VBIbP1JoQLHQ9O0+V4PT0CH +8f8+Wc2UqLskqJZOyDzV3Y81mcdmKYcWNN8LJBk6PA1Pt/RDAg58QW1KfUVleey8 +eaAMw7d9BoIM/nUe0Q4TPll0WJnFMWHLUItgs7YuN/39kOnxGQi3iKQCMV2qwRip +tHqTvw3fHXQxEbZZhLxInC04+pOF3ZpqdzkeaZsXpwklV+uDw1u+5rzlheVrErwA +BzAXENpJKZLs4mjHko+Z4loNzGsJzRci/YRkwSKW6stW04RWWEKQkc9YhFsDYDmt +I70j8+R7et3nx2DqaQ18WgQ/5xVsEoXg5sgFjj1cRVFDQdrexA== +-----END CERTIFICATE----- +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQDbw/84U9Mxa7kB +S3ZgXci98W1nYVGKgHTb4UgTbUq1gzWvzDbYZHK8u5Py8FpiiHxs1IYjbHZHU218 +5qVpQJR+v1dHJk9w/gUQTnHRBErZJhg5/OtRHHmC/dJKWFqGZKOwUCiOkbuVQGwq +mikVyNlvTBKreEUNUdA3ntfOa3VgdtNmsS3nIXKq3YBhFUfPxfEluxQ7LnXj9gZB +8stmq6RItiTTnmoUhyXMWpdT79yDpRWmXtSGUtOjQYCkFDZ4sYIENmpeGQJ/ZTWU +wYZUuwkU2YZq6YtXzzJmZFhAL0wlaUI8IYZlzIo3GNmrvRJuaMgDqzQVVxdhnTPr +FCoOlUGJyEt5uu1HxKEpTb27UhMvaRoXWIqBmLwXMQrCfPEuVNkAJ4F8Z4QuX0cp +SZS87BT0senRRfh9kuVyPkhheTQgwTEURhqp8pjFOGZ2jUaamXWitbeafH4Erz9W +0yftGq4dOBCh/92vLcQAyj7v832/e0jib7n32D4OgfI52lZg6xO0aaRhxYsbz+Rr +aXNxovHfjtUFEGA1r9p8qq48reiCLTjBhiUrzAFyLKGD4nsyQLSTVDl6TbmrKei0 +cuQgYL4uyYQ67DyQClzz9kW8DnJq+fgOybnH3gKZWElCovkOqvogSbPG9aMVYILi +YPaiznYV/VGUQa4QHkH0IB5pYJK4YQIDAQABAoICAFSVcz2yxa5Xz7T33m/oqMKy +kXEgu8ma919JrfwMLJ0AC0HGT7Wps5+AcskmSSNzdLBOe/JWZI+/RHy2KSQBfyXp +byYrUJgkrL5B8vyHsmcxilGHTurBEuOf3bhPmUfwpC/QKkv1O0WOrhMXkoiX7Vgw +516nw6wEuScvM9B2+45NLcBwoUI8VW3+ItM65ZDKlq32+ypsD2PV5UKsuCykE28I +69OnPRz5h0rH80aTI0Rn3ZVTGmk4p8xGAcUlInIBoBEPAJGG/rcZtS2z7ofeFPi5 +YEr16HO7g6J1LKJHkf80LBIItTmpJ+lc3yqCcv2bxp/i3QD5rD4dy0XHVQiX2cj/ +KQSh64MAszRtXZE+GOIAT6UOxyEZqYrPNopXZWEROpM4W/c20HsA1ffpXu0abOcY +rkZhWmKk5mJNBAe+EE6aoxGJ9Th6x7WO7WSjiAmXIxjoAk0YClOwMf6y8+5fuqg8 +aEtG4GOz2bfUksbE3AzaXMcu2o4tt+4+paZA0pkhB9Za1ySP5N4pCQ1kvKhdstZi +e9laJtBS3YSS+rm2pe3g/rA9x5OrTDpms8S3LvFflG1KdPzJeBX7zRg6AHFBhAPI +OpFvsdeVDHOmGe/MYjC0gpLkoVBJLY7cqiicqUUWYaFm0JMoma8VS2h+Y4+cV39N +erXXLflO+Zn1EhiWrWYBAoIBAQD8K3evijZepPil3wtLff/dHgWzOsIubkBWEALu +M+Ab9zm0ahZ1muEej3p0U+Eq8TUbOAqaOWWndFMiqTcjco3zluSIh7wYajY827uG +5Fuyw3zIUm1g5HicZyVsqBJxRcSxkNqzFMMnvhFqUNyrsKHL+/GsCnWiuZhz8z8h +j9l616DF6LpefZrDL1nne4xfliZuE99BWNBi7n1n48q/yxpZ+PyNjAZFUu2wcOFE ++/z5wuuNwywk0gsWEPiqt1LAIbLwHYaTfIwK/OO19ScfFjzw8ZRNLcNEhFbFyZfR +aVkL22uqU7vtBzA1ld2AtogDTEdMQ1ruFd/ZIoOiHK6OpUhJAoIBAQDfGoeTi51A +Fh6gDh4/eiXKpggUvxBbLOkyfcb3j9VCbPmm+09YYR+b99FD7jhM2gKY1QXscvm5 +LFwMiXBNOv110PL1H0lGaNnv5sRFgQwAd7BpBNSAxYQCjF7RZz1GBHq3iLaAYtY4 +jRVPGL+n55E84jzi2Ip856Nf4TaTf98rsslSPbr6zwFvsGALFOki7hJivcelS1Qm +3U6UuJ3CmWv6GA3BY6BCj+5pXzR8+5bisYuIFUm1uW6sSn+1aaogWtOTRIFqf3mT +WSewoOcJpXHs5kBHkRVMycTIUaijbMdoANQzsZ8d6dfcyxpfrncxEaf+VJWKxEVp +OVSTUbpn/d9ZAoIBAEBF0/up/rGg2r9sWjSjqNNzE4DbOSMcdsl4y0ZrcnOuT8bs +Q002bKqdZ1i/CGUplZ+aaRlmB8Lmo0nyV1txlzy++QDTl92hNLHOT73R9o1ZxjRI +zhgkI5m5sJBBRnIYlkmr4hJC+HrotweiFJyuKI8VaEOxZspTA8iJ901WnNfyncfT +yazL1uZo60FU/DJg0uq7pevB91s/7jbMmKDJ462LCNQLHI4O1QZjvwcWMyR1yhQX +6uh3oNu+96KLl0vhSvpojCSLWiZyzpdSJOaHhIDlEieZwmt0T6mZ+FgnwcqD4q1H +Kl7/tgnyaMKlw4UTrBiEEmkcqjFt2p83MEarWgECggEBAKzFnrhkJiK6/nx0cng1 +345PhXKLg98XqH+xZ6PPfxcxzSPC+m82x4PBJg21LWRWcCxqy2uQnlMIR0BuLsmg +JShX585rrBMan6toyhYJGYJDLhol42rViqVujv8bNBhE38PB25MQ91RT7WyTfdhJ +O/AqQ3xotNaFi790aQ9Qt0Lf8Yf+xg30wOf9bmMmjmS+eP5+eV1IOKLgPzpsvb81 +kKjcd8qLnE/vpnFziPJA41gqpiN8WNiiAVLrXnremSD1NWOWaaJPlZbGNDZUZJbT +yKXsqVrCv/v3RKzcj/v/AW1JNwvRQaeor8IMhyARu7wEMFSErEoKNLaH7zcm03Q0 +5gECggEAcaeUlvCqcQiU/i42f9cHu2wx7kDKil6JS8ZYsJ0dGl5g9u2+VyeUBJvI +jLTrnwaGyn+Gx+PlGaC0QK5i9WjbblxrPuWFJk63pGDzLW35jrM55QC9kiI4C1Q7 +TyWOfjGBGa+Gi+v5bo+B3osgAHx7F2SQaoZ2C6Qbo+CMvDCihYFsPcOVngNzf+7+ +QkPOOO5ixlopHyekKe6dcMvb1PzovSEK3DZoh0XH4M5hz3LNk3Z57seNyN9pxcLI +U4qSYvcgYwbYZWjXymU6LVsx5lKJ4RrJGS56K0Q3ebC9nz7pNoVGKxvFAo4OI179 +pTI3VZB66Rqurj0RWq452LJE1Onb5A== +-----END PRIVATE KEY----- diff --git a/presto-main/src/test/resources/oidc/openid-configuration-invalid-issuer.json b/presto-main/src/test/resources/oidc/openid-configuration-invalid-issuer.json new file mode 100644 index 0000000000000..2f23e8ac98154 --- /dev/null +++ b/presto-main/src/test/resources/oidc/openid-configuration-invalid-issuer.json @@ -0,0 +1,106 @@ +{ + "issuer": "https://invalid-issuer.com", + "authorization_endpoint": "https://invalid-issuer.com/connect/authorize", + "token_endpoint": "https://invalid-issuer.com/connect/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "private_key_jwt" + ], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256", + "ES256" + ], + "userinfo_endpoint": "https://invalid-issuer.com/connect/userinfo", + "check_session_iframe": "https://invalid-issuer.com/connect/check_session", + "end_session_endpoint": "https://invalid-issuer.com/connect/end_session", + "jwks_uri": "https://invalid-issuer.com/jwks.json", + "registration_endpoint": "https://invalid-issuer.com/connect/register", + "scopes_supported": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access" + ], + "response_types_supported": [ + "code", + "code id_token", + "id_token", + "token id_token" + ], + "acr_values_supported": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze" + ], + "subject_types_supported": [ + "public", + "pairwise" + ], + "userinfo_signing_alg_values_supported": [ + "RS256", + "ES256", + "HS256" + ], + "userinfo_encryption_alg_values_supported": [ + "RSA1_5", + "A128KW" + ], + "userinfo_encryption_enc_values_supported": [ + "A128CBC-HS256", + "A128GCM" + ], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "HS256" + ], + "id_token_encryption_alg_values_supported": [ + "RSA1_5", + "A128KW" + ], + "id_token_encryption_enc_values_supported": [ + "A128CBC-HS256", + "A128GCM" + ], + "request_object_signing_alg_values_supported": [ + "none", + "RS256", + "ES256" + ], + "display_values_supported": [ + "page", + "popup" + ], + "claim_types_supported": [ + "normal", + "distributed" + ], + "claims_supported": [ + "sub", + "iss", + "auth_time", + "acr", + "name", + "given_name", + "family_name", + "nickname", + "profile", + "picture", + "website", + "email", + "email_verified", + "locale", + "zoneinfo", + "http://example.info/claims/groups" + ], + "claims_parameter_supported": true, + "service_documentation": "http://invalid-issuer.com/connect/service_documentation.html", + "ui_locales_supported": [ + "en-US", + "en-GB", + "en-CA", + "fr-FR", + "fr-CA" + ] +} diff --git a/presto-main/src/test/resources/oidc/openid-configuration-with-access-token-issuer.json b/presto-main/src/test/resources/oidc/openid-configuration-with-access-token-issuer.json new file mode 100644 index 0000000000000..905032ce3df1d --- /dev/null +++ b/presto-main/src/test/resources/oidc/openid-configuration-with-access-token-issuer.json @@ -0,0 +1,87 @@ +{ + "issuer": "https://issuer.com", + "authorization_endpoint": "https://issuer.com/connect/authorize", + "token_endpoint": "https://issuer.com/connect/token", + "jwks_uri": "https://issuer.com/jwks.json", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", + "client_secret_basic", + "private_key_jwt", + "windows_client_authentication" + ], + "response_types_supported": [ + "code", + "id_token", + "code id_token", + "id_token token", + "code token", + "code id_token token" + ], + "response_modes_supported": [ + "query", + "fragment", + "form_post" + ], + "grant_types_supported": [ + "authorization_code", + "refresh_token", + "client_credentials", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "implicit", + "password", + "srv_challenge", + "urn:ietf:params:oauth:grant-type:device_code", + "device_code" + ], + "subject_types_supported": [ + "pairwise" + ], + "scopes_supported": [ + "user_impersonation", + "vpn_cert", + "email", + "openid", + "profile", + "allatclaims", + "logon_cert", + "aza", + "winhello_cert" + ], + "id_token_signing_alg_values_supported": [ + "RS256" + ], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256" + ], + "access_token_issuer": "http://access-token-issuer.com/adfs/services/trust", + "claims_supported": [ + "aud", + "iss", + "iat", + "exp", + "auth_time", + "nonce", + "at_hash", + "c_hash", + "sub", + "upn", + "unique_name", + "pwd_url", + "pwd_exp", + "mfa_auth_time", + "sid", + "nbf" + ], + "microsoft_multi_refresh_token": true, + "userinfo_endpoint": "https://issuer.com/connect/userinfo", + "capabilities": [], + "end_session_endpoint": "https://issuer.com/adfs/oauth2/logout", + "as_access_token_token_binding_supported": true, + "as_refresh_token_token_binding_supported": true, + "resource_access_token_token_binding_supported": true, + "op_id_token_token_binding_supported": true, + "rp_id_token_token_binding_supported": true, + "frontchannel_logout_supported": true, + "frontchannel_logout_session_supported": true, + "device_authorization_endpoint": "https://issuer.com/adfs/oauth2/devicecode" +} diff --git a/presto-main/src/test/resources/oidc/openid-configuration-without-userinfo.json b/presto-main/src/test/resources/oidc/openid-configuration-without-userinfo.json new file mode 100644 index 0000000000000..96f8c6048acf6 --- /dev/null +++ b/presto-main/src/test/resources/oidc/openid-configuration-without-userinfo.json @@ -0,0 +1,92 @@ +{ + "issuer": "https://issuer.com", + "authorization_endpoint": "https://issuer.com/connect/authorize", + "token_endpoint": "https://issuer.com/connect/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "private_key_jwt" + ], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256", + "ES256" + ], + "check_session_iframe": "https://issuer.com/connect/check_session", + "end_session_endpoint": "https://issuer.com/connect/end_session", + "jwks_uri": "https://issuer.com/jwks.json", + "registration_endpoint": "https://issuer.com/connect/register", + "scopes_supported": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access" + ], + "response_types_supported": [ + "code", + "code id_token", + "id_token", + "token id_token" + ], + "acr_values_supported": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze" + ], + "subject_types_supported": [ + "public", + "pairwise" + ], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "HS256" + ], + "id_token_encryption_alg_values_supported": [ + "RSA1_5", + "A128KW" + ], + "id_token_encryption_enc_values_supported": [ + "A128CBC-HS256", + "A128GCM" + ], + "request_object_signing_alg_values_supported": [ + "none", + "RS256", + "ES256" + ], + "display_values_supported": [ + "page", + "popup" + ], + "claim_types_supported": [ + "normal", + "distributed" + ], + "claims_supported": [ + "sub", + "iss", + "auth_time", + "acr", + "name", + "given_name", + "family_name", + "nickname", + "profile", + "picture", + "website", + "email", + "email_verified", + "locale", + "zoneinfo", + "http://example.info/claims/groups" + ], + "claims_parameter_supported": true, + "service_documentation": "http://issuer.com/connect/service_documentation.html", + "ui_locales_supported": [ + "en-US", + "en-GB", + "en-CA", + "fr-FR", + "fr-CA" + ] +} diff --git a/presto-main/src/test/resources/oidc/openid-configuration.json b/presto-main/src/test/resources/oidc/openid-configuration.json new file mode 100644 index 0000000000000..39c456fca85b6 --- /dev/null +++ b/presto-main/src/test/resources/oidc/openid-configuration.json @@ -0,0 +1,106 @@ +{ + "issuer": "https://issuer.com", + "authorization_endpoint": "https://issuer.com/connect/authorize", + "token_endpoint": "https://issuer.com/connect/token", + "token_endpoint_auth_methods_supported": [ + "client_secret_basic", + "private_key_jwt" + ], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256", + "ES256" + ], + "userinfo_endpoint": "https://issuer.com/connect/userinfo", + "check_session_iframe": "https://issuer.com/connect/check_session", + "end_session_endpoint": "https://issuer.com/connect/end_session", + "jwks_uri": "https://issuer.com/jwks.json", + "registration_endpoint": "https://issuer.com/connect/register", + "scopes_supported": [ + "openid", + "profile", + "email", + "address", + "phone", + "offline_access" + ], + "response_types_supported": [ + "code", + "code id_token", + "id_token", + "token id_token" + ], + "acr_values_supported": [ + "urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze" + ], + "subject_types_supported": [ + "public", + "pairwise" + ], + "userinfo_signing_alg_values_supported": [ + "RS256", + "ES256", + "HS256" + ], + "userinfo_encryption_alg_values_supported": [ + "RSA1_5", + "A128KW" + ], + "userinfo_encryption_enc_values_supported": [ + "A128CBC-HS256", + "A128GCM" + ], + "id_token_signing_alg_values_supported": [ + "RS256", + "ES256", + "HS256" + ], + "id_token_encryption_alg_values_supported": [ + "RSA1_5", + "A128KW" + ], + "id_token_encryption_enc_values_supported": [ + "A128CBC-HS256", + "A128GCM" + ], + "request_object_signing_alg_values_supported": [ + "none", + "RS256", + "ES256" + ], + "display_values_supported": [ + "page", + "popup" + ], + "claim_types_supported": [ + "normal", + "distributed" + ], + "claims_supported": [ + "sub", + "iss", + "auth_time", + "acr", + "name", + "given_name", + "family_name", + "nickname", + "profile", + "picture", + "website", + "email", + "email_verified", + "locale", + "zoneinfo", + "http://example.info/claims/groups" + ], + "claims_parameter_supported": true, + "service_documentation": "http://issuer.com/connect/service_documentation.html", + "ui_locales_supported": [ + "en-US", + "en-GB", + "en-CA", + "fr-FR", + "fr-CA" + ] +} diff --git a/presto-matching/pom.xml b/presto-matching/pom.xml index f4e8b0b44b0ef..98bffcbcb38d4 100644 --- a/presto-matching/pom.xml +++ b/presto-matching/pom.xml @@ -18,7 +18,7 @@ presto-root com.facebook.presto - 0.293 + 0.297-edge10.1-SNAPSHOT presto-matching @@ -26,6 +26,8 @@ ${project.parent.basedir} + 8 + true diff --git a/presto-memory-context/pom.xml b/presto-memory-context/pom.xml index 0455c8b08a6dc..24121d363ba20 100644 --- a/presto-memory-context/pom.xml +++ b/presto-memory-context/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-memory-context @@ -14,15 +14,22 @@ ${project.parent.basedir} + 8 + true - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true + + jakarta.annotation + jakarta.annotation-api + + com.google.guava guava @@ -42,7 +49,7 @@ - io.airlift + com.facebook.airlift units test diff --git a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/AbstractAggregatedMemoryContext.java b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/AbstractAggregatedMemoryContext.java index 4df13a7e2cf50..1a39c36a75741 100644 --- a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/AbstractAggregatedMemoryContext.java +++ b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/AbstractAggregatedMemoryContext.java @@ -15,10 +15,9 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.annotation.Nullable; import static com.google.common.base.MoreObjects.toStringHelper; import static java.lang.String.format; diff --git a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java index b9f1256b57cf5..9ab8e257f222f 100644 --- a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java +++ b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/MemoryTrackingContext.java @@ -14,8 +14,7 @@ package com.facebook.presto.memory.context; import com.google.common.io.Closer; - -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; import java.io.IOException; diff --git a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/SimpleLocalMemoryContext.java b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/SimpleLocalMemoryContext.java index 8a6d9d6d84d68..0878534936855 100644 --- a/presto-memory-context/src/main/java/com/facebook/presto/memory/context/SimpleLocalMemoryContext.java +++ b/presto-memory-context/src/main/java/com/facebook/presto/memory/context/SimpleLocalMemoryContext.java @@ -15,9 +15,8 @@ import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-memory-context/src/test/java/com/facebook/presto/memory/context/TestMemoryContexts.java b/presto-memory-context/src/test/java/com/facebook/presto/memory/context/TestMemoryContexts.java index 463004321d500..85e2e439baaff 100644 --- a/presto-memory-context/src/test/java/com/facebook/presto/memory/context/TestMemoryContexts.java +++ b/presto-memory-context/src/test/java/com/facebook/presto/memory/context/TestMemoryContexts.java @@ -13,17 +13,17 @@ */ package com.facebook.presto.memory.context; +import com.facebook.airlift.units.DataSize; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; import org.testng.annotations.Test; import java.io.IOException; +import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newRootAggregatedMemoryContext; import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; -import static io.airlift.units.DataSize.Unit.MEGABYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; diff --git a/presto-memory/pom.xml b/presto-memory/pom.xml index 74f9b930218b0..50d07e1735743 100644 --- a/presto-memory/pom.xml +++ b/presto-memory/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-memory + presto-memory Presto - Memory Connector presto-plugin ${project.parent.basedir} + true @@ -38,8 +40,8 @@ - com.google.code.findbugs - jsr305 + com.google.errorprone + error_prone_annotations true @@ -49,13 +51,13 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -72,7 +74,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -84,7 +86,7 @@ - io.airlift + com.facebook.airlift units provided diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryColumnHandle.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryColumnHandle.java index bf0fb7a3cc33d..b440acd42ae1f 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryColumnHandle.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryColumnHandle.java @@ -75,6 +75,14 @@ public ColumnMetadata toColumnMetadata() .build(); } + public ColumnMetadata toColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(columnType) + .build(); + } + @Override public int hashCode() { diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConfig.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConfig.java index f9ad6dbefc3c2..127fed18e7f7d 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConfig.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConfig.java @@ -14,9 +14,8 @@ package com.facebook.presto.plugin.memory; import com.facebook.airlift.configuration.Config; -import io.airlift.units.DataSize; - -import javax.validation.constraints.NotNull; +import com.facebook.airlift.units.DataSize; +import jakarta.validation.constraints.NotNull; public class MemoryConfig { diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConnector.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConnector.java index a82851bbae1ed..f316bdf59766a 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConnector.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryConnector.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; public class MemoryConnector implements Connector diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryInsertTableHandle.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryInsertTableHandle.java index fda8f2075aa8d..9d7bb488e2c02 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryInsertTableHandle.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryInsertTableHandle.java @@ -27,14 +27,22 @@ public class MemoryInsertTableHandle { private final MemoryTableHandle table; private final Set activeTableIds; + private final boolean insertOverwrite; @JsonCreator public MemoryInsertTableHandle( @JsonProperty("table") MemoryTableHandle table, - @JsonProperty("activeTableIds") Set activeTableIds) + @JsonProperty("activeTableIds") Set activeTableIds, + @JsonProperty("insertOverwrite") boolean insertOverwrite) { this.table = requireNonNull(table, "table is null"); this.activeTableIds = requireNonNull(activeTableIds, "activeTableIds is null"); + this.insertOverwrite = insertOverwrite; + } + + public MemoryInsertTableHandle(MemoryTableHandle table, Set activeTableIds) + { + this(table, activeTableIds, false); } @JsonProperty @@ -49,6 +57,12 @@ public Set getActiveTableIds() return activeTableIds; } + @JsonProperty + public boolean isInsertOverwrite() + { + return insertOverwrite; + } + @Override public String toString() { diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java index b600eabc8ea16..b2844e0e66d30 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryMetadata.java @@ -28,6 +28,8 @@ import com.facebook.presto.spi.ConnectorViewDefinition; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.MaterializedViewDefinition; +import com.facebook.presto.spi.MaterializedViewStatus; import com.facebook.presto.spi.Node; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; @@ -39,11 +41,11 @@ import com.facebook.presto.spi.connector.ConnectorOutputMetadata; import com.facebook.presto.spi.statistics.ComputedStatistics; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; - -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collection; @@ -64,6 +66,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; +import static java.lang.System.currentTimeMillis; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toMap; @@ -83,6 +86,12 @@ public class MemoryMetadata private final Map> tableDataFragments = new HashMap<>(); private final Map views = new HashMap<>(); + private final Map materializedViews = new HashMap<>(); + private final Map tableVersions = new HashMap<>(); + private final Map> mvRefreshVersions = new HashMap<>(); + private final Map storageTableToMaterializedView = new HashMap<>(); + private final Map mvLastRefreshTimes = new HashMap<>(); + @Inject public MemoryMetadata(NodeManager nodeManager, MemoryConnectorId connectorId) { @@ -169,17 +178,33 @@ public synchronized Map> listTableColumns( { return tables.values().stream() .filter(table -> prefix.matches(table.toSchemaTableName())) - .collect(toMap(MemoryTableHandle::toSchemaTableName, handle -> handle.toTableMetadata().getColumns())); + .collect(toImmutableMap(MemoryTableHandle::toSchemaTableName, handle -> toTableMetadata(handle, session).getColumns())); + } + + public ConnectorTableMetadata toTableMetadata(MemoryTableHandle memoryTableHandle, ConnectorSession session) + { + List columns = memoryTableHandle.getColumnHandles().stream() + .map(column -> column.toColumnMetadata(normalizeIdentifier(session, column.getName()))) + .collect(toImmutableList()); + + return new ConnectorTableMetadata(memoryTableHandle.toSchemaTableName(), columns); } @Override public synchronized void dropTable(ConnectorSession session, ConnectorTableHandle tableHandle) { MemoryTableHandle handle = (MemoryTableHandle) tableHandle; - Long tableId = tableIds.remove(handle.toSchemaTableName()); + SchemaTableName tableName = handle.toSchemaTableName(); + + if (storageTableToMaterializedView.containsKey(tableName)) { + throw new PrestoException(NOT_FOUND, format("Cannot drop table [%s] because it is a materialized view storage table. Use DROP MATERIALIZED VIEW instead.", tableName)); + } + + Long tableId = tableIds.remove(tableName); if (tableId != null) { tables.remove(tableId); tableDataFragments.remove(tableId); + tableVersions.remove(tableName); } } @@ -189,6 +214,11 @@ public synchronized void renameTable(ConnectorSession session, ConnectorTableHan checkSchemaExists(newTableName.getSchemaName()); checkTableNotExists(newTableName); MemoryTableHandle oldTableHandle = (MemoryTableHandle) tableHandle; + SchemaTableName oldTableName = oldTableHandle.toSchemaTableName(); + + if (storageTableToMaterializedView.containsKey(oldTableName)) { + throw new PrestoException(NOT_FOUND, format("Cannot rename table [%s] because it is a materialized view storage table", oldTableName)); + } MemoryTableHandle newTableHandle = new MemoryTableHandle( oldTableHandle.getConnectorId(), newTableName.getSchemaName(), @@ -254,6 +284,8 @@ public synchronized Optional finishCreateTable(Connecto MemoryOutputTableHandle memoryOutputHandle = (MemoryOutputTableHandle) tableHandle; updateRowsOnHosts(memoryOutputHandle.getTable(), fragments); + incrementTableVersion(memoryOutputHandle.getTable().toSchemaTableName()); + return Optional.empty(); } @@ -271,6 +303,8 @@ public synchronized Optional finishInsert(ConnectorSess MemoryInsertTableHandle memoryInsertHandle = (MemoryInsertTableHandle) insertHandle; updateRowsOnHosts(memoryInsertHandle.getTable(), fragments); + incrementTableVersion(memoryInsertHandle.getTable().toSchemaTableName()); + return Optional.empty(); } @@ -347,8 +381,13 @@ private void updateRowsOnHosts(MemoryTableHandle table, Collection fragme } } + private void incrementTableVersion(SchemaTableName tableName) + { + tableVersions.put(tableName, tableVersions.getOrDefault(tableName, 0L) + 1); + } + @Override - public synchronized List getTableLayouts( + public synchronized ConnectorTableLayoutResult getTableLayoutForConstraint( ConnectorSession session, ConnectorTableHandle handle, Constraint constraint, @@ -367,7 +406,7 @@ public synchronized List getTableLayouts( tableDataFragments.get(memoryTableHandle.getTableId()).values()); MemoryTableLayoutHandle layoutHandle = new MemoryTableLayoutHandle(memoryTableHandle, expectedFragments); - return ImmutableList.of(new ConnectorTableLayoutResult(getTableLayout(session, layoutHandle), constraint.getSummary())); + return new ConnectorTableLayoutResult(getTableLayout(session, layoutHandle), constraint.getSummary()); } @Override @@ -382,4 +421,143 @@ public synchronized ConnectorTableLayout getTableLayout(ConnectorSession session Optional.empty(), ImmutableList.of()); } + + @Override + public synchronized void createMaterializedView( + ConnectorSession session, + ConnectorTableMetadata viewMetadata, + MaterializedViewDefinition viewDefinition, + boolean ignoreExisting) + { + SchemaTableName viewName = viewMetadata.getTable(); + checkSchemaExists(viewName.getSchemaName()); + + if (materializedViews.containsKey(viewName)) { + if (ignoreExisting) { + return; + } + throw new PrestoException(ALREADY_EXISTS, "Materialized view already exists: " + viewName); + } + + if (getTableHandle(session, viewName) != null) { + throw new PrestoException(ALREADY_EXISTS, "Table already exists: " + viewName); + } + + if (views.containsKey(viewName)) { + throw new PrestoException(ALREADY_EXISTS, "View already exists: " + viewName); + } + + SchemaTableName storageTableName = new SchemaTableName( + viewDefinition.getSchema(), + viewDefinition.getTable()); + + ConnectorTableMetadata storageTableMetadata = new ConnectorTableMetadata( + storageTableName, + viewMetadata.getColumns(), + viewMetadata.getProperties(), + viewMetadata.getComment()); + + createTable(session, storageTableMetadata, false); + + materializedViews.put(viewName, viewDefinition); + Map baseTableVersionSnapshot = new HashMap<>(); + for (SchemaTableName baseTable : viewDefinition.getBaseTables()) { + baseTableVersionSnapshot.put(baseTable, 0L); + } + mvRefreshVersions.put(viewName, baseTableVersionSnapshot); + storageTableToMaterializedView.put(storageTableName, viewName); + } + + @Override + public synchronized Optional getMaterializedView(ConnectorSession session, SchemaTableName viewName) + { + return Optional.ofNullable(materializedViews.get(viewName)); + } + + @Override + public synchronized void dropMaterializedView(ConnectorSession session, SchemaTableName viewName) + { + MaterializedViewDefinition removed = materializedViews.remove(viewName); + if (removed == null) { + throw new PrestoException(NOT_FOUND, "Materialized view not found: " + viewName); + } + mvRefreshVersions.remove(viewName); + + SchemaTableName storageTableName = new SchemaTableName( + removed.getSchema(), + removed.getTable()); + storageTableToMaterializedView.remove(storageTableName); + + ConnectorTableHandle storageTableHandle = getTableHandle(session, storageTableName); + if (storageTableHandle != null) { + dropTable(session, storageTableHandle); + } + } + + @Override + public synchronized MaterializedViewStatus getMaterializedViewStatus( + ConnectorSession session, + SchemaTableName materializedViewName, + TupleDomain baseQueryDomain) + { + MaterializedViewDefinition mvDefinition = materializedViews.get(materializedViewName); + if (mvDefinition == null) { + throw new PrestoException(NOT_FOUND, "Materialized view not found: " + materializedViewName); + } + + Map baseTableVersionSnapshot = mvRefreshVersions.getOrDefault(materializedViewName, ImmutableMap.of()); + Optional lastFreshTime = Optional.ofNullable(mvLastRefreshTimes.get(materializedViewName)); + + for (SchemaTableName baseTable : mvDefinition.getBaseTables()) { + long currentVersion = tableVersions.getOrDefault(baseTable, 0L); + long refreshedVersion = baseTableVersionSnapshot.getOrDefault(baseTable, 0L); + if (currentVersion != refreshedVersion) { + return new MaterializedViewStatus( + MaterializedViewStatus.MaterializedViewState.NOT_MATERIALIZED, + ImmutableMap.of(), + lastFreshTime); + } + } + + return new MaterializedViewStatus( + MaterializedViewStatus.MaterializedViewState.FULLY_MATERIALIZED, + ImmutableMap.of(), + lastFreshTime); + } + + @Override + public synchronized ConnectorInsertTableHandle beginRefreshMaterializedView( + ConnectorSession session, + ConnectorTableHandle tableHandle) + { + MemoryTableHandle memoryTableHandle = (MemoryTableHandle) tableHandle; + tableDataFragments.put(memoryTableHandle.getTableId(), new HashMap<>()); + return new MemoryInsertTableHandle(memoryTableHandle, ImmutableSet.copyOf(tableIds.values()), true); + } + + @Override + public synchronized Optional finishRefreshMaterializedView( + ConnectorSession session, + ConnectorInsertTableHandle insertHandle, + Collection fragments, + Collection computedStatistics) + { + Optional result = finishInsert(session, insertHandle, fragments, computedStatistics); + + MemoryInsertTableHandle memoryInsertHandle = (MemoryInsertTableHandle) insertHandle; + SchemaTableName storageTableName = memoryInsertHandle.getTable().toSchemaTableName(); + + SchemaTableName materializedViewName = storageTableToMaterializedView.get(storageTableName); + checkState(materializedViewName != null, "No materialized view found for storage table: %s", storageTableName); + + MaterializedViewDefinition mvDefinition = materializedViews.get(materializedViewName); + Map baseTableVersionSnapshot = new HashMap<>(); + for (SchemaTableName baseTable : mvDefinition.getBaseTables()) { + baseTableVersionSnapshot.put(baseTable, tableVersions.getOrDefault(baseTable, 0L)); + } + mvRefreshVersions.put(materializedViewName, baseTableVersionSnapshot); + mvLastRefreshTimes.put(materializedViewName, currentTimeMillis()); + + return result; + } } diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java index 88215b0300253..9cddad538e0d7 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSinkProvider.java @@ -26,8 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.Collection; import java.util.concurrent.CompletableFuture; @@ -82,6 +81,9 @@ public ConnectorPageSink createPageSink(ConnectorTransactionHandle transactionHa checkState(memoryInsertTableHandle.getActiveTableIds().contains(tableId)); pagesStore.cleanUp(memoryInsertTableHandle.getActiveTableIds()); + if (memoryInsertTableHandle.isInsertOverwrite()) { + pagesStore.clearTable(tableId); + } pagesStore.initialize(tableId); return new MemoryPageSink(pagesStore, currentHostAddress, tableId); } diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java index ef69bad294cb8..8dbc30f4170e0 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPageSourceProvider.java @@ -22,8 +22,7 @@ import com.facebook.presto.spi.SplitContext; import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java index 7876b66547e8d..f913d768efa58 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemoryPagesStore.java @@ -17,10 +17,9 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.spi.PrestoException; import com.google.common.collect.ImmutableList; - -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import javax.inject.Inject; +import com.google.errorprone.annotations.ThreadSafe; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import jakarta.inject.Inject; import java.util.ArrayList; import java.util.Collections; @@ -105,6 +104,17 @@ public synchronized boolean contains(Long tableId) return tables.containsKey(tableId); } + public synchronized void clearTable(Long tableId) + { + TableData tableData = tables.get(tableId); + if (tableData != null) { + for (Page page : tableData.getPages()) { + currentBytes -= page.getRetainedSizeInBytes(); + } + tables.put(tableId, new TableData()); + } + } + public synchronized void cleanUp(Set activeTableIds) { // We have to remember that there might be some race conditions when there are two tables created at once. diff --git a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java index 32984dbc70b9f..f7a71959c33f9 100644 --- a/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java +++ b/presto-memory/src/main/java/com/facebook/presto/plugin/memory/MemorySplitManager.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMaterializedViewAccessControl.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMaterializedViewAccessControl.java new file mode 100644 index 0000000000000..9fc23a5fabdbb --- /dev/null +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMaterializedViewAccessControl.java @@ -0,0 +1,835 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.memory; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.security.ViewExpression; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.CREATE_VIEW_WITH_SELECT_COLUMNS; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DELETE_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.DROP_VIEW; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.INSERT_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SELECT_COLUMN; +import static com.facebook.presto.testing.TestingAccessControlManager.TestingPrivilegeType.SHOW_CREATE_TABLE; +import static com.facebook.presto.testing.TestingAccessControlManager.privilege; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestMaterializedViewAccessControl + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("memory") + .setSchema("default") + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) + .setNodeCount(4) + .setExtraProperties(ImmutableMap.of("experimental.allow-legacy-materialized-views-toggle", "true")) + .build(); + + queryRunner.installPlugin(new MemoryPlugin()); + queryRunner.createCatalog("memory", "memory", ImmutableMap.of()); + + return queryRunner; + } + + @Test + public void testCreateMaterializedViewRequiresBothCreateTableAndCreateView() + { + // Setup: Create a base table + assertUpdate("CREATE TABLE test_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO test_base VALUES (1, 'test')", 1); + + try { + // Deny only CREATE_VIEW - should fail + getQueryRunner().getAccessControl().deny(privilege("test_mv_create_table_only", CREATE_VIEW)); + assertQueryFails( + "CREATE MATERIALIZED VIEW test_mv_create_table_only AS SELECT * FROM test_base", + ".*Cannot create view.*"); + getQueryRunner().getAccessControl().reset(); + + // Deny only CREATE_TABLE - should fail + getQueryRunner().getAccessControl().deny(privilege("test_mv_create_view_only", CREATE_TABLE)); + assertQueryFails( + "CREATE MATERIALIZED VIEW test_mv_create_view_only AS SELECT * FROM test_base", + ".*Cannot create table.*"); + getQueryRunner().getAccessControl().reset(); + + // Allow both - should succeed + assertUpdate("CREATE MATERIALIZED VIEW test_mv_both_perms AS SELECT * FROM test_base"); + assertUpdate("DROP MATERIALIZED VIEW test_mv_both_perms"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS test_base"); + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testRefreshMaterializedViewRequiresBothDeleteAndInsert() + { + assertUpdate("CREATE TABLE refresh_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO refresh_base VALUES (1, 'test')", 1); + assertUpdate("CREATE MATERIALIZED VIEW test_mv_refresh AS SELECT * FROM refresh_base"); + + try { + // Deny only INSERT_TABLE - should fail + getQueryRunner().getAccessControl().deny(privilege("test_mv_refresh", INSERT_TABLE)); + assertQueryFails( + "REFRESH MATERIALIZED VIEW test_mv_refresh", + ".*Cannot insert into table.*"); + getQueryRunner().getAccessControl().reset(); + + // Deny only DELETE_TABLE - should fail + getQueryRunner().getAccessControl().deny(privilege("test_mv_refresh", DELETE_TABLE)); + assertQueryFails( + "REFRESH MATERIALIZED VIEW test_mv_refresh", + ".*Cannot delete from table.*"); + getQueryRunner().getAccessControl().reset(); + + // Allow both - should succeed + assertUpdate("REFRESH MATERIALIZED VIEW test_mv_refresh", 1); + } + finally { + assertUpdate("DROP MATERIALIZED VIEW IF EXISTS test_mv_refresh"); + assertUpdate("DROP TABLE IF EXISTS refresh_base"); + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testDropMaterializedViewRequiresBothDropTableAndDropView() + { + try { + // Deny only DROP_VIEW - should fail + assertUpdate("CREATE TABLE drop_base1 (id BIGINT)"); + assertUpdate("CREATE MATERIALIZED VIEW test_mv_drop1 AS SELECT * FROM drop_base1"); + + getQueryRunner().getAccessControl().deny(privilege("test_mv_drop1", DROP_VIEW)); + assertQueryFails( + "DROP MATERIALIZED VIEW test_mv_drop1", + ".*Cannot drop view.*"); + getQueryRunner().getAccessControl().reset(); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_drop1"); + assertUpdate("DROP TABLE drop_base1"); + + // Deny only DROP_TABLE - should fail + assertUpdate("CREATE TABLE drop_base2 (id BIGINT)"); + assertUpdate("CREATE MATERIALIZED VIEW test_mv_drop2 AS SELECT * FROM drop_base2"); + + getQueryRunner().getAccessControl().deny(privilege("test_mv_drop2", DROP_TABLE)); + assertQueryFails( + "DROP MATERIALIZED VIEW test_mv_drop2", + ".*Cannot drop table.*"); + getQueryRunner().getAccessControl().reset(); + + assertUpdate("DROP MATERIALIZED VIEW test_mv_drop2"); + assertUpdate("DROP TABLE drop_base2"); + + // Allow both - should succeed + assertUpdate("CREATE TABLE drop_base3 (id BIGINT)"); + assertUpdate("CREATE MATERIALIZED VIEW test_mv_drop3 AS SELECT * FROM drop_base3"); + assertUpdate("DROP MATERIALIZED VIEW test_mv_drop3"); + assertUpdate("DROP TABLE drop_base3"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityDefinerAllowsUnprivilegedUserToQuery() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE secure_base (id BIGINT, secret VARCHAR, value BIGINT)"); + assertUpdate(adminSession, "INSERT INTO secure_base VALUES (1, 'confidential', 100), (2, 'classified', 200)", 2); + + try { + // Deny restricted_user access to secure_base + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "secure_base", SELECT_COLUMN)); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_definer " + + "SECURITY DEFINER AS " + + "SELECT id, secret, value FROM secure_base"); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_definer", "SELECT 2"); + assertQuery(adminSession, "SELECT id, value FROM mv_definer ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertQueryFails(restrictedSession, "SELECT COUNT(*) FROM secure_base", ".*Access Denied.*"); + + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_definer", "SELECT 2"); + assertQuery(restrictedSession, "SELECT id, value FROM mv_definer ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_definer"); + assertUpdate(adminSession, "DROP TABLE secure_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityInvokerUsesCurrentUserPermissions() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE invoker_base (id BIGINT, data VARCHAR, value BIGINT)"); + assertUpdate(adminSession, "INSERT INTO invoker_base VALUES (1, 'data1', 100), (2, 'data2', 200)", 2); + + try { + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "invoker_base", SELECT_COLUMN)); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_invoker " + + "SECURITY INVOKER AS " + + "SELECT id, data, value FROM invoker_base"); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_invoker", "SELECT 2"); + assertQuery(adminSession, "SELECT id, value FROM mv_invoker ORDER BY id", + "VALUES (1, 100), (2, 200)"); + + assertQueryFails(restrictedSession, "SELECT COUNT(*) FROM mv_invoker", + ".*Access Denied.*"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_invoker"); + assertUpdate(adminSession, "DROP TABLE invoker_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testRefreshWithSecurityDefiner() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE refresh_definer_base (id BIGINT, value BIGINT)"); + assertUpdate(adminSession, "INSERT INTO refresh_definer_base VALUES (1, 100), (2, 200)", 2); + + try { + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "refresh_definer_base", SELECT_COLUMN)); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_refresh_definer " + + "SECURITY DEFINER AS " + + "SELECT id, value FROM refresh_definer_base"); + + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_refresh_definer", 2); + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_refresh_definer", "SELECT 2"); + + assertUpdate(adminSession, "INSERT INTO refresh_definer_base VALUES (3, 300)", 1); + + assertUpdate(restrictedSession, "REFRESH MATERIALIZED VIEW mv_refresh_definer", 3); + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_refresh_definer", "SELECT 3"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_refresh_definer"); + assertUpdate(adminSession, "DROP TABLE refresh_definer_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testRefreshWithSecurityInvoker() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE refresh_invoker_base (id BIGINT, value BIGINT)"); + assertUpdate(adminSession, "INSERT INTO refresh_invoker_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_refresh_invoker " + + "SECURITY INVOKER AS " + + "SELECT id, value FROM refresh_invoker_base"); + + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_refresh_invoker", 2); + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_refresh_invoker", "SELECT 2"); + + assertUpdate(adminSession, "INSERT INTO refresh_invoker_base VALUES (3, 300)", 1); + + assertUpdate(restrictedSession, "REFRESH MATERIALIZED VIEW mv_refresh_invoker", 3); + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_refresh_invoker", "SELECT 3"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_refresh_invoker"); + assertUpdate(adminSession, "DROP TABLE refresh_invoker_base"); + } + + @Test + public void testDefaultViewSecurityModeDefiner() + { + Session adminSession = Session.builder(createSessionForUser("admin")) + .setSystemProperty("default_view_security_mode", "DEFINER") + .build(); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE default_definer_base (id BIGINT, secret VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO default_definer_base VALUES (1, 'secret1'), (2, 'secret2')", 2); + + try { + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "default_definer_base", SELECT_COLUMN)); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_default_definer AS " + + "SELECT id, secret FROM default_definer_base"); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_default_definer", "SELECT 2"); + + assertQueryFails(restrictedSession, "SELECT COUNT(*) FROM default_definer_base", ".*Access Denied.*"); + + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_default_definer", "SELECT 2"); + + String showCreate = (String) computeScalar(adminSession, "SHOW CREATE MATERIALIZED VIEW mv_default_definer"); + assertTrue(showCreate.contains("SECURITY DEFINER"), + "SHOW CREATE should include SECURITY DEFINER for MV created with default_view_security_mode=DEFINER"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_default_definer"); + assertUpdate(adminSession, "DROP TABLE default_definer_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testDefaultViewSecurityModeInvoker() + { + Session adminSession = Session.builder(createSessionForUser("admin")) + .setSystemProperty("default_view_security_mode", "INVOKER") + .build(); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE default_invoker_base (id BIGINT, data VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO default_invoker_base VALUES (1, 'data1'), (2, 'data2')", 2); + + try { + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "default_invoker_base", SELECT_COLUMN)); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_default_invoker AS " + + "SELECT id, data FROM default_invoker_base"); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_default_invoker", "SELECT 2"); + + assertQueryFails(restrictedSession, "SELECT COUNT(*) FROM mv_default_invoker", + ".*Access Denied.*"); + + String showCreate = (String) computeScalar(adminSession, "SHOW CREATE MATERIALIZED VIEW mv_default_invoker"); + assertTrue(showCreate.contains("SECURITY INVOKER"), + "SHOW CREATE should include SECURITY INVOKER for MV created with default_view_security_mode=INVOKER"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_default_invoker"); + assertUpdate(adminSession, "DROP TABLE default_invoker_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testAccessControlOnMaterializedViewObject() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE accessible_base (id BIGINT, value BIGINT)"); + assertUpdate(adminSession, "INSERT INTO accessible_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_no_access " + + "SECURITY DEFINER AS " + + "SELECT id, value FROM accessible_base"); + + try { + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "mv_no_access", SELECT_COLUMN)); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_no_access", "SELECT 2"); + + assertQueryFails(restrictedSession, "SELECT COUNT(*) FROM mv_no_access", + ".*Access Denied.*"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_no_access"); + assertUpdate(adminSession, "DROP TABLE accessible_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityInvokerWithRowFiltersAlwaysTreatedAsStale() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE row_filter_base (id BIGINT, user_id BIGINT, value VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO row_filter_base VALUES (1, 1, 'user1_data'), (2, 2, 'user2_data'), (3, 1, 'more_user1')", 3); + + try { + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_with_row_filters " + + "SECURITY INVOKER AS " + + "SELECT id, user_id, value FROM row_filter_base"); + + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_with_row_filters", 3); + + // Add row filter on the base table for restricted_user to trigger staleness + getQueryRunner().getAccessControl().rowFilter( + QualifiedObjectName.valueOf("memory.default.row_filter_base"), + "restricted_user", + new ViewExpression("restricted_user", Optional.empty(), Optional.empty(), "user_id = 999")); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_with_row_filters", "SELECT 3"); + + // Since the row filter is "user_id = 999" and no data matches, should return 0 + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_with_row_filters", "SELECT 0"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_with_row_filters"); + assertUpdate(adminSession, "DROP TABLE row_filter_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityInvokerWithColumnMasksAlwaysTreatedAsStale() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE column_mask_base (id BIGINT, sensitive_data VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO column_mask_base VALUES (1, 'secret1'), (2, 'secret2'), (3, 'secret3')", 3); + + try { + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_with_column_masks " + + "SECURITY INVOKER AS " + + "SELECT id, sensitive_data FROM column_mask_base"); + + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_with_column_masks", 3); + + getQueryRunner().getAccessControl().columnMask( + QualifiedObjectName.valueOf("memory.default.column_mask_base"), + "sensitive_data", + "restricted_user", + new ViewExpression("restricted_user", Optional.empty(), Optional.empty(), "'MASKED'")); + + assertQuery(adminSession, "SELECT sensitive_data FROM mv_with_column_masks WHERE id = 1", "SELECT 'secret1'"); + + // Uses the view query plan that queries the base table as restricted_user, applying the column mask + assertQuery(restrictedSession, "SELECT sensitive_data FROM mv_with_column_masks WHERE id = 1", "SELECT 'MASKED'"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_with_column_masks"); + assertUpdate(adminSession, "DROP TABLE column_mask_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testNestedViewsWithDifferentSecurityModes() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE nested_base (id BIGINT, data VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO nested_base VALUES (1, 'data1'), (2, 'data2')", 2); + + try { + assertUpdate(adminSession, "CREATE VIEW v_inner SECURITY DEFINER AS SELECT * FROM nested_base"); + + assertUpdate(adminSession, "CREATE MATERIALIZED VIEW mv_outer SECURITY INVOKER AS SELECT * FROM v_inner"); + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_outer", 2); + + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "nested_base", SELECT_COLUMN)); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_outer", "SELECT 2"); + + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_outer", "SELECT 2"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_outer"); + assertUpdate(adminSession, "DROP VIEW v_inner"); + assertUpdate(adminSession, "DROP TABLE nested_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testConcurrentAccessWithDifferentSecurityContexts() + { + Session adminSession = createSessionForUser("admin"); + Session user1Session = createSessionForUser("user1"); + Session user2Session = createSessionForUser("user2"); + + assertUpdate(adminSession, "CREATE TABLE concurrent_base (id BIGINT, value VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO concurrent_base VALUES (1, 'a'), (2, 'b'), (3, 'c')", 3); + + try { + assertUpdate(adminSession, "CREATE MATERIALIZED VIEW mv_concurrent SECURITY INVOKER AS SELECT * FROM concurrent_base"); + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_concurrent", 3); + + getQueryRunner().getAccessControl().deny(privilege("user1", "concurrent_base", SELECT_COLUMN)); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_concurrent", "SELECT 3"); + + assertQueryFails(user1Session, "SELECT COUNT(*) FROM mv_concurrent", + ".*Access Denied.*concurrent_base.*"); + + assertQuery(user2Session, "SELECT COUNT(*) FROM mv_concurrent", "SELECT 3"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_concurrent"); + assertUpdate(adminSession, "DROP TABLE concurrent_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testDefinerModeWithRowFilters() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE row_filter_base (id BIGINT, owner VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO row_filter_base VALUES (1, 'admin'), (2, 'other'), (3, 'admin')", 3); + + try { + assertUpdate(adminSession, "CREATE MATERIALIZED VIEW mv_definer_filter SECURITY DEFINER AS SELECT * FROM row_filter_base"); + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_definer_filter", 3); + + getQueryRunner().getAccessControl().rowFilter( + QualifiedObjectName.valueOf("memory.default.row_filter_base"), + "restricted_user", + new ViewExpression("restricted_user", Optional.empty(), Optional.empty(), "owner = 'restricted_user'")); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_definer_filter", "SELECT 3"); + + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_definer_filter", "SELECT 3"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_definer_filter"); + assertUpdate(adminSession, "DROP TABLE row_filter_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testInvokerModeWithRowFilters() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE row_filter_invoker_base (id BIGINT, data VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO row_filter_invoker_base VALUES (1, 'visible'), (2, 'hidden'), (3, 'visible')", 3); + + try { + assertUpdate(adminSession, "CREATE MATERIALIZED VIEW mv_invoker_filter SECURITY INVOKER AS SELECT * FROM row_filter_invoker_base"); + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_invoker_filter", 3); + + getQueryRunner().getAccessControl().rowFilter( + QualifiedObjectName.valueOf("memory.default.row_filter_invoker_base"), + "restricted_user", + new ViewExpression("restricted_user", Optional.empty(), Optional.empty(), "data = 'visible'")); + + assertQuery(adminSession, "SELECT COUNT(*) FROM mv_invoker_filter", "SELECT 3"); + + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_invoker_filter", "SELECT 2"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_invoker_filter"); + assertUpdate(adminSession, "DROP TABLE row_filter_invoker_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityDefinerWithDataConsistencyDisabled() + { + Session adminSession = Session.builder(createSessionForUser("admin")) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + Session restrictedSession = Session.builder(createSessionForUser("restricted_user")) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + assertUpdate(adminSession, "CREATE TABLE bypass_test_base (id BIGINT, secret VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO bypass_test_base VALUES (1, 'confidential'), (2, 'classified')", 2); + + try { + // Create a SECURITY INVOKER materialized view + assertUpdate(restrictedSession, + "CREATE MATERIALIZED VIEW mv_bypass_test " + + "SECURITY DEFINER AS " + + "SELECT id, secret FROM bypass_test_base"); + + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_bypass_test", 2); + + // Deny restricted_user's ability to delegate access through views + // (ViewAccessControl checks CREATE_VIEW_WITH_SELECT_COLUMNS for DEFINER mode) + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "bypass_test_base", CREATE_VIEW_WITH_SELECT_COLUMNS)); + + // restricted_user (owner) can still query their own MV (uses regular access control, only checks SELECT) + assertQuery(restrictedSession, "SELECT COUNT(*) FROM mv_bypass_test", "SELECT 2"); + + // And restricted_user (owner) can still read directly from the base table; only delegation is restricted + assertQuery(restrictedSession, "SELECT COUNT(*) FROM bypass_test_base", "SELECT 2"); + + // But admin cannot access the MV because restricted_user lacks CREATE_VIEW_WITH_SELECT_COLUMNS + assertQueryFails(adminSession, "SELECT COUNT(*) FROM mv_bypass_test", + ".*View owner 'restricted_user' cannot create view that selects from.*bypass_test_base.*"); + + // Reset access control and show that admin can query again + getQueryRunner().getAccessControl().reset(); + assertQuerySucceeds(adminSession, "SELECT COUNT(*) FROM mv_bypass_test"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_bypass_test"); + assertUpdate(adminSession, "DROP TABLE bypass_test_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSecurityDefinerValidatesDefinerViewPermissions() + { + // Test that SECURITY DEFINER mode checks the definer's CREATE_VIEW_WITH_SELECT_COLUMNS + // permission at query time (for non-owner queries), matching regular view behavior. + + Session aliceSession = createSessionForUser("alice"); + Session bobSession = createSessionForUser("bob"); + + assertUpdate("CREATE TABLE sensitive_data (id BIGINT, secret VARCHAR)"); + assertUpdate("INSERT INTO sensitive_data VALUES (1, 'confidential'), (2, 'classified')", 2); + + try { + // Alice creates MV with SECURITY DEFINER + assertUpdate(aliceSession, + "CREATE MATERIALIZED VIEW alice_mv SECURITY DEFINER AS SELECT * FROM sensitive_data"); + assertUpdate("REFRESH MATERIALIZED VIEW alice_mv", 2); + + // Verify Alice and Bob can query it + assertQuery(aliceSession, "SELECT COUNT(*) FROM alice_mv", "SELECT 2"); + assertQuery(bobSession, "SELECT COUNT(*) FROM alice_mv", "SELECT 2"); + + // Revoke Alice's CREATE_VIEW_WITH_SELECT_COLUMNS permission on base table + // (this is what ViewAccessControl checks for non-owner queries) + getQueryRunner().getAccessControl().deny(privilege("alice", "sensitive_data", CREATE_VIEW_WITH_SELECT_COLUMNS)); + + // Alice (owner) can still query her own MV (uses regular access control) + assertQuery(aliceSession, "SELECT COUNT(*) FROM alice_mv", "SELECT 2"); + + // Bob should NOT be able to query it (definer lacks CREATE_VIEW_WITH_SELECT_COLUMNS) + assertQueryFails(bobSession, "SELECT COUNT(*) FROM alice_mv", + ".*View owner 'alice' cannot create view that selects from.*sensitive_data.*"); + + assertUpdate("DROP MATERIALIZED VIEW alice_mv"); + assertUpdate("DROP TABLE sensitive_data"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testShowCreateMaterializedViewAccessDenied() + { + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE show_create_base (id BIGINT, value VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO show_create_base VALUES (1, 'test')", 1); + + try { + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_show_create_test SECURITY DEFINER AS " + + "SELECT id, value FROM show_create_base"); + + // Admin can show create + String showCreate = (String) computeScalar(adminSession, "SHOW CREATE MATERIALIZED VIEW mv_show_create_test"); + assertTrue(showCreate.contains("mv_show_create_test")); + + // Deny SHOW_CREATE_TABLE for restricted user on the MV + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "mv_show_create_test", SHOW_CREATE_TABLE)); + + // Restricted user should be denied + assertQueryFails(restrictedSession, + "SHOW CREATE MATERIALIZED VIEW mv_show_create_test", + ".*Cannot show create table.*mv_show_create_test.*"); + } + finally { + assertUpdate(adminSession, "DROP MATERIALIZED VIEW IF EXISTS mv_show_create_test"); + assertUpdate(adminSession, "DROP TABLE IF EXISTS show_create_base"); + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testDefinerMvPreventsPrivilegeEscalation() + { + Session aliceSession = createSessionForUser("alice"); + Session bobSession = createSessionForUser("bob"); + + assertUpdate("CREATE TABLE escalation_test_base (id BIGINT, secret VARCHAR)"); + assertUpdate("INSERT INTO escalation_test_base VALUES (1, 'confidential'), (2, 'classified')", 2); + + try { + // Alice has SELECT but NOT CREATE_VIEW_WITH_SELECT_COLUMNS on the base table + // (This simulates Alice having read access but not permission to delegate access) + getQueryRunner().getAccessControl().deny(privilege("alice", "escalation_test_base", CREATE_VIEW_WITH_SELECT_COLUMNS)); + + // Bob has no access to the base table + getQueryRunner().getAccessControl().deny(privilege("bob", "escalation_test_base", SELECT_COLUMN)); + + // Alice creates a DEFINER MV - creation succeeds (permissions checked at query time) + assertUpdate(aliceSession, + "CREATE MATERIALIZED VIEW mv_escalation_test SECURITY DEFINER AS " + + "SELECT id, secret FROM escalation_test_base"); + + // Alice refreshes the MV (this should work as she has SELECT) + assertUpdate(aliceSession, "REFRESH MATERIALIZED VIEW mv_escalation_test", 2); + + // Bob should NOT be able to access data through Alice's DEFINER MV + // because Alice lacks CREATE_VIEW_WITH_SELECT_COLUMNS (the privilege to delegate access). + assertQueryFails(bobSession, "SELECT * FROM mv_escalation_test", + ".*View owner 'alice' cannot create view that selects from.*escalation_test_base.*"); + + // Alice (the owner) querying her own MV should succeed. + assertQuery(aliceSession, "SELECT COUNT(*) FROM mv_escalation_test", "SELECT 2"); + + assertUpdate(aliceSession, "DROP MATERIALIZED VIEW mv_escalation_test"); + assertUpdate("DROP TABLE escalation_test_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testSessionUserDoesNotNeedSelectOnBaseTableForDefinerMv() + { + // Test that session user doesn't need SELECT or CREATE_VIEW_WITH_SELECT_COLUMNS + // on underlying table for SECURITY DEFINER MVs + // This mirrors AbstractTestDistributedQueries.testViewAccessControl() lines 1297-1302 + Session mvOwnerSession = createSessionForUser("mv_owner"); + Session queryingSession = createSessionForUser("querying_user"); + + assertUpdate("CREATE TABLE session_user_base (id BIGINT, secret VARCHAR)"); + assertUpdate("INSERT INTO session_user_base VALUES (1, 'secret1'), (2, 'secret2')", 2); + + try { + // Create SECURITY DEFINER MV + assertUpdate(mvOwnerSession, + "CREATE MATERIALIZED VIEW mv_session_user SECURITY DEFINER AS " + + "SELECT id, secret FROM session_user_base"); + + assertUpdate(mvOwnerSession, "REFRESH MATERIALIZED VIEW mv_session_user", 2); + + // Deny SELECT and CREATE_VIEW_WITH_SELECT_COLUMNS for the querying user on base table + getQueryRunner().getAccessControl().deny(privilege("querying_user", "session_user_base", SELECT_COLUMN)); + getQueryRunner().getAccessControl().deny(privilege("querying_user", "session_user_base", CREATE_VIEW_WITH_SELECT_COLUMNS)); + + // Querying user should still be able to query the DEFINER MV + assertQuery(queryingSession, "SELECT COUNT(*) FROM mv_session_user", "SELECT 2"); + assertQuery(queryingSession, "SELECT id, secret FROM mv_session_user ORDER BY id", + "VALUES (1, 'secret1'), (2, 'secret2')"); + + assertUpdate(mvOwnerSession, "DROP MATERIALIZED VIEW mv_session_user"); + assertUpdate("DROP TABLE session_user_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + @Test + public void testColumnLevelAccessControlWithSecurityInvoker() + { + // Test column-level access control with SECURITY INVOKER + // Invoker should need access to specific columns being queried + Session adminSession = createSessionForUser("admin"); + Session restrictedSession = createSessionForUser("restricted_user"); + + assertUpdate(adminSession, "CREATE TABLE column_level_base (id BIGINT, public_col VARCHAR, secret_col VARCHAR)"); + assertUpdate(adminSession, "INSERT INTO column_level_base VALUES (1, 'public1', 'secret1'), (2, 'public2', 'secret2')", 2); + + try { + // Create INVOKER MV selecting all columns + assertUpdate(adminSession, + "CREATE MATERIALIZED VIEW mv_column_level SECURITY INVOKER AS " + + "SELECT id, public_col, secret_col FROM column_level_base"); + assertUpdate(adminSession, "REFRESH MATERIALIZED VIEW mv_column_level", 2); + + // Admin can query all columns + assertQuery(adminSession, "SELECT id, public_col, secret_col FROM mv_column_level WHERE id = 1", + "VALUES (1, 'public1', 'secret1')"); + + // Deny restricted_user access to column_level_base entirely + getQueryRunner().getAccessControl().deny(privilege("restricted_user", "column_level_base", SELECT_COLUMN)); + + // Restricted user should be denied access to the INVOKER MV + assertQueryFails(restrictedSession, "SELECT id FROM mv_column_level", + ".*Access Denied.*column_level_base.*"); + + assertUpdate(adminSession, "DROP MATERIALIZED VIEW mv_column_level"); + assertUpdate(adminSession, "DROP TABLE column_level_base"); + } + finally { + getQueryRunner().getAccessControl().reset(); + } + } + + private Session createSessionForUser(String user) + { + return Session.builder(getSession()) + .setIdentity(new Identity(user, Optional.empty())) + .build(); + } +} diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViewPlanner.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViewPlanner.java new file mode 100644 index 0000000000000..080645b6310ea --- /dev/null +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViewPlanner.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.memory; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +/** + * Plan-level tests for materialized views in the Memory connector. + * Tests verify plan structure with legacy_materialized_views=false. + */ +@Test(singleThreaded = true) +public class TestMemoryMaterializedViewPlanner + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("memory") + .setSchema("default") + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) + .setNodeCount(4) + .setExtraProperties(ImmutableMap.of("experimental.allow-legacy-materialized-views-toggle", "true")) + .build(); + + queryRunner.installPlugin(new MemoryPlugin()); + queryRunner.createCatalog("memory", "memory", ImmutableMap.of()); + + return queryRunner; + } + + @Test + public void testMaterializedViewNotRefreshed() + { + assertUpdate("CREATE TABLE base_table (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO base_table VALUES (1, 'Alice', 100), (2, 'Bob', 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW simple_mv AS SELECT id, name, value FROM base_table"); + + assertPlan(getSession(), "SELECT * FROM simple_mv", + anyTree(tableScan("base_table"))); + + assertUpdate("DROP MATERIALIZED VIEW simple_mv"); + assertUpdate("DROP TABLE base_table"); + } + + @Test + public void testMaterializedViewRefreshed() + { + assertUpdate("CREATE TABLE base_table (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO base_table VALUES (1, 100), (2, 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW mv AS SELECT id, value FROM base_table"); + assertUpdate("REFRESH MATERIALIZED VIEW mv", 2); + + assertPlan("SELECT * FROM mv", + anyTree(tableScan("mv"))); + + assertUpdate("DROP MATERIALIZED VIEW mv"); + assertUpdate("DROP TABLE base_table"); + } + + @Test + public void testQueryDroppedMaterializedView() + { + assertUpdate("CREATE TABLE base_table (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO base_table VALUES (1, 100), (2, 200)", 2); + assertUpdate("CREATE MATERIALIZED VIEW dropped_mv AS SELECT id, value FROM base_table"); + + assertUpdate("DROP MATERIALIZED VIEW dropped_mv"); + + assertQueryFails("SELECT * FROM dropped_mv", ".*Table memory\\.default\\.dropped_mv does not exist.*"); + + assertUpdate("DROP TABLE base_table"); + } +} diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViews.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViews.java new file mode 100644 index 0000000000000..11f072f765b8a --- /dev/null +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMaterializedViews.java @@ -0,0 +1,640 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.memory; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestMemoryMaterializedViews + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("memory") + .setSchema("default") + .setSystemProperty("legacy_materialized_views", "false") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session) + .setNodeCount(4) + .setExtraProperties(ImmutableMap.of("experimental.allow-legacy-materialized-views-toggle", "true")) + .build(); + + queryRunner.installPlugin(new MemoryPlugin()); + queryRunner.createCatalog("memory", "memory", ImmutableMap.of()); + + return queryRunner; + } + + @Test + public void testCreateMaterializedView() + { + assertUpdate("CREATE TABLE base_table (id BIGINT, name VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO base_table VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_simple AS SELECT id, name, value FROM base_table"); + + assertQuery("SELECT COUNT(*) FROM mv_simple", "SELECT 3"); + assertQuery("SELECT * FROM mv_simple ORDER BY id", + "VALUES (1, 'Alice', 100), (2, 'Bob', 200), (3, 'Charlie', 300)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_simple"); + assertUpdate("DROP TABLE base_table"); + } + + @Test + public void testCreateMaterializedViewDuplicateName() + { + assertUpdate("CREATE TABLE dup_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO dup_base VALUES (1, 'test')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW mv_dup AS SELECT id, value FROM dup_base"); + + assertQueryFails("CREATE MATERIALIZED VIEW mv_dup AS SELECT id FROM dup_base", + ".*Materialized view .* already exists.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_dup"); + assertUpdate("DROP TABLE dup_base"); + } + + @Test + public void testCreateMaterializedViewWithFilter() + { + assertUpdate("CREATE TABLE filtered_base (id BIGINT, status VARCHAR, amount BIGINT)"); + assertUpdate("INSERT INTO filtered_base VALUES (1, 'active', 100), (2, 'inactive', 200), (3, 'active', 300)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_filtered AS SELECT id, amount FROM filtered_base WHERE status = 'active'"); + + assertQuery("SELECT COUNT(*) FROM mv_filtered", "SELECT 2"); + assertQuery("SELECT * FROM mv_filtered ORDER BY id", + "VALUES (1, 100), (3, 300)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_filtered"); + assertUpdate("DROP TABLE filtered_base"); + } + + @Test + public void testCreateMaterializedViewWithComplexFilter() + { + assertUpdate("CREATE TABLE complex_filter_base (id BIGINT, status VARCHAR, amount BIGINT, priority INTEGER)"); + assertUpdate("INSERT INTO complex_filter_base VALUES (1, 'active', 100, 1), (2, 'inactive', 200, 2), (3, 'active', 50, 3), (4, 'active', 150, 1)", 4); + + assertUpdate("CREATE MATERIALIZED VIEW mv_complex_filter AS " + + "SELECT id, amount, priority FROM complex_filter_base " + + "WHERE status = 'active' AND amount > 75 AND priority = 1"); + + assertQuery("SELECT COUNT(*) FROM mv_complex_filter", "SELECT 2"); + assertQuery("SELECT * FROM mv_complex_filter ORDER BY id", + "VALUES (1, 100, 1), (4, 150, 1)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_complex_filter"); + assertUpdate("DROP TABLE complex_filter_base"); + } + + @Test + public void testCreateMaterializedViewWithAggregation() + { + assertUpdate("CREATE TABLE sales (product_id BIGINT, category VARCHAR, revenue BIGINT)"); + assertUpdate("INSERT INTO sales VALUES (1, 'Electronics', 1000), (2, 'Electronics', 1500), (3, 'Books', 500), (4, 'Books', 300)", 4); + + assertUpdate("CREATE MATERIALIZED VIEW mv_category_sales AS " + + "SELECT category, COUNT(*) as product_count, SUM(revenue) as total_revenue " + + "FROM sales GROUP BY category"); + + assertQuery("SELECT COUNT(*) FROM mv_category_sales", "SELECT 2"); + assertQuery("SELECT * FROM mv_category_sales ORDER BY category", + "VALUES ('Books', 2, 800), ('Electronics', 2, 2500)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_category_sales"); + assertUpdate("DROP TABLE sales"); + } + + @Test + public void testCreateMaterializedViewWithComputedColumns() + { + assertUpdate("CREATE TABLE transactions (trans_id BIGINT, amount BIGINT, tax_rate DOUBLE)"); + assertUpdate("INSERT INTO transactions VALUES (1, 100, 0.08), (2, 200, 0.08), (3, 150, 0.10)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_computed AS " + + "SELECT trans_id, amount, tax_rate, " + + "CAST(amount * tax_rate AS BIGINT) as tax_amount, " + + "CAST(amount * (1 + tax_rate) AS BIGINT) as total_amount " + + "FROM transactions"); + + assertQuery("SELECT COUNT(*) FROM mv_computed", "SELECT 3"); + assertQuery("SELECT trans_id, amount, tax_amount, total_amount FROM mv_computed ORDER BY trans_id", + "VALUES (1, 100, 8, 108), (2, 200, 16, 216), (3, 150, 15, 165)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_computed"); + assertUpdate("DROP TABLE transactions"); + } + + @Test + public void testCreateMaterializedViewWithJoin() + { + assertUpdate("CREATE TABLE customer_orders (order_id BIGINT, customer_id BIGINT, amount BIGINT)"); + assertUpdate("CREATE TABLE customers (customer_id BIGINT, customer_name VARCHAR)"); + + assertUpdate("INSERT INTO customer_orders VALUES (1, 100, 50), (2, 200, 75), (3, 100, 25)", 3); + assertUpdate("INSERT INTO customers VALUES (100, 'Alice'), (200, 'Bob')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_customer_orders AS " + + "SELECT o.order_id, c.customer_name, o.amount " + + "FROM customer_orders o JOIN customers c ON o.customer_id = c.customer_id"); + + assertQuery("SELECT COUNT(*) FROM mv_customer_orders", "SELECT 3"); + assertQuery("SELECT * FROM mv_customer_orders ORDER BY order_id", + "VALUES (1, 'Alice', 50), (2, 'Bob', 75), (3, 'Alice', 25)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_customer_orders"); + assertUpdate("DROP TABLE customers"); + assertUpdate("DROP TABLE customer_orders"); + } + + @Test + public void testRefreshMaterializedView() + { + assertUpdate("CREATE TABLE refresh_base (id BIGINT, value BIGINT)"); + assertUpdate("INSERT INTO refresh_base VALUES (1, 100), (2, 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_refresh AS SELECT id, value FROM refresh_base"); + + assertQuery("SELECT COUNT(*) FROM mv_refresh", "SELECT 2"); + assertQuery("SELECT * FROM mv_refresh ORDER BY id", "VALUES (1, 100), (2, 200)"); + + assertUpdate("INSERT INTO refresh_base VALUES (3, 300)", 1); + + assertQuery("SELECT COUNT(*) FROM mv_refresh", "SELECT 3"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_refresh", 3); + + assertQuery("SELECT COUNT(*) FROM mv_refresh", "SELECT 3"); + assertQuery("SELECT * FROM mv_refresh ORDER BY id", + "VALUES (1, 100), (2, 200), (3, 300)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_refresh"); + assertUpdate("DROP TABLE refresh_base"); + } + + @Test + public void testRefreshMaterializedViewWithAggregation() + { + assertUpdate("CREATE TABLE agg_refresh_base (category VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO agg_refresh_base VALUES ('A', 10), ('B', 20), ('A', 15)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_agg_refresh AS " + + "SELECT category, SUM(value) as total FROM agg_refresh_base GROUP BY category"); + + assertQuery("SELECT * FROM mv_agg_refresh ORDER BY category", + "VALUES ('A', 25), ('B', 20)"); + + assertUpdate("INSERT INTO agg_refresh_base VALUES ('A', 5), ('C', 30)", 2); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_agg_refresh", 3); + + assertQuery("SELECT * FROM mv_agg_refresh ORDER BY category", + "VALUES ('A', 30), ('B', 20), ('C', 30)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_agg_refresh"); + assertUpdate("DROP TABLE agg_refresh_base"); + } + + @Test + public void testRefreshNonExistentMaterializedView() + { + assertQueryFails("REFRESH MATERIALIZED VIEW mv_nonexistent", + ".*Materialized view .* does not exist.*"); + } + + @Test + public void testDropMaterializedView() + { + assertUpdate("CREATE TABLE drop_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO drop_base VALUES (1, 'test')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW mv_drop AS SELECT id, value FROM drop_base"); + + assertQuery("SELECT COUNT(*) FROM mv_drop", "SELECT 1"); + + assertUpdate("DROP MATERIALIZED VIEW mv_drop"); + + assertQuery("SELECT COUNT(*) FROM drop_base", "SELECT 1"); + + assertUpdate("DROP TABLE drop_base"); + } + + @Test + public void testDropNonExistentMaterializedView() + { + assertQueryFails("DROP MATERIALIZED VIEW mv_nonexistent", + ".*Materialized view .* does not exist.*"); + } + + @Test + public void testCreateMaterializedViewWithEmptyBaseTable() + { + assertUpdate("CREATE TABLE empty_base (id BIGINT, value VARCHAR)"); + + assertUpdate("CREATE MATERIALIZED VIEW mv_empty AS SELECT id, value FROM empty_base"); + + assertQuery("SELECT COUNT(*) FROM mv_empty", "SELECT 0"); + + assertUpdate("DROP MATERIALIZED VIEW mv_empty"); + assertUpdate("DROP TABLE empty_base"); + } + + @Test + public void testMultipleMaterializedViews() + { + assertUpdate("CREATE TABLE multi_base (id BIGINT, category VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO multi_base VALUES (1, 'A', 100), (2, 'B', 200), (3, 'A', 150)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_multi_1 AS SELECT id, value FROM multi_base WHERE category = 'A'"); + assertUpdate("CREATE MATERIALIZED VIEW mv_multi_2 AS SELECT category, SUM(value) as total FROM multi_base GROUP BY category"); + + assertQuery("SELECT COUNT(*) FROM mv_multi_1", "SELECT 2"); + assertQuery("SELECT * FROM mv_multi_1 ORDER BY id", "VALUES (1, 100), (3, 150)"); + + assertQuery("SELECT COUNT(*) FROM mv_multi_2", "SELECT 2"); + assertQuery("SELECT * FROM mv_multi_2 ORDER BY category", + "VALUES ('A', 250), ('B', 200)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_multi_1"); + assertUpdate("DROP MATERIALIZED VIEW mv_multi_2"); + assertUpdate("DROP TABLE multi_base"); + } + + @Test + public void testCreateMaterializedViewWithMultiTableJoin() + { + assertUpdate("CREATE TABLE orders (order_id BIGINT, customer_id BIGINT, product_id BIGINT, quantity BIGINT)"); + assertUpdate("CREATE TABLE customers (customer_id BIGINT, customer_name VARCHAR, region VARCHAR)"); + assertUpdate("CREATE TABLE products (product_id BIGINT, product_name VARCHAR, unit_price BIGINT)"); + + assertUpdate("INSERT INTO orders VALUES (1, 100, 1, 2), (2, 200, 2, 1), (3, 100, 2, 3)", 3); + assertUpdate("INSERT INTO customers VALUES (100, 'Alice', 'East'), (200, 'Bob', 'West')", 2); + assertUpdate("INSERT INTO products VALUES (1, 'Widget', 50), (2, 'Gadget', 75)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_order_details AS " + + "SELECT o.order_id, c.customer_name, c.region, p.product_name, o.quantity, " + + "CAST(p.unit_price * o.quantity AS BIGINT) as total_price " + + "FROM orders o " + + "JOIN customers c ON o.customer_id = c.customer_id " + + "JOIN products p ON o.product_id = p.product_id"); + + assertQuery("SELECT COUNT(*) FROM mv_order_details", "SELECT 3"); + assertQuery("SELECT order_id, customer_name, product_name, total_price FROM mv_order_details ORDER BY order_id", + "VALUES (1, 'Alice', 'Widget', 100), (2, 'Bob', 'Gadget', 75), (3, 'Alice', 'Gadget', 225)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_order_details"); + assertUpdate("DROP TABLE products"); + assertUpdate("DROP TABLE customers"); + assertUpdate("DROP TABLE orders"); + } + + @Test + public void testRefreshMaterializedViewAfterBaseTableDropped() + { + assertUpdate("CREATE TABLE temp_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO temp_base VALUES (1, 'test'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_temp AS SELECT id, value FROM temp_base"); + + assertQuery("SELECT COUNT(*) FROM mv_temp", "SELECT 2"); + + assertUpdate("DROP TABLE temp_base"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_temp", + ".*Table .* does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_temp"); + } + + @Test + public void testMaterializedViewBecomesUnqueryableAfterBaseTableDropped() + { + assertUpdate("CREATE TABLE persist_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO persist_base VALUES (1, 'test'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_persist AS SELECT id, value FROM persist_base"); + + assertQuery("SELECT COUNT(*) FROM mv_persist", "SELECT 2"); + assertQuery("SELECT * FROM mv_persist ORDER BY id", "VALUES (1, 'test'), (2, 'data')"); + + assertUpdate("INSERT INTO persist_base VALUES (3, 'more')", 1); + assertUpdate("REFRESH MATERIALIZED VIEW mv_persist", 3); + + assertQuery("SELECT COUNT(*) FROM mv_persist", "SELECT 3"); + assertQuery("SELECT * FROM mv_persist ORDER BY id", "VALUES (1, 'test'), (2, 'data'), (3, 'more')"); + + assertUpdate("DROP TABLE persist_base"); + + assertQueryFails("SELECT COUNT(*) FROM mv_persist", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_persist", + ".*Table .* does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_persist"); + } + + @Test + public void testMaterializedViewStalenessDetection() + { + assertUpdate("CREATE TABLE base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO base VALUES (1, 'first')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW mv AS SELECT id, value FROM base"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv", 1); + assertQuery("SELECT * FROM mv", "VALUES (1, 'first')"); + + assertUpdate("INSERT INTO base VALUES (2, 'second')", 1); + assertQuery("SELECT COUNT(*) FROM mv", "SELECT 2"); + assertUpdate("REFRESH MATERIALIZED VIEW mv", 2); + assertQuery("SELECT COUNT(*) FROM mv", "SELECT 2"); + + assertUpdate("INSERT INTO base VALUES (3, 'third')", 1); + assertQuery("SELECT COUNT(*) FROM mv", "SELECT 3"); + assertUpdate("REFRESH MATERIALIZED VIEW mv", 3); + assertQuery("SELECT COUNT(*) FROM mv", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW mv"); + assertUpdate("DROP TABLE base"); + } + + @Test + public void testMaterializedViewWithMultipleBaseTables() + { + assertUpdate("CREATE TABLE orders (order_id BIGINT, customer_id BIGINT)"); + assertUpdate("CREATE TABLE customers (customer_id BIGINT, name VARCHAR)"); + + assertUpdate("INSERT INTO orders VALUES (1, 100)", 1); + assertUpdate("INSERT INTO customers VALUES (100, 'Alice')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW mv_join AS " + + "SELECT o.order_id, c.name FROM orders o JOIN customers c ON o.customer_id = c.customer_id"); + assertQuery("SELECT * FROM mv_join", "VALUES (1, 'Alice')"); + assertUpdate("REFRESH MATERIALIZED VIEW mv_join", 1); + assertQuery("SELECT * FROM mv_join", "VALUES (1, 'Alice')"); + + assertUpdate("INSERT INTO orders VALUES (2, 100)", 1); + assertQuery("SELECT COUNT(*) FROM mv_join", "SELECT 2"); + assertUpdate("REFRESH MATERIALIZED VIEW mv_join", 2); + assertQuery("SELECT COUNT(*) FROM mv_join", "SELECT 2"); + + assertUpdate("INSERT INTO customers VALUES (200, 'Bob')", 1); + assertUpdate("INSERT INTO orders VALUES (3, 200)", 1); + assertQuery("SELECT COUNT(*) FROM mv_join", "SELECT 3"); + assertUpdate("REFRESH MATERIALIZED VIEW mv_join", 3); + assertQuery("SELECT COUNT(*) FROM mv_join", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW mv_join"); + assertUpdate("DROP TABLE customers"); + assertUpdate("DROP TABLE orders"); + } + + @Test + public void testMultipleMaterializedViewsIndependentTracking() + { + assertUpdate("CREATE TABLE shared (id BIGINT, category VARCHAR)"); + assertUpdate("INSERT INTO shared VALUES (1, 'A'), (2, 'B')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv1 AS SELECT * FROM shared WHERE category = 'A'"); + assertUpdate("CREATE MATERIALIZED VIEW mv2 AS SELECT * FROM shared WHERE category = 'B'"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv1", 1); + assertUpdate("REFRESH MATERIALIZED VIEW mv2", 1); + + assertUpdate("INSERT INTO shared VALUES (3, 'A'), (4, 'B')", 2); + + assertQuery("SELECT COUNT(*) FROM mv1", "SELECT 2"); + assertUpdate("REFRESH MATERIALIZED VIEW mv1", 2); + assertQuery("SELECT COUNT(*) FROM mv1", "SELECT 2"); + + assertQuery("SELECT COUNT(*) FROM mv2", "SELECT 2"); + assertUpdate("REFRESH MATERIALIZED VIEW mv2", 2); + assertQuery("SELECT COUNT(*) FROM mv2", "SELECT 2"); + + assertUpdate("DROP MATERIALIZED VIEW mv1"); + assertUpdate("DROP MATERIALIZED VIEW mv2"); + assertUpdate("DROP TABLE shared"); + } + + @Test + public void testMaterializedViewWithDataConsistencyDisabled() + { + assertUpdate("CREATE TABLE consistency_test (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO consistency_test VALUES (1, 'initial'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_consistency AS SELECT id, value FROM consistency_test"); + + Session session = Session.builder(getSession()) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + assertQuery(session, "SELECT COUNT(*) FROM mv_consistency", "SELECT 2"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_consistency", 2); + + assertQuery(session, "SELECT COUNT(*) FROM mv_consistency", "SELECT 2"); + assertQuery(session, "SELECT * FROM mv_consistency ORDER BY id", + "VALUES (1, 'initial'), (2, 'data')"); + + assertUpdate("INSERT INTO consistency_test VALUES (3, 'new')", 1); + + // Still reads fresh data from base tables + assertQuery(session, "SELECT COUNT(*) FROM mv_consistency", "SELECT 3"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_consistency", 3); + assertQuery(session, "SELECT COUNT(*) FROM mv_consistency", "SELECT 3"); + + assertUpdate("DROP MATERIALIZED VIEW mv_consistency"); + assertUpdate("DROP TABLE consistency_test"); + } + + @Test + public void testMaterializedViewStalenessWithDataConsistencyDisabled() + { + assertUpdate("CREATE TABLE stale_base (id BIGINT, category VARCHAR, amount BIGINT)"); + assertUpdate("INSERT INTO stale_base VALUES (1, 'A', 100), (2, 'B', 200)", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_stale AS " + + "SELECT category, SUM(amount) as total FROM stale_base GROUP BY category"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_stale", 2); + + Session sessionWithConsistencyDisabled = Session.builder(getSession()) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + assertQuery(sessionWithConsistencyDisabled, "SELECT * FROM mv_stale ORDER BY category", + "VALUES ('A', 100), ('B', 200)"); + + assertUpdate("INSERT INTO stale_base VALUES (3, 'A', 50), (4, 'C', 150)", 2); + + assertQuery(sessionWithConsistencyDisabled, "SELECT * FROM mv_stale ORDER BY category", + "VALUES ('A', 150), ('B', 200), ('C', 150)"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_stale", 3); + assertQuery(sessionWithConsistencyDisabled, "SELECT * FROM mv_stale ORDER BY category", + "VALUES ('A', 150), ('B', 200), ('C', 150)"); + + assertUpdate("DROP MATERIALIZED VIEW mv_stale"); + assertUpdate("DROP TABLE stale_base"); + } + + @Test + public void testMaterializedViewBecomesUnqueryableAfterBaseTableRenamed() + { + assertUpdate("CREATE TABLE rename_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO rename_base VALUES (1, 'test'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_rename AS SELECT id, value FROM rename_base"); + + assertQuery("SELECT COUNT(*) FROM mv_rename", "SELECT 2"); + assertQuery("SELECT * FROM mv_rename ORDER BY id", "VALUES (1, 'test'), (2, 'data')"); + + assertUpdate("INSERT INTO rename_base VALUES (3, 'more')", 1); + assertUpdate("REFRESH MATERIALIZED VIEW mv_rename", 3); + + assertQuery("SELECT COUNT(*) FROM mv_rename", "SELECT 3"); + assertQuery("SELECT * FROM mv_rename ORDER BY id", "VALUES (1, 'test'), (2, 'data'), (3, 'more')"); + + assertUpdate("ALTER TABLE rename_base RENAME TO rename_base_new"); + + assertQueryFails("SELECT COUNT(*) FROM mv_rename", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_rename", + ".*Table .* does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_rename"); + assertUpdate("DROP TABLE rename_base_new"); + } + + @Test + public void testMaterializedViewWithDataConsistencyDisabledAfterBaseTableDropped() + { + assertUpdate("CREATE TABLE drop_consistency_test (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO drop_consistency_test VALUES (1, 'initial'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_drop_consistency AS SELECT id, value FROM drop_consistency_test"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_drop_consistency", 2); + + Session session = Session.builder(getSession()) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + assertQuery(session, "SELECT COUNT(*) FROM mv_drop_consistency", "SELECT 2"); + assertQuery(session, "SELECT * FROM mv_drop_consistency ORDER BY id", + "VALUES (1, 'initial'), (2, 'data')"); + + assertUpdate("DROP TABLE drop_consistency_test"); + + assertQueryFails(session, "SELECT COUNT(*) FROM mv_drop_consistency", + ".*Table .* does not exist.*"); + assertQueryFails(session, "SELECT * FROM mv_drop_consistency ORDER BY id", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_drop_consistency", + ".*Table .* does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_drop_consistency"); + } + + @Test + public void testMaterializedViewWithDataConsistencyDisabledAfterBaseTableRenamed() + { + assertUpdate("CREATE TABLE rename_consistency_test (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO rename_consistency_test VALUES (1, 'initial'), (2, 'data')", 2); + + assertUpdate("CREATE MATERIALIZED VIEW mv_rename_consistency AS SELECT id, value FROM rename_consistency_test"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_rename_consistency", 2); + + Session session = Session.builder(getSession()) + .setSystemProperty("materialized_view_data_consistency_enabled", "false") + .build(); + + assertQuery(session, "SELECT COUNT(*) FROM mv_rename_consistency", "SELECT 2"); + assertQuery(session, "SELECT * FROM mv_rename_consistency ORDER BY id", + "VALUES (1, 'initial'), (2, 'data')"); + + assertUpdate("ALTER TABLE rename_consistency_test RENAME TO rename_consistency_test_new"); + + assertQueryFails(session, "SELECT COUNT(*) FROM mv_rename_consistency", + ".*Table .* does not exist.*"); + assertQueryFails(session, "SELECT * FROM mv_rename_consistency ORDER BY id", + ".*Table .* does not exist.*"); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_rename_consistency", + ".*Table .* does not exist.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_rename_consistency"); + assertUpdate("DROP TABLE rename_consistency_test_new"); + } + + @Test + public void testRefreshMaterializedViewWithWhereClause() + { + assertUpdate("CREATE TABLE where_base (id BIGINT, category VARCHAR, value BIGINT)"); + assertUpdate("INSERT INTO where_base VALUES (1, 'A', 100), (2, 'B', 200), (3, 'A', 150)", 3); + + assertUpdate("CREATE MATERIALIZED VIEW mv_where AS SELECT id, category, value FROM where_base"); + + assertUpdate("REFRESH MATERIALIZED VIEW mv_where", 3); + + assertQueryFails("REFRESH MATERIALIZED VIEW mv_where WHERE category = 'A'", + ".*WHERE clause in REFRESH MATERIALIZED VIEW is not supported.*"); + + assertUpdate("DROP MATERIALIZED VIEW mv_where"); + assertUpdate("DROP TABLE where_base"); + } + + @Test + public void testShowCreateIncludesSecurityMode() + { + assertUpdate("CREATE TABLE show_security_base (id BIGINT, value VARCHAR)"); + assertUpdate("INSERT INTO show_security_base VALUES (1, 'test')", 1); + + assertUpdate("CREATE MATERIALIZED VIEW mv_show_definer SECURITY DEFINER AS SELECT id, value FROM show_security_base"); + + String definerStatement = (String) computeScalar("SHOW CREATE MATERIALIZED VIEW mv_show_definer"); + assertTrue(definerStatement.contains("SECURITY DEFINER"), + "SHOW CREATE should include SECURITY DEFINER, but got: " + definerStatement); + + assertUpdate("CREATE MATERIALIZED VIEW mv_show_invoker SECURITY INVOKER AS SELECT id, value FROM show_security_base"); + + String invokerStatement = (String) computeScalar("SHOW CREATE MATERIALIZED VIEW mv_show_invoker"); + assertTrue(invokerStatement.contains("SECURITY INVOKER"), + "SHOW CREATE should include SECURITY INVOKER, but got: " + invokerStatement); + + assertUpdate("DROP MATERIALIZED VIEW mv_show_definer"); + assertUpdate("DROP MATERIALIZED VIEW mv_show_invoker"); + assertUpdate("DROP TABLE show_security_base"); + } +} diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java index 529b46cf580e4..3eb2834008551 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryMetadata.java @@ -150,9 +150,9 @@ public void testReadTableBeforeCreationCompleted() assertTrue(tableNames.size() == 1, "Expected exactly one table"); ConnectorTableHandle tableHandle = metadata.getTableHandle(SESSION, tableName); - List tableLayouts = metadata.getTableLayouts(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); - assertTrue(tableLayouts.size() == 1, "Expected exactly one layout."); - ConnectorTableLayout tableLayout = tableLayouts.get(0).getTableLayout(); + ConnectorTableLayoutResult tableLayoutResult = metadata.getTableLayoutForConstraint(SESSION, tableHandle, Constraint.alwaysTrue(), Optional.empty()); + assertTrue(tableLayoutResult != null, "Table layout is null."); + ConnectorTableLayout tableLayout = tableLayoutResult.getTableLayout(); ConnectorTableLayoutHandle tableLayoutHandle = tableLayout.getHandle(); assertTrue(tableLayoutHandle instanceof MemoryTableLayoutHandle); assertTrue(((MemoryTableLayoutHandle) tableLayoutHandle).getDataFragments().isEmpty(), "Data fragments should be empty"); diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java index 32308f8cb2ef8..ad10dc89816a3 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryPagesStore.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.plugin.memory; +import com.facebook.airlift.units.DataSize; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.spi.ConnectorInsertTableHandle; @@ -25,7 +26,6 @@ import com.facebook.presto.testing.TestingConnectorSession; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.airlift.units.DataSize; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; diff --git a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java index 6d33c46a13550..1c2d2bb27f983 100644 --- a/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java +++ b/presto-memory/src/test/java/com/facebook/presto/plugin/memory/TestMemoryWorkerCrash.java @@ -13,18 +13,18 @@ */ package com.facebook.presto.plugin.memory; +import com.facebook.airlift.units.Duration; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; import static com.facebook.airlift.testing.Assertions.assertLessThan; -import static io.airlift.units.Duration.nanosSince; +import static com.facebook.airlift.units.Duration.nanosSince; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; diff --git a/presto-ml/pom.xml b/presto-ml/pom.xml index 6f4bd6dcbc958..d497bf5b26541 100644 --- a/presto-ml/pom.xml +++ b/presto-ml/pom.xml @@ -4,15 +4,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-ml + presto-ml Presto - Machine Learning Plugin presto-plugin ${project.parent.basedir} + true @@ -65,7 +67,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -83,7 +85,7 @@ - io.airlift + com.facebook.airlift units provided @@ -137,5 +139,11 @@ test-jar test + + + com.facebook.presto + presto-main-tests + test + diff --git a/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java b/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java index dffbefd9189aa..4ee697957e540 100644 --- a/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java +++ b/presto-ml/src/test/java/com/facebook/presto/ml/AbstractTestMLFunctions.java @@ -15,13 +15,21 @@ package com.facebook.presto.ml; import com.facebook.presto.operator.scalar.AbstractTestFunctions; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; import org.testng.annotations.BeforeClass; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.FunctionExtractor.extractFunctions; abstract class AbstractTestMLFunctions extends AbstractTestFunctions { + public AbstractTestMLFunctions() + { + super(TEST_SESSION, new FeaturesConfig(), new FunctionsConfig(), false); + } + @BeforeClass protected void registerFunctions() { diff --git a/presto-mongodb/pom.xml b/presto-mongodb/pom.xml index 11d5ddf24ea19..288463c49d3cd 100644 --- a/presto-mongodb/pom.xml +++ b/presto-mongodb/pom.xml @@ -4,17 +4,19 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-mongodb + presto-mongodb Presto - mongodb Connector presto-plugin ${project.parent.basedir} - 3.6.0 - 1.5.0 + 3.12.14 + 1.47.0 + true @@ -37,8 +39,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -72,8 +74,8 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api @@ -100,7 +102,7 @@ - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -112,7 +114,7 @@ - io.airlift + com.facebook.airlift units provided @@ -220,17 +222,15 @@ + - test-mongo-distributed-queries + ci-full-tests org.apache.maven.plugins maven-surefire-plugin - - **/TestMongoDistributedQueries.java - diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientConfig.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientConfig.java index ebbb1833b7f88..ac21c4a29bb4f 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientConfig.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientConfig.java @@ -14,20 +14,24 @@ package com.facebook.presto.mongodb; import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigSecuritySensitive; import com.facebook.airlift.configuration.DefunctConfig; +import com.facebook.airlift.configuration.LegacyConfig; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.mongodb.MongoCredential; import com.mongodb.ServerAddress; import com.mongodb.Tag; import com.mongodb.TagSet; +import jakarta.validation.constraints.AssertTrue; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotNull; -import javax.validation.constraints.Size; - +import java.io.File; import java.util.Arrays; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.mongodb.MongoCredential.createCredential; @@ -39,7 +43,6 @@ public class MongoClientConfig private static final Splitter PORT_SPLITTER = Splitter.on(':').trimResults().omitEmptyStrings(); private static final Splitter TAGSET_SPLITTER = Splitter.on('&').trimResults().omitEmptyStrings(); private static final Splitter TAG_SPLITTER = Splitter.on(':').trimResults().omitEmptyStrings(); - private String schemaCollection = "_schema"; private List seeds = ImmutableList.of(); private List credentials = ImmutableList.of(); @@ -50,7 +53,6 @@ public class MongoClientConfig private int connectionTimeout = 10_000; private int socketTimeout; private boolean socketKeepAlive; - private boolean sslEnabled; // query configurations private int cursorBatchSize; // use driver default @@ -60,6 +62,91 @@ public class MongoClientConfig private WriteConcernType writeConcern = WriteConcernType.ACKNOWLEDGED; private String requiredReplicaSetName; private String implicitRowFieldPrefix = "_pos"; + private boolean tlsEnabled; + private File keystorePath; + private String keystorePassword; + private File truststorePath; + private String truststorePassword; + private boolean caseSensitiveNameMatchingEnabled; + + @AssertTrue(message = "'mongodb.tls.keystore-path', 'mongodb.tls.keystore-password', 'mongodb.tls.truststore-path' and 'mongodb.tls.truststore-password' must be empty when TLS is disabled") + public boolean isValidTlsConfig() + { + if (!tlsEnabled) { + return keystorePath == null && keystorePassword == null && truststorePath == null && truststorePassword == null; + } + // When TLS is enabled, validate keystore and truststore configurations + boolean validKeystore = (keystorePath == null && keystorePassword == null) || + (keystorePath != null && keystorePassword != null); + + boolean validTruststore = (truststorePath == null && truststorePassword == null) || + (truststorePath != null && truststorePassword != null); + + return validKeystore && validTruststore; + } + + public boolean isTlsEnabled() + { + return this.tlsEnabled; + } + + @Config("mongodb.tls.enabled") + @LegacyConfig("mongodb.ssl.enabled") + public MongoClientConfig setTlsEnabled(boolean tlsEnabled) + { + this.tlsEnabled = tlsEnabled; + return this; + } + + public Optional getKeystorePath() + { + return Optional.ofNullable(keystorePath); + } + + @Config("mongodb.tls.keystore-path") + public MongoClientConfig setKeystorePath(File keystorePath) + { + this.keystorePath = keystorePath; + return this; + } + + public Optional getKeystorePassword() + { + return Optional.ofNullable(keystorePassword); + } + + @Config("mongodb.tls.keystore-password") + @ConfigSecuritySensitive + public MongoClientConfig setKeystorePassword(String keystorePassword) + { + this.keystorePassword = keystorePassword; + return this; + } + + public Optional getTruststorePath() + { + return Optional.ofNullable(truststorePath); + } + + @Config("mongodb.tls.truststore-path") + public MongoClientConfig setTruststorePath(File truststorePath) + { + this.truststorePath = truststorePath; + return this; + } + + public Optional getTruststorePassword() + { + return Optional.ofNullable(truststorePassword); + } + + @Config("mongodb.tls.truststore-password") + @ConfigSecuritySensitive + public MongoClientConfig setTruststorePassword(String truststorePassword) + { + this.truststorePassword = truststorePassword; + return this; + } @NotNull public String getSchemaCollection() @@ -318,15 +405,15 @@ public MongoClientConfig setImplicitRowFieldPrefix(String implicitRowFieldPrefix return this; } - public boolean getSslEnabled() + public boolean isCaseSensitiveNameMatchingEnabled() { - return this.sslEnabled; + return caseSensitiveNameMatchingEnabled; } - @Config("mongodb.ssl.enabled") - public MongoClientConfig setSslEnabled(boolean sslEnabled) + @Config("case-sensitive-name-matching") + public MongoClientConfig setCaseSensitiveNameMatchingEnabled(boolean caseSensitiveNameMatchingEnabled) { - this.sslEnabled = sslEnabled; + this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled; return this; } } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientModule.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientModule.java index 0043ffe230e16..a0c0a5e33f4cd 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientModule.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoClientModule.java @@ -14,14 +14,14 @@ package com.facebook.presto.mongodb; import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.plugin.base.security.SslContextProvider; import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.mongodb.MongoClient; import com.mongodb.MongoClientOptions; - -import javax.inject.Singleton; +import jakarta.inject.Singleton; import static com.facebook.airlift.configuration.ConfigBinder.configBinder; import static java.util.Objects.requireNonNull; @@ -46,13 +46,11 @@ public static MongoSession createMongoSession(TypeManager typeManager, MongoClie { requireNonNull(config, "config is null"); - MongoClientOptions.Builder options = MongoClientOptions.builder(); - - options.connectionsPerHost(config.getConnectionsPerHost()) + MongoClientOptions.Builder options = MongoClientOptions.builder() + .connectionsPerHost(config.getConnectionsPerHost()) .connectTimeout(config.getConnectionTimeout()) .socketTimeout(config.getSocketTimeout()) .socketKeepAlive(config.getSocketKeepAlive()) - .sslEnabled(config.getSslEnabled()) .maxWaitTime(config.getMaxWaitTime()) .minConnectionsPerHost(config.getMinConnectionsPerHost()) .writeConcern(config.getWriteConcern().getWriteConcern()); @@ -61,18 +59,38 @@ public static MongoSession createMongoSession(TypeManager typeManager, MongoClie options.requiredReplicaSetName(config.getRequiredReplicaSetName()); } + configureReadPreference(options, config); + configureSsl(options, config); + + MongoClient client = new MongoClient(config.getSeeds(), config.getCredentials(), options.build()); + + return new MongoSession(typeManager, client, config); + } + + private static void configureReadPreference(MongoClientOptions.Builder options, MongoClientConfig config) + { if (config.getReadPreferenceTags().isEmpty()) { options.readPreference(config.getReadPreference().getReadPreference()); } else { options.readPreference(config.getReadPreference().getReadPreferenceWithTags(config.getReadPreferenceTags())); } + } - MongoClient client = new MongoClient(config.getSeeds(), config.getCredentials(), options.build()); + private static void configureSsl(MongoClientOptions.Builder options, MongoClientConfig config) + { + if (config.isTlsEnabled()) { + SslContextProvider sslContextProvider = new SslContextProvider( + config.getKeystorePath(), + config.getKeystorePassword(), + config.getTruststorePath(), + config.getTruststorePassword()); - return new MongoSession( - typeManager, - client, - config); + sslContextProvider.buildSslContext() + .ifPresent(sslContext -> { + options.sslContext(sslContext); + options.sslEnabled(true); + }); + } } } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoColumnHandle.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoColumnHandle.java index 582ff78bbfe5f..909ff6bc27fbb 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoColumnHandle.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoColumnHandle.java @@ -72,6 +72,15 @@ public ColumnMetadata toColumnMetadata() .build(); } + public ColumnMetadata toColumnMetadata(String name) + { + return ColumnMetadata.builder() + .setName(name) + .setType(type) + .setHidden(hidden) + .build(); + } + public Document getDocument() { return new Document().append("name", name) diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoConnector.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoConnector.java index b59394c81aa98..f7feee5e7eaf3 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoConnector.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoConnector.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.facebook.presto.spi.transaction.IsolationLevel; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -42,18 +41,21 @@ public class MongoConnector private final MongoPageSinkProvider pageSinkProvider; private final ConcurrentMap transactions = new ConcurrentHashMap<>(); + private final MongoClientConfig mongoClientConfig; @Inject public MongoConnector( MongoSession mongoSession, MongoSplitManager splitManager, MongoPageSourceProvider pageSourceProvider, - MongoPageSinkProvider pageSinkProvider) + MongoPageSinkProvider pageSinkProvider, + MongoClientConfig mongoClientConfig) { this.mongoSession = mongoSession; this.splitManager = requireNonNull(splitManager, "splitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); this.pageSinkProvider = requireNonNull(pageSinkProvider, "pageSinkProvider is null"); + this.mongoClientConfig = requireNonNull(mongoClientConfig, "mongoClientConfig is null"); } @Override @@ -61,7 +63,7 @@ public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel { checkConnectorSupports(READ_UNCOMMITTED, isolationLevel); MongoTransactionHandle transaction = new MongoTransactionHandle(); - transactions.put(transaction, new MongoMetadata(mongoSession)); + transactions.put(transaction, new MongoMetadata(mongoSession, mongoClientConfig)); return transaction; } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoMetadata.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoMetadata.java index a7868852e55f0..cc348ff4ad08a 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoMetadata.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoMetadata.java @@ -49,7 +49,7 @@ import java.util.concurrent.atomic.AtomicReference; import static com.google.common.base.Preconditions.checkState; -import static java.util.Locale.ENGLISH; +import static java.util.Locale.ROOT; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -61,10 +61,12 @@ public class MongoMetadata private final MongoSession mongoSession; private final AtomicReference rollbackAction = new AtomicReference<>(); + private final MongoClientConfig mongoClientConfig; - public MongoMetadata(MongoSession mongoSession) + public MongoMetadata(MongoSession mongoSession, MongoClientConfig mongoClientConfig) { this.mongoSession = requireNonNull(mongoSession, "mongoSession is null"); + this.mongoClientConfig = mongoClientConfig; } @Override @@ -101,7 +103,7 @@ public List listTables(ConnectorSession session, String schemaN for (String schemaName : listSchemas(session, schemaNameOrNull)) { for (String tableName : mongoSession.getAllTables(schemaName)) { - tableNames.add(new SchemaTableName(schemaName, tableName.toLowerCase(ENGLISH))); + tableNames.add(new SchemaTableName(schemaName, normalizeIdentifier(session, tableName))); } } return tableNames.build(); @@ -151,7 +153,11 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable } @Override - public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + public ConnectorTableLayoutResult getTableLayoutForConstraint( + ConnectorSession session, + ConnectorTableHandle table, + Constraint constraint, + Optional> desiredColumns) { MongoTableHandle tableHandle = (MongoTableHandle) table; @@ -181,7 +187,7 @@ public List getTableLayouts(ConnectorSession session Optional.empty(), localProperties.build()); - return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + return new ConnectorTableLayoutResult(layout, constraint.getSummary()); } @Override @@ -190,8 +196,7 @@ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTa MongoTableLayoutHandle layout = (MongoTableLayoutHandle) handle; // tables in this connector have a single layout - return getTableLayouts(session, layout.getTable(), Constraint.alwaysTrue(), Optional.empty()) - .get(0) + return getTableLayoutForConstraint(session, layout.getTable(), Constraint.alwaysTrue(), Optional.empty()) .getTableLayout(); } @@ -281,7 +286,7 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema List columns = ImmutableList.copyOf( getColumnHandles(session, tableHandle).values().stream() .map(MongoColumnHandle.class::cast) - .map(MongoColumnHandle::toColumnMetadata) + .map(column -> column.toColumnMetadata(normalizeIdentifier(session, column.getName()))) .collect(toList())); return new ConnectorTableMetadata(tableName, columns); @@ -319,4 +324,10 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl { mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).getName()); } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return mongoClientConfig.isCaseSensitiveNameMatchingEnabled() ? identifier : identifier.toLowerCase(ROOT); + } } diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSinkProvider.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSinkProvider.java index c2859865c5821..0cc77b794cb3a 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSinkProvider.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSinkProvider.java @@ -20,8 +20,7 @@ import com.facebook.presto.spi.PageSinkContext; import com.facebook.presto.spi.connector.ConnectorPageSinkProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; - -import javax.inject.Inject; +import jakarta.inject.Inject; import static com.google.common.base.Preconditions.checkArgument; diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSourceProvider.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSourceProvider.java index 6e9754184d5b9..e4ef71ce59adf 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSourceProvider.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoPageSourceProvider.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java index 9f39b701a21f3..c73f949ac551d 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSession.java @@ -404,7 +404,7 @@ private Document getTableMetadata(SchemaTableName schemaTableName) public boolean collectionExists(MongoDatabase db, String collectionName) { for (String name : db.listCollectionNames()) { - if (name.equalsIgnoreCase(collectionName)) { + if (name.equals(collectionName)) { return true; } } @@ -457,14 +457,10 @@ private boolean deleteTableMetadata(SchemaTableName schemaTableName) String tableName = schemaTableName.getTableName(); MongoDatabase db = client.getDatabase(schemaName); - if (!collectionExists(db, tableName)) { - return false; - } - DeleteResult result = db.getCollection(schemaCollection) .deleteOne(new Document(TABLE_NAME_KEY, tableName)); - return result.getDeletedCount() == 1; + return result.getDeletedCount() == 1 || !collectionExists(db, tableName); } private List guessTableFields(SchemaTableName schemaTableName) @@ -583,8 +579,8 @@ public void addColumn(MongoTableHandle table, ColumnMetadata columnMetadata) Document newColumn = new Document() .append(FIELDS_NAME_KEY, columnMetadata.getName()) .append(FIELDS_TYPE_KEY, columnMetadata.getType().getTypeSignature().toString()) - .append(COMMENT_KEY, columnMetadata.getComment()) .append(FIELDS_HIDDEN_KEY, false); + columnMetadata.getComment().ifPresent(comment -> newColumn.append(COMMENT_KEY, comment)); columns.add(newColumn); metadata.append(FIELDS_KEY, columns); diff --git a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSplitManager.java b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSplitManager.java index ac8f76701a116..8012387ca9141 100644 --- a/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSplitManager.java +++ b/presto-mongodb/src/main/java/com/facebook/presto/mongodb/MongoSplitManager.java @@ -21,8 +21,7 @@ import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; import com.google.common.collect.ImmutableList; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.util.List; diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/MongoQueryRunner.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/MongoQueryRunner.java index 9c07bed90d15a..1b079c6c93645 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/MongoQueryRunner.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/MongoQueryRunner.java @@ -24,6 +24,7 @@ import io.airlift.tpch.TpchTable; import java.net.InetSocketAddress; +import java.util.HashMap; import java.util.Map; import static com.facebook.airlift.testing.Closeables.closeAllSuppress; @@ -53,10 +54,10 @@ private MongoQueryRunner(Session session, int workers) public static MongoQueryRunner createMongoQueryRunner(TpchTable... tables) throws Exception { - return createMongoQueryRunner(ImmutableList.copyOf(tables)); + return createMongoQueryRunner(ImmutableList.copyOf(tables), ImmutableMap.of()); } - public static MongoQueryRunner createMongoQueryRunner(Iterable> tables) + public static MongoQueryRunner createMongoQueryRunner(Iterable> tables, Map connectorProperties) throws Exception { MongoQueryRunner queryRunner = null; @@ -66,12 +67,12 @@ public static MongoQueryRunner createMongoQueryRunner(Iterable> tab queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch"); - Map properties = ImmutableMap.of( - "mongodb.seeds", queryRunner.getAddress().getHostString() + ":" + queryRunner.getAddress().getPort(), - "mongodb.socket-keep-alive", "true"); + connectorProperties = new HashMap<>(connectorProperties); + connectorProperties.putIfAbsent("mongodb.seeds", queryRunner.getAddress().getHostString() + ":" + queryRunner.getAddress().getPort()); + connectorProperties.putIfAbsent("mongodb.socket-keep-alive", "true"); queryRunner.installPlugin(new MongoPlugin()); - queryRunner.createCatalog("mongodb", "mongodb", properties); + queryRunner.createCatalog("mongodb", "mongodb", connectorProperties); copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, createSession(), tables); diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/SyncMemoryBackend.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/SyncMemoryBackend.java index 828230a075424..c8c7abdcc392c 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/SyncMemoryBackend.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/SyncMemoryBackend.java @@ -13,7 +13,7 @@ */ package com.facebook.presto.mongodb; -import de.bwaldvogel.mongo.MongoBackend; +import de.bwaldvogel.mongo.backend.CursorRegistry; import de.bwaldvogel.mongo.backend.memory.MemoryBackend; import de.bwaldvogel.mongo.backend.memory.MemoryDatabase; import de.bwaldvogel.mongo.exception.MongoServerException; @@ -25,16 +25,16 @@ public class SyncMemoryBackend public MemoryDatabase openOrCreateDatabase(String databaseName) throws MongoServerException { - return new SyncMemoryDatabase(this, databaseName); + return new SyncMemoryDatabase(databaseName, this.getCursorRegistry()); } private static class SyncMemoryDatabase extends MemoryDatabase { - public SyncMemoryDatabase(MongoBackend backend, String databaseName) + public SyncMemoryDatabase(String databaseName, CursorRegistry cursorRegistry) throws MongoServerException { - super(backend, databaseName); + super(databaseName, cursorRegistry); } } } diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoClientConfig.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoClientConfig.java index 42e4c35c2edfe..1e0e9d5e309a6 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoClientConfig.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoClientConfig.java @@ -16,14 +16,45 @@ import com.facebook.airlift.configuration.testing.ConfigAssertions; import com.google.common.collect.ImmutableMap; import com.mongodb.MongoCredential; +import jakarta.validation.constraints.AssertTrue; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Map; +import static com.facebook.airlift.testing.ValidationAssertions.assertFailsValidation; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; public class TestMongoClientConfig { + private static Path keystoreFile; + private static Path truststoreFile; + + @BeforeClass + public void setUp() + throws IOException + { + keystoreFile = Files.createTempFile("test-keystore", ".jks"); + truststoreFile = Files.createTempFile("test-truststore", ".jks"); + } + + @AfterClass + public void tearDown() + throws IOException + { + if (keystoreFile != null) { + Files.deleteIfExists(keystoreFile); + } + if (truststoreFile != null) { + Files.deleteIfExists(truststoreFile); + } + } + @Test public void testDefaults() { @@ -37,17 +68,22 @@ public void testDefaults() .setConnectionTimeout(10_000) .setSocketTimeout(0) .setSocketKeepAlive(false) - .setSslEnabled(false) + .setTlsEnabled(false) + .setKeystorePath(null) + .setKeystorePassword(null) + .setTruststorePath(null) + .setTruststorePassword(null) .setCursorBatchSize(0) .setReadPreference(ReadPreferenceType.PRIMARY) .setReadPreferenceTags("") .setWriteConcern(WriteConcernType.ACKNOWLEDGED) .setRequiredReplicaSetName(null) - .setImplicitRowFieldPrefix("_pos")); + .setImplicitRowFieldPrefix("_pos") + .setCaseSensitiveNameMatchingEnabled(false)); } @Test - public void testExplicitPropertyMappings() + public void testExplicitPropertyMappings() throws IOException { Map properties = new ImmutableMap.Builder() .put("mongodb.schema-collection", "_my_schema") @@ -59,13 +95,18 @@ public void testExplicitPropertyMappings() .put("mongodb.connection-timeout", "9999") .put("mongodb.socket-timeout", "1") .put("mongodb.socket-keep-alive", "true") - .put("mongodb.ssl.enabled", "true") + .put("mongodb.tls.enabled", "true") // Use the primary TLS config + .put("mongodb.tls.keystore-path", keystoreFile.toString()) + .put("mongodb.tls.keystore-password", "keystore-password") + .put("mongodb.tls.truststore-path", truststoreFile.toString()) + .put("mongodb.tls.truststore-password", "truststore-password") .put("mongodb.cursor-batch-size", "1") .put("mongodb.read-preference", "NEAREST") .put("mongodb.read-preference-tags", "tag_name:tag_value") .put("mongodb.write-concern", "UNACKNOWLEDGED") .put("mongodb.required-replica-set", "replica_set") .put("mongodb.implicit-row-field-prefix", "_prefix") + .put("case-sensitive-name-matching", "true") .build(); MongoClientConfig expected = new MongoClientConfig() @@ -77,18 +118,34 @@ public void testExplicitPropertyMappings() .setMaxWaitTime(120_001) .setConnectionTimeout(9_999) .setSocketTimeout(1) - .setSocketKeepAlive(true) - .setSslEnabled(true) - .setCursorBatchSize(1) + .setSocketKeepAlive(true); + + configureTlsProperties(expected, "keystore-password", "truststore-password"); + + expected.setCursorBatchSize(1) .setReadPreference(ReadPreferenceType.NEAREST) .setReadPreferenceTags("tag_name:tag_value") .setWriteConcern(WriteConcernType.UNACKNOWLEDGED) .setRequiredReplicaSetName("replica_set") - .setImplicitRowFieldPrefix("_prefix"); + .setImplicitRowFieldPrefix("_prefix") + .setCaseSensitiveNameMatchingEnabled(true); ConfigAssertions.assertFullMapping(properties, expected); } + @Test + public void testTlsConfigurationWithAllProperties() throws IOException + { + MongoClientConfig config = new MongoClientConfig(); + configureTlsProperties(config, "keystore-password", "truststore-password"); + + assertTrue(config.isTlsEnabled(), "TLS should be enabled when explicitly set to true"); + assertEquals(config.getKeystorePath().get(), keystoreFile.toFile()); + assertEquals(config.getKeystorePassword().get(), "keystore-password"); + assertEquals(config.getTruststorePath().get(), truststoreFile.toFile()); + assertEquals(config.getTruststorePassword().get(), "truststore-password"); + } + @Test public void testSpecialCharacterCredential() { @@ -99,4 +156,54 @@ public void testSpecialCharacterCredential() MongoCredential expected = MongoCredential.createCredential("username", "database", "P@ss:w0rd".toCharArray()); assertEquals(credential, expected); } + + @Test + public void testTlsPropertyValidationFailsIfTlsIsDisabled() + throws Exception + { + assertFailsTlsValidation(new MongoClientConfig().setKeystorePath(keystoreFile.toFile())); + assertFailsTlsValidation(new MongoClientConfig().setKeystorePassword("keystore password")); + assertFailsTlsValidation(new MongoClientConfig().setTruststorePath(truststoreFile.toFile())); + assertFailsTlsValidation(new MongoClientConfig().setTruststorePassword("truststore password")); + } + + @Test + public void testTlsPropertyValidationPassesIfTlsIsEnabled() + throws Exception + { + // These should all pass validation when TLS is enabled + MongoClientConfig config1 = new MongoClientConfig() + .setTlsEnabled(true) + .setKeystorePath(keystoreFile.toFile()) + .setKeystorePassword("keystore password"); + assertTrue(config1.isValidTlsConfig()); + + MongoClientConfig config2 = new MongoClientConfig() + .setTlsEnabled(true) + .setTruststorePath(truststoreFile.toFile()) + .setTruststorePassword("truststore password"); + assertTrue(config2.isValidTlsConfig()); + + MongoClientConfig config3 = new MongoClientConfig(); + configureTlsProperties(config3, "keystore password", "truststore password"); + assertTrue(config3.isValidTlsConfig()); + } + + private static void configureTlsProperties(MongoClientConfig config, String keystorePassword, String truststorePassword) + { + config.setTlsEnabled(true) + .setKeystorePath(keystoreFile.toFile()) + .setKeystorePassword(keystorePassword) + .setTruststorePath(truststoreFile.toFile()) + .setTruststorePassword(truststorePassword); + } + + private static void assertFailsTlsValidation(MongoClientConfig config) + { + assertFailsValidation( + config, + "validTlsConfig", + "'mongodb.tls.keystore-path', 'mongodb.tls.keystore-password', 'mongodb.tls.truststore-path' and 'mongodb.tls.truststore-password' must be empty when TLS is disabled", + AssertTrue.class); + } } diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDbIntegrationMixedCaseTest.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDbIntegrationMixedCaseTest.java new file mode 100644 index 0000000000000..24ba8fb81e1bd --- /dev/null +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDbIntegrationMixedCaseTest.java @@ -0,0 +1,223 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.mongodb; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import com.mongodb.MongoClient; +import io.airlift.tpch.TpchTable; +import org.bson.Document; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.mongodb.MongoQueryRunner.createMongoQueryRunner; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test +public class TestMongoDbIntegrationMixedCaseTest + extends AbstractTestQueryFramework +{ + private MongoQueryRunner mongoQueryRunner; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createMongoQueryRunner(TpchTable.getTables(), ImmutableMap.of("case-sensitive-name-matching", "true")); + } + + @BeforeClass + public void setUp() + { + mongoQueryRunner = (MongoQueryRunner) getQueryRunner(); + } + + @AfterClass(alwaysRun = true) + public final void destroy() + { + if (mongoQueryRunner != null) { + mongoQueryRunner.shutdown(); + } + } + + public void testDescribeTableWithDifferentCaseInSameSchema() + { + try { + getQueryRunner().execute("CREATE TABLE ORDERS AS SELECT * FROM orders"); + + assertTrue(getQueryRunner().tableExists(getSession(), "orders")); + assertTrue(getQueryRunner().tableExists(getSession(), "ORDERS")); + + MaterializedResult actualColumns = computeActual("DESC ORDERS").toTestTypes(); + + MaterializedResult expectedColumns = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar(1)", "", "", null, null, 1L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar(15)", "", "", null, null, 15L) + .row("clerk", "varchar(15)", "", "", null, null, 15L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar(79)", "", "", null, null, 79L) + .build(); + assertEquals(actualColumns, expectedColumns); + } + finally { + assertUpdate("DROP TABLE tpch.ORDERS"); + } + } + + @Test + public void testCreateAndDropTable() + { + Session session = testSessionBuilder() + .setCatalog("mongodb") + .setSchema("Mixed_Test_Database") + .build(); + + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATE(name VARCHAR(50), id int)"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATE")); + assertFalse(getQueryRunner().tableExists(session, "test_create")); + + getQueryRunner().execute(session, "CREATE TABLE test_create(name VARCHAR(50), id int)"); + assertTrue(getQueryRunner().tableExists(session, "test_create")); + assertFalse(getQueryRunner().tableExists(session, "Test_Create")); + } + + finally { + assertUpdate(session, "DROP TABLE TEST_CREATE"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATE")); + + assertUpdate(session, "DROP TABLE test_create"); + assertFalse(getQueryRunner().tableExists(session, "test_create")); + } + } + + @Test + public void testCreateTableAs() + { + Session session = testSessionBuilder() + .setCatalog("mongodb") + .setSchema("Mixed_Test_Database") + .build(); + + try { + getQueryRunner().execute(session, "CREATE TABLE TEST_CTAS AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CTAS")); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS test_ctas AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "test_ctas")); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CTAS_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.orders o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CTAS_Join")); + + assertQueryFails("CREATE TABLE Mixed_Test_Database.TEST_CTAS_FAIL_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.ORDERS1 o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'", "Table mongodb.tpch.ORDERS1 does not exist"); //failure scenario since tpch.ORDERS1 doesn't exist + assertFalse(getQueryRunner().tableExists(session, "TEST_CTAS_FAIL_Join")); + + getQueryRunner().execute(session, "CREATE TABLE Test_CTAS_Mixed_Join AS SELECT Cus.custkey, Ord.orderkey FROM " + + "tpch.customer Cus INNER JOIN tpch.orders Ord ON Cus.custkey = Ord.custkey WHERE Cus.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "Test_CTAS_Mixed_Join")); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_CTAS"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS test_ctas"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_CTAS_Join"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_CTAS_Mixed_Join"); + } + } + + @Test + public void testInsert() + { + Session session = testSessionBuilder() + .setCatalog("mongodb") + .setSchema("Mixed_Test_Database") + .build(); + + try { + getQueryRunner().execute(session, "CREATE TABLE Test_Insert (x bigint, y varchar(100))"); + getQueryRunner().execute(session, "INSERT INTO Test_Insert VALUES (123, 'test')"); + assertTrue(getQueryRunner().tableExists(session, "Test_Insert")); + assertQuery("SELECT * FROM Mixed_Test_Database.Test_Insert", "SELECT 123 x, 'test' y"); + + getQueryRunner().execute(session, "CREATE TABLE TEST_INSERT (x bigint, Y varchar(100))"); + getQueryRunner().execute(session, "INSERT INTO TEST_INSERT VALUES (1234, 'test1')"); + assertTrue(getQueryRunner().tableExists(session, "TEST_INSERT")); + assertQuery("SELECT * FROM Mixed_Test_Database.TEST_INSERT", "SELECT 1234 x, 'test1' Y"); + } + finally { + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Insert"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_INSERT"); + } + } + + @Test + public void testMixedCaseColumns() + { + try { + assertUpdate("CREATE TABLE test (a integer, B integer)"); + assertUpdate("CREATE TABLE TEST (a integer, aA integer)"); + assertUpdate("INSERT INTO TEST VALUES (123, 12)", 1); + assertTableColumnNames("TEST", "a", "aA"); + assertUpdate("ALTER TABLE TEST ADD COLUMN EMAIL varchar"); + assertUpdate("ALTER TABLE TEST RENAME COLUMN a TO a_New"); + assertTableColumnNames("TEST", "a_New", "aA", "EMAIL"); + assertUpdate("ALTER TABLE TEST DROP COLUMN aA"); + assertTableColumnNames("TEST", "a_New", "EMAIL"); + } + finally { + assertUpdate("DROP TABLE test"); + assertUpdate("DROP TABLE TEST"); + } + } + + @Test + public void testShowSchemas() + { + // Create two MongoDB databases directly, since Presto doesn't support create schema for mongodb + MongoClient mongoClient = mongoQueryRunner.getMongoClient(); + try { + mongoClient.getDatabase("TESTDB1").getCollection("dummy").insertOne(new Document("x", 1)); + mongoClient.getDatabase("testdb1").getCollection("dummy").insertOne(new Document("x", 1)); + + MaterializedResult result = computeActual("SHOW SCHEMAS"); + List schemaNames = result.getMaterializedRows().stream() + .map(row -> row.getField(0).toString()) + .collect(Collectors.toList()); + + assertTrue(schemaNames.contains("TESTDB1")); + assertTrue(schemaNames.contains("testdb1")); + } + finally { + mongoClient.getDatabase("TESTDB1").drop(); + mongoClient.getDatabase("testdb1").drop(); + } + } +} diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDistributedQueries.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDistributedQueries.java index 08e6e9f04feac..23c58a846d886 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDistributedQueries.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoDistributedQueries.java @@ -15,6 +15,7 @@ import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueries; +import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -32,7 +33,7 @@ public class TestMongoDistributedQueries protected QueryRunner createQueryRunner() throws Exception { - return createMongoQueryRunner(TpchTable.getTables()); + return createMongoQueryRunner(TpchTable.getTables(), ImmutableMap.of()); } @BeforeClass diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java index a5e820d4a1dfb..fd35e5ea284ee 100644 --- a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoIntegrationSmokeTest.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableMap; import com.mongodb.client.MongoCollection; import org.bson.Document; +import org.bson.types.ObjectId; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -146,7 +147,8 @@ public void testInsertWithEveryType() assertEquals(row.getField(5), LocalDate.of(1980, 5, 7)); assertEquals(row.getField(6), LocalDateTime.of(1980, 5, 7, 11, 22, 33, 456_000_000)); assertEquals(row.getField(7), LocalTime.of(11, 22, 33, 456_000_000)); - assertEquals(row.getField(8), "{\"name\":\"alice\"}"); + assertEquals(new ObjectId((byte[]) row.getField(8)), new ObjectId("ffffffffffffffffffffffff")); + assertEquals(row.getField(9).toString(), "{\"name\":\"alice\"}"); assertUpdate("DROP TABLE test_insert_types_table"); assertFalse(getQueryRunner().tableExists(getSession(), "test_insert_types_table")); } @@ -218,13 +220,13 @@ public void testMaps() assertUpdate("CREATE TABLE test.tmp_map9 (col VARCHAR)"); mongoQueryRunner.getMongoClient().getDatabase("test").getCollection("tmp_map9").insertOne(new Document( ImmutableMap.of("col", new Document(ImmutableMap.of("key1", "value1", "key2", "value2"))))); - assertQuery("SELECT col FROM test.tmp_map9", "SELECT '{ \"key1\" : \"value1\", \"key2\" : \"value2\" }'"); + assertQuery("SELECT col FROM test.tmp_map9", "SELECT '{\"key1\": \"value1\", \"key2\": \"value2\"}'"); assertUpdate("CREATE TABLE test.tmp_map10 (col VARCHAR)"); mongoQueryRunner.getMongoClient().getDatabase("test").getCollection("tmp_map10").insertOne(new Document( ImmutableMap.of("col", ImmutableList.of(new Document(ImmutableMap.of("key1", "value1", "key2", "value2")), new Document(ImmutableMap.of("key3", "value3", "key4", "value4")))))); - assertQuery("SELECT col FROM test.tmp_map10", "SELECT '[{ \"key1\" : \"value1\", \"key2\" : \"value2\" }, { \"key3\" : \"value3\", \"key4\" : \"value4\" }]'"); + assertQuery("SELECT col FROM test.tmp_map10", "SELECT '[{\"key1\": \"value1\", \"key2\": \"value2\"}, {\"key3\": \"value3\", \"key4\": \"value4\"}]'"); assertUpdate("CREATE TABLE test.tmp_map11 (col VARCHAR)"); mongoQueryRunner.getMongoClient().getDatabase("test").getCollection("tmp_map11").insertOne(new Document( diff --git a/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoTlsConfiguration.java b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoTlsConfiguration.java new file mode 100644 index 0000000000000..5683c11988f64 --- /dev/null +++ b/presto-mongodb/src/test/java/com/facebook/presto/mongodb/TestMongoTlsConfiguration.java @@ -0,0 +1,334 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.mongodb; + +import com.facebook.presto.plugin.base.security.SslContextProvider; +import com.facebook.presto.tests.SslKeystoreManager; +import com.mongodb.MongoClientOptions; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import javax.net.ssl.SSLContext; + +import java.io.File; +import java.util.Optional; + +import static com.facebook.presto.tests.SslKeystoreManager.SSL_STORE_PASSWORD; +import static com.facebook.presto.tests.SslKeystoreManager.getKeystorePath; +import static com.facebook.presto.tests.SslKeystoreManager.getTruststorePath; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +/** + * Integration tests for MongoDB TLS configuration. + * Tests the complete flow from MongoClientConfig through SslContextProvider to MongoClient. + */ +public class TestMongoTlsConfiguration +{ + private File keystoreFile; + private File truststoreFile; + + @BeforeClass + public void setUp() throws Exception + { + SslKeystoreManager.initializeKeystoreAndTruststore(); + + keystoreFile = new File(getKeystorePath()); + truststoreFile = new File(getTruststorePath()); + } + + @Test + public void testTlsDisabledByDefault() + { + MongoClientConfig config = new MongoClientConfig(); + + assertFalse(config.isTlsEnabled(), "TLS should be disabled by default"); + assertFalse(config.getKeystorePath().isPresent(), "Keystore path should be empty by default"); + assertFalse(config.getKeystorePassword().isPresent(), "Keystore password should be empty by default"); + assertFalse(config.getTruststorePath().isPresent(), "Truststore path should be empty by default"); + assertFalse(config.getTruststorePassword().isPresent(), "Truststore password should be empty by default"); + } + + @Test + public void testTlsEnabledWithKeystoreAndTruststore() + { + MongoClientConfig config = new MongoClientConfig(); + configureTlsProperties(config); + + assertTrue(config.isTlsEnabled(), "TLS should be enabled"); + assertTrue(config.getKeystorePath().isPresent(), "Keystore path should be present"); + assertTrue(config.getKeystorePassword().isPresent(), "Keystore password should be present"); + assertTrue(config.getTruststorePath().isPresent(), "Truststore path should be present"); + assertTrue(config.getTruststorePassword().isPresent(), "Truststore password should be present"); + assertEquals(config.getKeystorePath().get(), keystoreFile); + assertEquals(config.getTruststorePath().get(), truststoreFile); + } + + @Test + public void testTlsEnabledWithKeystoreOnly() + { + MongoClientConfig config = new MongoClientConfig() + .setTlsEnabled(true) + .setKeystorePath(keystoreFile) + .setKeystorePassword(SSL_STORE_PASSWORD); + + assertTrue(config.isTlsEnabled(), "TLS should be enabled"); + assertTrue(config.getKeystorePath().isPresent(), "Keystore path should be present"); + assertTrue(config.getKeystorePassword().isPresent(), "Keystore password should be present"); + assertFalse(config.getTruststorePath().isPresent(), "Truststore path should be empty"); + assertFalse(config.getTruststorePassword().isPresent(), "Truststore password should be empty"); + assertTrue(config.isValidTlsConfig(), "TLS config should be valid with keystore only"); + } + + @Test + public void testTlsEnabledWithTruststoreOnly() + { + MongoClientConfig config = new MongoClientConfig() + .setTlsEnabled(true) + .setTruststorePath(truststoreFile) + .setTruststorePassword(SSL_STORE_PASSWORD); + + assertTrue(config.isTlsEnabled(), "TLS should be enabled"); + assertFalse(config.getKeystorePath().isPresent(), "Keystore path should be empty"); + assertFalse(config.getKeystorePassword().isPresent(), "Keystore password should be empty"); + assertTrue(config.getTruststorePath().isPresent(), "Truststore path should be present"); + assertTrue(config.getTruststorePassword().isPresent(), "Truststore password should be present"); + assertTrue(config.isValidTlsConfig(), "TLS config should be valid with truststore only"); + } + + @Test + public void testSslContextProviderIntegration() + { + MongoClientConfig config = new MongoClientConfig(); + configureTlsProperties(config); + + SslContextProvider provider = createSslContextProvider(config); + Optional sslContext = provider.buildSslContext(); + + assertTrue(sslContext.isPresent(), "SSL context should be created"); + assertNotNull(sslContext.get(), "SSL context should not be null"); + assertEquals(sslContext.get().getProtocol(), "TLS", "SSL context should use TLS protocol"); + } + + @Test + public void testMongoClientOptionsWithTlsEnabled() + { + MongoClientConfig config = new MongoClientConfig() + .setSeeds("localhost:27017"); + configureTlsProperties(config); + + MongoClientOptions.Builder optionsBuilder = MongoClientOptions.builder() + .connectionsPerHost(config.getConnectionsPerHost()) + .connectTimeout(config.getConnectionTimeout()) + .socketTimeout(config.getSocketTimeout()) + .socketKeepAlive(config.getSocketKeepAlive()) + .maxWaitTime(config.getMaxWaitTime()) + .minConnectionsPerHost(config.getMinConnectionsPerHost()) + .writeConcern(config.getWriteConcern().getWriteConcern()); + + // Configure SSL + if (config.isTlsEnabled()) { + SslContextProvider sslContextProvider = createSslContextProvider(config); + + sslContextProvider.buildSslContext().ifPresent(sslContext -> { + optionsBuilder.sslContext(sslContext); + optionsBuilder.sslEnabled(true); + }); + } + + MongoClientOptions options = optionsBuilder.build(); + + assertTrue(options.isSslEnabled(), "SSL should be enabled in MongoClientOptions"); + assertNotNull(options.getSslContext(), "SSL context should be set in MongoClientOptions"); + } + + @Test + public void testMongoClientOptionsWithTlsDisabled() + { + MongoClientConfig config = new MongoClientConfig() + .setSeeds("localhost:27017") + .setTlsEnabled(false); + + MongoClientOptions.Builder optionsBuilder = MongoClientOptions.builder() + .connectionsPerHost(config.getConnectionsPerHost()) + .connectTimeout(config.getConnectionTimeout()) + .socketTimeout(config.getSocketTimeout()) + .socketKeepAlive(config.getSocketKeepAlive()) + .maxWaitTime(config.getMaxWaitTime()) + .minConnectionsPerHost(config.getMinConnectionsPerHost()) + .writeConcern(config.getWriteConcern().getWriteConcern()); + + // Configure SSL + if (config.isTlsEnabled()) { + SslContextProvider sslContextProvider = createSslContextProvider(config); + + sslContextProvider.buildSslContext().ifPresent(sslContext -> { + optionsBuilder.sslContext(sslContext); + optionsBuilder.sslEnabled(true); + }); + } + + MongoClientOptions options = optionsBuilder.build(); + + assertFalse(options.isSslEnabled(), "SSL should be disabled in MongoClientOptions"); + } + + @Test + public void testLegacyPropertySupport() + { + // Test that the legacy mongodb.ssl.enabled property still works + MongoClientConfig config = new MongoClientConfig(); + + // The @LegacyConfig annotation should map mongodb.ssl.enabled to mongodb.tls.enabled + // This would be tested through the configuration system, but we can verify the setter works + config.setTlsEnabled(true); + + assertTrue(config.isTlsEnabled(), "TLS should be enabled via legacy property mapping"); + } + + @Test + public void testTlsConfigurationValidationWithPartialKeystore() + { + // Test that having only keystore path without password fails validation + MongoClientConfig config = new MongoClientConfig() + .setTlsEnabled(true) + .setKeystorePath(keystoreFile); + + assertFalse(config.isValidTlsConfig(), + "TLS config should be invalid when keystore path is set without password"); + } + + @Test + public void testTlsConfigurationValidationWithPartialTruststore() + { + // Test that having only truststore path without password fails validation + MongoClientConfig config = new MongoClientConfig() + .setTlsEnabled(true) + .setTruststorePath(truststoreFile); + + assertFalse(config.isValidTlsConfig(), + "TLS config should be invalid when truststore path is set without password"); + } + + @Test + public void testFullMongoClientCreationFlow() + { + // This tests the complete flow similar to what happens in MongoClientModule + MongoClientConfig config = new MongoClientConfig() + .setSeeds("localhost:27017") + .setConnectionsPerHost(50) + .setConnectionTimeout(5000) + .setReadPreference(ReadPreferenceType.PRIMARY); + configureTlsProperties(config); + // Verify configuration + assertTrue(config.isTlsEnabled(), "TLS should be enabled"); + assertTrue(config.isValidTlsConfig(), "TLS configuration should be valid"); + assertEquals(config.getConnectionsPerHost(), 50); + assertEquals(config.getConnectionTimeout(), 5000); + + // Create SSL context + SslContextProvider sslContextProvider = createSslContextProvider(config); + + Optional sslContext = sslContextProvider.buildSslContext(); + assertTrue(sslContext.isPresent(), "SSL context should be created"); + + // Build MongoClientOptions + MongoClientOptions.Builder optionsBuilder = MongoClientOptions.builder() + .connectionsPerHost(config.getConnectionsPerHost()) + .connectTimeout(config.getConnectionTimeout()) + .readPreference(config.getReadPreference().getReadPreference()); + + sslContext.ifPresent(ctx -> { + optionsBuilder.sslContext(ctx); + optionsBuilder.sslEnabled(true); + }); + + MongoClientOptions options = optionsBuilder.build(); + + // Verify final options + assertTrue(options.isSslEnabled(), "SSL should be enabled"); + assertNotNull(options.getSslContext(), "SSL context should be set"); + assertEquals(options.getConnectionsPerHost(), 50); + assertEquals(options.getConnectTimeout(), 5000); + } + + @Test + public void testTlsEnabledWithKeystoreOnlyUsesKeystoreAsTruststore() + { + // When only keystore is provided, it should be used as truststore for backward compatibility + MongoClientConfig config = new MongoClientConfig() + .setTlsEnabled(true) + .setKeystorePath(keystoreFile) + .setKeystorePassword(SSL_STORE_PASSWORD); + + SslContextProvider provider = new SslContextProvider( + config.getKeystorePath(), + config.getKeystorePassword(), + Optional.empty(), + Optional.empty()); + + Optional sslContext = provider.buildSslContext(); + + assertTrue(sslContext.isPresent(), "SSL context should be created when only keystore is provided"); + assertNotNull(sslContext.get(), "SSL context should not be null"); + } + + @Test + public void testTlsValidationWithBothKeystoreAndTruststore() + { + MongoClientConfig config = new MongoClientConfig(); + configureTlsProperties(config); + + assertTrue(config.isValidTlsConfig(), "TLS config should be valid with complete keystore and truststore"); + } + + @Test + public void testTlsValidationFailsWhenDisabledWithProperties() + { + // Test that TLS properties cannot be set when TLS is disabled + MongoClientConfig config1 = new MongoClientConfig() + .setTlsEnabled(false) + .setKeystorePath(keystoreFile); + + assertFalse(config1.isValidTlsConfig(), + "TLS config should be invalid when TLS is disabled but keystore path is set"); + + MongoClientConfig config2 = new MongoClientConfig() + .setTlsEnabled(false) + .setTruststorePath(truststoreFile); + + assertFalse(config2.isValidTlsConfig(), + "TLS config should be invalid when TLS is disabled but truststore path is set"); + } + + private void configureTlsProperties(MongoClientConfig config) + { + config.setTlsEnabled(true) + .setKeystorePath(keystoreFile) + .setKeystorePassword(SSL_STORE_PASSWORD) + .setTruststorePath(truststoreFile) + .setTruststorePassword(SSL_STORE_PASSWORD); + } + + private SslContextProvider createSslContextProvider(MongoClientConfig config) + { + return new SslContextProvider( + config.getKeystorePath(), + config.getKeystorePassword(), + config.getTruststorePath(), + config.getTruststorePassword()); + } +} diff --git a/presto-mysql/pom.xml b/presto-mysql/pom.xml index 6c3ba6c74835a..ba570731426ae 100644 --- a/presto-mysql/pom.xml +++ b/presto-mysql/pom.xml @@ -5,15 +5,17 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-mysql + presto-mysql Presto - MySQL Connector presto-plugin ${project.parent.basedir} + true @@ -38,8 +40,8 @@ - javax.validation - validation-api + jakarta.validation + jakarta.validation-api @@ -48,8 +50,13 @@ - javax.inject - javax.inject + jakarta.inject + jakarta.inject-api + + + + com.esri.geometry + esri-geometry-api @@ -65,11 +72,6 @@ provided - - com.esri.geometry - esri-geometry-api - - org.openjdk.jol jol-core @@ -77,12 +79,7 @@ - com.facebook.presto - presto-geospatial-toolkit - - - - com.facebook.drift + com.facebook.airlift.drift drift-api provided @@ -94,7 +91,7 @@ - io.airlift + com.facebook.airlift units provided @@ -162,21 +159,20 @@ - com.facebook.presto - testing-mysql-server-8 + org.testcontainers + testcontainers test - com.facebook.presto - testing-mysql-server-base + org.testcontainers + mysql test org.jetbrains annotations - 19.0.0 test @@ -193,6 +189,15 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + com.esri.geometry:esri-geometry-api + + + diff --git a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java index a2e2e487b1d38..419ddbeedad24 100644 --- a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java +++ b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlClient.java @@ -13,7 +13,6 @@ */ package com.facebook.presto.plugin.mysql; -import com.esri.core.geometry.ogc.OGCGeometry; import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; @@ -28,7 +27,7 @@ import com.facebook.presto.plugin.jdbc.JdbcTableHandle; import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; import com.facebook.presto.plugin.jdbc.QueryBuilder; -import com.facebook.presto.plugin.jdbc.ReadMapping; +import com.facebook.presto.plugin.jdbc.mapping.ReadMapping; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.PrestoException; @@ -36,9 +35,7 @@ import com.google.common.collect.ImmutableSet; import com.mysql.cj.jdbc.JdbcStatement; import com.mysql.jdbc.Driver; -import io.airlift.slice.Slice; - -import javax.inject.Inject; +import jakarta.inject.Inject; import java.sql.Connection; import java.sql.DatabaseMetaData; @@ -52,31 +49,22 @@ import java.util.Optional; import java.util.Properties; -import static com.esri.core.geometry.ogc.OGCGeometry.fromBinary; import static com.facebook.presto.common.type.RealType.REAL; import static com.facebook.presto.common.type.StandardTypes.GEOMETRY; import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; -import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.Varchars.isVarcharType; -import static com.facebook.presto.geospatial.GeometryUtils.wktFromJtsGeometry; -import static com.facebook.presto.geospatial.serde.EsriGeometrySerde.serialize; -import static com.facebook.presto.geospatial.serde.JtsGeometrySerde.deserialize; import static com.facebook.presto.plugin.jdbc.DriverConnectionFactory.basicConnectionProperties; import static com.facebook.presto.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static com.facebook.presto.plugin.jdbc.QueryBuilder.quote; -import static com.facebook.presto.plugin.jdbc.ReadMapping.sliceReadMapping; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.geometryReadMapping; import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.slice.Slices.utf8Slice; -import static io.airlift.slice.Slices.wrappedBuffer; import static java.lang.String.format; import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; public class MySqlClient @@ -85,9 +73,11 @@ public class MySqlClient /** * Error code corresponding to code thrown when a table already exists. * The code is derived from the MySQL documentation. + * * @see MySQL documentation */ private static final String SQL_STATE_ER_TABLE_EXISTS_ERROR = "42S01"; + @Inject public MySqlClient(JdbcConnectorId connectorId, BaseJdbcConfig config, MySqlConfig mySqlConfig) throws SQLException @@ -219,7 +209,8 @@ protected String toSqlType(Type type) } @Override - public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, List columnHandles) throws SQLException + public PreparedStatement buildSql(ConnectorSession session, Connection connection, JdbcSplit split, List columnHandles) + throws SQLException { Map columnExpressions = columnHandles.stream() .filter(handle -> handle.getJdbcTypeHandle().getJdbcTypeName().equalsIgnoreCase(GEOMETRY)) @@ -253,31 +244,6 @@ public Optional toPrestoType(ConnectorSession session, JdbcTypeHand return super.toPrestoType(session, typeHandle); } - protected static ReadMapping geometryReadMapping() - { - return sliceReadMapping(VARCHAR, - (resultSet, columnIndex) -> getAsText(stGeomFromBinary(wrappedBuffer(resultSet.getBytes(columnIndex))))); - } - - protected static Slice getAsText(Slice input) - { - return utf8Slice(wktFromJtsGeometry(deserialize(input))); - } - - private static Slice stGeomFromBinary(Slice input) - { - requireNonNull(input, "input is null"); - OGCGeometry geometry; - try { - geometry = fromBinary(input.toByteBuffer().slice()); - } - catch (IllegalArgumentException | IndexOutOfBoundsException e) { - throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Invalid Well-Known Binary (WKB)", e); - } - geometry.setSpatialReference(null); - return serialize(geometry); - } - @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { @@ -323,4 +289,10 @@ protected void renameTable(JdbcIdentity identity, String catalogName, SchemaTabl // catalogName parameter to null it will be omitted in the alter table statement. super.renameTable(identity, null, oldTable, newTable); } + + @Override + public String normalizeIdentifier(ConnectorSession session, String identifier) + { + return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ENGLISH); + } } diff --git a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlConfig.java b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlConfig.java index 6425d3f57abc7..28c1278504e49 100644 --- a/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlConfig.java +++ b/presto-mysql/src/main/java/com/facebook/presto/plugin/mysql/MySqlConfig.java @@ -14,9 +14,8 @@ package com.facebook.presto.plugin.mysql; import com.facebook.airlift.configuration.Config; -import io.airlift.units.Duration; - -import javax.validation.constraints.Min; +import com.facebook.airlift.units.Duration; +import jakarta.validation.constraints.Min; import java.util.concurrent.TimeUnit; diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/MySqlQueryRunner.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/MySqlQueryRunner.java index ecc0ad9d19095..2a0fb593c80bd 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/MySqlQueryRunner.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/MySqlQueryRunner.java @@ -15,10 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.testing.QueryRunner; -import com.facebook.presto.testing.mysql.TestingMySqlServer; import com.facebook.presto.tests.DistributedQueryRunner; import com.facebook.presto.tpch.TpchPlugin; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; @@ -38,25 +36,13 @@ private MySqlQueryRunner() private static final String TPCH_SCHEMA = "tpch"; - public static QueryRunner createMySqlQueryRunner(TestingMySqlServer server, TpchTable... tables) + public static QueryRunner createMySqlQueryRunner(String jdbcUrl, Map connectorProperties, Iterable> tables) throws Exception { - return createMySqlQueryRunner(server, ImmutableMap.of(), ImmutableList.copyOf(tables)); + return createMySqlQueryRunner(jdbcUrl, connectorProperties, tables, "testuser", "testpass"); } - public static QueryRunner createMySqlQueryRunner(TestingMySqlServer server, Map connectorProperties, Iterable> tables) - throws Exception - { - try { - return createMySqlQueryRunner(server.getJdbcUrl(), connectorProperties, tables); - } - catch (Throwable e) { - closeAllSuppress(e, server); - throw e; - } - } - - public static QueryRunner createMySqlQueryRunner(String jdbcUrl, Map connectorProperties, Iterable> tables) + public static QueryRunner createMySqlQueryRunner(String jdbcUrl, Map connectorProperties, Iterable> tables, String username, String password) throws Exception { DistributedQueryRunner queryRunner = null; @@ -66,8 +52,12 @@ public static QueryRunner createMySqlQueryRunner(String jdbcUrl, Map mysqlContainer; private final QueryRunner mySqlQueryRunner; public TestCredentialPassthrough() throws Exception { - mysqlServer = new TestingMySqlServer("testuser", "testpass", TEST_SCHEMA); - mySqlQueryRunner = createQueryRunner(mysqlServer); + mysqlContainer = new MySQLContainer<>("mysql:8.0") + .withDatabaseName(TEST_SCHEMA) + .withUsername(TEST_USER) + .withPassword(TEST_PASSWORD); + mysqlContainer.start(); + + try (Connection connection = DriverManager.getConnection(mysqlContainer.getJdbcUrl(), TEST_USER, TEST_PASSWORD); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE IF NOT EXISTS " + TEST_SCHEMA); + } + catch (SQLException e) { + throw new RuntimeException("Failed to create " + TEST_SCHEMA, e); + } + + mySqlQueryRunner = createQueryRunner(mysqlContainer); + } + + @AfterClass(alwaysRun = true) + public void destroy() + { + if (mysqlContainer != null) { + mysqlContainer.stop(); + } } @Test public void testCredentialPassthrough() - throws Exception { - mySqlQueryRunner.execute(getSession(mysqlServer), "CREATE TABLE test_create (a bigint, b double, c varchar)"); + mySqlQueryRunner.execute(getSession(mysqlContainer), "CREATE TABLE test_create (a bigint, b double, c varchar)"); } - public static QueryRunner createQueryRunner(TestingMySqlServer mySqlServer) + public static QueryRunner createQueryRunner(MySQLContainer mysqlContainer) throws Exception { DistributedQueryRunner queryRunner = null; @@ -56,7 +85,7 @@ public static QueryRunner createQueryRunner(TestingMySqlServer mySqlServer) queryRunner = DistributedQueryRunner.builder(testSessionBuilder().build()).build(); queryRunner.installPlugin(new MySqlPlugin()); Map properties = ImmutableMap.builder() - .put("connection-url", getConnectionUrl(mySqlServer)) + .put("connection-url", getConnectionUrl(mysqlContainer)) .put("user-credential-name", "mysql.user") .put("password-credential-name", "mysql.password") .build(); @@ -65,19 +94,19 @@ public static QueryRunner createQueryRunner(TestingMySqlServer mySqlServer) return queryRunner; } catch (Exception e) { - closeAllSuppress(e, queryRunner, mySqlServer); + closeAllSuppress(e, queryRunner); throw e; } } - private static Session getSession(TestingMySqlServer mySqlServer) + private static Session getSession(MySQLContainer mysqlContainer) { - Map extraCredentials = ImmutableMap.of("mysql.user", mySqlServer.getUser(), "mysql.password", mySqlServer.getPassword()); + Map extraCredentials = ImmutableMap.of("mysql.user", mysqlContainer.getUsername(), "mysql.password", mysqlContainer.getPassword()); return testSessionBuilder() .setCatalog("mysql") .setSchema(TEST_SCHEMA) .setIdentity(new Identity( - mySqlServer.getUser(), + mysqlContainer.getUsername(), Optional.empty(), ImmutableMap.of(), extraCredentials, @@ -87,8 +116,9 @@ private static Session getSession(TestingMySqlServer mySqlServer) .build(); } - private static String getConnectionUrl(TestingMySqlServer mySqlServer) + private static String getConnectionUrl(MySQLContainer mysqlContainer) { - return format("jdbc:mysql://localhost:%s?useSSL=false&allowPublicKeyRetrieval=true", mySqlServer.getPort()); + String jdbcUrlWithoutDatabase = removeDatabaseFromJdbcUrl(mysqlContainer.getJdbcUrl()); + return format("%s?useSSL=false&allowPublicKeyRetrieval=true", jdbcUrlWithoutDatabase.split("\\?")[0]); } } diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlClient.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlClient.java index 6ccb43cc207ba..10096bde0fba5 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlClient.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlClient.java @@ -16,7 +16,7 @@ import com.esri.core.geometry.Point; import com.esri.core.geometry.ogc.OGCGeometry; import com.esri.core.geometry.ogc.OGCPoint; -import com.facebook.presto.plugin.jdbc.SliceReadFunction; +import com.facebook.presto.plugin.jdbc.mapping.functions.SliceReadFunction; import com.facebook.presto.spi.PrestoException; import io.airlift.slice.Slice; import io.airlift.slice.Slices; @@ -27,8 +27,8 @@ import java.sql.SQLException; import static com.facebook.presto.geospatial.GeoFunctions.stGeomFromBinary; -import static com.facebook.presto.plugin.mysql.MySqlClient.geometryReadMapping; -import static com.facebook.presto.plugin.mysql.MySqlClient.getAsText; +import static com.facebook.presto.plugin.jdbc.GeometryUtils.getAsText; +import static com.facebook.presto.plugin.jdbc.mapping.StandardColumnMappings.geometryReadMapping; import static org.testng.Assert.assertEquals; import static org.testng.Assert.fail; diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlConfig.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlConfig.java index 543145fe12efb..1cda3184aa235 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlConfig.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlConfig.java @@ -13,8 +13,8 @@ */ package com.facebook.presto.plugin.mysql; +import com.facebook.airlift.units.Duration; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.testng.annotations.Test; import java.util.Map; diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlDistributedQueries.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlDistributedQueries.java index 5f01ee7b22997..9cec5920c9be1 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlDistributedQueries.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlDistributedQueries.java @@ -15,45 +15,40 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; -import com.facebook.presto.testing.mysql.MySqlOptions; -import com.facebook.presto.testing.mysql.TestingMySqlServer; import com.facebook.presto.tests.AbstractTestDistributedQueries; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.tpch.TpchTable; -import io.airlift.units.Duration; +import org.testcontainers.containers.MySQLContainer; import org.testng.annotations.AfterClass; +import org.testng.annotations.Optional; import org.testng.annotations.Test; -import java.io.IOException; - +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; import static com.facebook.presto.testing.MaterializedResult.resultBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; -import static java.util.concurrent.TimeUnit.SECONDS; @Test(singleThreaded = true) public class TestMySqlDistributedQueries extends AbstractTestDistributedQueries { - private static final MySqlOptions MY_SQL_OPTIONS = MySqlOptions.builder() - .setCommandTimeout(new Duration(90, SECONDS)) - .build(); - - private final TestingMySqlServer mysqlServer; + private final MySQLContainer mysqlContainer; public TestMySqlDistributedQueries() - throws Exception { - this.mysqlServer = new TestingMySqlServer("testuser", "testpass", ImmutableList.of("tpch"), MY_SQL_OPTIONS); + this.mysqlContainer = new MySQLContainer<>("mysql:8.0") + .withDatabaseName("tpch") + .withUsername("testuser") + .withPassword("testpass"); + this.mysqlContainer.start(); } @Override protected QueryRunner createQueryRunner() throws Exception { - return createMySqlQueryRunner(mysqlServer, ImmutableMap.of(), TpchTable.getTables()); + return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of(), TpchTable.getTables()); } @Override @@ -64,26 +59,25 @@ protected boolean supportsViews() @AfterClass(alwaysRun = true) public final void destroy() - throws IOException { - mysqlServer.close(); + mysqlContainer.stop(); } @Override - public void testShowColumns() + public void testShowColumns(@Optional("PARQUET") String storageFormat) { - MaterializedResult actual = computeActual("SHOW COLUMNS FROM orders"); - - MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("orderkey", "bigint", "", "") - .row("custkey", "bigint", "", "") - .row("orderstatus", "varchar(255)", "", "") - .row("totalprice", "double", "", "") - .row("orderdate", "date", "", "") - .row("orderpriority", "varchar(255)", "", "") - .row("clerk", "varchar(255)", "", "") - .row("shippriority", "integer", "", "") - .row("comment", "varchar(255)", "", "") + MaterializedResult actual = computeActual("SHOW COLUMNS FROM orders").toTestTypes(); + + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar(255)", "", "", null, null, 255L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar(255)", "", "", null, null, 255L) + .row("clerk", "varchar(255)", "", "", null, null, 255L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar(255)", "", "", null, null, 255L) .build(); assertEquals(actual, expectedParametrizedVarchar); @@ -120,5 +114,53 @@ public void testUpdate() // Updates are not supported by the connector } + @Override + public void testNonAutoCommitTransactionWithRollback() + { + // JDBC connectors do not support multi-statement writes within transactions + } + + @Override + public void testNonAutoCommitTransactionWithCommit() + { + // JDBC connectors do not support multi-statement writes within transactions + } + + @Override + public void testNonAutoCommitTransactionWithFailAndRollback() + { + // JDBC connectors do not support multi-statement writes within transactions + } + + @Override + public void testPayloadJoinApplicability() + { + // MySQL does not support MAP type + } + + @Override + public void testPayloadJoinCorrectness() + { + // MySQL does not support MAP type + } + + @Override + public void testRemoveRedundantCastToVarcharInJoinClause() + { + // MySQL does not support MAP type + } + + @Override + public void testSubfieldAccessControl() + { + // MySQL does not support ROW type + } + + @Override + public void testStringFilters() + { + // MySQL maps char types to varchar(255), causing type mismatches + } + // MySQL specific tests should normally go in TestMySqlIntegrationSmokeTest } diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationMixedCaseTest.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationMixedCaseTest.java new file mode 100644 index 0000000000000..26f8d6f277e99 --- /dev/null +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationMixedCaseTest.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.plugin.mysql; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testcontainers.containers.MySQLContainer; +import org.testng.annotations.AfterClass; +import org.testng.annotations.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +@Test +public class TestMySqlIntegrationMixedCaseTest + extends AbstractTestQueryFramework +{ + private final MySQLContainer mysqlContainer; + + public TestMySqlIntegrationMixedCaseTest() + throws Exception + { + this.mysqlContainer = new MySQLContainer<>("mysql:8.0") + .withDatabaseName("tpch") + .withUsername("testuser") + .withPassword("testpass"); + this.mysqlContainer.start(); + + mysqlContainer.execInContainer("mysql", + "-u", "root", + "-p" + mysqlContainer.getPassword(), + "-e", "CREATE DATABASE IF NOT EXISTS Mixed_Test_Database; GRANT ALL PRIVILEGES ON Mixed_Test_Database.* TO 'testuser'@'%';"); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of("case-sensitive-name-matching", "true"), TpchTable.getTables()); + } + + @AfterClass(alwaysRun = true) + public final void destroy() + { + mysqlContainer.stop(); + } + + public void testDescribeTable() + { + // CI tests run on Linux, where MySQL is case-sensitive by default (lower_case_table_names=0), + // treating "orders" and "ORDERS" as different tables. + // Since the test runs with case-sensitive-name-matching=true, ensure "ORDERS" exists if not already present. + try { + execute("CREATE TABLE IF NOT EXISTS tpch.ORDERS AS SELECT * FROM tpch.orders"); + } + catch (SQLException e) { + throw new RuntimeException(e); + } + + // we need specific implementation of this tests due to specific Presto<->Mysql varchar length mapping. + MaterializedResult actualColumns = computeActual("DESC ORDERS").toTestTypes(); + + MaterializedResult expectedColumns = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar(255)", "", "", null, null, 255L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar(255)", "", "", null, null, 255L) + .row("clerk", "varchar(255)", "", "", null, null, 255L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar(255)", "", "", null, null, 255L) + .build(); + assertEquals(actualColumns, expectedColumns); + } + + @Test + public void testCreateTable() + { + Session session = testSessionBuilder() + .setCatalog("mysql") + .setSchema("Mixed_Test_Database") + .build(); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATE(name VARCHAR(50), id int)"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATE")); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS test_create(name VARCHAR(50), id int)"); + assertTrue(getQueryRunner().tableExists(session, "test_create")); + + assertUpdate(session, "DROP TABLE IF EXISTS TEST_CREATE"); + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATE")); + + assertUpdate(session, "DROP TABLE IF EXISTS test_create"); + assertFalse(getQueryRunner().tableExists(session, "test_create")); + } + + @Test + public void testCreateTableAs() + { + Session session = testSessionBuilder() + .setCatalog("mysql") + .setSchema("Mixed_Test_Database") + .build(); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS")); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS test_createas AS SELECT * FROM tpch.region"); + assertTrue(getQueryRunner().tableExists(session, "test_createas")); + + getQueryRunner().execute(session, "CREATE TABLE TEST_CREATEAS_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.orders o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "TEST_CREATEAS_Join")); + + assertQueryFails("CREATE TABLE Mixed_Test_Database.TEST_CREATEAS_FAIL_Join AS SELECT c.custkey, o.orderkey FROM " + + "tpch.customer c INNER JOIN tpch.ORDERS1 o ON c.custkey = o.custkey WHERE c.mktsegment = 'BUILDING'", "Table mysql.tpch.ORDERS1 does not exist"); //failure scenario since tpch.ORDERS1 doesn't exist + assertFalse(getQueryRunner().tableExists(session, "TEST_CREATEAS_FAIL_Join")); + + getQueryRunner().execute(session, "CREATE TABLE Test_CreateAs_Mixed_Join AS SELECT Cus.custkey, Ord.orderkey FROM " + + "tpch.customer Cus INNER JOIN tpch.orders Ord ON Cus.custkey = Ord.custkey WHERE Cus.mktsegment = 'BUILDING'"); + assertTrue(getQueryRunner().tableExists(session, "Test_CreateAs_Mixed_Join")); + } + + @Test + public void testInsert() + { + Session session = testSessionBuilder() + .setCatalog("mysql") + .setSchema("Mixed_Test_Database") + .build(); + + getQueryRunner().execute(session, "CREATE TABLE Test_Insert (x bigint, y varchar(100))"); + getQueryRunner().execute(session, "INSERT INTO Test_Insert VALUES (123, 'test')"); + assertTrue(getQueryRunner().tableExists(session, "Test_Insert")); + assertQuery("SELECT * FROM Mixed_Test_Database.Test_Insert", "SELECT 123 x, 'test' y"); + + getQueryRunner().execute(session, "CREATE TABLE IF NOT EXISTS TEST_INSERT (x bigint, y varchar(100))"); + getQueryRunner().execute(session, "INSERT INTO TEST_INSERT VALUES (1234, 'test1')"); + assertTrue(getQueryRunner().tableExists(session, "TEST_INSERT")); + + getQueryRunner().execute(session, "DROP TABLE IF EXISTS Test_Insert"); + getQueryRunner().execute(session, "DROP TABLE IF EXISTS TEST_INSERT"); + } + + @Test + public void testSelectInformationSchemaColumnIsNullable() + { + assertUpdate("CREATE TABLE test_column (name VARCHAR NOT NULL, email VARCHAR)"); + assertQueryFails("SELECT is_nullable FROM Information_Schema.columns WHERE table_name = 'test_column'", "Schema Information_Schema does not exist"); + assertQuery("SELECT is_nullable FROM information_schema.columns WHERE table_name = 'test_column'", "VALUES 'NO','YES'"); + } + + @Test + public void testDuplicatedRowCreateTable() + { + assertQueryFails("CREATE TABLE test (a integer, a integer)", + "line 1:31: Column name 'a' specified more than once"); + assertQueryFails("CREATE TABLE TEST (a integer, a integer)", + "line 1:31: Column name 'a' specified more than once"); + assertQueryFails("CREATE TABLE test (a integer, orderkey integer, LIKE orders INCLUDING PROPERTIES)", + "line 1:49: Column name 'orderkey' specified more than once"); + + assertQueryFails("CREATE TABLE test (a integer, A integer)", + "Duplicate column name 'A'"); + assertQueryFails("CREATE TABLE TEST (a integer, A integer)", + "Duplicate column name 'A'"); + assertQueryFails("CREATE TABLE test (a integer, OrderKey integer, LIKE orders INCLUDING PROPERTIES)", + "Duplicate column name 'orderkey'"); + } + + private void execute(String sql) + throws SQLException + { + try (Connection connection = DriverManager.getConnection( + mysqlContainer.getJdbcUrl(), + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + Statement statement = connection.createStatement()) { + statement.execute(sql); + } + } +} diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationSmokeTest.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationSmokeTest.java index 789f934b9607f..ec32ea5594b30 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationSmokeTest.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlIntegrationSmokeTest.java @@ -17,32 +17,30 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.MaterializedRow; import com.facebook.presto.testing.QueryRunner; -import com.facebook.presto.testing.mysql.MySqlOptions; -import com.facebook.presto.testing.mysql.TestingMySqlServer; import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; import com.facebook.presto.tests.DistributedQueryRunner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; import org.intellij.lang.annotations.Language; +import org.testcontainers.containers.MySQLContainer; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; -import java.io.IOException; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; import java.util.Map; +import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; +import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.removeDatabaseFromJdbcUrl; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static com.facebook.presto.testing.assertions.Assert.assertEquals; import static com.google.common.collect.Iterables.getOnlyElement; import static io.airlift.tpch.TpchTable.ORDERS; import static java.lang.String.format; -import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -50,30 +48,34 @@ public class TestMySqlIntegrationSmokeTest extends AbstractTestIntegrationSmokeTest { - private static final MySqlOptions MY_SQL_OPTIONS = MySqlOptions.builder() - .setCommandTimeout(new Duration(90, SECONDS)) - .build(); - - private final TestingMySqlServer mysqlServer; + private final MySQLContainer mysqlContainer; public TestMySqlIntegrationSmokeTest() throws Exception { - this.mysqlServer = new TestingMySqlServer("testuser", "testpass", ImmutableList.of("tpch", "test_database"), MY_SQL_OPTIONS); + this.mysqlContainer = new MySQLContainer<>("mysql:8.0") + .withDatabaseName("tpch") + .withUsername("testuser") + .withPassword("testpass"); + this.mysqlContainer.start(); + + mysqlContainer.execInContainer("mysql", + "-u", "root", + "-p" + mysqlContainer.getPassword(), + "-e", "CREATE DATABASE IF NOT EXISTS test_database; GRANT ALL PRIVILEGES ON test_database.* TO 'testuser'@'%';"); } @Override protected QueryRunner createQueryRunner() throws Exception { - return createMySqlQueryRunner(mysqlServer, ORDERS); + return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of(), ImmutableList.of(ORDERS)); } @AfterClass(alwaysRun = true) public final void destroy() - throws IOException { - mysqlServer.close(); + mysqlContainer.stop(); } @Override @@ -82,16 +84,16 @@ public void testDescribeTable() // we need specific implementation of this tests due to specific Presto<->Mysql varchar length mapping. MaterializedResult actualColumns = computeActual("DESC ORDERS").toTestTypes(); - MaterializedResult expectedColumns = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("orderkey", "bigint", "", "") - .row("custkey", "bigint", "", "") - .row("orderstatus", "varchar(255)", "", "") - .row("totalprice", "double", "", "") - .row("orderdate", "date", "", "") - .row("orderpriority", "varchar(255)", "", "") - .row("clerk", "varchar(255)", "", "") - .row("shippriority", "integer", "", "") - .row("comment", "varchar(255)", "", "") + MaterializedResult expectedColumns = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("orderkey", "bigint", "", "", 19L, null, null) + .row("custkey", "bigint", "", "", 19L, null, null) + .row("orderstatus", "varchar(255)", "", "", null, null, 255L) + .row("totalprice", "double", "", "", 53L, null, null) + .row("orderdate", "date", "", "", null, null, null) + .row("orderpriority", "varchar(255)", "", "", null, null, 255L) + .row("clerk", "varchar(255)", "", "", null, null, 255L) + .row("shippriority", "integer", "", "", 10L, null, null) + .row("comment", "varchar(255)", "", "", null, null, 255L) .build(); assertEquals(actualColumns, expectedColumns); } @@ -158,8 +160,8 @@ public void testMySqlTinyint1() execute("CREATE TABLE tpch.mysql_test_tinyint1 (c_tinyint tinyint(1))"); MaterializedResult actual = computeActual("SHOW COLUMNS FROM mysql_test_tinyint1"); - MaterializedResult expected = MaterializedResult.resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) - .row("c_tinyint", "tinyint", "", "") + MaterializedResult expected = MaterializedResult.resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR, BIGINT, BIGINT, BIGINT) + .row("c_tinyint", "tinyint", "", "", 3L, null, null) .build(); assertEquals(actual, expected); @@ -218,6 +220,46 @@ public void testMysqlGeometry() assertUpdate("DROP TABLE tpch.test_geometry"); } + @Test + public void testMysqlDecimal() + { + assertUpdate("CREATE TABLE test_decimal (d DECIMAL(10, 2))"); + + assertUpdate("INSERT INTO test_decimal VALUES (123.45)", 1); + assertUpdate("INSERT INTO test_decimal VALUES (67890.12)", 1); + assertUpdate("INSERT INTO test_decimal VALUES (0.99)", 1); + + assertQuery( + "SELECT d FROM test_decimal WHERE d<200.00 AND d>0.00", + "VALUES " + + "CAST('123.45' AS DECIMAL), " + + "CAST('0.99' AS DECIMAL)"); + + assertUpdate("DROP TABLE test_decimal"); + } + + @Test + public void testMysqlTime() + { + assertUpdate("CREATE TABLE test_time (datatype_time time)"); + + assertUpdate("INSERT INTO test_time VALUES (time '01:02:03.456')", 1); + + assertQuery( + "SELECT datatype_time FROM test_time", + "VALUES " + + "CAST('01:02:03.456' AS time)"); + + assertUpdate("DROP TABLE test_time"); + } + + @Test + public void testMysqlUnsupportedTimeTypes() + { + assertQueryFails("CREATE TABLE test_timestamp_with_timezone (timestamp_with_time_zone timestamp with time zone)", "Unsupported column type: timestamp with time zone"); + assertQueryFails("CREATE TABLE test_time_with_timezone (time_with_with_time_zone time with time zone)", "Unsupported column type: time with time zone"); + } + @Test public void testCharTrailingSpace() throws Exception @@ -232,7 +274,13 @@ public void testCharTrailingSpace() assertEquals(getQueryRunner().execute("SELECT * FROM char_trailing_space WHERE x = char ' test'").getRowCount(), 0); Map properties = ImmutableMap.of("deprecated.legacy-char-to-varchar-coercion", "true"); - Map connectorProperties = ImmutableMap.of("connection-url", mysqlServer.getJdbcUrl()); + String jdbcUrlWithoutDatabase = removeDatabaseFromJdbcUrl(mysqlContainer.getJdbcUrl()); + String jdbcUrlWithCredentials = format("%s%suser=%s&password=%s", + jdbcUrlWithoutDatabase, + jdbcUrlWithoutDatabase.contains("?") ? "&" : "?", + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + Map connectorProperties = ImmutableMap.of("connection-url", jdbcUrlWithCredentials); try (QueryRunner queryRunner = new DistributedQueryRunner(getSession(), 3, properties)) { queryRunner.installPlugin(new MySqlPlugin()); @@ -321,7 +369,10 @@ public void testColumnComment() private void execute(String sql) throws SQLException { - try (Connection connection = DriverManager.getConnection(mysqlServer.getJdbcUrl()); + try (Connection connection = DriverManager.getConnection( + mysqlContainer.getJdbcUrl(), + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); Statement statement = connection.createStatement()) { statement.execute(sql); } diff --git a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlTypeMapping.java b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlTypeMapping.java index 92f3550f75784..9adc4bfe1169f 100644 --- a/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlTypeMapping.java +++ b/presto-mysql/src/test/java/com/facebook/presto/plugin/mysql/TestMySqlTypeMapping.java @@ -17,8 +17,6 @@ import com.facebook.presto.common.type.TimeZoneKey; import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.testing.QueryRunner; -import com.facebook.presto.testing.mysql.MySqlOptions; -import com.facebook.presto.testing.mysql.TestingMySqlServer; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.facebook.presto.tests.datatype.CreateAndInsertDataSetup; import com.facebook.presto.tests.datatype.CreateAsSelectDataSetup; @@ -28,12 +26,15 @@ import com.facebook.presto.tests.sql.PrestoSqlExecutor; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.units.Duration; +import org.testcontainers.containers.MySQLContainer; import org.testng.annotations.AfterClass; import org.testng.annotations.Test; -import java.io.IOException; import java.math.BigDecimal; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; import java.time.LocalDate; import java.time.ZoneId; @@ -56,37 +57,37 @@ import static com.google.common.base.Strings.repeat; import static com.google.common.base.Verify.verify; import static java.lang.String.format; -import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; @Test public class TestMySqlTypeMapping extends AbstractTestQueryFramework { private static final String CHARACTER_SET_UTF8 = "CHARACTER SET utf8"; - private static final MySqlOptions MY_SQL_OPTIONS = MySqlOptions.builder() - .setCommandTimeout(new Duration(90, SECONDS)) - .build(); - private final TestingMySqlServer mysqlServer; + private final MySQLContainer mysqlContainer; public TestMySqlTypeMapping() - throws Exception { - this.mysqlServer = new TestingMySqlServer("testuser", "testpass", ImmutableList.of("tpch"), MY_SQL_OPTIONS); + this.mysqlContainer = new MySQLContainer<>("mysql:8.0") + .withDatabaseName("tpch") + .withUsername("testuser") + .withPassword("testpass"); + this.mysqlContainer.start(); } @Override protected QueryRunner createQueryRunner() throws Exception { - return createMySqlQueryRunner(mysqlServer, ImmutableMap.of(), ImmutableList.of()); + return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of(), ImmutableList.of()); } @AfterClass(alwaysRun = true) public final void destroy() - throws IOException { - mysqlServer.close(); + mysqlContainer.stop(); } @Test @@ -230,8 +231,8 @@ public void testDate() ZoneId jvmZone = ZoneId.systemDefault(); checkState(jvmZone.getId().equals("America/Bahia_Banderas"), "This test assumes certain JVM time zone"); - LocalDate dateOfLocalTimeChangeForwardAtMidnightInJvmZone = LocalDate.of(1970, 1, 1); - verify(jvmZone.getRules().getValidOffsets(dateOfLocalTimeChangeForwardAtMidnightInJvmZone.atStartOfDay()).isEmpty()); + LocalDate dateOfLocalTimeChangeForwardAtHour2InJvmZone = LocalDate.of(2012, 4, 1); + verify(jvmZone.getRules().getValidOffsets(dateOfLocalTimeChangeForwardAtHour2InJvmZone.atTime(2, 1)).isEmpty()); ZoneId someZone = ZoneId.of("Europe/Vilnius"); LocalDate dateOfLocalTimeChangeForwardAtMidnightInSomeZone = LocalDate.of(1983, 4, 1); @@ -240,12 +241,12 @@ public void testDate() verify(someZone.getRules().getValidOffsets(dateOfLocalTimeChangeBackwardAtMidnightInSomeZone.atStartOfDay().minusMinutes(1)).size() == 2); DataTypeTest testCases = DataTypeTest.create() - .addRoundTrip(dateDataType(), LocalDate.of(1952, 4, 3)) // before epoch + .addRoundTrip(dateDataType(), LocalDate.of(1937, 4, 3)) // before epoch .addRoundTrip(dateDataType(), LocalDate.of(1970, 1, 1)) .addRoundTrip(dateDataType(), LocalDate.of(1970, 2, 3)) .addRoundTrip(dateDataType(), LocalDate.of(2017, 7, 1)) // summer on northern hemisphere (possible DST) .addRoundTrip(dateDataType(), LocalDate.of(2017, 1, 1)) // winter on northern hemisphere (possible DST on southern hemisphere) - .addRoundTrip(dateDataType(), dateOfLocalTimeChangeForwardAtMidnightInJvmZone) + .addRoundTrip(dateDataType(), dateOfLocalTimeChangeForwardAtHour2InJvmZone) .addRoundTrip(dateDataType(), dateOfLocalTimeChangeForwardAtMidnightInSomeZone) .addRoundTrip(dateDataType(), dateOfLocalTimeChangeBackwardAtMidnightInSomeZone); @@ -259,20 +260,165 @@ public void testDate() } @Test - public void testDatetime() + public void testDatetimeUnderlyingStorageVerification() + throws Exception { - // TODO MySQL datetime is not correctly read (see comment in StandardReadMappings.timestampReadMapping), but testing this is hard because of #7122 + String jdbcUrl = mysqlContainer.getJdbcUrl(); + String jdbcUrlWithCredentials = format("%s%suser=%s&password=%s", + jdbcUrl, + jdbcUrl.contains("?") ? "&" : "?", + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + JdbcSqlExecutor jdbcExecutor = new JdbcSqlExecutor(jdbcUrlWithCredentials); + + try { + jdbcExecutor.execute("CREATE TABLE tpch.test_datetime_storage (" + + "id INT PRIMARY KEY, " + + "dt DATETIME(6), " + + "source VARCHAR(10))"); + + // MySQL insertion, MySQL retrieval, and Presto retrieval all agree on wall clock time + jdbcExecutor.execute("INSERT INTO tpch.test_datetime_storage VALUES (1, '1970-01-01 00:00:00.000000', 'jdbc')"); + + try (Connection conn = DriverManager.getConnection(jdbcUrlWithCredentials); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT CAST(dt AS CHAR) FROM tpch.test_datetime_storage WHERE id = 1")) { + assertTrue(rs.next(), "Expected one row"); + String dbValue1 = rs.getString(1); + assertEquals(dbValue1, "1970-01-01 00:00:00.000000", "JDBC insert should store wall clock time 1970-01-01 00:00:00 in DB"); + } + + Session session = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty("legacy_timestamp", "false") + .build(); + assertQuery(session, + "SELECT dt FROM mysql.tpch.test_datetime_storage WHERE id = 1", + "VALUES TIMESTAMP '1970-01-01 00:00:00.000000'"); + + // Presto insertion, retrieval via MySQL, and retrieval via Presto all agree on wall clock time + assertUpdate(session, "INSERT INTO mysql.tpch.test_datetime_storage VALUES (2, TIMESTAMP '2023-06-15 14:30:00.000000', 'presto')", 1); + + try (Connection conn = DriverManager.getConnection(jdbcUrlWithCredentials); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT CAST(dt AS CHAR) FROM tpch.test_datetime_storage WHERE id = 2")) { + assertTrue(rs.next(), "Expected one row"); + String dbValue2 = rs.getString(1); + assertEquals(dbValue2, "2023-06-15 14:30:00.000000", "Presto insert should store wall clock time 2023-06-15 14:30:00 in DB"); + } + + assertQuery(session, + "SELECT dt FROM mysql.tpch.test_datetime_storage WHERE id = 2", + "VALUES TIMESTAMP '2023-06-15 14:30:00.000000'"); + + for (String timeZoneId : ImmutableList.of("UTC", "America/New_York", "Asia/Tokyo", "Europe/Warsaw")) { + Session sessionWithTimezone = Session.builder(getQueryRunner().getDefaultSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey(timeZoneId)) + .setSystemProperty("legacy_timestamp", "false") + .build(); + + assertQuery(sessionWithTimezone, + "SELECT dt FROM mysql.tpch.test_datetime_storage WHERE id = 1", + "VALUES TIMESTAMP '1970-01-01 00:00:00.000000'"); + + assertQuery(sessionWithTimezone, + "SELECT dt FROM mysql.tpch.test_datetime_storage WHERE id = 2", + "VALUES TIMESTAMP '2023-06-15 14:30:00.000000'"); + } + } + finally { + jdbcExecutor.execute("DROP TABLE IF EXISTS tpch.test_datetime_storage"); + } } @Test - public void testTimestamp() + public void testDatetimeLegacyUnderlyingStorageVerification() + throws Exception { - // TODO MySQL timestamp is not correctly read (see comment in StandardReadMappings.timestampReadMapping), but testing this is hard because of #7122 + String jdbcUrl = mysqlContainer.getJdbcUrl(); + String jdbcUrlWithCredentials = format("%s%suser=%s&password=%s", + jdbcUrl, + jdbcUrl.contains("?") ? "&" : "?", + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + JdbcSqlExecutor jdbcExecutor = new JdbcSqlExecutor(jdbcUrlWithCredentials); + + try { + jdbcExecutor.execute("CREATE TABLE tpch.test_datetime_legacy_storage (" + + "id INT PRIMARY KEY, " + + "dt DATETIME(6), " + + "source VARCHAR(10))"); + + // MySQL insertion and MySQL retrieval agree, Presto incorrectly interprets DB value due to legacy mode + jdbcExecutor.execute("INSERT INTO tpch.test_datetime_legacy_storage VALUES (1, '1970-01-01 00:00:00.000000', 'jdbc')"); + + // Prove that the value is 1970-01-01 00:00:00 by reading directly from the DB via JDBC + try (Connection conn = DriverManager.getConnection(jdbcUrlWithCredentials); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT CAST(dt AS CHAR) FROM tpch.test_datetime_legacy_storage WHERE id = 1")) { + assertTrue(rs.next(), "Expected one row"); + String dbValue1 = rs.getString(1); + assertEquals(dbValue1, "1970-01-01 00:00:00.000000", "JDBC insert should store wall clock time 1970-01-01 00:00:00 in DB"); + } + + // In legacy mode, DB value 1970-01-01 00:00:00 is interpreted as if it's in JVM timezone (America/Bahia_Banderas UTC-7) + // and then converted to the session timezone. Since both are the same (America/Bahia_Banderas), + // the offset comes from treating the wall-clock DB time as UTC, resulting in 1969-12-31 20:00:00 + Session legacySession = Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty("legacy_timestamp", "true") + .build(); + assertQuery(legacySession, + "SELECT dt FROM mysql.tpch.test_datetime_legacy_storage WHERE id = 1", + "VALUES TIMESTAMP '1969-12-31 20:00:00.000000'"); + + // Presto insertion with legacy mode, verify DB storage via JDBC (should apply JVM timezone conversion during write) + assertUpdate(legacySession, "INSERT INTO mysql.tpch.test_datetime_legacy_storage VALUES (2, TIMESTAMP '2023-06-15 14:30:00.000000', 'presto')", 1); + + try (Connection conn = DriverManager.getConnection(jdbcUrlWithCredentials); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT CAST(dt AS CHAR) FROM tpch.test_datetime_legacy_storage WHERE id = 2")) { + assertTrue(rs.next(), "Expected one row"); + String dbValue2 = rs.getString(1); + // JVM timezone is America/Bahia_Banderas (UTC-7), so 2023-06-15 14:30:00 becomes 2023-06-14 19:30:00.000000 + assertEquals(dbValue2, "2023-06-14 19:30:00.000000", "Legacy mode applies timezone conversion during write, expected 2023-06-14 19:30:00.000000"); + } + + // Verify Presto reads it back correctly in legacy mode (round-trip should work) + assertQuery(legacySession, + "SELECT dt FROM mysql.tpch.test_datetime_legacy_storage WHERE id = 2", + "VALUES TIMESTAMP '2023-06-15 14:30:00.000000'"); + + // DB value 1970-01-01 00:00:00 is interpreted as JVM timezone (America/Bahia_Banderas UTC-7), + // then converted to the session timezone + Session legacyUtcSession = Session.builder(getQueryRunner().getDefaultSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("UTC")) + .setSystemProperty("legacy_timestamp", "true") + .build(); + assertQuery(legacyUtcSession, + "SELECT dt FROM mysql.tpch.test_datetime_legacy_storage WHERE id = 1", + "VALUES TIMESTAMP '1970-01-01 07:00:00.000000'"); + + Session legacyTokyoSession = Session.builder(getQueryRunner().getDefaultSession()) + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("Asia/Tokyo")) + .setSystemProperty("legacy_timestamp", "true") + .build(); + assertQuery(legacyTokyoSession, + "SELECT dt FROM mysql.tpch.test_datetime_legacy_storage WHERE id = 1", + "VALUES TIMESTAMP '1970-01-01 16:00:00.000000'"); + } + finally { + jdbcExecutor.execute("DROP TABLE IF EXISTS tpch.test_datetime_legacy_storage"); + } } private void testUnsupportedDataType(String databaseDataType) { - JdbcSqlExecutor jdbcSqlExecutor = new JdbcSqlExecutor(mysqlServer.getJdbcUrl()); + String jdbcUrl = mysqlContainer.getJdbcUrl(); + String jdbcUrlWithCredentials = format("%s%suser=%s&password=%s", + jdbcUrl, + jdbcUrl.contains("?") ? "&" : "?", + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + JdbcSqlExecutor jdbcSqlExecutor = new JdbcSqlExecutor(jdbcUrlWithCredentials); jdbcSqlExecutor.execute(format("CREATE TABLE tpch.test_unsupported_data_type(supported_column varchar(5), unsupported_column %s)", databaseDataType)); try { assertQuery( @@ -291,7 +437,13 @@ private DataSetup prestoCreateAsSelect(String tableNamePrefix) private DataSetup mysqlCreateAndInsert(String tableNamePrefix) { - JdbcSqlExecutor mysqlUnicodeExecutor = new JdbcSqlExecutor(mysqlServer.getJdbcUrl() + "&useUnicode=true&characterEncoding=utf8"); + String jdbcUrl = mysqlContainer.getJdbcUrl(); + String jdbcUrlWithCredentials = format("%s%suser=%s&password=%s&useUnicode=true&characterEncoding=utf8", + jdbcUrl, + jdbcUrl.contains("?") ? "&" : "?", + mysqlContainer.getUsername(), + mysqlContainer.getPassword()); + JdbcSqlExecutor mysqlUnicodeExecutor = new JdbcSqlExecutor(jdbcUrlWithCredentials); return new CreateAndInsertDataSetup(mysqlUnicodeExecutor, tableNamePrefix); } } diff --git a/presto-native-execution/.clang-tidy b/presto-native-execution/.clang-tidy new file mode 100644 index 0000000000000..5b36ac93d48d2 --- /dev/null +++ b/presto-native-execution/.clang-tidy @@ -0,0 +1,96 @@ +Checks: > + *, + -abseil-*, + -android-*, + -cert-err58-cpp, + -cert-err58-cpp, + -clang-analyzer-osx-*, + -cppcoreguidelines-avoid-c-arrays, + -cppcoreguidelines-avoid-goto, + -cppcoreguidelines-avoid-magic-numbers, + -cppcoreguidelines-avoid-non-const-global-variables, + -cppcoreguidelines-owning-memory, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-pro-type-reinterpret-cast, + -cppcoreguidelines-pro-type-vararg, + -cppcoreguidelines-pro-type-vararg, + -cppcoreguidelines-special-member-functions, + -fuchsia-*, + -google-*, + -hicpp-avoid-c-arrays, + -hicpp-avoid-goto, + -hicpp-deprecated-headers, + -hicpp-no-array-decay, + -hicpp-special-member-functions, + -hicpp-use-equals-default, + -hicpp-vararg, + -hicpp-vararg, + -llvm-header-guard, + -llvm-include-order, + -llvmlibc-*, + -misc-no-recursion, + -misc-no-recursion, + -misc-non-private-member-variables-in-classes, + -misc-unused-parameters, + -modernize-avoid-c-arrays, + -modernize-deprecated-headers, + -modernize-use-nodiscard, + -modernize-use-trailing-return-type, + -mpi-*, + -objc-*, + -openmp-*, + -readability-avoid-const-params-in-decls, + -readability-convert-member-functions-to-static, + -readability-implicit-bool-conversion, + -readability-magic-numbers, + -zircon-*, + +HeaderFilterRegex: '.*' + +WarningsAsErrors: '' + +CheckOptions: + # Naming conventions as explicitly stated in CODING_STYLE.md + - key: readability-identifier-naming.ClassCase + value: CamelCase + - key: readability-identifier-naming.StructCase + value: CamelCase + - key: readability-identifier-naming.EnumCase + value: CamelCase + - key: readability-identifier-naming.TypeAliasCase + value: CamelCase + - key: readability-identifier-naming.TypeTemplateParameterCase + value: CamelCase + - key: readability-identifier-naming.FunctionCase + value: camelBack + - key: readability-identifier-naming.VariableCase + value: camelBack + - key: readability-identifier-naming.ParameterCase + value: camelBack + - key: readability-identifier-naming.PrivateMemberCase + value: camelBack + - key: readability-identifier-naming.PrivateMemberSuffix + value: _ + - key: readability-identifier-naming.ProtectedMemberCase + value: camelBack + - key: readability-identifier-naming.ProtectedMemberSuffix + value: _ + - key: readability-identifier-naming.MacroDefinitionCase + value: UPPER_CASE + - key: readability-identifier-naming.NamespaceCase + value: lower_case + - key: readability-identifier-naming.StaticConstantPrefix + value: k + - key: readability-identifier-naming.EnumConstantCase + value: CamelCase + - key: readability-identifier-naming.EnumConstantPrefix + value: k + + # Use nullptr instead of NULL or 0 + - key: modernize-use-nullptr.NullMacros + value: 'NULL' + + # Prefer enum class over enum + - key: modernize-use-using.IgnoreUsingStdAllocator + value: 1 diff --git a/presto-native-execution/.devcontainer/Dockerfile b/presto-native-execution/.devcontainer/Dockerfile new file mode 100644 index 0000000000000..86e72e86a4318 --- /dev/null +++ b/presto-native-execution/.devcontainer/Dockerfile @@ -0,0 +1,19 @@ +FROM presto/prestissimo-dependency:centos9 + +# Install necesary packages to run dev containers in CLion +# https://www.jetbrains.com/help/clion/prerequisites-for-dev-containers.html#remote_container +RUN dnf -y update && dnf -y install epel-release +RUN dnf -y --skip-broken install \ + curl \ + unzip \ + procps \ + libXext \ + libXrender \ + libXtst \ + libXi \ + freetype \ + procps \ + java-17-openjdk-headless \ + python3.12 + +RUN pip3 install pre-commit diff --git a/presto-native-execution/.devcontainer/README.md b/presto-native-execution/.devcontainer/README.md new file mode 100644 index 0000000000000..92bc62b93d9b9 --- /dev/null +++ b/presto-native-execution/.devcontainer/README.md @@ -0,0 +1,34 @@ +# How to develop Presto C++ with dev-containers in CLion + +> **_NOTE:_** For this to work you need CLion 2025.2.2 or greater. + +If you can't build, or want to build the development environment on your machine, you can use dev-containers. With them, you can have your IDE frontend working against a CLion backend running on a docker container. To set it up, run the following command: + +```sh +docker compose build centos-native-dependency +``` +Once the image is built, open the `presto-native-execution` module on CLion. + +Right-click on `.devcontainer\devcontainer.json`, and in the contextual menu select `Dev Containers->Create Dev Container and mount sources...->CLion`. Wait for the container to be up and running before you continue. + +The source code is mounted from your machine so any change made into it from the dev-container will also be on your machine. + +## Debug or execute `presto_server` + +Reload CMake project and configure the `presto_server` executable. See [Setup Presto with IntelliJ IDEA and Prestissimo with CLion](https://github.com/prestodb/presto/tree/master/presto-native-execution#setup-presto-with-intellij-idea-and-prestissimo-with-clion). Compile the project as needed. + +Then, execute the script `./devcontainer/install-shared-libs.sh` inside the container. This will create a directory `/runtime-libraries` and copy all the shared libraries needed for your compilation runtime in there. + +Edit the `presto_server` configuration to add the environment variable `LD_LIBRARY_PATH=/runtime-libraries`. This way, you'll have the same environment as distributed prestissimo images. + +## Known errors + - In some cases an error such as `Computing backend... error. Collection contains no element matching the predicate` can appear. The feature is still in beta. In this case, the container will be created and running, but there might have been an issue starting the CLion backend inside the container. + +To resolve this issue, close CLion and reopen it. + +In the `Welcome to CLion` window go to `Remote Development (beta)->Dev Containers`. You should see that the container `Presto C++ Dev Container` is up and running, so connect to it. In this case, the backend should start properly and the project should be opened. + + - In you can't use git inside the container, you need to manually add the mounted repo to the trusted directories for the dev-container + ```sh + git config --global --add safe.directory /workspace/presto + ``` diff --git a/presto-native-execution/.devcontainer/devcontainer.json b/presto-native-execution/.devcontainer/devcontainer.json new file mode 100644 index 0000000000000..d72ba896b0efd --- /dev/null +++ b/presto-native-execution/.devcontainer/devcontainer.json @@ -0,0 +1,8 @@ +{ + "name": "Presto C++ Dev Container", + "build": { + "dockerfile": "Dockerfile" + }, + "workspaceMount": "source=${localWorkspaceFolder}/..,target=/workspace/presto,type=bind", + "workspaceFolder": "/workspace/presto/presto-native-execution" +} diff --git a/presto-native-execution/.devcontainer/install-shared-libs.sh b/presto-native-execution/.devcontainer/install-shared-libs.sh new file mode 100755 index 0000000000000..9e935f1c5f559 --- /dev/null +++ b/presto-native-execution/.devcontainer/install-shared-libs.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Before executing this script, you should compile presto_server +# Copy shared libs to the directory /runtime-libraries +mkdir /runtime-libraries && + bash -c '!(LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib:/usr/local/lib64 ldd ../cmake-build-debug/presto_cpp/main/presto_server | grep "not found")' && + LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib:/usr/local/lib64 ldd ../cmake-build-debug/presto_cpp/main/presto_server | awk 'NF == 4 { system("cp " $3 " /runtime-libraries") }' diff --git a/presto-native-execution/.dockerignore b/presto-native-execution/.dockerignore new file mode 100644 index 0000000000000..b25c310f81620 --- /dev/null +++ b/presto-native-execution/.dockerignore @@ -0,0 +1,4 @@ +# Ignore build directories +_build/ +cmake-build-debug/ +cmake-build-release/ diff --git a/presto-native-execution/.gersemirc b/presto-native-execution/.gersemirc new file mode 100644 index 0000000000000..380ff2031d1c3 --- /dev/null +++ b/presto-native-execution/.gersemirc @@ -0,0 +1,6 @@ +# vim: set filetype=yaml : + +line_length: 100 +indent: 2 +definitions: +- presto_cpp/main/thrift/ThriftLibrary.cmake diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index 5474fe970e6d8..fc5d1ddf62696 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -22,44 +22,27 @@ execute_process( bash -c "( source ${CMAKE_CURRENT_SOURCE_DIR}/velox/scripts/setup-helper-functions.sh && echo -n $(get_cxx_flags $ENV{CPU_TARGET}))" OUTPUT_VARIABLE SCRIPT_CXX_FLAGS - RESULT_VARIABLE COMMAND_STATUS) + RESULT_VARIABLE COMMAND_STATUS +) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED True) message("Appending CMAKE_CXX_FLAGS with ${SCRIPT_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SCRIPT_CXX_FLAGS}") -# Known warnings that are benign can be disabled: -# -# * `restrict` since it triggers a bug in gcc 12. See -# https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105651 -set(DISABLED_WARNINGS - "-Wno-nullability-completeness \ - -Wno-deprecated-declarations \ - -Wno-restrict") - -# Disable -Wstringop-overflow to avoid a false positive in the following -# compiler versions. See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=117983 -# -# NOTE: Typically, this is disabled using a VELOX_SUPPRESS_STRINGOP guard within -# the code; but that doesn't apply to 3rd-party libraries, so we need to disable -# it globally for now. -if("GNU" STREQUAL "${CMAKE_CXX_COMPILER_ID}" - AND ((CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12 - AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 12.5) - OR (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13 - AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.4) - OR (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 14 - AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 14.3) - )) - set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-stringop-overflow") +# Known warnings that are benign can be disabled. +set(DISABLED_WARNINGS "-Wno-nullability-completeness -Wno-deprecated-declarations") + +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0.0") + string(APPEND DISABLED_WARNINGS " -Wno-error=template-id-cdtor") + endif() endif() # Important warnings that must be explicitly enabled. set(ENABLE_WARNINGS "-Wreorder") -set(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} ${DISABLED_WARNINGS} ${ENABLE_WARNINGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DISABLED_WARNINGS} ${ENABLE_WARNINGS}") # Add all Presto options below. @@ -76,75 +59,71 @@ option(PRESTO_ENABLE_GCS "Build GCS support" OFF) option(PRESTO_ENABLE_ABFS "Build ABFS support" OFF) # Forwards user input to VELOX_ENABLE_PARQUET. -option(PRESTO_ENABLE_PARQUET "Enable Parquet support" OFF) +option(PRESTO_ENABLE_PARQUET "Enable Parquet support" ON) + +# Forwards user input to VELOX_ENABLE_CUDF. +option(PRESTO_ENABLE_CUDF "Enable cuDF support" OFF) # Forwards user input to VELOX_ENABLE_REMOTE_FUNCTIONS. option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) +option(PRESTO_ENABLE_EXAMPLES "Enable Presto examples" OFF) + option(PRESTO_ENABLE_TESTING "Enable tests" ON) option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) option(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR "Enable Arrow Flight connector" OFF) -# Set all Velox options below. +option(PRESTO_ENABLE_SPATIAL "Enable spatial support" ON) -# Make sure that if we include folly headers or other dependency headers that -# include folly headers we turn off the coroutines and turn on int128. +# Set all Velox options below and make sure that if we include folly headers or +# other dependency headers that include folly headers we turn off the coroutines +# and turn on int128. add_compile_definitions(FOLLY_HAVE_INT128_T=1 FOLLY_CFG_NO_COROUTINES) -if(PRESTO_ENABLE_S3) - set(VELOX_ENABLE_S3 - ON - CACHE BOOL "Build S3 support") -endif() +set(VELOX_ENABLE_S3 ${PRESTO_ENABLE_S3} CACHE BOOL "Build S3 support") -if(PRESTO_ENABLE_HDFS) - set(VELOX_ENABLE_HDFS - ON - CACHE BOOL "Build HDFS support") -endif() +set(VELOX_ENABLE_HDFS ${PRESTO_ENABLE_HDFS} CACHE BOOL "Build HDFS support") -if(PRESTO_ENABLE_GCS) - set(VELOX_ENABLE_GCS - ON - CACHE BOOL "Build GCS support") -endif() +set(VELOX_ENABLE_GCS ${PRESTO_ENABLE_GCS} CACHE BOOL "Build GCS support") -if(PRESTO_ENABLE_ABFS) - set(VELOX_ENABLE_ABFS - ON - CACHE BOOL "Build ABFS support") -endif() +set(VELOX_ENABLE_ABFS ${PRESTO_ENABLE_ABFS} CACHE BOOL "Build ABFS support") -if(PRESTO_ENABLE_PARQUET) - set(VELOX_ENABLE_PARQUET - ON - CACHE BOOL "Enable Parquet support") -endif() +set(VELOX_ENABLE_PARQUET ${PRESTO_ENABLE_PARQUET} CACHE BOOL "Enable Parquet support") +set( + VELOX_ENABLE_REMOTE_FUNCTIONS + ${PRESTO_ENABLE_REMOTE_FUNCTIONS} + CACHE BOOL + "Enable remote function support in Velox" +) if(PRESTO_ENABLE_REMOTE_FUNCTIONS) - set(VELOX_ENABLE_REMOTE_FUNCTIONS - ON - CACHE BOOL "Enable remote function support in Velox") add_compile_definitions(PRESTO_ENABLE_REMOTE_FUNCTIONS) endif() -set(VELOX_BUILD_TESTING - OFF - CACHE BOOL "Enable Velox tests") +set(VELOX_ENABLE_CUDF ${PRESTO_ENABLE_CUDF} CACHE BOOL "Enable cuDF support") +if(PRESTO_ENABLE_CUDF) + add_compile_definitions(PRESTO_ENABLE_CUDF) + enable_language(CUDA) + # Determine CUDA_ARCHITECTURES automatically. + cmake_policy(SET CMP0104 NEW) +endif() -set(VELOX_ENABLE_SPARK_FUNCTIONS - OFF - CACHE BOOL "Enable Velox Spark functions") +set( + VELOX_ENABLE_GEO + ${PRESTO_ENABLE_SPATIAL} + CACHE BOOL + "Enable Velox Geometry (aka spatial) support" +) -set(VELOX_ENABLE_EXAMPLES - OFF - CACHE BOOL "Enable Velox examples") +set(VELOX_BUILD_TESTING OFF CACHE BOOL "Enable Velox tests") -set(VELOX_BUILD_TEST_UTILS - ${PRESTO_ENABLE_TESTING} - CACHE BOOL "Enable Velox test utils") +set(VELOX_ENABLE_SPARK_FUNCTIONS OFF CACHE BOOL "Enable Velox Spark functions") + +set(VELOX_ENABLE_EXAMPLES OFF CACHE BOOL "Enable Velox examples") + +set(VELOX_BUILD_TEST_UTILS ${PRESTO_ENABLE_TESTING} CACHE BOOL "Enable Velox test utils") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -157,15 +136,8 @@ set(Boost_USE_MULTITHREADED TRUE) find_package( Boost 1.66.0 - REQUIRED - program_options - context - filesystem - regex - thread - system - date_time - atomic) + REQUIRED program_options context filesystem regex thread system date_time url atomic +) include_directories(SYSTEM ${Boost_INCLUDE_DIRS}) find_package(gflags COMPONENTS shared) @@ -185,16 +157,18 @@ find_package(ZLIB) find_library(SNAPPY snappy) find_package(folly CONFIG REQUIRED) -set(FOLLY_WITH_DEPENDENCIES - ${FOLLY_LIBRARIES} - ${DOUBLE_CONVERSION} - Boost::context - dl - ${EVENT} - ${SNAPPY} - ${LZ4} - ${ZSTD} - ${ZLIB_LIBRARIES}) +set( + FOLLY_WITH_DEPENDENCIES + ${FOLLY_LIBRARIES} + ${DOUBLE_CONVERSION} + Boost::context + dl + ${EVENT} + ${SNAPPY} + ${LZ4} + ${ZSTD} + ${ZLIB_LIBRARIES} +) find_package(BZip2 MODULE) if(BZIP2_FOUND) @@ -218,10 +192,16 @@ find_package(wangle CONFIG) find_package(FBThrift) include_directories(SYSTEM ${FBTHRIFT_INCLUDE_DIR}) -set(PROXYGEN_LIBRARIES ${PROXYGEN_HTTP_SERVER} ${PROXYGEN} ${WANGLE} ${FIZZ} - ${MVFST_EXCEPTION}) +set( + PROXYGEN_LIBRARIES + ${PROXYGEN_HTTP_SERVER} + ${PROXYGEN} + ${WANGLE} + ${FIZZ} + ${MVFST_EXCEPTION} +) find_path(PROXYGEN_DIR NAMES include/proxygen) -set(PROXYGEN_INCLUDE_DIR "${PROXYGEN_DIR}/include/proxygen") +set(PROXYGEN_INCLUDE_DIR "${PROXYGEN_DIR}/include/") include_directories(SYSTEM ${OPENSSL_INCLUDE_DIR} ${PROXYGEN_INCLUDE_DIR}) include_directories(.) @@ -233,6 +213,9 @@ include_directories(${CMAKE_BINARY_DIR}) # set this for backwards compatibility, will be overwritten in velox/ set(VELOX_GTEST_INCUDE_DIR "velox/third_party/googletest/googletest/include") +# Do not use the Mono library because it causes link errors. +set(VELOX_MONO_LIBRARY OFF CACHE BOOL "Build Velox mono library") + add_subdirectory(velox) if(PRESTO_ENABLE_TESTING) @@ -249,8 +232,7 @@ if(PRESTO_ENABLE_JWT) endif() if("${MAX_LINK_JOBS}") - set_property(GLOBAL APPEND PROPERTY JOB_POOLS - "presto_link_job_pool=${MAX_LINK_JOBS}") + set_property(GLOBAL APPEND PROPERTY JOB_POOLS "presto_link_job_pool=${MAX_LINK_JOBS}") else() set_property(GLOBAL APPEND PROPERTY JOB_POOLS "presto_link_job_pool=8") endif() diff --git a/presto-native-execution/Makefile b/presto-native-execution/Makefile index f3fb5f709f4d5..1a0a4e18ed963 100644 --- a/presto-native-execution/Makefile +++ b/presto-native-execution/Makefile @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -.PHONY: all cmake build cmake-and-build clean debug release unittest submodules velox-submodule +.PHONY: all cmake build cmake-and-build clean debug release unittest submodules velox-submodule format-fix format-check header-fix header-check BUILD_BASE_DIR=_build BUILD_DIR=release @@ -24,26 +24,38 @@ PYTHON_VENV ?= .venv EXTRA_CMAKE_FLAGS ?= "" +define deprecate_message + $(eval $@_VAR_NAME = $(1)) + $(warning ${$@_VAR_NAME} environment variable is deprecated and will be removed in the future. Use EXTRA_CMAKE_FLAGS.) +endef + ifeq ($(PRESTO_ENABLE_PARQUET), ON) EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_PARQUET=ON + output := $(call deprecate_message,"PRESTO_ENABLE_PARQUET") endif ifeq ($(PRESTO_ENABLE_S3), ON) EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_S3=ON + output := $(call deprecate_message,"PRESTO_ENABLE_S3") endif ifeq ($(PRESTO_ENABLE_HDFS), ON) EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_HDFS=ON + output := $(call deprecate_message,"PRESTO_ENABLE_HDFS") endif ifeq ($(PRESTO_ENABLE_REMOTE_FUNCTIONS), ON) EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON + output := $(call deprecate_message,"PRESTO_ENABLE_REMOTE_FUNCTIONS") endif ifeq ($(PRESTO_ENABLE_JWT), ON) EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_JWT=ON + output := $(call deprecate_message,"PRESTO_ENABLE_JWT") endif ifneq ($(PRESTO_STATS_REPORTER_TYPE),) EXTRA_CMAKE_FLAGS += -DPRESTO_STATS_REPORTER_TYPE=$(PRESTO_STATS_REPORTER_TYPE) + output := $(call deprecate_message,"PRESTO_STATS_REPORTER_TYPE") endif ifneq ($(PRESTO_MEMORY_CHECKER_TYPE),) EXTRA_CMAKE_FLAGS += -DPRESTO_MEMORY_CHECKER_TYPE=$(PRESTO_MEMORY_CHECKER_TYPE) + output := $(call deprecate_message,"PRESTO_MEMORY_CHECKER_TYPE") endif CMAKE_FLAGS := -DTREAT_WARNINGS_AS_ERRORS=${TREAT_WARNINGS_AS_ERRORS} @@ -51,6 +63,16 @@ CMAKE_FLAGS += -DENABLE_ALL_WARNINGS=${ENABLE_WALL} CMAKE_FLAGS += -DCMAKE_PREFIX_PATH=$(CMAKE_PREFIX_PATH) CMAKE_FLAGS += -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) +ifdef CUDA_ARCHITECTURES +CMAKE_FLAGS += -DCMAKE_CUDA_ARCHITECTURES="$(CUDA_ARCHITECTURES)" +endif +ifdef CUDA_COMPILER +CMAKE_FLAGS += -DCMAKE_CUDA_COMPILER="$(CUDA_COMPILER)" +endif +ifdef CUDA_FLAGS +CMAKE_FLAGS += -DCMAKE_CUDA_FLAGS="$(CUDA_FLAGS)" +endif + SHELL := /bin/bash # Use Ninja if available. If Ninja is used, pass through parallelism control flags. @@ -73,7 +95,7 @@ clean: #: Delete all build artifacts rm -rf $(BUILD_BASE_DIR) velox-submodule: #: Check out code for velox submodule - git submodule sync --recursive + git submodule sync --recursive && \ git submodule update --init --recursive submodules: velox-submodule @@ -85,7 +107,7 @@ build: #: Build the software based in BUILD_DIR and BUILD_TYPE variables cmake --build $(BUILD_BASE_DIR)/$(BUILD_DIR) -j $(NUM_THREADS) debug: #: Build with debugging symbols - $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=Debug + $(MAKE) cmake BUILD_DIR=debug BUILD_TYPE=Debug && \ $(MAKE) build BUILD_DIR=debug release: #: Build the release version @@ -93,11 +115,11 @@ release: #: Build the release version $(MAKE) build BUILD_DIR=release cmake-and-build: #: cmake and build without updating submodules which requires git - cmake -B "$(BUILD_BASE_DIR)/$(BUILD_DIR)" $(FORCE_COLOR) $(CMAKE_FLAGS) $(EXTRA_CMAKE_FLAGS) + cmake -B "$(BUILD_BASE_DIR)/$(BUILD_DIR)" $(FORCE_COLOR) $(CMAKE_FLAGS) $(EXTRA_CMAKE_FLAGS) && \ cmake --build $(BUILD_BASE_DIR)/$(BUILD_DIR) -j $(NUM_THREADS) unittest: debug #: Build with debugging and run unit tests - cd $(BUILD_BASE_DIR)/debug && ctest -j $(NUM_THREADS) -VV --output-on-failure --exclude-regex velox.* + cd $(BUILD_BASE_DIR)/debug && ctest -j $(NUM_THREADS) -VV --output-on-failure --exclude-regex ^velox.* presto_protocol: #: Build the presto_protocol serde library cd presto_cpp/presto_protocol; $(MAKE) presto_protocol diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index b48aee0c27c95..f40b753cebe78 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -65,33 +65,45 @@ The supported architectures are `x86_64 (avx, sse)`, and `AArch64 (apple-m1+crc, Prestissimo can be built by a variety of compilers (and versions) but not all. Compilers (and versions) not mentioned are known to not work or have not been tried. -#### Recommended +#### Minimum required | OS | compiler | | -- | -------- | -| CentOS 9/RHEL 9 | `gcc12` | | Ubuntu 22.04 | `gcc11` | | macOS | `clang15` | +| CentOS 9/RHEL 9 | `gcc11` | -#### Older alternatives +#### Recommended | OS | compiler | | -- | -------- | -| Ubuntu 20.04 | `gcc9` | -| macOS | `clang14` | +| CentOS 9/RHEL 9 | `gcc12` | +| Ubuntu 22.04 | `gcc11` | +| macOS | `clang15 (or later)` | ### Build Prestissimo #### Parquet and S3 Support -To enable Parquet and S3 support, set `PRESTO_ENABLE_PARQUET = "ON"`, -`PRESTO_ENABLE_S3 = "ON"` in the environment. +Parquet support is enabled by default. To disable it, add `-DPRESTO_ENABLE_PARQUET=OFF` +to the `EXTRA_CMAKE_FLAGS` environment variable. + +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS -DPRESTO_ENABLE_PARQUET=OFF"` + +To enable S3 support, add `-DPRESTO_ENABLE_S3=ON` to the `EXTRA_CMAKE_FLAGS` +environment variable. + +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DPRESTO_ENABLE_S3=ON"` S3 support needs the [AWS SDK C++](https://github.com/aws/aws-sdk-cpp) library. This dependency can be installed by running the target platform build script from the `presto/presto-native-execution` directory. -`./velox/scripts/setup-centos9.sh install_aws` +`./velox/scripts/setup-centos9.sh install_aws_deps` + Or +`./velox/scripts/setup-ubuntu.sh install_aws_deps` #### JWT Authentication -To enable JWT authentication support, set `PRESTO_ENABLE_JWT = "ON"` in -the environment. +To enable JWT authentication support, add `-DPRESTO_ENABLE_JWT=ON` to the +`EXTRA_CMAKE_FLAGS` environment variable. + +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DPRESTO_ENABLE_JWT=ON"` JWT authentication support needs the [JWT CPP](https://github.com/Thalhammer/jwt-cpp) library. This dependency can be installed by running the script below from the @@ -108,6 +120,8 @@ follow these steps: *CMake flags:* `PRESTO_STATS_REPORTER_TYPE=PROMETHEUS` +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DPRESTO_STATS_REPORTER_TYPE=PROMETHEUS"` + *Runtime configuration:* `runtime-metrics-collection-enabled=true` * After installing the above dependencies, from the @@ -116,14 +130,33 @@ follow these steps: * Use `make unittest` to build and run tests. #### Arrow Flight Connector -To enable Arrow Flight connector support, add to the extra cmake flags: -`EXTRA_CMAKE_FLAGS = -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=ON` +To enable Arrow Flight connector support, add to the `EXTRA_CMAKE_FLAGS` environment variable: +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=ON"` The Arrow Flight connector requires the Arrow Flight library. You can install this dependency by running the following script from the `presto/presto-native-execution` directory: `./scripts/setup-adapters.sh arrow_flight` +#### Nvidia cuDF GPU Support + +To enable support with [cuDF](https://github.com/facebookincubator/velox/tree/main/velox/experimental/cudf), +add to the `EXTRA_CMAKE_FLAGS` environment variable: +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -DPRESTO_ENABLE_CUDF=ON"` + +In some environments, the CUDA_ARCHITECTURES and CUDA_COMPILER location must be explicitly set. +The make command will look like: + +`CUDA_ARCHITECTURES=80 CUDA_COMPILER=/usr/local/cuda/bin/nvcc EXTRA_CMAKE_FLAGS=" -DPRESTO_ENABLE_CUDF=ON" make` + +The required dependencies are bundled from the Velox setup scripts. + +#### Spatial type and function support +Spatial type and function support is enabled by default. To disable it, add to `EXTRA_CMAKE_FLAGS` environment variable: +`export EXTRA_CMAKE_FLAGS="${EXTRA_CMAKE_FLAGS} -PRESTO_ENABLE_SPATIAL=OFF"` + +The spatial support adds new types (OGC geometry types) and functionality for spatial calculations. + ### Makefile Targets A reminder of the available Makefile targets can be obtained using `make help` ``` @@ -172,23 +205,33 @@ Run IcebergExternalWorkerQueryRunner, * Main class: `com.facebook.presto.nativeworker.IcebergExternalWorkerQueryRunner`. * VM options: `-ea -Xmx5G -XX:+ExitOnOutOfMemoryError -Duser.timezone=America/Bahia_Banderas -Dhive.security=legacy`. * Working directory: `$MODULE_DIR$` - * Environment variables: `PRESTO_SERVER=/Users//git/presto/presto-native-execution/cmake-build-debug/presto_cpp/main/presto_server;DATA_DIR=/Users//Desktop/data;WORKER_COUNT=0` - * When `addStorageFormatToPath = false` **(Default)**, - - `$DATA_DIR/iceberg_data/`. Here `catalog_type` could be `HIVE | HADOOP | NESSIE | REST`. - - `addStorageFormatToPath` is `false` by default because Java `HiveQueryRunner` and `IcebergQueryRunner` do not add the file format to the path. - * When `addStorageFormatToPath = true`, - - `$DATA_DIR/iceberg_data//`. Here `file_format` could be `PARQUET | ORC | AVRO` and `catalog_type` could be `HIVE | HADOOP | NESSIE | REST`. + * Environment variables: + - PRESTO_SERVER: Absolute path to the native worker binary. For example: `/Users//git/presto/presto-native-execution/cmake-build-debug/presto_cpp/main/presto_server` + - DATA_DIR: Base data directory for test data and catalog warehouses. For example: `/Users//Desktop/data` + - WORKER_COUNT: Number of native workers to launch (default: 4) + - CATALOG_TYPE: Iceberg catalog type to use. One of `HADOOP | HIVE` (default: `HIVE`) + + Example: + `PRESTO_SERVER=/Users//git/presto/presto-native-execution/cmake-build-debug/presto_cpp/main/presto_server;DATA_DIR=/Users//Desktop/data;WORKER_COUNT=1;CATALOG_TYPE=HIVE` * Use classpath of module: choose `presto-native-execution` module. +Run NativeSidecarPluginQueryRunner: +* Edit/Create `NativeSidecarPluginQueryRunner` Application Run/Debug Configuration (alter paths accordingly). + * Main class: `com.facebook.presto.sidecar.NativeSidecarPluginQueryRunner`. + * VM options : `-ea -Xmx5G -XX:+ExitOnOutOfMemoryError -Duser.timezone=America/Bahia_Banderas -Dhive.security=legacy`. + * Working directory: `$MODULE_DIR$` + * Environment variables: `PRESTO_SERVER=/Users//git/presto/presto-native-execution/cmake-build-debug/presto_cpp/main/presto_server;DATA_DIR=/Users//Desktop/data;WORKER_COUNT=0` + * Use classpath of module: choose `presto-native-sidecar-plugin` module. + Run CLion: * File->Close Project if any is open. * Open `presto/presto-native-execution` directory as CMake project and wait till CLion loads/generates cmake files, symbols, etc. * Edit configuration for `presto_server` module (alter paths accordingly). * Program arguments: `--logtostderr=1 --v=1 --etc_dir=/Users//git/presto/presto-native-execution/etc` * Working directory: `/Users//git/presto/presto-native-execution` +* For sidecar, Edit configuration for `presto_server` module (alter paths accordingly). + * Program arguments: `--logtostderr=1 --v=1 --etc_dir=/Users//git/presto/presto-native-execution/etc_sidecar` + * Working directory: `/Users//git/presto/presto-native-execution` * Edit menu CLion->Preferences->Build, Execution, Deployment->CMake * CMake options: `-DVELOX_BUILD_TESTING=ON -DCMAKE_BUILD_TYPE=Debug` * Build options: `-- -j 12` @@ -198,7 +241,11 @@ Run CLion: * To enable clang format you need * Open any h or cpp file in the editor and select `Enable ClangFormat` by clicking `4 spaces` rectangle in the status bar (bottom right) which is next to `UTF-8` bar. - ![ScreenShot](cl_clangformat_switcherenable.png) + ![ScreenShot](docs/images/cl_clangformat_switcherenable.png) + +### Setup Presto C ++ with dev containers using [CLion](https://www.jetbrains.com/clion/) + +See [How to develop Presto C++ with dev-containers in CLion](.devcontainer/README.md). ### Run Presto Coordinator + Worker * Note that everything below can be done without using IDEs by running command line commands (not in this readme). @@ -206,12 +253,24 @@ Run CLion: * For Hive, Run `HiveExternalWorkerQueryRunner` from IntelliJ and wait until it starts (`======== SERVER STARTED ========` is displayed in the log output). * For Iceberg, Run `IcebergExternalWorkerQueryRunner` from IntelliJ and wait until it starts (`======== SERVER STARTED ========` is displayed in the log output). * Scroll up the log output and find `Discovery URL http://127.0.0.1:50555`. The port is 'random' with every start. -* Copy that port (or the whole URL) to the `discovery.uri` field in `presto/presto-native-execution/etc/config.properties` for the worker to discover the Coordinator. +* Copy that port (or the whole URL) to the `discovery.uri` field in `presto/presto-native-execution/etc/config.properties` for the worker to announce itself to the Coordinator. +* In CLion run "presto_server" module. Connection success will be indicated by `Announcement succeeded: 202` line in the log output. +* See **Run Presto Client** to start executing queries on the running local setup. + +### Run Presto Coordinator + Sidecar +* Note that everything below can be done without using IDEs by running command line commands (not in this readme). +* Add a property `presto.default-namespace=native.default` to `presto-native-execution/etc/config.properties`. +* Run `NativeSidecarPluginQueryRunner` from IntelliJ and wait until it starts (`======== SERVER STARTED ========` is displayed in the log output). +* Scroll up the log output and find `Discovery URL http://127.0.0.1:50555`. The port is 'random' with every startup. +* Copy that port (or the whole URL) to the `discovery.uri` field in`presto/presto-native-execution/etc_sidecar/config.properties` for the sidecar to announce itself to the Coordinator. * In CLion run "presto_server" module. Connection success will be indicated by `Announcement succeeded: 202` line in the log output. -* Two ways to run Presto client to start executing queries on the running local setup: - 1. In command line from presto root directory run the presto client: - * `java -jar presto-cli/target/presto-cli-*-executable.jar --catalog hive --schema tpch` - 2. Run `Presto Client` Application (see above on how to create and setup the configuration) inside IntelliJ +* See **Run Presto Client** to start executing queries on the running local setup. + +### Run Presto Client +* Run the following command from the presto root directory to start the Presto client: + ``` + java -jar presto-cli/target/presto-cli-*-executable.jar --catalog hive --schema tpch + ``` * You can start from `show tables;` and `describe table;` queries and execute more queries as needed. ### Run Integration (End to End or E2E) Tests @@ -224,43 +283,36 @@ Run CLion: ### Code formatting, headers, and clang-tidy -Makefile targets exist for showing, fixing and checking formatting, license -headers and clang-tidy warnings. These targets are shortcuts for calling -`presto/presto-native-execution/scripts/check.py` . - -GitHub Actions run `make format-check`, `make header-check` and -`make tidy-check` as part of our continuous integration. Pull requests should -pass linux-build, format-check, header-check and other jobs without errors -before being accepted. +Code formatting, license headers, and other checks are handled by pre-commit. -Formatting issues found on the changed lines in the current commit can be -displayed using `make format-show`. These issues can be fixed by using `make -format-fix`. This will apply formatting changes to changed lines in the -current commit. +The [pre-commit](https://pre-commit.com/) configuration in `.pre-commit-config.yaml` +provides Git hooks that run automatically before commits and pushes to check and fix +formatting and license headers. -Header issues found on the changed files in the current commit can be displayed -using `make header-show`. These issues can be fixed by using `make -header-fix`. This will apply license header updates to the files in the current -commit. +GitHub Actions run pre-commit checks as part of our continuous integration. Using +pre-commit hooks locally ensures pull requests pass these checks before they have +the chance to fail. When pre-commit automatically fixes issues on commit, it is a good +idea to manually check the modified files to ensure pre-commit did not make unintended +changes. -Similar commands `make tidy-show`, `make-tidy-fix`, `make tidy-check` exist for -running clang-tidy, but these checks are currently advisory only. +To install the pre-commit hooks, first ensure your Python version is 3.9 or higher. +Then run: -An entire directory tree of files can be formatted and have license headers added -using the `tree` variant of the format.sh commands: ``` -presto/presto-native-execution/scripts/check.py format tree -presto/presto-native-execution/scripts/check.py format tree --fix - -presto/presto-native-execution/scripts/check.py header tree -presto/presto-native-execution/scripts/check.py header tree --fix +pip install pre-commit +pre-commit install --allow-missing-config ``` -All the available formatting commands can be displayed by using -`presto/presto-native-execution/scripts/check.py help`. +The option `--allow-missing-config` will allow commits and pushes to succeed locally +if the config is missing (e.g. you are working on an older branch). + +In addition to the Git hooks, pre-commit can be run manually on changed files using +`pre-commit run` or on all files using `pre-commit run -a`. To run a +specific hook, use `pre-commit run [hook-id]` and refer to a specific hook `id` in +`.pre-commit-config.yaml`. -There is currently no mechanism to *opt out* files or directories from the -checks. When we need one it can be added. +The `clang-tidy` hook is not run locally or in CI by default, but +can be run manually for optional checks using `pre-commit run --hook-stage manual clang-tidy`. ## Create Pull Request * Submit PRs as usual following [Presto repository guidelines](https://github.com/prestodb/presto/wiki/Review-and-Commit-guidelines). diff --git a/presto-native-execution/docker-compose.yml b/presto-native-execution/docker-compose.yml index 534e2a8d79a95..f1f39ddc8e16d 100644 --- a/presto-native-execution/docker-compose.yml +++ b/presto-native-execution/docker-compose.yml @@ -26,15 +26,13 @@ services: image: presto/prestissimo-runtime:ubuntu-22.04 build: args: - # A few files in Velox require significant memory to compile and link. - # Build requires 18GB of memory for 2 threads. + # A few files in Velox require significant memory to compile and link. + # Build requires 18GB of memory for 2 threads. - NUM_THREADS=2 # default value for NUM_THREADS. - DEPENDENCY_IMAGE=presto/prestissimo-dependency:ubuntu-22.04 - BASE_IMAGE=ubuntu:22.04 - OSNAME=ubuntu - - EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF - -DPRESTO_ENABLE_PARQUET=ON - -DPRESTO_ENABLE_S3=ON + - EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_PARQUET=ON -DPRESTO_ENABLE_S3=ON context: . dockerfile: scripts/dockerfiles/prestissimo-runtime.dockerfile @@ -54,11 +52,9 @@ services: image: presto/prestissimo-runtime:centos9 build: args: - # A few files in Velox require significant memory to compile and link. - # Build requires 18GB of memory for 2 threads. + # A few files in Velox require significant memory to compile and link. + # Build requires 18GB of memory for 2 threads. - NUM_THREADS=2 # default value for NUM_THREADS - - EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF - -DPRESTO_ENABLE_PARQUET=ON - -DPRESTO_ENABLE_S3=ON + - EXTRA_CMAKE_FLAGS=-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_PARQUET=ON -DPRESTO_ENABLE_S3=ON -DPRESTO_ENABLE_CUDF=${GPU:-OFF} context: . dockerfile: scripts/dockerfiles/prestissimo-runtime.dockerfile diff --git a/presto-native-execution/docs/images/cl_clangformat_switcherenable.png b/presto-native-execution/docs/images/cl_clangformat_switcherenable.png new file mode 100644 index 0000000000000..ef3e7e664e6a4 Binary files /dev/null and b/presto-native-execution/docs/images/cl_clangformat_switcherenable.png differ diff --git a/presto-native-execution/etc/catalog/tpchstandard.properties b/presto-native-execution/etc/catalog/tpchstandard.properties index 16e833ca8f436..75110c5acf145 100644 --- a/presto-native-execution/etc/catalog/tpchstandard.properties +++ b/presto-native-execution/etc/catalog/tpchstandard.properties @@ -1 +1 @@ -connector.name=tpch \ No newline at end of file +connector.name=tpch diff --git a/presto-native-execution/etc/config.properties b/presto-native-execution/etc/config.properties index 37db096a9402a..bec551886110b 100644 --- a/presto-native-execution/etc/config.properties +++ b/presto-native-execution/etc/config.properties @@ -1,4 +1,4 @@ -discovery.uri=http://127.0.0.1:58215 +discovery.uri=http://127.0.0.1: presto.version=testversion http-server.http.port=7777 shutdown-onset-sec=1 diff --git a/presto-native-execution/etc/velox.properties b/presto-native-execution/etc/velox.properties index 6c2506bd99a8e..e69de29bb2d1d 100644 --- a/presto-native-execution/etc/velox.properties +++ b/presto-native-execution/etc/velox.properties @@ -1 +0,0 @@ -mutable-config=true \ No newline at end of file diff --git a/presto-native-execution/etc_sidecar/catalog/hive.properties b/presto-native-execution/etc_sidecar/catalog/hive.properties new file mode 100644 index 0000000000000..466b7e664e44f --- /dev/null +++ b/presto-native-execution/etc_sidecar/catalog/hive.properties @@ -0,0 +1 @@ +connector.name=hive diff --git a/presto-native-execution/etc_sidecar/catalog/iceberg.properties b/presto-native-execution/etc_sidecar/catalog/iceberg.properties new file mode 100644 index 0000000000000..f3a43dcb28126 --- /dev/null +++ b/presto-native-execution/etc_sidecar/catalog/iceberg.properties @@ -0,0 +1 @@ +connector.name=iceberg diff --git a/presto-native-execution/etc_sidecar/catalog/tpchstandard.properties b/presto-native-execution/etc_sidecar/catalog/tpchstandard.properties new file mode 100644 index 0000000000000..75110c5acf145 --- /dev/null +++ b/presto-native-execution/etc_sidecar/catalog/tpchstandard.properties @@ -0,0 +1 @@ +connector.name=tpch diff --git a/presto-native-execution/etc_sidecar/config.properties b/presto-native-execution/etc_sidecar/config.properties new file mode 100644 index 0000000000000..76ea4efa7c059 --- /dev/null +++ b/presto-native-execution/etc_sidecar/config.properties @@ -0,0 +1,7 @@ +discovery.uri=http://127.0.0.1: +presto.version=testversion +http-server.http.port=7778 +shutdown-onset-sec=1 +runtime-metrics-collection-enabled=true +native-sidecar=true +presto.default-namespace=native.default diff --git a/presto-native-execution/etc_sidecar/node.properties b/presto-native-execution/etc_sidecar/node.properties new file mode 100644 index 0000000000000..1d92b7ace8087 --- /dev/null +++ b/presto-native-execution/etc_sidecar/node.properties @@ -0,0 +1,3 @@ +node.environment=testing +node.internal-address=127.0.0.1 +node.location=testing-location diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 200ffa6834afc..84217668d47a6 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -5,7 +5,7 @@ com.facebook.presto presto-root - 0.293 + 0.297-edge10.1-SNAPSHOT presto-native-execution @@ -15,42 +15,104 @@ ${project.parent.basedir} src/checkstyle/presto-checks.xml + 17 + true + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + + + + javax.ws.rs + javax.ws.rs-api + test + + + javax.servlet + javax.servlet-api + 3.1.0 + test + + com.facebook.presto presto-tpch + test + + + + org.jetbrains + annotations + test + + + + com.facebook.airlift + json + test + + + + org.weakref + jmxutils + test io.airlift.tpch tpch + test com.google.guava guava + test + + + + com.google.inject + guice + test org.testng testng + test org.assertj assertj-core + test com.facebook.presto presto-main-base + test + + + + com.facebook.presto + presto-main-tests + test com.facebook.presto presto-main + test @@ -63,16 +125,31 @@ com.facebook.presto presto-common + test com.facebook.presto presto-tests + test com.facebook.presto presto-function-namespace-managers + test + + + + com.facebook.presto + presto-function-namespace-managers-common + test + + + + com.facebook.presto + presto-built-in-worker-function-tools + test @@ -85,12 +162,29 @@ com.facebook.presto presto-spi + test + + + + com.facebook.presto + presto-parser + test + + + com.google.guava + guava + + + com.facebook.presto + presto-spi + + com.facebook.presto presto-iceberg - runtime + test org.xerial.snappy @@ -100,18 +194,6 @@ org.apache.hudi hudi-presto-bundle - - org.apache.parquet - parquet-column - - - org.apache.parquet - parquet-common - - - org.apache.parquet - parquet-format-structures - org.apache.commons commons-lang3 @@ -123,35 +205,64 @@ com.facebook.presto presto-iceberg test-jar + test com.esotericsoftware kryo-shaded + + + + + com.facebook.presto + presto-delta + test-jar + test + + + org.scala-lang + scala-library + - org.apache.parquet - parquet-column + org.apache.commons + commons-lang3 + + + + + com.facebook.presto + presto-delta + test + - org.apache.parquet - parquet-common + org.scala-lang + scala-library - org.apache.parquet - parquet-format-structures + org.apache.commons + commons-lang3 + + org.xerial.snappy + snappy-java + test + + com.facebook.presto presto-hive-metastore + test com.facebook.presto presto-hive - runtime + test org.xerial.snappy @@ -171,12 +282,14 @@ com.facebook.presto.hive hive-apache + test com.facebook.presto presto-hive test-jar + test org.xerial.snappy @@ -203,34 +316,53 @@ com.facebook.presto presto-spark-base test + + + com.facebook.drift + * + + com.facebook.presto.spark spark-core test + + + com.facebook.drift + * + + - com.facebook.presto presto-spark-base test-jar test + + + com.facebook.drift + * + + - com.facebook.airlift log + test com.facebook.airlift log-manager + test com.facebook.airlift testing + test org.testcontainers @@ -248,22 +380,31 @@ presto-jdbc test + org.apache.commons commons-lang3 test + + com.facebook.presto + presto-sql-invoked-functions-plugin + ${project.version} + test + + com.facebook.presto presto-clp - runtime + test com.facebook.presto presto-clp test-jar + test @@ -306,7 +447,7 @@ 1 false - remote-function,textfile_reader + remote-function /root/project/build/debug/presto_cpp/main/presto_server /tmp/velox @@ -358,7 +499,7 @@ org.apache.maven.plugins maven-surefire-plugin - writer,parquet,remote-function,textfile_reader,no_textfile_reader,async_data_cache + writer,parquet,remote-function,textfile @@ -396,6 +537,12 @@ presto-cli-*-executable.jar + + ${project.parent.basedir}/presto-function-server/target + + presto-function-server-executable.jar + + ${project.parent.basedir}/presto-server/target @@ -467,7 +614,7 @@ Release presto-native-dependency:latest - -DPRESTO_ENABLE_TESTING=OFF + "-DPRESTO_ENABLE_TESTING=OFF -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON" 2 ubuntu:22.04 ubuntu diff --git a/presto-native-execution/presto_cpp/docs/develop/orc-dump-output.rst b/presto-native-execution/presto_cpp/docs/develop/orc-dump-output.rst index 8f09b7a3ecd11..a79ab6d7b171a 100644 --- a/presto-native-execution/presto_cpp/docs/develop/orc-dump-output.rst +++ b/presto-native-execution/presto_cpp/docs/develop/orc-dump-output.rst @@ -84,4 +84,4 @@ Sample output of an orcfiledump tool orc.writer.version -> 1 presto_query_id -> 20210814_094649_15363_c5483 orc.writer.name -> presto - presto_version -> 0.259.1 \ No newline at end of file + presto_version -> 0.259.1 diff --git a/presto-native-execution/presto_cpp/docs/index.rst b/presto-native-execution/presto_cpp/docs/index.rst index 588b88d5eb374..6b60868f0600d 100644 --- a/presto-native-execution/presto_cpp/docs/index.rst +++ b/presto-native-execution/presto_cpp/docs/index.rst @@ -6,4 +6,3 @@ Presto Native Execution Documentation :maxdepth: 2 develop - diff --git a/presto-native-execution/presto_cpp/external/json/nlohmann/json.hpp b/presto-native-execution/presto_cpp/external/json/nlohmann/json.hpp index 77bd1739c8789..ee4763544e9a1 100644 --- a/presto-native-execution/presto_cpp/external/json/nlohmann/json.hpp +++ b/presto-native-execution/presto_cpp/external/json/nlohmann/json.hpp @@ -2235,8 +2235,6 @@ struct static_const static constexpr T value{}; }; -template -constexpr T static_const::value; } // namespace detail } // namespace nlohmann diff --git a/presto-native-execution/presto_cpp/main/Announcer.cpp b/presto-native-execution/presto_cpp/main/Announcer.cpp index 0c9433ac335e0..6051fc290b356 100644 --- a/presto-native-execution/presto_cpp/main/Announcer.cpp +++ b/presto-native-execution/presto_cpp/main/Announcer.cpp @@ -16,8 +16,6 @@ #include #include #include -#include -#include #include "presto_cpp/external/json/nlohmann/json.hpp" namespace facebook::presto { diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 8a038eab6b378..9138bb9be597b 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -15,6 +15,12 @@ add_subdirectory(http) add_subdirectory(common) add_subdirectory(thrift) add_subdirectory(connectors) +add_subdirectory(functions) +add_subdirectory(tool) + +add_library(presto_session_properties SessionProperties.cpp) + +target_link_libraries(presto_session_properties ${FOLLY_WITH_DEPENDENCIES}) add_library( presto_server_lib @@ -27,19 +33,27 @@ add_library( PrestoServer.cpp PrestoServerOperations.cpp PrestoTask.cpp + PrestoToVeloxQueryConfig.cpp QueryContextManager.cpp ServerOperation.cpp SignalHandler.cpp - SessionProperties.cpp TaskManager.cpp TaskResource.cpp PeriodicHeartbeatManager.cpp - PeriodicServiceInventoryManager.cpp) + PeriodicServiceInventoryManager.cpp +) -add_dependencies(presto_server_lib presto_operators presto_protocol - presto_types presto_thrift-cpp2 presto_thrift_extra) +add_dependencies( + presto_server_lib + presto_operators + presto_protocol + presto_types + presto_thrift-cpp2 + presto_thrift_extra +) target_include_directories(presto_server_lib PRIVATE ${presto_thrift_INCLUDES}) + target_link_libraries( presto_server_lib $ @@ -47,15 +61,17 @@ target_link_libraries( $ presto_common presto_exception + presto_expression_optimizer presto_function_metadata presto_connectors presto_http presto_operators + presto_session_properties presto_velox_plan_conversion + presto_hive_functions velox_abfs velox_aggregates velox_caching - velox_clp_connector velox_common_base velox_core velox_dwio_common_exception @@ -64,6 +80,9 @@ target_link_libraries( velox_dwio_orc_reader velox_dwio_parquet_reader velox_dwio_parquet_writer + velox_dwio_text_reader_register + velox_dwio_text_writer_register + velox_dynamic_library_loader velox_encode velox_exec velox_file @@ -75,10 +94,10 @@ target_link_libraries( velox_hive_iceberg_splitreader velox_hive_partition_function velox_presto_serializer + velox_presto_type_parser velox_s3fs velox_serialization velox_time - velox_type_parser velox_type velox_type_fbhive velox_type_tz @@ -88,39 +107,55 @@ target_link_libraries( ${FOLLY_WITH_DEPENDENCIES} ${GLOG} ${GFLAGS_LIBRARIES} - pthread) + pthread +) + +if(PRESTO_ENABLE_CUDF) + target_link_libraries(presto_server_lib velox_cudf_exec) +endif() # Enabling Parquet causes build errors with missing symbols on MacOS. This is # likely due to a conflict between Arrow Thrift from velox_hive_connector and # FBThrift libraries. The build issue is fixed by linking velox_hive_connector # dependencies first followed by FBThrift. -target_link_libraries(presto_server_lib presto_thrift-cpp2 presto_thrift_extra - ${THRIFT_LIBRARY}) +target_link_libraries(presto_server_lib presto_thrift-cpp2 presto_thrift_extra ${THRIFT_LIBRARY}) if(PRESTO_ENABLE_REMOTE_FUNCTIONS) - add_library(presto_server_remote_function JsonSignatureParser.cpp - RemoteFunctionRegisterer.cpp) + add_library(presto_server_remote_function JsonSignatureParser.cpp RemoteFunctionRegisterer.cpp) - target_link_libraries(presto_server_remote_function velox_expression - velox_functions_remote ${FOLLY_WITH_DEPENDENCIES}) + target_link_libraries( + presto_server_remote_function + velox_expression + velox_functions_remote + ${FOLLY_WITH_DEPENDENCIES} + ) target_link_libraries(presto_server_lib presto_server_remote_function) endif() -set_property(TARGET presto_server_lib PROPERTY JOB_POOL_LINK - presto_link_job_pool) +set_property(TARGET presto_server_lib PROPERTY JOB_POOL_LINK presto_link_job_pool) add_executable(presto_server PrestoMain.cpp) +# The below additional flags are necessary for resolving dependencies for +# loading dynamic libraries. +if(APPLE) + target_link_options(presto_server BEFORE PUBLIC "-Wl,-undefined,dynamic_lookup") +else() + target_link_options(presto_server BEFORE PUBLIC "-Wl,-export-dynamic") +endif() + # velox_tpch_connector is an OBJECT target in Velox and so needs to be linked to # the executable or use TARGET_OBJECT linkage for the presto_server_lib target. # However, we also would need to add its dependencies (tpch_gen etc). TODO # change the target in Velox to a library target then we can move this to the # presto_server_lib. -target_link_libraries(presto_server presto_server_lib velox_tpch_connector) +target_link_libraries(presto_server presto_server_lib velox_tpch_connector velox_tpcds_connector) # Clang requires explicit linking with libatomic. -if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" - AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_GREATER_EQUAL 15) +if( + "${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" + AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_GREATER_EQUAL 15 +) target_link_libraries(presto_server atomic) endif() @@ -136,9 +171,7 @@ if(PRESTO_STATS_REPORTER_TYPE) add_subdirectory(runtime-metrics) target_link_libraries(presto_server presto_prometheus_reporter) else() - message( - FATAL_ERROR - "${PRESTO_STATS_REPORTER_TYPE} is not a valid stats reporter name") + message(FATAL_ERROR "${PRESTO_STATS_REPORTER_TYPE} is not a valid stats reporter name") endif() endif() @@ -146,14 +179,10 @@ if(PRESTO_MEMORY_CHECKER_TYPE) add_compile_definitions(PRESTO_MEMORY_CHECKER_TYPE) # Check if the current platform is Linux and the memory checker type is # LINUX_MEMORY_CHECKER. - if(UNIX - AND NOT APPLE - AND (PRESTO_MEMORY_CHECKER_TYPE STREQUAL "LINUX_MEMORY_CHECKER")) + if(UNIX AND NOT APPLE AND (PRESTO_MEMORY_CHECKER_TYPE STREQUAL "LINUX_MEMORY_CHECKER")) add_library(presto_linux_memory_checker OBJECT LinuxMemoryChecker.cpp) target_link_libraries(presto_server presto_linux_memory_checker) else() - message( - FATAL_ERROR - "${PRESTO_MEMORY_CHECKER_TYPE} is not a valid memory checker name") + message(FATAL_ERROR "${PRESTO_MEMORY_CHECKER_TYPE} is not a valid memory checker name") endif() endif() diff --git a/presto-native-execution/presto_cpp/main/LinuxMemoryChecker.cpp b/presto-native-execution/presto_cpp/main/LinuxMemoryChecker.cpp index 3d9aa0ac77c71..a344d24aadd8f 100644 --- a/presto-native-execution/presto_cpp/main/LinuxMemoryChecker.cpp +++ b/presto-native-execution/presto_cpp/main/LinuxMemoryChecker.cpp @@ -63,7 +63,7 @@ class LinuxMemoryChecker : public PeriodicMemoryChecker { ~LinuxMemoryChecker() override {} int64_t getUsedMemory() { - return systemUsedMemoryBytes(); + return systemUsedMemoryBytes(/*fetchFresh=*/true); } void setStatFile(std::string statFile) { @@ -180,7 +180,7 @@ class LinuxMemoryChecker : public PeriodicMemoryChecker { // value. It may be better than what we currently use. For // consistency we will match cgroup V1 and change if // necessary. - int64_t systemUsedMemoryBytes() override { + void loadSystemMemoryUsage() override { size_t memAvailable = 0; size_t memTotal = 0; size_t inactiveAnon = 0; @@ -207,7 +207,7 @@ class LinuxMemoryChecker : public PeriodicMemoryChecker { // Unit is in bytes. const auto memBytes = inactiveAnon + activeAnon; cachedSystemUsedMemoryBytes_ = memBytes; - return memBytes; + return; } // Last resort use host machine info. @@ -231,7 +231,6 @@ class LinuxMemoryChecker : public PeriodicMemoryChecker { const auto memBytes = (memAvailable && memTotal) ? memTotal - memAvailable : 0; cachedSystemUsedMemoryBytes_ = memBytes; - return memBytes; } int64_t mallocBytes() const override { diff --git a/presto-native-execution/presto_cpp/main/PeriodicHeartbeatManager.cpp b/presto-native-execution/presto_cpp/main/PeriodicHeartbeatManager.cpp index efad3c23b51c2..b6406222c9407 100644 --- a/presto-native-execution/presto_cpp/main/PeriodicHeartbeatManager.cpp +++ b/presto-native-execution/presto_cpp/main/PeriodicHeartbeatManager.cpp @@ -12,7 +12,6 @@ * limitations under the License. */ #include "presto_cpp/main/PeriodicHeartbeatManager.h" -#include namespace facebook::presto { PeriodicHeartbeatManager::PeriodicHeartbeatManager( diff --git a/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.cpp b/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.cpp index 76302be846c31..a4eadbb4d1d18 100644 --- a/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.cpp +++ b/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.cpp @@ -22,8 +22,8 @@ #include "velox/common/time/Timer.h" namespace facebook::presto { -PeriodicMemoryChecker::PeriodicMemoryChecker(Config config) - : config_(std::move(config)) { +PeriodicMemoryChecker::PeriodicMemoryChecker(const Config& config) + : config_(config) { if (config_.systemMemPushbackEnabled) { VELOX_CHECK_GT(config_.systemMemLimitBytes, 0); } @@ -71,6 +71,7 @@ void PeriodicMemoryChecker::start() { scheduler_->setThreadName("MemoryCheckerThread"); scheduler_->addFunction( [&]() { + loadSystemMemoryUsage(); periodicCb(); if (config_.mallocMemHeapDumpEnabled) { maybeDumpHeap(); @@ -92,6 +93,13 @@ void PeriodicMemoryChecker::stop() { scheduler_.reset(); } +int64_t PeriodicMemoryChecker::systemUsedMemoryBytes(bool fetchFresh) { + if (fetchFresh) { + loadSystemMemoryUsage(); + } + return cachedSystemUsedMemoryBytes_; +} + std::string PeriodicMemoryChecker::createHeapDumpFilePath() const { const size_t now = velox::getCurrentTimeMs() / 1000; // Format as follow: @@ -210,7 +218,9 @@ void PeriodicMemoryChecker::pushbackMemory() { RECORD_HISTOGRAM_METRIC_VALUE( kCounterMemoryPushbackLatencyMs, latencyUs / 1000); const auto actualFreedBytes = std::max( - 0, static_cast(currentMemBytes) - systemUsedMemoryBytes()); + 0, + static_cast(currentMemBytes) - + systemUsedMemoryBytes(/*fetchFresh=*/true)); RECORD_HISTOGRAM_METRIC_VALUE( kCounterMemoryPushbackExpectedReductionBytes, freedBytes); RECORD_HISTOGRAM_METRIC_VALUE( diff --git a/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.h b/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.h index da4b3e7582a7a..2f7d61286ac0a 100644 --- a/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.h +++ b/presto-native-execution/presto_cpp/main/PeriodicMemoryChecker.h @@ -65,7 +65,7 @@ class PeriodicMemoryChecker { size_t mallocBytesUsageDumpThreshold{20UL * 1024 * 1024 * 1024}; }; - explicit PeriodicMemoryChecker(Config config); + explicit PeriodicMemoryChecker(const Config& config); virtual ~PeriodicMemoryChecker() = default; @@ -74,17 +74,17 @@ class PeriodicMemoryChecker { virtual void start(); /// Stops the 'PeriodicMemoryChecker'. - void stop(); + virtual void stop(); - /// Returns the last known cached 'current' system memory usage in bytes. - int64_t cachedSystemUsedMemoryBytes() const { - return cachedSystemUsedMemoryBytes_; - } + /// Returns the last known cached 'current' system memory usage in bytes. If + /// 'fetchFresh' is true, retrieves and returns the current system memory + /// usage. The returned value is used to compare with + /// 'Config::systemMemLimitBytes'. + int64_t systemUsedMemoryBytes(bool fetchFresh = false); protected: - /// Fetches and returns current system memory usage in bytes. - /// The returned value is used to compare with 'Config::systemMemLimitBytes'. - virtual int64_t systemUsedMemoryBytes() = 0; + /// Fetches current system memory usage in bytes and stores it in the cache. + virtual void loadSystemMemoryUsage() = 0; /// Returns current bytes allocated by malloc. The returned value is used to /// compare with 'Config::mallocBytesUsageDumpThreshold' diff --git a/presto-native-execution/presto_cpp/main/PeriodicTaskManager.cpp b/presto-native-execution/presto_cpp/main/PeriodicTaskManager.cpp index 8537f173fe564..5b21a163b5e0f 100644 --- a/presto-native-execution/presto_cpp/main/PeriodicTaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/PeriodicTaskManager.cpp @@ -148,27 +148,34 @@ class HiveConnectorStatsReporter { explicit HiveConnectorStatsReporter( std::shared_ptr connector) : connector_(std::move(connector)), - numElementsMetricName_(fmt::format( - kCounterHiveFileHandleCacheNumElementsFormat, - connector_->connectorId())), - pinnedSizeMetricName_(fmt::format( - kCounterHiveFileHandleCachePinnedSizeFormat, - connector_->connectorId())), - curSizeMetricName_(fmt::format( - kCounterHiveFileHandleCacheCurSizeFormat, - connector_->connectorId())), - numAccumulativeHitsMetricName_(fmt::format( - kCounterHiveFileHandleCacheNumAccumulativeHitsFormat, - connector_->connectorId())), - numAccumulativeLookupsMetricName_(fmt::format( - kCounterHiveFileHandleCacheNumAccumulativeLookupsFormat, - connector_->connectorId())), - numHitsMetricName_(fmt::format( - kCounterHiveFileHandleCacheNumHitsFormat, - connector_->connectorId())), - numLookupsMetricName_(fmt::format( - kCounterHiveFileHandleCacheNumLookupsFormat, - connector_->connectorId())) { + numElementsMetricName_( + fmt::format( + kCounterHiveFileHandleCacheNumElementsFormat, + connector_->connectorId())), + pinnedSizeMetricName_( + fmt::format( + kCounterHiveFileHandleCachePinnedSizeFormat, + connector_->connectorId())), + curSizeMetricName_( + fmt::format( + kCounterHiveFileHandleCacheCurSizeFormat, + connector_->connectorId())), + numAccumulativeHitsMetricName_( + fmt::format( + kCounterHiveFileHandleCacheNumAccumulativeHitsFormat, + connector_->connectorId())), + numAccumulativeLookupsMetricName_( + fmt::format( + kCounterHiveFileHandleCacheNumAccumulativeLookupsFormat, + connector_->connectorId())), + numHitsMetricName_( + fmt::format( + kCounterHiveFileHandleCacheNumHitsFormat, + connector_->connectorId())), + numLookupsMetricName_( + fmt::format( + kCounterHiveFileHandleCacheNumLookupsFormat, + connector_->connectorId())) { DEFINE_METRIC(numElementsMetricName_, velox::StatType::AVG); DEFINE_METRIC(pinnedSizeMetricName_, velox::StatType::AVG); DEFINE_METRIC(curSizeMetricName_, velox::StatType::AVG); @@ -427,45 +434,33 @@ void PeriodicTaskManager::addConnectorStatsTask() { } void PeriodicTaskManager::updateOperatingSystemStats() { - struct rusage usage {}; + struct rusage usage{}; memset(&usage, 0, sizeof(usage)); getrusage(RUSAGE_SELF, &usage); const int64_t userCpuTimeUs{ static_cast(usage.ru_utime.tv_sec) * 1'000'000 + static_cast(usage.ru_utime.tv_usec)}; - RECORD_METRIC_VALUE( - kCounterOsUserCpuTimeMicros, userCpuTimeUs - lastUserCpuTimeUs_); - lastUserCpuTimeUs_ = userCpuTimeUs; + RECORD_METRIC_VALUE(kCounterOsUserCpuTimeMicros, userCpuTimeUs); const int64_t systemCpuTimeUs{ static_cast(usage.ru_stime.tv_sec) * 1'000'000 + static_cast(usage.ru_stime.tv_usec)}; - RECORD_METRIC_VALUE( - kCounterOsSystemCpuTimeMicros, systemCpuTimeUs - lastSystemCpuTimeUs_); - lastSystemCpuTimeUs_ = systemCpuTimeUs; + RECORD_METRIC_VALUE(kCounterOsSystemCpuTimeMicros, systemCpuTimeUs); const int64_t softPageFaults{usage.ru_minflt}; - RECORD_METRIC_VALUE( - kCounterOsNumSoftPageFaults, softPageFaults - lastSoftPageFaults_); - lastSoftPageFaults_ = softPageFaults; + RECORD_METRIC_VALUE(kCounterOsNumSoftPageFaults, softPageFaults); const int64_t hardPageFaults{usage.ru_majflt}; - RECORD_METRIC_VALUE( - kCounterOsNumHardPageFaults, hardPageFaults - lastHardPageFaults_); - lastHardPageFaults_ = hardPageFaults; + RECORD_METRIC_VALUE(kCounterOsNumHardPageFaults, hardPageFaults); const int64_t voluntaryContextSwitches{usage.ru_nvcsw}; RECORD_METRIC_VALUE( - kCounterOsNumVoluntaryContextSwitches, - voluntaryContextSwitches - lastVoluntaryContextSwitches_); - lastVoluntaryContextSwitches_ = voluntaryContextSwitches; + kCounterOsNumVoluntaryContextSwitches, voluntaryContextSwitches); const int64_t forcedContextSwitches{usage.ru_nivcsw}; RECORD_METRIC_VALUE( - kCounterOsNumForcedContextSwitches, - forcedContextSwitches - lastForcedContextSwitches_); - lastForcedContextSwitches_ = forcedContextSwitches; + kCounterOsNumForcedContextSwitches, forcedContextSwitches); } void PeriodicTaskManager::addOperatingSystemStatsUpdateTask() { @@ -533,6 +528,8 @@ void PeriodicTaskManager::addWatchdogTask() { } RECORD_METRIC_VALUE(kCounterNumStuckDrivers, stuckOpCalls.size()); + const char* detachReason = nullptr; + // Detach worker from the cluster if more than a certain number of // driver threads are blocked by stuck operators (one unique operator // can only get stuck on one unique thread). @@ -540,9 +537,33 @@ void PeriodicTaskManager::addWatchdogTask() { SystemConfig::instance()->driverNumStuckOperatorsToDetachWorker(), numDriverThreads_); if (stuckOpCalls.size() >= numStuckOperatorsToDetachWorker) { - detachWorker("detected stuck operators"); + detachReason = "detected stuck operators"; } else if (!deadlockTasks.empty()) { - detachWorker("starving or deadlocked task"); + detachReason = "starving or deadlocked task"; + } + + // Detach worker from the cluster if it has been overloaded for too + // long. + const auto now = velox::getCurrentTimeSec(); + const auto lastNotOverloadedTime = + taskManager_->lastNotOverloadedTimeInSecs(); + const auto overloadedDurationSec = + taskManager_->isServerOverloaded() && (now > lastNotOverloadedTime) + ? now - lastNotOverloadedTime + : 0UL; + RECORD_METRIC_VALUE( + kCounterOverloadedDurationSec, overloadedDurationSec); + if (detachReason == nullptr) { + const uint64_t secondsThreshold = + SystemConfig::instance()->workerOverloadedSecondsToDetachWorker(); + if (secondsThreshold > 0 && + overloadedDurationSec > secondsThreshold) { + detachReason = "worker has been overloaded for too long"; + } + } + + if (detachReason != nullptr) { + detachWorker(detachReason); } else { maybeAttachWorker(); } diff --git a/presto-native-execution/presto_cpp/main/PeriodicTaskManager.h b/presto-native-execution/presto_cpp/main/PeriodicTaskManager.h index 74ffb2fe96a6c..35c555e68e0d2 100644 --- a/presto-native-execution/presto_cpp/main/PeriodicTaskManager.h +++ b/presto-native-execution/presto_cpp/main/PeriodicTaskManager.h @@ -13,8 +13,8 @@ */ #pragma once -#include #include +#include #include "velox/common/memory/Memory.h" namespace folly { @@ -142,14 +142,6 @@ class PeriodicTaskManager { std::shared_ptr>& connectors_; PrestoServer* server_; - // Operating system related stats. - int64_t lastUserCpuTimeUs_{0}; - int64_t lastSystemCpuTimeUs_{0}; - int64_t lastSoftPageFaults_{0}; - int64_t lastHardPageFaults_{0}; - int64_t lastVoluntaryContextSwitches_{0}; - int64_t lastForcedContextSwitches_{0}; - int64_t lastHttpClientNumConnectionsCreated_{0}; // NOTE: declare last since the threads access other members of `this`. diff --git a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp index fc246db319c9c..130bc49690292 100644 --- a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp @@ -18,8 +18,8 @@ #include #include -#include "presto_cpp/main/QueryContextManager.h" #include "presto_cpp/main/common/Counters.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/common/base/Exceptions.h" #include "velox/common/testutil/TestValue.h" @@ -144,7 +144,7 @@ folly::SemiFuture PrestoExchangeSource::request( VELOX_CHECK(requestPending_); // This call cannot be made concurrently from multiple threads, but other // calls that mutate promise_ can be called concurrently. - auto promise = VeloxPromise("PrestoExchangeSource::request"); + VeloxPromise promise{"PrestoExchangeSource::request"}; auto future = promise.getSemiFuture(); velox::common::testutil::TestValue::adjust( "facebook::presto::PrestoExchangeSource::request", this); @@ -159,10 +159,10 @@ folly::SemiFuture PrestoExchangeSource::request( } failedAttempts_ = 0; - dataRequestRetryState_ = - RetryState(std::chrono::duration_cast( - SystemConfig::instance()->exchangeMaxErrorDuration()) - .count()); + dataRequestRetryState_ = RetryState( + std::chrono::duration_cast( + SystemConfig::instance()->exchangeMaxErrorDuration()) + .count()); doRequest(dataRequestRetryState_.nextDelayMs(), maxBytes, maxWait); return future; @@ -197,6 +197,7 @@ void PrestoExchangeSource::doRequest( protocol::PRESTO_MAX_WAIT_HTTP_HEADER, protocol::Duration(maxWait.count(), protocol::TimeUnit::MICROSECONDS) .toString()) + .header(proxygen::HTTP_HEADER_HOST, fmt::format("{}:{}", host_, port_)) .send(httpClient_.get(), "", delayMs) .via(driverExecutor_) .thenTry( @@ -245,7 +246,8 @@ void PrestoExchangeSource::handleDataResponse( } else if (response->hasError()) { processDataError(httpRequestPath, maxBytes, maxWait, response->error()); } else { - processDataResponse(std::move(response)); + const bool isGetDataSizeRequest = (maxBytes == 0); + processDataResponse(std::move(response), isGetDataSizeRequest); } } catch (const std::exception& e) { processDataError(httpRequestPath, maxBytes, maxWait, e.what()); @@ -254,11 +256,29 @@ void PrestoExchangeSource::handleDataResponse( } void PrestoExchangeSource::processDataResponse( - std::unique_ptr response) { - RECORD_HISTOGRAM_METRIC_VALUE( - kCounterExchangeRequestDuration, dataRequestRetryState_.durationMs()); - RECORD_HISTOGRAM_METRIC_VALUE( - kCounterExchangeRequestNumTries, dataRequestRetryState_.numTries()); + std::unique_ptr response, + bool isGetDataSizeRequest) { + if (isGetDataSizeRequest) { + int64_t waitTimeMs = 0; + auto waitTimeMsString = response->headers()->getHeaders().getSingleOrEmpty( + protocol::PRESTO_BUFFER_WAIT_TIME_MS_HEADER); + if (!waitTimeMsString.empty()) { + waitTimeMs = std::stoll(waitTimeMsString); + getDataSizeNs_.addValue( + (dataRequestRetryState_.durationMs() - waitTimeMs) * 1'000'000); + RECORD_HISTOGRAM_METRIC_VALUE( + kCounterExchangeGetDataSizeDuration, + dataRequestRetryState_.durationMs() - waitTimeMs); + } + RECORD_HISTOGRAM_METRIC_VALUE( + kCounterExchangeGetDataSizeNumTries, dataRequestRetryState_.numTries()); + } else { + getDataNs_.addValue(dataRequestRetryState_.durationMs() * 1'000'000); + RECORD_HISTOGRAM_METRIC_VALUE( + kCounterExchangeRequestDuration, dataRequestRetryState_.durationMs()); + RECORD_HISTOGRAM_METRIC_VALUE( + kCounterExchangeRequestNumTries, dataRequestRetryState_.numTries()); + } if (closed_.load()) { // If PrestoExchangeSource is already closed, just free all buffers // allocated without doing any processing. This can happen when a super slow @@ -267,9 +287,11 @@ void PrestoExchangeSource::processDataResponse( return; } auto* headers = response->headers(); - VELOX_CHECK( - !headers->getIsChunked(), - "Chunked http transferring encoding is not supported."); + if (!SystemConfig::instance()->httpClientHttp2Enabled()) { + VELOX_CHECK( + !headers->getIsChunked(), + "Chunked http transferring encoding is not supported."); + } const uint64_t contentLength = atol(headers->getHeaders() .getSingleOrEmpty(proxygen::HTTP_HEADER_CONTENT_LENGTH) @@ -310,7 +332,7 @@ void PrestoExchangeSource::processDataResponse( contentLength, 0, "next token is not set in non-empty data response"); } - std::unique_ptr page; + std::unique_ptr page; const bool empty = response->empty(); if (!empty) { std::vector> iobufs; @@ -319,20 +341,27 @@ void PrestoExchangeSource::processDataResponse( } else { iobufs.emplace_back(response->consumeBody(pool_.get())); } - int64_t totalBytes{0}; + int64_t iobufBytes{0}; std::unique_ptr singleChain; for (auto& buf : iobufs) { - totalBytes += buf->capacity(); + iobufBytes += buf->capacity(); if (!singleChain) { singleChain = std::move(buf); } else { singleChain->prev()->appendChain(std::move(buf)); } } - PrestoExchangeSource::updateMemoryUsage(totalBytes); + PrestoExchangeSource::updateMemoryUsage(iobufBytes); + + // Record IOBuf size counter when not a get-data-size request + if (!isGetDataSizeRequest) { + iobufBytes_.addValue(iobufBytes); + RECORD_HISTOGRAM_METRIC_VALUE( + kCounterExchangeRequestPageSize, iobufBytes); + } if (enableBufferCopy_) { - page = std::make_unique( + page = std::make_unique( std::move(singleChain), [pool = pool_](folly::IOBuf& iobuf) { int64_t freedBytes{0}; // Free the backed memory from MemoryAllocator on page dtor @@ -346,15 +375,15 @@ void PrestoExchangeSource::processDataResponse( PrestoExchangeSource::updateMemoryUsage(-freedBytes); }); } else { - page = std::make_unique( - std::move(singleChain), [totalBytes](folly::IOBuf& iobuf) { - PrestoExchangeSource::updateMemoryUsage(-totalBytes); + page = std::make_unique( + std::move(singleChain), [iobufBytes](folly::IOBuf& iobuf) { + PrestoExchangeSource::updateMemoryUsage(-iobufBytes); }); } } const int64_t pageSize = empty ? 0 : page->size(); - VeloxPromise requestPromise; + VeloxPromise requestPromise{VeloxPromise::makeEmpty()}; std::vector queuePromises; { std::lock_guard l(queue_->mutex()); @@ -362,7 +391,7 @@ void PrestoExchangeSource::processDataResponse( VLOG(1) << "Enqueuing page for " << basePath_ << "/" << sequence_ << ": " << pageSize << " bytes"; ++numPages_; - totalBytes_ += pageSize; + pageSize_ += pageSize; queue_->enqueueLocked(std::move(page), queuePromises); } if (complete) { @@ -473,10 +502,10 @@ void PrestoExchangeSource::abortResults() { return; } - abortRetryState_ = - RetryState(std::chrono::duration_cast( - SystemConfig::instance()->exchangeMaxErrorDuration()) - .count()); + abortRetryState_ = RetryState( + std::chrono::duration_cast( + SystemConfig::instance()->exchangeMaxErrorDuration()) + .count()); VLOG(1) << "Sending abort results " << basePath_; doAbortResults(abortRetryState_.nextDelayMs()); } @@ -517,7 +546,7 @@ void PrestoExchangeSource::handleAbortResponse( } bool PrestoExchangeSource::checkSetRequestPromise() { - VeloxPromise promise; + VeloxPromise promise{VeloxPromise::makeEmpty()}; { std::lock_guard l(queue_->mutex()); promise = std::move(promise_); diff --git a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.h b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.h index 009a84d4439b0..394627776780e 100644 --- a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.h +++ b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.h @@ -137,12 +137,22 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource { folly::F14FastMap metrics() const override { - return { + folly::F14FastMap result = { {"prestoExchangeSource.numPages", velox::RuntimeMetric(numPages_)}, - {"prestoExchangeSource.totalBytes", - velox::RuntimeMetric( - totalBytes_, velox::RuntimeCounter::Unit::kBytes)}, + {"prestoExchangeSource.pageSize", + velox::RuntimeMetric(pageSize_, velox::RuntimeCounter::Unit::kBytes)}, }; + if (getDataNs_.count > 0) { + result["prestoExchangeSource.getDataNanos"] = getDataNs_; + } + if (getDataSizeNs_.count > 0) { + result["prestoExchangeSource.getDataSizeNanos"] = getDataSizeNs_; + } + if (iobufBytes_.count > 0) { + result["prestoExchangeSource.iobufBytes"] = iobufBytes_; + } + + return result; } folly::dynamic toJson() override { @@ -154,7 +164,7 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource { obj["basePath"] = basePath_; obj["host"] = host_; obj["numPages"] = numPages_; - obj["totalBytes"] = totalBytes_; + obj["pageSize"] = pageSize_; obj["closed"] = std::to_string(closed_); obj["abortResultsIssued"] = std::to_string(abortResultsIssued_); obj["atEnd"] = atEnd_; @@ -200,7 +210,9 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource { // response without an end marker. Sends delete-results if received an end // marker. The sequence of operations is: add data or end marker to the // queue; complete the future, send ack or delete-results. - void processDataResponse(std::unique_ptr response); + void processDataResponse( + std::unique_ptr response, + bool isGetDataSizeRequest); // If 'retry' is true, then retry the http request failure until reaches the // retry limit, otherwise just set exchange source error without retry. As @@ -280,13 +292,17 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource { int failedAttempts_; // The number of pages received from this presto exchange source. uint64_t numPages_{0}; - uint64_t totalBytes_{0}; + uint64_t pageSize_{0}; std::atomic_bool closed_{false}; // A boolean indicating whether abortResults() call was issued std::atomic_bool abortResultsIssued_{false}; velox::VeloxPromise promise_{ velox::VeloxPromise::makeEmpty()}; + velox::RuntimeMetric getDataNs_{velox::RuntimeCounter::Unit::kNanos}; + velox::RuntimeMetric getDataSizeNs_{velox::RuntimeCounter::Unit::kNanos}; + velox::RuntimeMetric iobufBytes_{velox::RuntimeCounter::Unit::kBytes}; + friend class test::PrestoExchangeSourceTestHelper; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/PrestoMain.cpp b/presto-native-execution/presto_cpp/main/PrestoMain.cpp index bca2dab8dd1bd..526d42311cc8d 100644 --- a/presto-native-execution/presto_cpp/main/PrestoMain.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoMain.cpp @@ -16,6 +16,7 @@ #include #include #include "presto_cpp/main/PrestoServer.h" +#include "presto_cpp/main/common/Exception.h" #include "presto_cpp/main/common/Utils.h" #include "velox/common/base/StatsReporter.h" @@ -37,3 +38,16 @@ folly::Singleton reporter([]() { return new facebook::velox::DummyStatsReporter(); }); #endif + +// Initialize singleton for the exception translator. +// NOTE: folly::Singleton enforces that only ONE registration per type can +// exist. If another file tries to register VeloxToPrestoExceptionTranslator +// again, the program will fail during static initialization with a duplicate +// registration error. Extended servers should register a DERIVED class instead: +// folly::Singleton customTranslator([]() { +// return new CustomExceptionTranslator(); // derived class +// }); +folly::Singleton + exceptionTranslator([]() { + return new facebook::presto::VeloxToPrestoExceptionTranslator(); + }); diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index fb926119c25d9..9c06fa12d4bdd 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -15,11 +15,14 @@ #include #include #include +#include #include +#include #include "presto_cpp/main/Announcer.h" #include "presto_cpp/main/CoordinatorDiscoverer.h" #include "presto_cpp/main/PeriodicMemoryChecker.h" #include "presto_cpp/main/PeriodicTaskManager.h" +#include "presto_cpp/main/SessionProperties.h" #include "presto_cpp/main/SignalHandler.h" #include "presto_cpp/main/TaskResource.h" #include "presto_cpp/main/common/ConfigReader.h" @@ -27,6 +30,8 @@ #include "presto_cpp/main/common/Utils.h" #include "presto_cpp/main/connectors/Registration.h" #include "presto_cpp/main/connectors/SystemConnector.h" +#include "presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h" +#include "presto_cpp/main/functions/FunctionMetadata.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -34,20 +39,23 @@ #include "presto_cpp/main/http/filters/StatsFilter.h" #include "presto_cpp/main/operators/BroadcastExchangeSource.h" #include "presto_cpp/main/operators/BroadcastWrite.h" -#include "presto_cpp/main/operators/LocalPersistentShuffle.h" +#include "presto_cpp/main/operators/LocalShuffle.h" #include "presto_cpp/main/operators/PartitionAndSerialize.h" +#include "presto_cpp/main/operators/ShuffleExchangeSource.h" #include "presto_cpp/main/operators/ShuffleRead.h" -#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h" -#include "presto_cpp/main/types/FunctionMetadata.h" +#include "presto_cpp/main/operators/ShuffleWrite.h" +#include "presto_cpp/main/types/ExpressionOptimizer.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/VeloxPlanConversion.h" #include "velox/common/base/Counters.h" #include "velox/common/base/StatsReporter.h" #include "velox/common/caching/CacheTTLController.h" #include "velox/common/caching/SsdCache.h" +#include "velox/common/dynamic_registry/DynamicLibraryLoader.h" #include "velox/common/file/FileSystems.h" #include "velox/common/memory/SharedArbitrator.h" #include "velox/connectors/Connector.h" +#include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/storage_adapters/abfs/RegisterAbfsFileSystem.h" #include "velox/connectors/hive/storage_adapters/gcs/RegisterGcsFileSystem.h" #include "velox/connectors/hive/storage_adapters/hdfs/RegisterHdfsFileSystem.h" @@ -57,7 +65,10 @@ #include "velox/dwio/orc/reader/OrcReader.h" #include "velox/dwio/parquet/RegisterParquetReader.h" #include "velox/dwio/parquet/RegisterParquetWriter.h" +#include "velox/dwio/text/RegisterTextReader.h" +#include "velox/dwio/text/RegisterTextWriter.h" #include "velox/exec/OutputBufferManager.h" +#include "velox/exec/TraceUtil.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" @@ -65,6 +76,11 @@ #include "velox/serializers/PrestoSerializer.h" #include "velox/serializers/UnsafeRowSerializer.h" +#ifdef PRESTO_ENABLE_CUDF +#include "velox/experimental/cudf/CudfConfig.h" +#include "velox/experimental/cudf/exec/ToCudf.h" +#endif + #ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS #include "presto_cpp/main/RemoteFunctionRegisterer.h" #endif @@ -85,6 +101,10 @@ constexpr char const* kHttps = "https"; constexpr char const* kTaskUriFormat = "{}://{}:{}"; // protocol, address and port constexpr char const* kConnectorName = "connector.name"; +constexpr char const* kLinuxSharedLibExt = ".so"; +constexpr char const* kMacOSSharedLibExt = ".dylib"; +constexpr char const* kOptimized = "OPTIMIZED"; +constexpr char const* kEvaluated = "EVALUATED"; protocol::NodeState convertNodeState(presto::NodeState nodeState) { switch (nodeState) { @@ -99,11 +119,10 @@ protocol::NodeState convertNodeState(presto::NodeState nodeState) { } void enableChecksum() { - velox::exec::OutputBufferManager::getInstanceRef()->setListenerFactory( - []() { - return std::make_unique< - velox::serializer::presto::PrestoOutputStreamListener>(); - }); + velox::exec::OutputBufferManager::getInstanceRef()->setListenerFactory([]() { + return std::make_unique< + velox::serializer::presto::PrestoOutputStreamListener>(); + }); } // Log only the catalog keys that are configured to avoid leaking @@ -141,6 +160,87 @@ bool cachePeriodicPersistenceEnabled() { std::chrono::seconds::zero(); } +bool isSharedLibrary(const fs::path& path) { + std::string pathExt = path.extension().string(); + std::transform(pathExt.begin(), pathExt.end(), pathExt.begin(), ::tolower); + return pathExt == kLinuxSharedLibExt || pathExt == kMacOSSharedLibExt; +} + +void registerVeloxCudf() { +#ifdef PRESTO_ENABLE_CUDF + // Disable by default. + velox::cudf_velox::CudfConfig::getInstance().enabled = false; + auto systemConfig = SystemConfig::instance(); + velox::cudf_velox::CudfConfig::getInstance().functionNamePrefix = + systemConfig->prestoDefaultNamespacePrefix(); + if (systemConfig->values().contains( + velox::cudf_velox::CudfConfig::kCudfEnabled)) { + velox::cudf_velox::CudfConfig::getInstance().initialize( + systemConfig->values()); + if (velox::cudf_velox::CudfConfig::getInstance().enabled) { + velox::cudf_velox::registerCudf(); + PRESTO_STARTUP_LOG(INFO) << "cuDF is registered."; + } + } +#endif +} + +void unregisterVeloxCudf() { +#ifdef PRESTO_ENABLE_CUDF + auto systemConfig = SystemConfig::instance(); + if (systemConfig->values().contains( + velox::cudf_velox::CudfConfig::kCudfEnabled) && + velox::cudf_velox::CudfConfig::getInstance().enabled) { + velox::cudf_velox::unregisterCudf(); + PRESTO_SHUTDOWN_LOG(INFO) << "cuDF is unregistered."; + } +#endif +} + +json::array_t getOptimizedExpressions( + const proxygen::HTTPHeaders& httpHeaders, + const std::vector>& body, + folly::Executor* executor, + velox::memory::MemoryPool* pool) { + static constexpr char const* kOptimizerLevelHeader = + "X-Presto-Expression-Optimizer-Level"; + const auto& optimizerLevelString = + httpHeaders.getSingleOrEmpty(kOptimizerLevelHeader); + VELOX_USER_CHECK( + (optimizerLevelString == kOptimized) || + (optimizerLevelString == kEvaluated), + "Optimizer level should be OPTIMIZED or EVALUATED, received {}.", + optimizerLevelString); + auto optimizerLevel = (optimizerLevelString == kOptimized) + ? expression::OptimizerLevel::kOptimized + : expression::OptimizerLevel::kEvaluated; + + static constexpr char const* kTimezoneHeader = "X-Presto-Time-Zone"; + const auto& timezone = httpHeaders.getSingleOrEmpty(kTimezoneHeader); + std::unordered_map config( + {{velox::core::QueryConfig::kSessionTimezone, timezone}, + {velox::core::QueryConfig::kAdjustTimestampToTimezone, "true"}}); + auto queryConfig = velox::core::QueryConfig{std::move(config)}; + auto queryCtx = + velox::core::QueryCtx::create(executor, std::move(queryConfig)); + + json input = json::parse(util::extractMessageBody(body)); + VELOX_USER_CHECK(input.is_array(), "Body of request should be a JSON array."); + const json::array_t expressionList = static_cast(input); + std::vector expressions; + for (const auto& j : expressionList) { + expressions.push_back(j); + } + const auto optimizedList = expression::optimizeExpressions( + expressions, optimizerLevel, queryCtx.get(), pool); + + json::array_t result; + for (const auto& optimized : optimizedList) { + result.push_back(optimized); + } + return result; +} + } // namespace std::string nodeState2String(NodeState nodeState) { @@ -166,7 +266,6 @@ PrestoServer::~PrestoServer() {} void PrestoServer::run() { auto systemConfig = SystemConfig::instance(); auto nodeConfig = NodeConfig::instance(); - auto baseVeloxQueryConfig = BaseVeloxQueryConfig::instance(); int httpPort{0}; std::string certPath; @@ -182,9 +281,6 @@ void PrestoServer::run() { fmt::format("{}/config.properties", configDirectoryPath_)); nodeConfig->initialize( fmt::format("{}/node.properties", configDirectoryPath_)); - // velox.properties is optional. - baseVeloxQueryConfig->initialize( - fmt::format("{}/velox.properties", configDirectoryPath_), true); httpPort = systemConfig->httpServerHttpPort(); if (systemConfig->httpServerHttpsEnabled()) { @@ -215,8 +311,10 @@ void PrestoServer::run() { "Https Client Certificates are not configured correctly"); } - sslContext_ = - util::createSSLContext(optionalClientCertPath.value(), ciphers); + sslContext_ = util::createSSLContext( + optionalClientCertPath.value(), + ciphers, + systemConfig->httpClientHttp2Enabled()); } if (systemConfig->internalCommunicationJwtEnabled()) { @@ -245,13 +343,17 @@ void PrestoServer::run() { exit(EXIT_FAILURE); } - registerFileSinks(); registerFileSystems(); + registerFileSinks(); registerFileReadersAndWriters(); registerMemoryArbitrators(); registerShuffleInterfaceFactories(); registerCustomOperators(); + // We need to register cuDF before the connectors so that the cuDF connector + // factories can be used. + registerVeloxCudf(); + // Register Presto connector factories and connectors registerConnectors(); @@ -289,8 +391,18 @@ void PrestoServer::run() { httpsSocketAddress.setFromLocalPort(httpsPort.value()); } + const bool http2Enabled = + SystemConfig::instance()->httpServerHttp2Enabled(); + const std::string clientCaFile = + SystemConfig::instance()->httpsClientCaFile().value_or(""); httpsConfig = std::make_unique( - httpsSocketAddress, certPath, keyPath, ciphers, reusePort); + httpsSocketAddress, + certPath, + keyPath, + ciphers, + reusePort, + http2Enabled, + clientCaFile); } httpServer_ = std::make_unique( @@ -321,6 +433,14 @@ void PrestoServer::run() { json infoStateJson = convertNodeState(server->nodeState()); http::sendOkResponse(downstream, infoStateJson); }); + httpServer_->registerGet( + "/v1/info/stats", + [server = this]( + proxygen::HTTPMessage* /*message*/, + const std::vector>& /*body*/, + proxygen::ResponseHandler* downstream) { + server->reportNodeStats(downstream); + }); httpServer_->registerPut( "/v1/info/state", [server = this]( @@ -355,13 +475,23 @@ void PrestoServer::run() { if (folly::Singleton::try_get()) { httpServer_->registerGet( "/v1/info/metrics", - [](proxygen::HTTPMessage* /*message*/, + [](proxygen::HTTPMessage* message, const std::vector>& /*body*/, proxygen::ResponseHandler* downstream) { - http::sendOkResponse( - downstream, - folly::Singleton::try_get() - ->fetchMetrics()); + auto acceptHeader = message->getHeaders().getSingleOrEmpty( + proxygen::HTTPHeaderCode::HTTP_HEADER_ACCEPT); + if (acceptHeader.find(http::kMimeTypeTextPlain) != + std::string::npos) { + http::sendOkTextResponse( + downstream, + folly::Singleton::try_get() + ->fetchMetrics()); + } else { + http::sendOkResponse( + downstream, + folly::Singleton::try_get() + ->fetchMetrics()); + } }); } } @@ -369,39 +499,8 @@ void PrestoServer::run() { registerRemoteFunctions(); registerVectorSerdes(); registerPrestoPlanNodeSerDe(); - - const auto numExchangeHttpClientIoThreads = std::max( - systemConfig->exchangeHttpClientNumIoThreadsHwMultiplier() * - std::thread::hardware_concurrency(), - 1); - exchangeHttpIoExecutor_ = std::make_shared( - numExchangeHttpClientIoThreads, - std::make_shared("ExchangeIO")); - - PRESTO_STARTUP_LOG(INFO) << "Exchange Http IO executor '" - << exchangeHttpIoExecutor_->getName() << "' has " - << exchangeHttpIoExecutor_->numThreads() - << " threads."; - - const auto numExchangeHttpClientCpuThreads = std::max( - systemConfig->exchangeHttpClientNumCpuThreadsHwMultiplier() * - std::thread::hardware_concurrency(), - 1); - - exchangeHttpCpuExecutor_ = std::make_shared( - numExchangeHttpClientCpuThreads, - std::make_shared("ExchangeCPU")); - - PRESTO_STARTUP_LOG(INFO) << "Exchange Http CPU executor '" - << exchangeHttpCpuExecutor_->getName() << "' has " - << exchangeHttpCpuExecutor_->numThreads() - << " threads."; - - if (systemConfig->exchangeEnableConnectionPool()) { - PRESTO_STARTUP_LOG(INFO) << "Enable exchange Http Client connection pool."; - exchangeSourceConnectionPool_ = - std::make_unique(); - } + registerTraceNodeFactories(); + registerDynamicFunctions(); facebook::velox::exec::ExchangeSource::registerFactory( [this]( @@ -421,7 +520,7 @@ void PrestoServer::run() { }); velox::exec::ExchangeSource::registerFactory( - operators::UnsafeRowExchangeSource::createExchangeSource); + operators::ShuffleExchangeSource::createExchangeSource); // Batch broadcast exchange source. velox::exec::ExchangeSource::registerFactory( @@ -432,8 +531,7 @@ void PrestoServer::run() { nativeWorkerPool_ = velox::memory::MemoryManager::getInstance()->addLeafPool( "PrestoNativeWorker"); - taskManager_ = std::make_unique( - driverExecutor_.get(), httpSrvCpuExecutor_.get(), spillerExecutor_.get()); + createTaskManager(); if (systemConfig->prestoNativeSidecar()) { registerSidecarEndpoints(); @@ -466,6 +564,10 @@ void PrestoServer::run() { } } + if (auto factory = getSplitListenerFactory()) { + velox::exec::registerSplitListenerFactory(factory); + } + if (systemConfig->enableVeloxExprSetLogging()) { if (auto listener = getExprSetListener()) { velox::exec::registerExprSetListener(listener); @@ -486,8 +588,8 @@ void PrestoServer::run() { }); PRESTO_STARTUP_LOG(INFO) << "Driver CPU executor '" - << driverExecutor_->getName() << "' has " - << driverExecutor_->numThreads() << " threads."; + << driverCpuExecutor_->getName() << "' has " + << driverCpuExecutor_->numThreads() << " threads."; if (httpServer_->getExecutor()) { PRESTO_STARTUP_LOG(INFO) << "HTTP Server IO executor '" << httpServer_->getExecutor()->getName() @@ -497,11 +599,18 @@ void PrestoServer::run() { PRESTO_STARTUP_LOG(INFO) << "HTTP Server CPU executor '" << httpSrvCpuExecutor_->getName() << "' has " << httpSrvCpuExecutor_->numThreads() << " threads."; + for (auto evb : httpSrvIoExecutor_->getAllEventBases()) { + evb->setMaxLatency( + std::chrono::milliseconds( + systemConfig->httpSrvIoEvbViolationThresholdMs()), + []() { RECORD_METRIC_VALUE(kCounterHttpServerIoEvbViolation, 1); }, + /*dampen=*/false); + } } if (spillerExecutor_ != nullptr) { PRESTO_STARTUP_LOG(INFO) - << "Spiller CPU executor '" << spillerExecutor_->getName() << "', has " - << spillerExecutor_->numThreads() << " threads."; + << "Spiller CPU executor '" << spillerCpuExecutor_->getName() + << "', has " << spillerCpuExecutor_->numThreads() << " threads."; } else { PRESTO_STARTUP_LOG(INFO) << "Spill executor was not configured."; } @@ -511,8 +620,8 @@ void PrestoServer::run() { auto* memoryAllocator = velox::memory::memoryManager()->allocator(); auto* asyncDataCache = velox::cache::AsyncDataCache::getInstance(); periodicTaskManager_ = std::make_unique( - driverExecutor_.get(), - spillerExecutor_.get(), + driverCpuExecutor_, + spillerCpuExecutor_, httpSrvIoExecutor_.get(), httpSrvCpuExecutor_.get(), exchangeHttpIoExecutor_.get(), @@ -637,12 +746,13 @@ void PrestoServer::run() { unregisterFileReadersAndWriters(); unregisterFileSystems(); unregisterConnectors(); + unregisterVeloxCudf(); PRESTO_SHUTDOWN_LOG(INFO) - << "Joining Driver CPU Executor '" << driverExecutor_->getName() - << "': threads: " << driverExecutor_->numActiveThreads() << "/" - << driverExecutor_->numThreads() - << ", task queue: " << driverExecutor_->getTaskQueueSize(); + << "Joining Driver CPU Executor '" << driverCpuExecutor_->getName() + << "': threads: " << driverCpuExecutor_->numActiveThreads() << "/" + << driverCpuExecutor_->numThreads() + << ", task queue: " << driverCpuExecutor_->getTaskQueueSize(); // Schedule release of SessionPools held by HttpClients before the exchange // HTTP IO executor threads are joined. driverExecutor_.reset(); @@ -738,7 +848,7 @@ void PrestoServer::yieldTasks() { return; } static std::atomic numYields = 0; - const auto numQueued = driverExecutor_->getTaskQueueSize(); + const auto numQueued = driverCpuExecutor_->getTaskQueueSize(); if (numQueued > 0) { numYields += taskManager_->yieldTasks(numQueued, timeslice); } @@ -770,7 +880,7 @@ class BatchThreadFactory : public folly::NamedThreadFactory { #endif void PrestoServer::initializeThreadPools() { - const auto hwConcurrency = std::thread::hardware_concurrency(); + const auto hwConcurrency = folly::hardware_concurrency(); auto* systemConfig = SystemConfig::instance(); const auto numDriverCpuThreads = std::max( @@ -787,8 +897,10 @@ void PrestoServer::initializeThreadPools() { threadFactory = std::make_shared("Driver"); } - driverExecutor_ = std::make_shared( + driverExecutor_ = std::make_unique( numDriverCpuThreads, threadFactory); + driverCpuExecutor_ = velox::checkedPointerCast( + driverExecutor_.get()); const auto numIoThreads = std::max( systemConfig->httpServerNumIoThreadsHwMultiplier() * hwConcurrency, 1); @@ -797,15 +909,57 @@ void PrestoServer::initializeThreadPools() { const auto numCpuThreads = std::max( systemConfig->httpServerNumCpuThreadsHwMultiplier() * hwConcurrency, 1); - httpSrvCpuExecutor_ = std::make_shared( + httpSrvCpuExecutor_ = std::make_unique( numCpuThreads, std::make_shared("HTTPSrvCpu")); const auto numSpillerCpuThreads = std::max( systemConfig->spillerNumCpuThreadsHwMultiplier() * hwConcurrency, 0); if (numSpillerCpuThreads > 0) { - spillerExecutor_ = std::make_shared( + spillerExecutor_ = std::make_unique( numSpillerCpuThreads, std::make_shared("Spiller")); + spillerCpuExecutor_ = + velox::checkedPointerCast( + spillerExecutor_.get()); + } + const auto numExchangeHttpClientIoThreads = std::max( + systemConfig->exchangeHttpClientNumIoThreadsHwMultiplier() * + folly::hardware_concurrency(), + 1); + exchangeHttpIoExecutor_ = std::make_unique( + numExchangeHttpClientIoThreads, + std::make_shared("ExchangeIO")); + + PRESTO_STARTUP_LOG(INFO) << "Exchange Http IO executor '" + << exchangeHttpIoExecutor_->getName() << "' has " + << exchangeHttpIoExecutor_->numThreads() + << " threads."; + for (auto evb : exchangeHttpIoExecutor_->getAllEventBases()) { + evb->setMaxLatency( + std::chrono::milliseconds( + systemConfig->exchangeIoEvbViolationThresholdMs()), + []() { RECORD_METRIC_VALUE(kCounterExchangeIoEvbViolation, 1); }, + /*dampen=*/false); + } + + const auto numExchangeHttpClientCpuThreads = std::max( + systemConfig->exchangeHttpClientNumCpuThreadsHwMultiplier() * + folly::hardware_concurrency(), + 1); + + exchangeHttpCpuExecutor_ = std::make_unique( + numExchangeHttpClientCpuThreads, + std::make_shared("ExchangeCPU")); + + PRESTO_STARTUP_LOG(INFO) << "Exchange Http CPU executor '" + << exchangeHttpCpuExecutor_->getName() << "' has " + << exchangeHttpCpuExecutor_->numThreads() + << " threads."; + + if (systemConfig->exchangeEnableConnectionPool()) { + PRESTO_STARTUP_LOG(INFO) << "Enable exchange Http Client connection pool."; + exchangeSourceConnectionPool_ = + std::make_unique(); } } @@ -827,7 +981,8 @@ std::unique_ptr PrestoServer::setupSsdCache() { systemConfig->asyncCacheSsdCheckpointGb() << 30, systemConfig->asyncCacheSsdDisableFileCow(), systemConfig->ssdCacheChecksumEnabled(), - systemConfig->ssdCacheReadVerificationEnabled()); + systemConfig->ssdCacheReadVerificationEnabled(), + systemConfig->ssdCacheMaxEntries()); PRESTO_STARTUP_LOG(INFO) << "Initializing SSD cache with " << cacheConfig.toString(); return std::make_unique(cacheConfig); @@ -839,7 +994,7 @@ void PrestoServer::initializeVeloxMemory() { PRESTO_STARTUP_LOG(INFO) << "Starting with node memory " << memoryGb << "GB"; // Set up velox memory manager. - velox::memory::MemoryManagerOptions options; + velox::memory::MemoryManager::Options options; options.allocatorCapacity = memoryGb << 30; if (systemConfig->useMmapAllocator()) { options.useMmapAllocator = true; @@ -987,7 +1142,7 @@ size_t PrestoServer::numDriverThreads() const { VELOX_CHECK( driverExecutor_ != nullptr, "Driver executor is expected to be not null, but it is null!"); - return driverExecutor_->numThreads(); + return driverCpuExecutor_->numThreads(); } void PrestoServer::detachWorker() { @@ -1039,6 +1194,16 @@ void PrestoServer::addServerPeriodicTasks() { 1'000'000, // 1 second "populate_mem_cpu_info"); + periodicTaskManager_->addTask( + [start = start_]() { + const auto seconds = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + RECORD_METRIC_VALUE(kCounterWorkerRuntimeUptimeSecs, seconds); + }, + 2'000'000, // 2 seconds + "worker_runtime_uptime_secs"); + const auto timeslice = SystemConfig::instance()->taskRunTimeSliceMicros(); if (timeslice > 0) { periodicTaskManager_->addTask( @@ -1111,6 +1276,11 @@ PrestoServer::getExprSetListener() { return nullptr; } +std::shared_ptr +PrestoServer::getSplitListenerFactory() { + return nullptr; +} + std::vector> PrestoServer::getHttpServerFilters() const { std::vector> filters; @@ -1144,12 +1314,12 @@ std::vector PrestoServer::registerVeloxConnectors( const auto numConnectorCpuThreads = std::max( SystemConfig::instance()->connectorNumCpuThreadsHwMultiplier() * - std::thread::hardware_concurrency(), + folly::hardware_concurrency(), 0); if (numConnectorCpuThreads > 0) { connectorCpuExecutor_ = std::make_unique( numConnectorCpuThreads, - std::make_shared("Connector")); + std::make_shared("ConnectorCPU")); PRESTO_STARTUP_LOG(INFO) << "Connector CPU executor has " << connectorCpuExecutor_->numThreads() @@ -1158,12 +1328,12 @@ std::vector PrestoServer::registerVeloxConnectors( const auto numConnectorIoThreads = std::max( SystemConfig::instance()->connectorNumIoThreadsHwMultiplier() * - std::thread::hardware_concurrency(), + folly::hardware_concurrency(), 0); if (numConnectorIoThreads > 0) { connectorIoExecutor_ = std::make_unique( numConnectorIoThreads, - std::make_shared("Connector")); + std::make_shared("ConnectorIO")); PRESTO_STARTUP_LOG(INFO) << "Connector IO executor has " << connectorIoExecutor_->numThreads() @@ -1196,14 +1366,12 @@ std::vector PrestoServer::registerVeloxConnectors( // make sure connector type is supported getPrestoToVeloxConnector(connectorName); - - std::shared_ptr connector = - velox::connector::getConnectorFactory(connectorName) - ->newConnector( - catalogName, - std::move(properties), - connectorIoExecutor_.get(), - connectorCpuExecutor_.get()); + auto connector = getConnectorFactory(connectorName) + ->newConnector( + catalogName, + std::move(properties), + connectorIoExecutor_.get(), + connectorCpuExecutor_.get()); velox::connector::registerConnector(connector); } } @@ -1271,6 +1439,12 @@ void PrestoServer::registerFunctions() { prestoBuiltinFunctionPrefix_); velox::window::prestosql::registerAllWindowFunctions( prestoBuiltinFunctionPrefix_); + + if (velox::connector::hasConnector( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName) || + velox::connector::hasConnector("hive-hadoop2")) { + hive::functions::registerHiveNativeFunctions(); + } } void PrestoServer::registerRemoteFunctions() { @@ -1291,11 +1465,12 @@ void PrestoServer::registerRemoteFunctions() { << catalogName << "' catalog."; } else { VELOX_FAIL( - "To register remote functions using a json file path you need to " - "specify the remote server location using '{}', '{}' or '{}'.", + "To register remote functions you need to specify the remote server " + "location using '{}', '{}' or '{}' or {}.", SystemConfig::kRemoteFunctionServerThriftAddress, SystemConfig::kRemoteFunctionServerThriftPort, - SystemConfig::kRemoteFunctionServerThriftUdsPath); + SystemConfig::kRemoteFunctionServerThriftUdsPath, + SystemConfig::kRemoteFunctionServerRestURL); } } #endif @@ -1344,6 +1519,12 @@ void PrestoServer::registerFileReadersAndWriters() { velox::orc::registerOrcReaderFactory(); velox::parquet::registerParquetReaderFactory(); velox::parquet::registerParquetWriterFactory(); + if (SystemConfig::instance()->textWriterEnabled()) { + velox::text::registerTextWriterFactory(); + } + if (SystemConfig::instance()->textReaderEnabled()) { + velox::text::registerTextReaderFactory(); + } } void PrestoServer::unregisterFileReadersAndWriters() { @@ -1351,6 +1532,12 @@ void PrestoServer::unregisterFileReadersAndWriters() { velox::dwrf::unregisterDwrfWriterFactory(); velox::parquet::unregisterParquetReaderFactory(); velox::parquet::unregisterParquetWriterFactory(); + if (SystemConfig::instance()->textWriterEnabled()) { + velox::text::unregisterTextWriterFactory(); + } + if (SystemConfig::instance()->textReaderEnabled()) { + velox::text::unregisterTextReaderFactory(); + } } void PrestoServer::registerStatsCounters() { @@ -1388,7 +1575,7 @@ void PrestoServer::enableWorkerStatsReporting() { void PrestoServer::initVeloxPlanValidator() { VELOX_CHECK_NULL(planValidator_); - planValidator_ = std::make_shared(); + planValidator_ = std::make_unique(); } VeloxPlanValidator* PrestoServer::getVeloxPlanValidator() { @@ -1438,7 +1625,7 @@ void PrestoServer::checkOverload() { systemConfig->workerOverloadedThresholdMemGb() * 1024 * 1024 * 1024; if (overloadedThresholdMemBytes > 0) { const auto currentUsedMemoryBytes = (memoryChecker_ != nullptr) - ? memoryChecker_->cachedSystemUsedMemoryBytes() + ? memoryChecker_->systemUsedMemoryBytes() : 0; const bool memOverloaded = (currentUsedMemoryBytes > overloadedThresholdMemBytes); @@ -1464,20 +1651,32 @@ void PrestoServer::checkOverload() { memOverloaded_ = memOverloaded; } + static const auto hwConcurrency = folly::hardware_concurrency(); const auto overloadedThresholdCpuPct = systemConfig->workerOverloadedThresholdCpuPct(); - if (overloadedThresholdCpuPct > 0) { + const auto overloadedThresholdQueuedDrivers = hwConcurrency * + systemConfig->workerOverloadedThresholdNumQueuedDriversHwMultiplier(); + if (overloadedThresholdCpuPct > 0 && overloadedThresholdQueuedDrivers > 0) { const auto currentUsedCpuPct = cpuMon_.getCPULoadPct(); - const bool cpuOverloaded = (currentUsedCpuPct > overloadedThresholdCpuPct); + const auto currentQueuedDrivers = taskManager_->numQueuedDrivers(); + const bool cpuOverloaded = + (currentUsedCpuPct > overloadedThresholdCpuPct) && + (currentQueuedDrivers > overloadedThresholdQueuedDrivers); if (cpuOverloaded && !cpuOverloaded_) { LOG(WARNING) << "OVERLOAD: Server CPU is overloaded. Currently used: " << currentUsedCpuPct - << "%, threshold: " << overloadedThresholdCpuPct << "%"; + << "% CPU (threshold: " << overloadedThresholdCpuPct + << "%), " << currentQueuedDrivers + << " queued drivers (threshold: " + << overloadedThresholdQueuedDrivers << ")"; } else if (!cpuOverloaded && cpuOverloaded_) { LOG(INFO) << "OVERLOAD: Server CPU is no longer overloaded. Currently used: " - << currentUsedCpuPct << "%, threshold: " << overloadedThresholdCpuPct - << "%"; + << currentUsedCpuPct + << "% CPU (threshold: " << overloadedThresholdCpuPct << "%), " + << currentQueuedDrivers + << " queued drivers (threshold: " << overloadedThresholdQueuedDrivers + << ")"; } RECORD_METRIC_VALUE(kCounterOverloadedCpu, cpuOverloaded ? 100 : 0); cpuOverloaded_ = cpuOverloaded; @@ -1575,9 +1774,8 @@ void PrestoServer::registerSidecarEndpoints() { proxygen::HTTPMessage* /*message*/, const std::vector>& /*body*/, proxygen::ResponseHandler* downstream) { - auto sessionProperties = - taskManager_->getQueryContextManager()->getSessionProperties(); - http::sendOkResponse(downstream, sessionProperties.serialize()); + const auto* sessionProperties = SessionProperties::instance(); + http::sendOkResponse(downstream, sessionProperties->serialize()); }); httpServer_->registerGet( "/v1/functions", @@ -1586,6 +1784,30 @@ void PrestoServer::registerSidecarEndpoints() { proxygen::ResponseHandler* downstream) { http::sendOkResponse(downstream, getFunctionsMetadata()); }); + httpServer_->registerGet( + R"(/v1/functions/([^/]+))", + [](proxygen::HTTPMessage* /*message*/, + const std::vector& pathMatch) { + return new http::CallbackRequestHandler( + [catalog = pathMatch[1]]( + proxygen::HTTPMessage* /*message*/, + std::vector>& /*body*/, + proxygen::ResponseHandler* downstream) { + http::sendOkResponse(downstream, getFunctionsMetadata(catalog)); + }); + }); + httpServer_->registerPost( + "/v1/expressions", + [this]( + proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + const auto& httpHeaders = message->getHeaders(); + const auto result = getOptimizedExpressions( + httpHeaders, body, driverExecutor_.get(), nativeWorkerPool_.get()); + http::sendOkResponse(downstream, result); + }); + httpServer_->registerPost( "/v1/velox/plan", [server = this]( @@ -1624,7 +1846,7 @@ protocol::NodeStatus PrestoServer::fetchNodeStatus() { address_, address_, **memoryInfo_.rlock(), - (int)std::thread::hardware_concurrency(), + (int)folly::hardware_concurrency(), cpuLoadPct, cpuLoadPct, pool_ ? pool_->usedBytes() : 0, @@ -1634,4 +1856,91 @@ protocol::NodeStatus PrestoServer::fetchNodeStatus() { return nodeStatus; } +void PrestoServer::registerDynamicFunctions() { + // For using the non-throwing overloads of functions below. + std::error_code ec; + const auto systemConfig = SystemConfig::instance(); + fs::path pluginDir = std::filesystem::current_path().append("plugin"); + if (!systemConfig->pluginDir().empty()) { + pluginDir = systemConfig->pluginDir(); + } + // If it is a valid directory, traverse and call dynamic function loader. + if (fs::is_directory(pluginDir, ec)) { + PRESTO_STARTUP_LOG(INFO) + << "Loading dynamic libraries from directory path: " << pluginDir; + for (const auto& dirEntry : + std::filesystem::directory_iterator(pluginDir)) { + if (isSharedLibrary(dirEntry.path())) { + PRESTO_STARTUP_LOG(INFO) + << "Loading dynamic libraries from: " << dirEntry.path().string(); + velox::loadDynamicLibrary(dirEntry.path().c_str()); + } + } + } else { + PRESTO_STARTUP_LOG(INFO) + << "Plugin directory path: " << pluginDir << " is invalid."; + return; + } +} + +void PrestoServer::createTaskManager() { + taskManager_ = std::make_unique( + driverExecutor_.get(), httpSrvCpuExecutor_.get(), spillerExecutor_.get()); +} + +void PrestoServer::reportNodeStats(proxygen::ResponseHandler* downstream) { + protocol::NodeStats nodeStats; + + auto loadMetrics = std::make_shared(); + loadMetrics->cpuOverload = cpuOverloaded_; + loadMetrics->memoryOverload = memOverloaded_; + + nodeStats.loadMetrics = loadMetrics; + nodeStats.nodeState = convertNodeState(this->nodeState()); + + http::sendOkResponse(downstream, json(nodeStats)); +} + +void PrestoServer::registerTraceNodeFactories() { + // Register trace node factory for BroadcastWrite operator. + velox::exec::trace::registerTraceNodeFactory( + "BroadcastWrite", + [](const velox::core::PlanNode* traceNode, + const velox::core::PlanNodeId& nodeId) -> velox::core::PlanNodePtr { + if (const auto* broadcastWriteNode = + dynamic_cast(traceNode)) { + return std::make_shared( + nodeId, + broadcastWriteNode->basePath(), + broadcastWriteNode->maxBroadcastBytes(), + broadcastWriteNode->serdeRowType(), + std::make_shared( + broadcastWriteNode->sources().front()->outputType())); + } + return nullptr; + }); + + // Register trace node factory for PartitionAndSerialize operator + velox::exec::trace::registerTraceNodeFactory( + "PartitionAndSerialize", + [](const velox::core::PlanNode* traceNode, + const velox::core::PlanNodeId& nodeId) -> velox::core::PlanNodePtr { + if (const auto* partitionAndSerializeNode = + dynamic_cast( + traceNode)) { + return std::make_shared( + nodeId, + partitionAndSerializeNode->keys(), + partitionAndSerializeNode->numPartitions(), + partitionAndSerializeNode->serializedRowType(), + std::make_shared( + partitionAndSerializeNode->sources().front()->outputType()), + partitionAndSerializeNode->isReplicateNullsAndAny(), + partitionAndSerializeNode->partitionFunctionFactory(), + partitionAndSerializeNode->sortingOrders(), + partitionAndSerializeNode->sortingKeys()); + } + return nullptr; + }); +} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.h b/presto-native-execution/presto_cpp/main/PrestoServer.h index dda92ea2675af..5a240a13c44c9 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.h +++ b/presto-native-execution/presto_cpp/main/PrestoServer.h @@ -120,16 +120,21 @@ class PrestoServer { /// Hook for derived PrestoServer implementations to add/stop additional /// periodic tasks. - virtual void addAdditionalPeriodicTasks(){}; + virtual void addAdditionalPeriodicTasks() {}; - virtual void stopAdditionalPeriodicTasks(){}; + virtual void stopAdditionalPeriodicTasks() {}; virtual void initializeCoordinatorDiscoverer(); + virtual void initializeThreadPools(); + virtual std::shared_ptr getTaskListener(); virtual std::shared_ptr getExprSetListener(); + virtual std::shared_ptr + getSplitListenerFactory(); + virtual std::vector registerVeloxConnectors( const fs::path& configDirectoryPath); @@ -160,6 +165,8 @@ class PrestoServer { virtual void registerMemoryArbitrators(); + virtual void registerTraceNodeFactories(); + /// Invoked after creating global (singleton) config objects (SystemConfig and /// NodeConfig) and before loading their properties from the file. /// In the implementation any extra config properties can be registered. @@ -167,8 +174,8 @@ class PrestoServer { /// Invoked to get the ip address of the process. In certain deployment /// setup, each process has different ip address. Deployment environment - /// may provide there own library to get process specific ip address. - /// In such cases, getLocalIp can be overriden to pass process specific + /// may provide their own library to get process specific ip address. + /// In such cases, getLocalIp can be overridden to pass process specific /// ip address. virtual std::string getLocalIp() const; @@ -183,17 +190,16 @@ class PrestoServer { VeloxPlanValidator* getVeloxPlanValidator(); + void registerDynamicFunctions(); + /// Invoked to get the list of filters passed to the http server. virtual std::vector> getHttpServerFilters() const; void initializeVeloxMemory(); - void initializeThreadPools(); - void registerStatsCounters(); - protected: void updateAnnouncerDetails(); void addServerPeriodicTasks(); @@ -204,6 +210,8 @@ class PrestoServer { void reportNodeStatus(proxygen::ResponseHandler* downstream); + void reportNodeStats(proxygen::ResponseHandler* downstream); + void handleGracefulShutdown( const std::vector>& body, proxygen::ResponseHandler* downstream); @@ -223,6 +231,8 @@ class PrestoServer { void checkOverload(); + virtual void createTaskManager(); + const std::string configDirectoryPath_; std::shared_ptr coordinatorDiscoverer_; @@ -237,24 +247,37 @@ class PrestoServer { std::unique_ptr connectorIoExecutor_; // Executor for exchange data over http. - std::shared_ptr exchangeHttpIoExecutor_; + std::unique_ptr exchangeHttpIoExecutor_; // Executor for exchange request processing. - std::shared_ptr exchangeHttpCpuExecutor_; + std::unique_ptr exchangeHttpCpuExecutor_; // Executor for HTTP request dispatching std::shared_ptr httpSrvIoExecutor_; // Executor for HTTP request processing after dispatching - std::shared_ptr httpSrvCpuExecutor_; - - // Executor for query engine driver executions. - std::shared_ptr driverExecutor_; - - // Executor for spilling. - std::shared_ptr spillerExecutor_; - - std::shared_ptr planValidator_; + std::unique_ptr httpSrvCpuExecutor_; + + // Executor for query engine driver executions. The underlying thread pool + // executor is a folly::CPUThreadPoolExecutor. The executor is stored as + // abstract type to provide flexibility of thread pool monitoring. The + // underlying folly::CPUThreadPoolExecutor can be obtained through + // 'driverCpuExecutor()' method. + std::unique_ptr driverExecutor_; + // Raw pointer pointing to the underlying folly::CPUThreadPoolExecutor of + // 'driverExecutor_'. + folly::CPUThreadPoolExecutor* driverCpuExecutor_; + + // Executor for spilling. The underlying thread pool executor is a + // folly::CPUThreadPoolExecutor. The executor is stored as abstract type to + // provide flexibility of thread pool monitoring. The underlying + // folly::CPUThreadPoolExecutor can be obtained through 'spillerCpuExecutor_'. + std::unique_ptr spillerExecutor_; + // Raw pointer pointing to the underlying folly::CPUThreadPoolExecutor of + // 'spillerExecutor_'. + folly::CPUThreadPoolExecutor* spillerCpuExecutor_; + + std::unique_ptr planValidator_; std::unique_ptr exchangeSourceConnectionPool_; diff --git a/presto-native-execution/presto_cpp/main/PrestoServerOperations.cpp b/presto-native-execution/presto_cpp/main/PrestoServerOperations.cpp index 7de490c948e79..1f85c15fa8e74 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServerOperations.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServerOperations.cpp @@ -177,7 +177,7 @@ std::string PrestoServerOperations::veloxQueryConfigOperation( "Have set system property value '{}' to '{}'. Old value was '{}'.\n", name, value, - BaseVeloxQueryConfig::instance() + SystemConfig::instance() ->setValue(name, value) .value_or("")); } @@ -190,7 +190,7 @@ std::string PrestoServerOperations::veloxQueryConfigOperation( ServerOperation::actionString(op.action)); return fmt::format( "{}\n", - BaseVeloxQueryConfig::instance()->optionalProperty(name).value_or( + SystemConfig::instance()->optionalProperty(name).value_or( "")); } default: @@ -222,7 +222,7 @@ std::string PrestoServerOperations::taskOperation( limit = limitStr == proxygen::empty_string ? std::numeric_limits::max() : stoi(limitStr); - } catch (std::exception& ex) { + } catch (const std::exception& /* unused */) { VELOX_USER_FAIL("Invalid limit provided '{}'.", limitStr); } std::stringstream oss; diff --git a/presto-native-execution/presto_cpp/main/PrestoTask.cpp b/presto-native-execution/presto_cpp/main/PrestoTask.cpp index 89d4f438a21a3..afdf9f4f82dc5 100644 --- a/presto-native-execution/presto_cpp/main/PrestoTask.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoTask.cpp @@ -74,17 +74,19 @@ PrestoTaskState toPrestoTaskState(exec::TaskState state) { return PrestoTaskState::kAborted; } -protocol::TaskState toProtocolTaskState(exec::TaskState state) { +protocol::TaskState toProtocolTaskState(PrestoTaskState state) { switch (state) { - case exec::TaskState::kRunning: + case PrestoTaskState::kRunning: return protocol::TaskState::RUNNING; - case exec::TaskState::kFinished: + case PrestoTaskState::kFinished: return protocol::TaskState::FINISHED; - case exec::TaskState::kCanceled: + case PrestoTaskState::kCanceled: return protocol::TaskState::CANCELED; - case exec::TaskState::kFailed: + case PrestoTaskState::kFailed: return protocol::TaskState::FAILED; - case exec::TaskState::kAborted: + case PrestoTaskState::kPlanned: + return protocol::TaskState::PLANNED; + case PrestoTaskState::kAborted: [[fallthrough]]; default: return protocol::TaskState::ABORTED; @@ -95,9 +97,9 @@ protocol::ExecutionFailureInfo toPrestoError(std::exception_ptr ex) { try { rethrow_exception(ex); } catch (const VeloxException& e) { - return VeloxToPrestoExceptionTranslator::translate(e); + return translateToPrestoException(e); } catch (const std::exception& e) { - return VeloxToPrestoExceptionTranslator::translate(e); + return translateToPrestoException(e); } } @@ -349,7 +351,8 @@ void updatePipelineStats( protocol::PipelineStats& prestoPipelineStats) { prestoPipelineStats.inputPipeline = veloxPipelineStats.inputPipeline; prestoPipelineStats.outputPipeline = veloxPipelineStats.outputPipeline; - prestoPipelineStats.firstStartTimeInMillis = prestoTaskStats.createTimeInMillis; + prestoPipelineStats.firstStartTimeInMillis = + prestoTaskStats.createTimeInMillis; prestoPipelineStats.lastStartTimeInMillis = prestoTaskStats.endTimeInMillis; prestoPipelineStats.lastEndTimeInMillis = prestoTaskStats.endTimeInMillis; @@ -440,12 +443,18 @@ void updatePipelineStats( prestoOp.blockedWall = protocol::Duration( veloxOp.blockedWallNanos, protocol::TimeUnit::NANOSECONDS); - prestoOp.userMemoryReservationInBytes = veloxOp.memoryStats.userMemoryReservation; - prestoOp.revocableMemoryReservationInBytes = veloxOp.memoryStats.revocableMemoryReservation; - prestoOp.systemMemoryReservationInBytes = veloxOp.memoryStats.systemMemoryReservation; - prestoOp.peakUserMemoryReservationInBytes = veloxOp.memoryStats.peakUserMemoryReservation; - prestoOp.peakSystemMemoryReservationInBytes = veloxOp.memoryStats.peakSystemMemoryReservation; - prestoOp.peakTotalMemoryReservationInBytes = veloxOp.memoryStats.peakTotalMemoryReservation; + prestoOp.userMemoryReservationInBytes = + veloxOp.memoryStats.userMemoryReservation; + prestoOp.revocableMemoryReservationInBytes = + veloxOp.memoryStats.revocableMemoryReservation; + prestoOp.systemMemoryReservationInBytes = + veloxOp.memoryStats.systemMemoryReservation; + prestoOp.peakUserMemoryReservationInBytes = + veloxOp.memoryStats.peakUserMemoryReservation; + prestoOp.peakSystemMemoryReservationInBytes = + veloxOp.memoryStats.peakSystemMemoryReservation; + prestoOp.peakTotalMemoryReservationInBytes = + veloxOp.memoryStats.peakTotalMemoryReservation; prestoOp.spilledDataSizeInBytes = veloxOp.spilledBytes; @@ -557,14 +566,6 @@ void PrestoTask::recordProcessCpuTime() { } protocol::TaskStatus PrestoTask::updateStatusLocked() { - if (!taskStarted && (error == nullptr)) { - protocol::TaskStatus ret = info.taskStatus; - if (ret.state != protocol::TaskState::ABORTED) { - ret.state = protocol::TaskState::PLANNED; - } - return ret; - } - // Error occurs when creating task or even before task is created. Set error // and return immediately if (error != nullptr) { @@ -575,11 +576,16 @@ protocol::TaskStatus PrestoTask::updateStatusLocked() { recordProcessCpuTime(); return info.taskStatus; } - VELOX_CHECK_NOT_NULL(task, "task is null when updating status"); + + // We can be here before the fragment plan is received and exec task created. + if (task == nullptr) { + VELOX_CHECK(!taskStarted); + return info.taskStatus; + } const auto veloxTaskStats = task->taskStats(); - info.taskStatus.state = toProtocolTaskState(task->state()); + info.taskStatus.state = toProtocolTaskState(taskState()); // Presto has a Driver per split. When splits represent partitions // of data, there is a queue of them per Task. We represent @@ -640,7 +646,7 @@ void PrestoTask::updateOutputBufferInfoLocked( const auto& outputBufferStats = veloxTaskStats.outputBufferStats.value(); auto& outputBufferInfo = info.outputBuffers; outputBufferInfo.type = - velox::core::PartitionedOutputNode::kindString(outputBufferStats.kind); + velox::core::PartitionedOutputNode::toName(outputBufferStats.kind); outputBufferInfo.canAddBuffers = !outputBufferStats.noMoreBuffers; outputBufferInfo.canAddPages = !outputBufferStats.noMoreData; outputBufferInfo.totalBufferedBytes = outputBufferStats.bufferedBytes; @@ -711,7 +717,7 @@ protocol::TaskInfo PrestoTask::updateInfoLocked(bool summarize) { for (const auto it : veloxTaskStats.numBlockedDrivers) { addRuntimeMetricIfNotZero( taskRuntimeStats, - fmt::format("drivers.{}", exec::blockingReasonToString(it.first)), + fmt::format("drivers.{}", exec::BlockingReasonName::toName(it.first)), it.second); } if (veloxTaskStats.longestRunningOpCallMs != 0) { @@ -770,6 +776,12 @@ void PrestoTask::updateTimeInfoLocked( taskRuntimeStats["endTime"].addValue(veloxTaskStats.endTimeMs); } taskRuntimeStats.insert({"nativeProcessCpuTime", fromNanos(processCpuTime_)}); + // Represents the time between receiving first taskUpdate and task creation + // time. + taskRuntimeStats.insert( + {"taskCreationTime", + fromNanos( + (createFinishTimeMs - firstTimeReceiveTaskUpdateMs) * 1'000'000)}); } void PrestoTask::updateMemoryInfoLocked( @@ -823,21 +835,34 @@ void PrestoTask::updateExecutionInfoLocked( prestoTaskStats.outputPositions = 0; prestoTaskStats.outputDataSizeInBytes = 0; - // Presto Java reports number of drivers to number of splits in Presto UI - // because split and driver are 1 to 1 mapping relationship. This is not true - // in Prestissimo where 1 driver handles many splits. In order to quickly - // unblock developers from viewing the correct progress of splits in - // Prestissimo's coordinator UI, we put number of splits in total, queued, and - // finished to indicate the progress of the query. Number of running drivers - // are passed as it is to have a proper running drivers count in UI. + // NOTE: This logic is implemented in a backwards-compatible way because + // the coordinator and worker may not be upgraded at the same time. + // + // To ensure safe rollout: + // - We are introducing new fields (e.g., `totalNewDrivers`) instead of + // modifying or removing existing ones. + // - The worker is updated first to populate both old and new fields. + // - The coordinator continues to use the old fields until it is updated to + // handle the new ones. // - // TODO: We should really extend the API (protocol::TaskStats and Presto - // coordinator UI) to have splits information as a proper fix. + // Once both coordinator and worker support the new fields, we can safely + // remove the legacy fields in a follow-up cleanup PR. + prestoTaskStats.totalDrivers = veloxTaskStats.numTotalSplits; prestoTaskStats.queuedDrivers = veloxTaskStats.numQueuedSplits; prestoTaskStats.runningDrivers = veloxTaskStats.numRunningDrivers; prestoTaskStats.completedDrivers = veloxTaskStats.numFinishedSplits; + prestoTaskStats.totalNewDrivers = veloxTaskStats.numTotalDrivers; + prestoTaskStats.queuedNewDrivers = veloxTaskStats.numQueuedDrivers; + prestoTaskStats.runningNewDrivers = veloxTaskStats.numRunningDrivers; + prestoTaskStats.completedNewDrivers = veloxTaskStats.numCompletedDrivers; + + prestoTaskStats.totalSplits = veloxTaskStats.numTotalSplits; + prestoTaskStats.queuedSplits = veloxTaskStats.numQueuedSplits; + prestoTaskStats.runningSplits = veloxTaskStats.numRunningSplits; + prestoTaskStats.completedSplits = veloxTaskStats.numFinishedSplits; + if (includePipelineStats) { prestoTaskStats.pipelines.resize(veloxTaskStats.pipelineStats.size()); } else { diff --git a/presto-native-execution/presto_cpp/main/PrestoTask.h b/presto-native-execution/presto_cpp/main/PrestoTask.h index bceae01d1005d..5748267bd2dc5 100644 --- a/presto-native-execution/presto_cpp/main/PrestoTask.h +++ b/presto-native-execution/presto_cpp/main/PrestoTask.h @@ -14,6 +14,7 @@ #pragma once #include +#include #include "presto_cpp/main/http/HttpServer.h" #include "presto_cpp/main/types/PrestoTaskId.h" #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" @@ -70,6 +71,7 @@ struct Result { std::unique_ptr data; bool complete; std::vector remainingBytes; + int64_t waitTimeMs; }; struct ResultRequest { @@ -120,7 +122,12 @@ struct PrestoTask { uint64_t lastTaskStatsUpdateMs{0}; uint64_t lastMemoryReservation{0}; + /// Time point (in ms) when the time we start task creating. uint64_t createTimeMs{0}; + /// Time point (in ms) when the first time we receive task update. + uint64_t firstTimeReceiveTaskUpdateMs{0}; + /// Time point (in ms) when the time we finish task creating. + uint64_t createFinishTimeMs{0}; uint64_t startTimeMs{0}; uint64_t firstSplitStartTimeMs{0}; uint64_t lastEndTimeMs{0}; @@ -143,6 +150,10 @@ struct PrestoTask { /// Info request. May arrive before there is a Task. PromiseHolderWeakPtr> infoRequest; + /// If the task has not been started yet, we collect all plan node IDs that + /// had 'no more splits' message to process them after the task starts. + std::unordered_set delayedNoMoreSplitsPlanNodes_; + /// @param taskId Task ID. /// @param nodeId Node ID. /// @param startCpuTime CPU time in nanoseconds recorded when request to diff --git a/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.cpp b/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.cpp new file mode 100644 index 0000000000000..2929ccfcd7c0b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.cpp @@ -0,0 +1,288 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/PrestoToVeloxQueryConfig.h" +#include "presto_cpp/main/SessionProperties.h" +#include "presto_cpp/main/common/Configs.h" +#include "velox/common/compression/Compression.h" +#include "velox/core/QueryConfig.h" +#include "velox/type/tz/TimeZoneMap.h" + +namespace facebook::presto { +namespace { + +void updateVeloxConfigsWithSpecialCases( + std::unordered_map& configStrings) { + // If `legacy_timestamp` is true, the coordinator expects timestamp + // conversions without a timezone to be converted to the user's + // session_timezone. + auto it = configStrings.find("legacy_timestamp"); + // `legacy_timestamp` default value is true in the coordinator. + if ((it == configStrings.end()) || (folly::to(it->second))) { + configStrings.emplace( + velox::core::QueryConfig::kAdjustTimestampToTimezone, "true"); + } + // TODO: remove this once cpu driver slicing config is turned on by default in + // Velox. + it = configStrings.find(velox::core::QueryConfig::kDriverCpuTimeSliceLimitMs); + if (it == configStrings.end()) { + // Set it to 1 second to be aligned with Presto Java. + configStrings.emplace( + velox::core::QueryConfig::kDriverCpuTimeSliceLimitMs, "1000"); + } +} + +void updateFromSessionConfigs( + const protocol::SessionRepresentation& session, + std::unordered_map& queryConfigs) { + auto* sessionProperties = SessionProperties::instance(); + std::optional traceFragmentId; + std::optional traceShardId; + for (const auto& it : session.systemProperties) { + if (it.first == SessionProperties::kQueryTraceFragmentId) { + traceFragmentId = it.second; + } else if (it.first == SessionProperties::kQueryTraceShardId) { + traceShardId = it.second; + } else if (it.first == SessionProperties::kShuffleCompressionCodec) { + auto compression = it.second; + std::transform( + compression.begin(), + compression.end(), + compression.begin(), + ::tolower); + velox::common::CompressionKind compressionKind = + velox::common::stringToCompressionKind(compression); + queryConfigs[velox::core::QueryConfig::kShuffleCompressionKind] = + velox::common::compressionKindToString(compressionKind); + } else if (!sessionProperties->hasVeloxConfig(it.first)) { + sessionProperties->updateSessionPropertyValue(it.first, it.second); + } else { + queryConfigs[sessionProperties->toVeloxConfig(it.first)] = it.second; + } + } + + if (session.startTime) { + queryConfigs[velox::core::QueryConfig::kSessionStartTime] = + std::to_string(session.startTime); + } + + if (session.source) { + queryConfigs[velox::core::QueryConfig::kSource] = *session.source; + } + if (!session.clientTags.empty()) { + queryConfigs[velox::core::QueryConfig::kClientTags] = + folly::join(',', session.clientTags); + } + + // If there's a timeZoneKey, convert to timezone name and add to the + // configs. Throws if timeZoneKey can't be resolved. + if (session.timeZoneKey != 0) { + queryConfigs.emplace( + velox::core::QueryConfig::kSessionTimezone, + velox::tz::getTimeZoneName(session.timeZoneKey)); + } + + // Construct query tracing regex and pass to Velox config. + // It replaces the given native_query_trace_task_reg_exp if also set. + if (traceFragmentId.has_value() || traceShardId.has_value()) { + queryConfigs.emplace( + velox::core::QueryConfig::kQueryTraceTaskRegExp, + ".*\\." + traceFragmentId.value_or(".*") + "\\..*\\." + + traceShardId.value_or(".*") + "\\..*"); + } +} + +void updateFromSystemConfigs( + std::unordered_map& queryConfigs) { + const auto& systemConfig = SystemConfig::instance(); + struct ConfigMapping { + std::string prestoSystemConfig; + std::string veloxConfig; + std::function + toVeloxPropertyValueConverter{ + [](const std::string& prestoValue) { return prestoValue; }}; + }; + + static const std::vector veloxToPrestoConfigMapping{ + {.prestoSystemConfig = std::string(SystemConfig::kQueryMaxMemoryPerNode), + .veloxConfig = velox::core::QueryConfig::kQueryMaxMemoryPerNode}, + + { + .prestoSystemConfig = + std::string(SystemConfig::kSpillerFileCreateConfig), + .veloxConfig = velox::core::QueryConfig::kSpillFileCreateConfig, + }, + + {.prestoSystemConfig = std::string(SystemConfig::kSpillEnabled), + .veloxConfig = velox::core::QueryConfig::kSpillEnabled}, + + {.prestoSystemConfig = std::string(SystemConfig::kJoinSpillEnabled), + .veloxConfig = velox::core::QueryConfig::kJoinSpillEnabled}, + + {.prestoSystemConfig = std::string(SystemConfig::kOrderBySpillEnabled), + .veloxConfig = velox::core::QueryConfig::kOrderBySpillEnabled}, + + {.prestoSystemConfig = + std::string(SystemConfig::kAggregationSpillEnabled), + .veloxConfig = velox::core::QueryConfig::kAggregationSpillEnabled}, + + {.prestoSystemConfig = + std::string(SystemConfig::kRequestDataSizesMaxWaitSec), + .veloxConfig = velox::core::QueryConfig::kRequestDataSizesMaxWaitSec}, + + {.prestoSystemConfig = std::string(SystemConfig::kDriverMaxSplitPreload), + .veloxConfig = velox::core::QueryConfig::kMaxSplitPreloadPerDriver}, + + {.prestoSystemConfig = + std::string(SystemConfig::kMaxLocalExchangeBufferSize), + .veloxConfig = velox::core::QueryConfig::kMaxLocalExchangeBufferSize}, + + {.prestoSystemConfig = + std::string(SystemConfig::kMaxLocalExchangePartitionBufferSize), + .veloxConfig = + velox::core::QueryConfig::kMaxLocalExchangePartitionBufferSize}, + + {.prestoSystemConfig = + std::string(SystemConfig::kParallelOutputJoinBuildRowsEnabled), + .veloxConfig = + velox::core::QueryConfig::kParallelOutputJoinBuildRowsEnabled}, + + {.prestoSystemConfig = + std::string(SystemConfig::kHashProbeBloomFilterPushdownMaxSize), + .veloxConfig = + velox::core::QueryConfig::kHashProbeBloomFilterPushdownMaxSize}, + + {.prestoSystemConfig = std::string(SystemConfig::kUseLegacyArrayAgg), + .veloxConfig = velox::core::QueryConfig::kPrestoArrayAggIgnoreNulls}, + + {.prestoSystemConfig = std::string{SystemConfig::kTaskWriterCount}, + .veloxConfig = velox::core::QueryConfig::kTaskWriterCount}, + + {.prestoSystemConfig = + std::string{SystemConfig::kTaskPartitionedWriterCount}, + .veloxConfig = velox::core::QueryConfig::kTaskPartitionedWriterCount}, + + {.prestoSystemConfig = std::string{SystemConfig::kExchangeMaxBufferSize}, + .veloxConfig = velox::core::QueryConfig::kMaxExchangeBufferSize, + .toVeloxPropertyValueConverter = + [](const auto& value) { + return folly::to(velox::config::toCapacity( + value, velox::config::CapacityUnit::BYTE)); + }}, + + {.prestoSystemConfig = std::string(SystemConfig::kSinkMaxBufferSize), + .veloxConfig = velox::core::QueryConfig::kMaxOutputBufferSize, + .toVeloxPropertyValueConverter = + [](const auto& value) { + return folly::to(velox::config::toCapacity( + value, velox::config::CapacityUnit::BYTE)); + }}, + + {.prestoSystemConfig = + std::string(SystemConfig::kDriverMaxPagePartitioningBufferSize), + .veloxConfig = velox::core::QueryConfig::kMaxPartitionedOutputBufferSize, + .toVeloxPropertyValueConverter = + [](const auto& value) { + return folly::to(velox::config::toCapacity( + value, velox::config::CapacityUnit::BYTE)); + }}, + + {.prestoSystemConfig = + std::string(SystemConfig::kTaskMaxPartialAggregationMemory), + .veloxConfig = velox::core::QueryConfig::kMaxPartialAggregationMemory, + .toVeloxPropertyValueConverter = + [](const auto& value) { + return folly::to(velox::config::toCapacity( + value, velox::config::CapacityUnit::BYTE)); + }}, + + {.prestoSystemConfig = + std::string(SystemConfig::kExchangeLazyFetchingEnabled), + .veloxConfig = velox::core::QueryConfig::kExchangeLazyFetchingEnabled}, + }; + + for (const auto& configMapping : veloxToPrestoConfigMapping) { + const auto& veloxConfigName = configMapping.veloxConfig; + const auto& systemConfigName = configMapping.prestoSystemConfig; + const auto propertyOpt = systemConfig->optionalProperty(systemConfigName); + if (propertyOpt.has_value()) { + queryConfigs[veloxConfigName] = + configMapping.toVeloxPropertyValueConverter(propertyOpt.value()); + } + } +} +} // namespace + +std::unordered_map toVeloxConfigs( + const protocol::SessionRepresentation& session) { + std::unordered_map configs; + + // Firstly apply Presto system properties to Velox query config. + updateFromSystemConfigs(configs); + + // Secondly apply and possibly override with Presto session properties. + updateFromSessionConfigs(session, configs); + + // Finally apply special case configs. + updateVeloxConfigsWithSpecialCases(configs); + return configs; +} + +velox::core::QueryConfig toVeloxConfigs( + const protocol::SessionRepresentation& session, + const std::map& extraCredentials) { + // Start with the session-based configuration + auto configs = toVeloxConfigs(session); + + // If there are any extra credentials, add them all to the config + if (!extraCredentials.empty()) { + // Create new config map with all extra credentials added + configs.insert(extraCredentials.begin(), extraCredentials.end()); + } + return velox::core::QueryConfig(configs); +} + +std::unordered_map> +toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest) { + std::unordered_map> + connectorConfigs; + for (const auto& entry : taskUpdateRequest.session.catalogProperties) { + std::unordered_map connectorConfig; + // remove native prefix from native connector session property names + for (const auto& sessionProperty : entry.second) { + auto veloxConfig = (sessionProperty.first.rfind("native_", 0) == 0) + ? sessionProperty.first.substr(7) + : sessionProperty.first; + connectorConfig.emplace(veloxConfig, sessionProperty.second); + } + connectorConfig.insert( + taskUpdateRequest.extraCredentials.begin(), + taskUpdateRequest.extraCredentials.end()); + connectorConfig.insert({"user", taskUpdateRequest.session.user}); + if (taskUpdateRequest.session.source) { + connectorConfig.insert({"source", *taskUpdateRequest.session.source}); + } + if (taskUpdateRequest.session.schema) { + connectorConfig.insert({"schema", *taskUpdateRequest.session.schema}); + } + connectorConfigs.insert( + {entry.first, + std::make_shared( + std::move(connectorConfig))}); + } + + return connectorConfigs; +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.h b/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.h new file mode 100644 index 0000000000000..8e6b4cbf9261e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/PrestoToVeloxQueryConfig.h @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +namespace facebook::velox::config { +class ConfigBase; +} + +namespace facebook::velox::core { +class QueryConfig; +} + +namespace facebook::presto { + +/// Translates Presto configs to Velox 'QueryConfig' config map. Presto query +/// session properties take precedence over Presto system config properties. +std::unordered_map toVeloxConfigs( + const protocol::SessionRepresentation& session); + +/// Translates Presto configs to Velox 'QueryConfig' config map. It is the +/// temporary overload that builds a QueryConfig from session properties and +/// extraCredentials, including all extraCredentials so they can be consumed by +/// UDFs and connectors. +/// This implementation is a temporary solution until a more unified +/// configuration mechanism (TokenProvider) is available. +velox::core::QueryConfig toVeloxConfigs( + const protocol::SessionRepresentation& session, + const std::map& extraCredentials); + +std::unordered_map> +toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest); + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp index a67df141c41d9..3598603c300eb 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp @@ -14,12 +14,11 @@ #include "presto_cpp/main/QueryContextManager.h" #include +#include "presto_cpp/main/PrestoToVeloxQueryConfig.h" +#include "presto_cpp/main/SessionProperties.h" #include "presto_cpp/main/common/Configs.h" -#include "presto_cpp/main/common/Counters.h" -#include "velox/common/base/StatsReporter.h" #include "velox/connectors/hive/HiveConfig.h" #include "velox/core/QueryConfig.h" -#include "velox/type/tz/TimeZoneMap.h" using namespace facebook::velox; @@ -29,134 +28,162 @@ using facebook::presto::protocol::TaskId; namespace facebook::presto { namespace { -// Update passed in query session configs with system configs. For any pairing -// system/session configs if session config is present, it overrides system -// config, otherwise system config is fed in queryConfigs. E.g. -// "query.max-memory-per-node" system config and "query_max_memory_per_node" -// session config is a pairing config. If system config is 7GB but session -// config is not provided, then 7GB will be added to 'queryConfigs'. On the -// other hand if system config is 7GB but session config is 4GB then 4GB will -// be preserved in 'queryConfigs'. -void updateFromSystemConfigs( - std::unordered_map& queryConfigs) { - const auto& systemConfig = SystemConfig::instance(); - static const std::unordered_map - sessionSystemConfigMapping{ - {core::QueryConfig::kQueryMaxMemoryPerNode, - std::string(SystemConfig::kQueryMaxMemoryPerNode)}, - {core::QueryConfig::kSpillFileCreateConfig, - std::string(SystemConfig::kSpillerFileCreateConfig)}, - {core::QueryConfig::kSpillEnabled, - std::string(SystemConfig::kSpillEnabled)}, - {core::QueryConfig::kJoinSpillEnabled, - std::string(SystemConfig::kJoinSpillEnabled)}, - {core::QueryConfig::kOrderBySpillEnabled, - std::string(SystemConfig::kOrderBySpillEnabled)}, - {core::QueryConfig::kAggregationSpillEnabled, - std::string(SystemConfig::kAggregationSpillEnabled)}, - {core::QueryConfig::kRequestDataSizesMaxWaitSec, - std::string(SystemConfig::kRequestDataSizesMaxWaitSec)}}; - for (const auto& configNameEntry : sessionSystemConfigMapping) { - const auto& sessionName = configNameEntry.first; - const auto& systemConfigName = configNameEntry.second; - if (queryConfigs.count(sessionName) == 0) { - const auto propertyOpt = systemConfig->optionalProperty(systemConfigName); - if (propertyOpt.hasValue()) { - queryConfigs[sessionName] = propertyOpt.value(); - } - } - } +inline QueryId queryIdFromTaskId(const TaskId& taskId) { + return taskId.substr(0, taskId.find('.')); } -std::unordered_map> -toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest) { - std::unordered_map> - connectorConfigs; - for (const auto& entry : taskUpdateRequest.session.catalogProperties) { - std::unordered_map connectorConfig; - // remove native prefix from native connector session property names - for (const auto& sessionProperty : entry.second) { - auto veloxConfig = (sessionProperty.first.rfind("native_", 0) == 0) - ? sessionProperty.first.substr(7) - : sessionProperty.first; - connectorConfig.emplace(veloxConfig, sessionProperty.second); - } - connectorConfig.insert( - taskUpdateRequest.extraCredentials.begin(), - taskUpdateRequest.extraCredentials.end()); - connectorConfig.insert({"user", taskUpdateRequest.session.user}); - connectorConfigs.insert( - {entry.first, connectorConfig}); +} // namespace + +std::shared_ptr QueryContextCache::get( + const protocol::QueryId& queryId) { + auto iter = queryCtxs_.find(queryId); + if (iter == queryCtxs_.end()) { + return nullptr; } - return connectorConfigs; + queryIds_.erase(iter->second.idListIterator); + + if (auto queryCtx = iter->second.queryCtx.lock()) { + // Move the queryId to front, if queryCtx is still alive. + queryIds_.push_front(queryId); + iter->second.idListIterator = queryIds_.begin(); + return queryCtx; + } + queryCtxs_.erase(iter); + return nullptr; } -void updateVeloxConfigs( - std::unordered_map& configStrings) { - // If `legacy_timestamp` is true, the coordinator expects timestamp - // conversions without a timezone to be converted to the user's - // session_timezone. - auto it = configStrings.find("legacy_timestamp"); - // `legacy_timestamp` default value is true in the coordinator. - if ((it == configStrings.end()) || (folly::to(it->second))) { - configStrings.emplace( - core::QueryConfig::kAdjustTimestampToTimezone, "true"); +std::shared_ptr QueryContextCache::insert( + const protocol::QueryId& queryId, + std::shared_ptr queryCtx) { + if (queryCtxs_.size() >= capacity_) { + evict(); } - // TODO: remove this once cpu driver slicing config is turned on by default in - // Velox. - it = configStrings.find(core::QueryConfig::kDriverCpuTimeSliceLimitMs); - if (it == configStrings.end()) { - // Set it to 1 second to be aligned with Presto Java. - configStrings.emplace( - core::QueryConfig::kDriverCpuTimeSliceLimitMs, "1000"); + queryIds_.push_front(queryId); + queryCtxs_[queryId] = { + folly::to_weak_ptr(queryCtx), queryIds_.begin(), false}; + return queryCtx; +} + +bool QueryContextCache::hasStartedTasks( + const protocol::QueryId& queryId) const { + auto iter = queryCtxs_.find(queryId); + if (iter != queryCtxs_.end()) { + return iter->second.hasStartedTasks; } + return false; } -} // namespace +void QueryContextCache::setTasksStarted(const protocol::QueryId& queryId) { + auto iter = queryCtxs_.find(queryId); + if (iter != queryCtxs_.end()) { + iter->second.hasStartedTasks = true; + } +} + +void QueryContextCache::evict() { + // Evict least recently used queryCtx if it is not referenced elsewhere. + for (auto victim = queryIds_.end(); victim != queryIds_.begin();) { + --victim; + if (!queryCtxs_[*victim].queryCtx.lock()) { + queryCtxs_.erase(*victim); + queryIds_.erase(victim); + return; + } + } + + // All queries are still inflight. Increase capacity. + capacity_ = std::max(kInitialCapacity, capacity_ * 2); +} QueryContextManager::QueryContextManager( folly::Executor* driverExecutor, folly::Executor* spillerExecutor) - : driverExecutor_(driverExecutor), - spillerExecutor_(spillerExecutor), - sessionProperties_(SessionProperties()) {} + : driverExecutor_(driverExecutor), spillerExecutor_(spillerExecutor) {} std::shared_ptr QueryContextManager::findOrCreateQueryCtx( const protocol::TaskId& taskId, const protocol::TaskUpdateRequest& taskUpdateRequest) { - return findOrCreateQueryCtx( + std::lock_guard lock(queryContextCacheMutex_); + return findOrCreateQueryCtxLocked( taskId, - toVeloxConfigs(taskUpdateRequest.session), + toVeloxConfigs( + taskUpdateRequest.session, taskUpdateRequest.extraCredentials), toConnectorConfigs(taskUpdateRequest)); } -std::shared_ptr QueryContextManager::findOrCreateQueryCtx( - const TaskId& taskId, - std::unordered_map&& configStrings, - std::unordered_map< - std::string, - std::unordered_map>&& - connectorConfigStrings) { - QueryId queryId = taskId.substr(0, taskId.find('.')); - - auto lockedCache = queryContextCache_.wlock(); - if (auto queryCtx = lockedCache->get(queryId)) { - return queryCtx; +std::shared_ptr +QueryContextManager::findOrCreateBatchQueryCtx( + const protocol::TaskId& taskId, + const protocol::TaskUpdateRequest& taskUpdateRequest) { + std::lock_guard lock(queryContextCacheMutex_); + auto queryCtx = findOrCreateQueryCtxLocked( + taskId, + toVeloxConfigs( + taskUpdateRequest.session, taskUpdateRequest.extraCredentials), + toConnectorConfigs(taskUpdateRequest)); + if (queryCtx->pool()->aborted()) { + // In Batch mode, only one query is running at a time. When tasks fail + // during memory arbitration, the query memory pool will be set + // aborted, failing any successive tasks immediately. Yet one task + // should not fail other newly admitted tasks because of task retries + // and server reuse. Failure control among tasks should be + // independent. So if query memory pool is aborted already, a cache clear is + // performed to allow successive tasks to create a new query context to + // continue execution. + VELOX_CHECK_EQ(queryContextCache_.size(), 1); + queryContextCache_.clear(); + queryCtx = findOrCreateQueryCtxLocked( + taskId, + toVeloxConfigs( + taskUpdateRequest.session, taskUpdateRequest.extraCredentials), + toConnectorConfigs(taskUpdateRequest)); } + return queryCtx; +} + +bool QueryContextManager::queryHasStartedTasks( + const protocol::TaskId& taskId) const { + std::lock_guard lock(queryContextCacheMutex_); + return queryContextCache_.hasStartedTasks(queryIdFromTaskId(taskId)); +} - updateVeloxConfigs(configStrings); +void QueryContextManager::setQueryHasStartedTasks( + const protocol::TaskId& taskId) { + std::lock_guard lock(queryContextCacheMutex_); + queryContextCache_.setTasksStarted(queryIdFromTaskId(taskId)); +} - std::unordered_map> - connectorConfigs; - for (auto& entry : connectorConfigStrings) { - connectorConfigs.insert( - {entry.first, - std::make_shared(std::move(entry.second))}); +std::shared_ptr +QueryContextManager::createAndCacheQueryCtxLocked( + const QueryId& queryId, + velox::core::QueryConfig&& queryConfig, + std::unordered_map>&& + connectorConfigs, + std::shared_ptr&& pool) { + auto queryCtx = core::QueryCtx::create( + driverExecutor_, + std::move(queryConfig), + std::move(connectorConfigs), + cache::AsyncDataCache::getInstance(), + std::move(pool), + spillerExecutor_, + queryId); + return queryContextCache_.insert(queryId, std::move(queryCtx)); +} + +std::shared_ptr QueryContextManager::findOrCreateQueryCtxLocked( + const TaskId& taskId, + velox::core::QueryConfig&& queryConfig, + std::unordered_map>&& + connectorConfigs) { + const QueryId queryId{queryIdFromTaskId(taskId)}; + + if (auto queryCtx = queryContextCache_.get(queryId)) { + return queryCtx; } - velox::core::QueryConfig queryConfig{std::move(configStrings)}; // NOTE: the monotonically increasing 'poolId' is appended to 'queryId' to // ensure that the name of root memory pool instance is always unique. In some // edge case, we found some background activities such as the long-running @@ -165,10 +192,12 @@ std::shared_ptr QueryContextManager::findOrCreateQueryCtx( // is still indexed by the query id. static std::atomic_uint64_t poolId{0}; std::optional poolDbgOpts; - const auto debugMemoryPoolNameRegex = queryConfig.debugMemoryPoolNameRegex(); + auto debugMemoryPoolNameRegex = queryConfig.debugMemoryPoolNameRegex(); if (!debugMemoryPoolNameRegex.empty()) { poolDbgOpts = memory::MemoryPool::DebugOptions{ - .debugPoolNameRegex = debugMemoryPoolNameRegex}; + .debugPoolNameRegex = std::move(debugMemoryPoolNameRegex), + .debugPoolWarnThresholdBytes = + queryConfig.debugMemoryPoolWarnThresholdBytes()}; } auto pool = memory::MemoryManager::getInstance()->addRootPool( fmt::format("{}_{}", queryId, poolId++), @@ -176,87 +205,33 @@ std::shared_ptr QueryContextManager::findOrCreateQueryCtx( nullptr, poolDbgOpts); - auto queryCtx = core::QueryCtx::create( - driverExecutor_, + return createAndCacheQueryCtxLocked( + queryId, std::move(queryConfig), - connectorConfigs, - cache::AsyncDataCache::getInstance(), - std::move(pool), - spillerExecutor_, - queryId); - - return lockedCache->insert(queryId, queryCtx); + std::move(connectorConfigs), + std::move(pool)); } void QueryContextManager::visitAllContexts( - std::function - visitor) const { - auto lockedCache = queryContextCache_.rlock(); - for (const auto& it : lockedCache->ctxs()) { - if (const auto queryCtxSP = it.second.first.lock()) { + const std::function< + void(const protocol::QueryId&, const velox::core::QueryCtx*)>& visitor) + const { + std::lock_guard lock(queryContextCacheMutex_); + for (const auto& it : queryContextCache_.ctxMap()) { + if (const auto queryCtxSP = it.second.queryCtx.lock()) { visitor(it.first, queryCtxSP.get()); } } } -void QueryContextManager::testingClearCache() { - queryContextCache_.wlock()->testingClear(); +void QueryContextManager::clearCache() { + std::lock_guard lock(queryContextCacheMutex_); + queryContextCache_.clear(); } -void QueryContextCache::testingClear() { +void QueryContextCache::clear() { queryCtxs_.clear(); queryIds_.clear(); } -std::unordered_map -QueryContextManager::toVeloxConfigs( - const protocol::SessionRepresentation& session) { - // Use base velox query config as the starting point and add Presto session - // properties on top of it. - auto configs = BaseVeloxQueryConfig::instance()->values(); - std::optional traceFragmentId; - std::optional traceShardId; - for (const auto& it : session.systemProperties) { - if (it.first == SessionProperties::kQueryTraceFragmentId) { - traceFragmentId = it.second; - } else if (it.first == SessionProperties::kQueryTraceShardId) { - traceShardId = it.second; - } else if (it.first == SessionProperties::kShuffleCompressionCodec) { - auto compression = it.second; - std::transform( - compression.begin(), - compression.end(), - compression.begin(), - ::tolower); - velox::common::CompressionKind compressionKind = - common::stringToCompressionKind(compression); - configs[core::QueryConfig::kShuffleCompressionKind] = - velox::common::compressionKindToString(compressionKind); - } else { - configs[sessionProperties_.toVeloxConfig(it.first)] = it.second; - sessionProperties_.updateVeloxConfig(it.first, it.second); - } - } - - // If there's a timeZoneKey, convert to timezone name and add to the - // configs. Throws if timeZoneKey can't be resolved. - if (session.timeZoneKey != 0) { - configs.emplace( - velox::core::QueryConfig::kSessionTimezone, - velox::tz::getTimeZoneName(session.timeZoneKey)); - } - - // Construct query tracing regex and pass to Velox config. - // It replaces the given native_query_trace_task_reg_exp if also set. - if (traceFragmentId.has_value() || traceShardId.has_value()) { - configs.emplace( - velox::core::QueryConfig::kQueryTraceTaskRegExp, - ".*\\." + traceFragmentId.value_or(".*") + "\\..*\\." + - traceShardId.value_or(".*") + "\\..*"); - } - - updateFromSystemConfigs(configs); - return configs; -} - } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.h b/presto-native-execution/presto_cpp/main/QueryContextManager.h index f8b1a1836ce55..3a07b73c36335 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.h +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.h @@ -17,9 +17,9 @@ #include #include #include +#include #include -#include "presto_cpp/main/SessionProperties.h" #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/core/QueryCtx.h" @@ -28,7 +28,11 @@ class QueryContextCache { public: using QueryCtxWeakPtr = std::weak_ptr; using QueryIdList = std::list; - using QueryCtxCacheValue = std::pair; + struct QueryCtxCacheValue { + QueryCtxWeakPtr queryCtx; + QueryIdList::iterator idListIterator; + bool hasStartedTasks{false}; + }; using QueryCtxMap = std::unordered_map; QueryContextCache(size_t initial_capacity = kInitialCapacity) @@ -41,62 +45,31 @@ class QueryContextCache { return queryCtxs_.size(); } - std::shared_ptr get(protocol::QueryId queryId) { - auto iter = queryCtxs_.find(queryId); - if (iter != queryCtxs_.end()) { - queryIds_.erase(iter->second.second); - - if (auto queryCtx = iter->second.first.lock()) { - // Move the queryId to front, if queryCtx is still alive. - queryIds_.push_front(queryId); - iter->second.second = queryIds_.begin(); - return queryCtx; - } else { - queryCtxs_.erase(iter); - } - } - return nullptr; + const QueryCtxMap& ctxMap() const { + return queryCtxs_; } + std::shared_ptr get(const protocol::QueryId& queryId); + std::shared_ptr insert( - protocol::QueryId queryId, - std::shared_ptr queryCtx) { - if (queryCtxs_.size() >= capacity_) { - evict(); - } - queryIds_.push_front(queryId); - queryCtxs_[queryId] = - std::make_pair(folly::to_weak_ptr(queryCtx), queryIds_.begin()); - return queryCtx; - } + const protocol::QueryId& queryId, + std::shared_ptr queryCtx); - void evict() { - // Evict least recently used queryCtx if it is not referenced elsewhere. - for (auto victim = queryIds_.end(); victim != queryIds_.begin();) { - --victim; - if (!queryCtxs_[*victim].first.lock()) { - queryCtxs_.erase(*victim); - queryIds_.erase(victim); - return; - } - } - - // All queries are still inflight. Increase capacity. - capacity_ = std::max(kInitialCapacity, capacity_ * 2); - } - const QueryCtxMap& ctxs() const { - return queryCtxs_; - } + bool hasStartedTasks(const protocol::QueryId& queryId) const; + + void setTasksStarted(const protocol::QueryId& queryId); + + void evict(); - void testingClear(); + void clear(); private: + static constexpr size_t kInitialCapacity = 256UL; + size_t capacity_; QueryCtxMap queryCtxs_; QueryIdList queryIds_; - - static constexpr size_t kInitialCapacity = 256UL; }; class QueryContextManager { @@ -105,39 +78,53 @@ class QueryContextManager { folly::Executor* driverExecutor, folly::Executor* spillerExecutor); + virtual ~QueryContextManager() = default; + std::shared_ptr findOrCreateQueryCtx( const protocol::TaskId& taskId, const protocol::TaskUpdateRequest& taskUpdateRequest); + std::shared_ptr findOrCreateBatchQueryCtx( + const protocol::TaskId& taskId, + const protocol::TaskUpdateRequest& taskUpdateRequest); + + /// Returns true if the given task's query has at least one task started. + bool queryHasStartedTasks(const protocol::TaskId& taskId) const; + + /// Sets flag indicating that task's query has at least one task started. + void setQueryHasStartedTasks(const protocol::TaskId& taskId); + /// Calls the given functor for every present query context. - void visitAllContexts(std::function visitor) const; + void visitAllContexts( + const std::function< + void(const protocol::QueryId&, const velox::core::QueryCtx*)>& + visitor) const; /// Test method to clear the query context cache. - void testingClearCache(); + void clearCache(); - const SessionProperties& getSessionProperties() const { - return sessionProperties_; - } + protected: + folly::Executor* const driverExecutor_{nullptr}; + folly::Executor* const spillerExecutor_{nullptr}; + QueryContextCache queryContextCache_; private: - std::shared_ptr findOrCreateQueryCtx( - const protocol::TaskId& taskId, - std::unordered_map&& configStrings, + virtual std::shared_ptr createAndCacheQueryCtxLocked( + const protocol::QueryId& queryId, + velox::core::QueryConfig&& queryConfig, std::unordered_map< std::string, - std::unordered_map>&& - connectorConfigStrings); - - std::unordered_map toVeloxConfigs( - const protocol::SessionRepresentation& session); + std::shared_ptr>&& connectorConfigs, + std::shared_ptr&& pool); - folly::Executor* const driverExecutor_{nullptr}; - folly::Executor* const spillerExecutor_{nullptr}; + std::shared_ptr findOrCreateQueryCtxLocked( + const protocol::TaskId& taskId, + velox::core::QueryConfig&& queryConfig, + std::unordered_map< + std::string, + std::shared_ptr>&& connectorConfigStrings); - folly::Synchronized queryContextCache_; - SessionProperties sessionProperties_; + mutable std::mutex queryContextCacheMutex_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/RemoteFunctionRegisterer.cpp b/presto-native-execution/presto_cpp/main/RemoteFunctionRegisterer.cpp index c97ba39d58f18..a32ae0347c577 100644 --- a/presto-native-execution/presto_cpp/main/RemoteFunctionRegisterer.cpp +++ b/presto-native-execution/presto_cpp/main/RemoteFunctionRegisterer.cpp @@ -58,7 +58,7 @@ size_t processFile( std::stringstream buffer; buffer << stream.rdbuf(); - velox::functions::RemoteVectorFunctionMetadata metadata; + velox::functions::RemoteThriftVectorFunctionMetadata metadata; metadata.location = location; metadata.serdeFormat = fromSerdeString(serde); diff --git a/presto-native-execution/presto_cpp/main/SessionProperties.cpp b/presto-native-execution/presto_cpp/main/SessionProperties.cpp index 0e50644ef42b0..39781c0e778c3 100644 --- a/presto-native-execution/presto_cpp/main/SessionProperties.cpp +++ b/presto-native-execution/presto_cpp/main/SessionProperties.cpp @@ -24,14 +24,10 @@ const std::string boolToString(bool value) { } } // namespace -json SessionProperty::serialize() { - json j; - j["name"] = name_; - j["description"] = description_; - j["typeSignature"] = type_; - j["defaultValue"] = defaultValue_; - j["hidden"] = hidden_; - return j; +SessionProperties* SessionProperties::instance() { + static std::unique_ptr instance = + std::make_unique(); + return instance.get(); } void SessionProperties::addSessionProperty( @@ -39,15 +35,10 @@ void SessionProperties::addSessionProperty( const std::string& description, const TypePtr& type, bool isHidden, - const std::string& veloxConfigName, - const std::string& veloxDefault) { + const std::optional veloxConfig, + const std::string& defaultValue) { sessionProperties_[name] = std::make_shared( - name, - description, - type->toString(), - isHidden, - veloxConfigName, - veloxDefault); + name, description, type->toString(), isHidden, veloxConfig, defaultValue); } // List of native session properties is kept as the source of truth here. @@ -274,6 +265,19 @@ SessionProperties::SessionProperties() { QueryConfig::kDebugMemoryPoolNameRegex, c.debugMemoryPoolNameRegex()); + addSessionProperty( + kDebugMemoryPoolWarnThresholdBytes, + "Warning threshold in bytes for debug memory pools. When set to a " + "non-zero value, a warning will be logged once per memory pool when " + "allocations cause the pool to exceed this threshold. This is useful for " + "identifying memory usage patterns during debugging. Requires allocation " + "tracking to be enabled with `native_debug_memory_pool_name_regex` " + "for the pool. A value of 0 means no warning threshold is enforced.", + BIGINT(), + false, + QueryConfig::kDebugMemoryPoolWarnThresholdBytes, + std::to_string(c.debugMemoryPoolWarnThresholdBytes())); + addSessionProperty( kSelectiveNimbleReaderEnabled, "Temporary flag to control whether selective Nimble reader should be " @@ -301,13 +305,12 @@ SessionProperties::SessionProperties() { c.queryTraceDir()); addSessionProperty( - kQueryTraceNodeIds, - "A comma-separated list of plan node ids whose input data will be traced." - " Empty string if only want to trace the query metadata.", + kQueryTraceNodeId, + "The plan node id whose input data will be traced.", VARCHAR(), false, - QueryConfig::kQueryTraceNodeIds, - c.queryTraceNodeIds()); + QueryConfig::kQueryTraceNodeId, + c.queryTraceNodeId()); addSessionProperty( kQueryTraceMaxBytes, @@ -317,7 +320,6 @@ SessionProperties::SessionProperties() { QueryConfig::kQueryTraceMaxBytes, std::to_string(c.queryTraceMaxBytes())); - addSessionProperty( kOpTraceDirectoryCreateConfig, "Config used to create operator trace directory. This config is provided to" @@ -343,7 +345,7 @@ SessionProperties::SessionProperties() { "creating tiny SerializedPages. For " "PartitionedOutputNode::Kind::kPartitioned, PartitionedOutput operator" "would buffer up to that number of bytes / number of destinations for " - "each destination before producing a SerializedPage.", + "each destination before producing a SerializedPageBase.", BIGINT(), false, QueryConfig::kMaxPartitionedOutputBufferSize, @@ -486,40 +488,155 @@ SessionProperties::SessionProperties() { kRequestDataSizesMaxWaitSec, "Maximum wait time for exchange long poll requests in seconds.", INTEGER(), - 10, + false, QueryConfig::kRequestDataSizesMaxWaitSec, std::to_string(c.requestDataSizesMaxWaitSec())); -} -const std::unordered_map>& -SessionProperties::getSessionProperties() { - return sessionProperties_; -} + addSessionProperty( + kNativeQueryMemoryReclaimerPriority, + "Memory pool reclaimer priority.", + INTEGER(), + false, + QueryConfig::kQueryMemoryReclaimerPriority, + std::to_string(c.queryMemoryReclaimerPriority())); -const std::string SessionProperties::toVeloxConfig(const std::string& name) { - auto it = sessionProperties_.find(name); - return it == sessionProperties_.end() ? name - : it->second->getVeloxConfigName(); + addSessionProperty( + kMaxNumSplitsListenedTo, + "Maximum number of splits to listen to by SplitListener on native workers.", + INTEGER(), + true, + QueryConfig::kMaxNumSplitsListenedTo, + std::to_string(c.maxNumSplitsListenedTo())); + + addSessionProperty( + kMaxSplitPreloadPerDriver, + "Maximum number of splits to preload per driver. Set to 0 to disable preloading.", + INTEGER(), + false, + QueryConfig::kMaxSplitPreloadPerDriver, + std::to_string(c.maxSplitPreloadPerDriver())); + + addSessionProperty( + kIndexLookupJoinMaxPrefetchBatches, + "Specifies the max number of input batches to prefetch to do index" + "lookup ahead. If it is zero, then process one input batch at a time.", + INTEGER(), + false, + QueryConfig::kIndexLookupJoinMaxPrefetchBatches, + std::to_string(c.indexLookupJoinMaxPrefetchBatches())); + + addSessionProperty( + kIndexLookupJoinSplitOutput, + "If this is true, then the index join operator might split output for" + "each input batch based on the output batch size control. Otherwise, it tries to" + "produce a single output for each input batch.", + BOOLEAN(), + false, + QueryConfig::kIndexLookupJoinSplitOutput, + std::to_string(c.indexLookupJoinSplitOutput())); + + addSessionProperty( + kUnnestSplitOutput, + "In streaming aggregation, wait until we have enough number of output" + "rows to produce a batch of size specified by this. If set to 0, then" + "Operator::outputBatchRows will be used as the min output batch rows.", + BOOLEAN(), + false, + QueryConfig::kUnnestSplitOutput, + std::to_string(c.unnestSplitOutput())); + + addSessionProperty( + kPreferredOutputBatchBytes, + "Preferred memory budget for operator output batches. Used in tandem with average row size estimates when available.", + BIGINT(), + true, + QueryConfig::kPreferredOutputBatchBytes, + std::to_string(c.preferredOutputBatchBytes())); + + addSessionProperty( + kPreferredOutputBatchRows, + "Preferred row count per operator output batch. Used when average row size estimates are unknown.", + INTEGER(), + true, + QueryConfig::kPreferredOutputBatchRows, + std::to_string(c.preferredOutputBatchRows())); + + addSessionProperty( + kMaxOutputBatchRows, + "Upperbound for row count per output batch, used together with preferred_output_batch_bytes and average row size estimates.", + INTEGER(), + true, + QueryConfig::kMaxOutputBatchRows, + std::to_string(c.maxOutputBatchRows())); + + addSessionProperty( + kRowSizeTrackingMode, + "Enable (reader) row size tracker as a fallback to file level row size estimates.", + INTEGER(), + true, + QueryConfig::kRowSizeTrackingMode, + std::to_string(static_cast(c.rowSizeTrackingMode()))); + + addSessionProperty( + kUseVeloxGeospatialJoin, + "If this is true, then the protocol::SpatialJoinNode is converted to a" + "velox::core::SpatialJoinNode. Otherwise, it is converted to a" + "velox::core::NestedLoopJoinNode.", + BOOLEAN(), + false, + std::nullopt, + "true"); + + addSessionProperty( + kAggregationCompactionBytesThreshold, + "Memory threshold in bytes for triggering string compaction during global " + "aggregation. When total string storage exceeds this limit with high unused " + "memory ratio, compaction is triggered to reclaim dead strings. Disabled by " + "default (0). NOTE: Currently only applies to approx_most_frequent aggregate " + "with StringView type during global aggregation. May extend to other aggregates.", + BIGINT(), + false, + QueryConfig::kAggregationCompactionBytesThreshold, + std::to_string(c.aggregationCompactionBytesThreshold())); + + addSessionProperty( + kAggregationCompactionUnusedMemoryRatio, + "Ratio of unused (evicted) bytes to total bytes that triggers compaction. " + "The value is in the range of [0, 1). Default is 0.25. NOTE: Currently only applies " + "to approx_most_frequent aggregate with StringView type during global " + "aggregation. May extend to other aggregates.", + DOUBLE(), + false, + QueryConfig::kAggregationCompactionUnusedMemoryRatio, + std::to_string(c.aggregationCompactionUnusedMemoryRatio())); } -void SessionProperties::updateVeloxConfig( - const std::string& name, - const std::string& value) { +const std::string SessionProperties::toVeloxConfig( + const std::string& name) const { auto it = sessionProperties_.find(name); - // Velox config value is updated only for presto session properties. - if (it == sessionProperties_.end()) { - return; + if (it != sessionProperties_.end() && + it->second->getVeloxConfig().has_value()) { + return it->second->getVeloxConfig().value(); } - it->second->updateValue(value); + return name; } -json SessionProperties::serialize() { +json SessionProperties::serialize() const { json j = json::array(); - const auto sessionProperties = getSessionProperties(); - for (const auto& entry : sessionProperties) { - j.push_back(entry.second->serialize()); + json tj; + for (const auto& sessionProperty : sessionProperties_) { + protocol::to_json(tj, sessionProperty.second->getMetadata()); + j.push_back(tj); } return j; } +bool SessionProperties::useVeloxGeospatialJoin() const { + auto it = sessionProperties_.find(kUseVeloxGeospatialJoin); + if (it != sessionProperties_.end()) { + return it->second->getValue() == "true"; + } + VELOX_UNREACHABLE(); +} + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/SessionProperties.h b/presto-native-execution/presto_cpp/main/SessionProperties.h index c1e9dd0bf93aa..ebe416e1e595c 100644 --- a/presto-native-execution/presto_cpp/main/SessionProperties.h +++ b/presto-native-execution/presto_cpp/main/SessionProperties.h @@ -14,6 +14,7 @@ #pragma once #include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/type/Type.h" using json = nlohmann::json; @@ -27,20 +28,24 @@ class SessionProperty { SessionProperty( const std::string& name, const std::string& description, - const std::string& type, + const std::string& typeSignature, bool hidden, - const std::string& veloxConfigName, + const std::optional veloxConfig, const std::string& defaultValue) - : name_(name), - description_(description), - type_(type), - hidden_(hidden), - veloxConfigName_(veloxConfigName), - defaultValue_(defaultValue), + : metadata_({name, description, typeSignature, defaultValue, hidden}), + veloxConfig_(veloxConfig), value_(defaultValue) {} - const std::string getVeloxConfigName() { - return veloxConfigName_; + const protocol::SessionPropertyMetadata getMetadata() { + return metadata_; + } + + const std::optional getVeloxConfig() { + return veloxConfig_; + } + + const std::string getValue() { + return value_; } void updateValue(const std::string& value) { @@ -48,23 +53,18 @@ class SessionProperty { } bool operator==(const SessionProperty& other) const { - return name_ == other.name_ && description_ == other.description_ && - type_ == other.type_ && hidden_ == other.hidden_ && - veloxConfigName_ == other.veloxConfigName_ && - defaultValue_ == other.defaultValue_; + const auto otherMetadata = other.metadata_; + return metadata_.name == otherMetadata.name && + metadata_.description == otherMetadata.description && + metadata_.typeSignature == otherMetadata.typeSignature && + metadata_.hidden == otherMetadata.hidden && + metadata_.defaultValue == otherMetadata.defaultValue && + veloxConfig_ == other.veloxConfig_; } - json serialize(); - private: - const std::string name_; - const std::string description_; - - // Datatype of presto native property. - const std::string type_; - const bool hidden_; - const std::string veloxConfigName_; - const std::string defaultValue_; + const protocol::SessionPropertyMetadata metadata_; + const std::optional veloxConfig_; std::string value_; }; @@ -184,6 +184,12 @@ class SessionProperties { static constexpr const char* kDebugMemoryPoolNameRegex = "native_debug_memory_pool_name_regex"; + /// Warning threshold in bytes for memory pool allocations. Logs callsites + /// when exceeded. Requires allocation tracking to be enabled with + /// `native_debug_memory_pool_name_regex` property for the pool. + static constexpr const char* kDebugMemoryPoolWarnThresholdBytes = + "native_debug_memory_pool_warn_threshold_bytes"; + /// Temporary flag to control whether selective Nimble reader should be used /// in this query or not. Will be removed after the selective Nimble reader /// is fully rolled out. @@ -230,10 +236,8 @@ class SessionProperties { /// Base dir of a query to store tracing data. static constexpr const char* kQueryTraceDir = "native_query_trace_dir"; - /// A comma-separated list of plan node ids whose input data will be traced. - /// Empty string if only want to trace the query metadata. - static constexpr const char* kQueryTraceNodeIds = - "native_query_trace_node_ids"; + /// The plan node id whose input data will be traced. + static constexpr const char* kQueryTraceNodeId = "native_query_trace_node_id"; /// The max trace bytes limit. Tracing is disabled if zero. static constexpr const char* kQueryTraceMaxBytes = @@ -264,7 +268,7 @@ class SessionProperties { /// creating tiny SerializedPages. For /// PartitionedOutputNode::Kind::kPartitioned, PartitionedOutput operator /// would buffer up to that number of bytes / number of destinations for each - /// destination before producing a SerializedPage. + /// destination before producing a SerializedPageBase. static constexpr const char* kMaxPartitionedOutputBufferSize = "native_max_page_partitioning_buffer_size"; @@ -310,30 +314,126 @@ class SessionProperties { "native_streaming_aggregation_min_output_batch_rows"; /// Maximum wait time for exchange long poll requests in seconds. - static constexpr const char* kRequestDataSizesMaxWaitSec = + static constexpr const char* kRequestDataSizesMaxWaitSec = "native_request_data_sizes_max_wait_sec"; - SessionProperties(); + /// Priority of memory pool reclaimer when deciding on memory pool to abort. + /// Lower value has higher priority and less likely to be chosen as candidate + /// for memory pool abort. + static constexpr const char* kNativeQueryMemoryReclaimerPriority = + "native_query_memory_reclaimer_priority"; + + /// Maximum number of splits to listen to by SplitListener on native workers. + static constexpr const char* kMaxNumSplitsListenedTo = + "native_max_num_splits_listened_to"; + + /// Maximum number of splits to preload per driver. Set to 0 to disable + /// preloading. + static constexpr const char* kMaxSplitPreloadPerDriver = + "native_max_split_preload_per_driver"; + + /// Specifies the max number of input batches to prefetch to do index lookup + /// ahead. If it is zero, then process one input batch at a time. + static constexpr const char* kIndexLookupJoinMaxPrefetchBatches = + "native_index_lookup_join_max_prefetch_batches"; + + /// If this is true, then the index join operator might split output for each + /// input batch based on the output batch size control. Otherwise, it tries to + /// produce a single output for each input batch. + static constexpr const char* kIndexLookupJoinSplitOutput = + "native_index_lookup_join_split_output"; + + /// If this is true, then the unnest operator might split output for each + /// input batch based on the output batch size control. Otherwise, it produces + /// a single output for each input batch. + static constexpr const char* kUnnestSplitOutput = + "native_unnest_split_output"; + + /// Preferred size of batches in bytes to be returned by operators from + /// Operator::getOutput. It is used when an estimate of average row size is + /// known. Otherwise kPreferredOutputBatchRows is used. + static constexpr const char* kPreferredOutputBatchBytes = + "preferred_output_batch_bytes"; + + /// Preferred number of rows to be returned by operators from + /// Operator::getOutput. It is used when an estimate of average row size is + /// not known. When the estimate of average row size is known, + /// kPreferredOutputBatchBytes is used. + static constexpr const char* kPreferredOutputBatchRows = + "preferred_output_batch_rows"; + + /// Max number of rows that could be return by operators from + /// Operator::getOutput. It is used when an estimate of average row size is + /// known and kPreferredOutputBatchBytes is used to compute the number of + /// output rows. + static constexpr const char* kMaxOutputBatchRows = "max_output_batch_rows"; + + /// Enable (reader) row size tracker as a fallback to file level row size + /// estimates. + static constexpr const char* kRowSizeTrackingMode = "row_size_tracking_mode"; + + /// If this is true, then the protocol::SpatialJoinNode is converted to a + /// velox::core::SpatialJoinNode. Otherwise, it is converted to a + /// velox::core::NestedLoopJoinNode. + static constexpr const char* kUseVeloxGeospatialJoin = + "native_use_velox_geospatial_join"; + + /// Memory threshold in bytes for triggering string compaction during global + /// aggregation. When total string storage exceeds this limit with high unused + /// memory ratio, compaction is triggered to reclaim dead strings. Disabled by + /// default (0). + /// + /// NOTE: Currently only applies to approx_most_frequent aggregate with + /// StringView type during global aggregation. May extend to other aggregates. + static constexpr const char* kAggregationCompactionBytesThreshold = + "native_aggregation_compaction_bytes_threshold"; + + /// Ratio of unused (evicted) bytes to total bytes that triggers compaction. + /// The value is in the range of [0, 1). Default is 0.25. + /// + /// NOTE: Currently only applies to approx_most_frequent aggregate with + /// StringView type during global aggregation. May extend to other aggregates. + static constexpr const char* kAggregationCompactionUnusedMemoryRatio = + "native_aggregation_compaction_unused_memory_ratio"; + + inline bool hasVeloxConfig(const std::string& key) { + auto sessionProperty = sessionProperties_.find(key); + if (sessionProperty == sessionProperties_.end()) { + // In this case a queryConfig is being created so we should return + // true since it will also have a veloxConfig. + return true; + } + return sessionProperty->second->getVeloxConfig().has_value(); + } + + inline void updateSessionPropertyValue( + const std::string& key, + const std::string& value) { + auto sessionProperty = sessionProperties_.find(key); + VELOX_CHECK(sessionProperty != sessionProperties_.end()); + sessionProperty->second->updateValue(value); + } - const std::unordered_map>& - getSessionProperties(); + static SessionProperties* instance(); + + SessionProperties(); /// Utility function to translate a config name in Presto to its equivalent in /// Velox. Returns 'name' as is if there is no mapping. - const std::string toVeloxConfig(const std::string& name); + const std::string toVeloxConfig(const std::string& name) const; - void updateVeloxConfig(const std::string& name, const std::string& value); + json serialize() const; - json serialize(); + bool useVeloxGeospatialJoin() const; - protected: + private: void addSessionProperty( const std::string& name, const std::string& description, const velox::TypePtr& type, bool isHidden, - const std::string& veloxConfigName, - const std::string& veloxDefault); + const std::optional veloxConfig, + const std::string& defaultValue); std::unordered_map> sessionProperties_; diff --git a/presto-native-execution/presto_cpp/main/TaskManager.cpp b/presto-native-execution/presto_cpp/main/TaskManager.cpp index 5af124f43a5bd..a25a85b6e954f 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/TaskManager.cpp @@ -55,53 +55,53 @@ void cancelAbandonedTasksInternal(const TaskMap& taskMap, int32_t abandonedMs) { } } -// If spilling is enabled and the given Task can spill, then this helper -// generates the spilling directory path for the Task, and sets the path to it -// in the Task. -static void maybeSetupTaskSpillDirectory( +// If spilling is enabled and the task plan fragment can spill, then this helper +// generates the disk spilling options for the task. +std::optional getTaskSpillOptions( + const TaskId& taskId, const core::PlanFragment& planFragment, - exec::Task& execTask, - const std::string& baseSpillDirectory) { - if (baseSpillDirectory.empty() || - !planFragment.canSpill(execTask.queryCtx()->queryConfig())) { - return; + const std::shared_ptr& queryCtx, + const std::string& baseSpillDir) { + if (baseSpillDir.empty() || !planFragment.canSpill(queryCtx->queryConfig())) { + return std::nullopt; } - const auto includeNodeInSpillPath = + common::SpillDiskOptions spillDiskOpts; + const bool includeNodeInSpillPath = SystemConfig::instance()->includeNodeInSpillPath(); auto nodeConfig = NodeConfig::instance(); const auto [taskSpillDirPath, dateSpillDirPath] = TaskManager::buildTaskSpillDirectoryPath( - baseSpillDirectory, + baseSpillDir, nodeConfig->nodeInternalAddress(), nodeConfig->nodeId(), - execTask.queryCtx()->queryId(), - execTask.taskId(), + queryCtx->queryId(), + taskId, includeNodeInSpillPath); - execTask.setSpillDirectory(taskSpillDirPath, /*alreadyCreated=*/false); - - execTask.setCreateSpillDirectoryCb( - [spillDir = taskSpillDirPath, dateStrDir = dateSpillDirPath]() { - auto fs = filesystems::getFileSystem(dateStrDir, nullptr); - // First create the top level directory (date string of the query) with - // TTL or other configs if set. - filesystems::DirectoryOptions options; - // Do not fail if the directory already exist because another process - // may have already created the dateStrDir. - options.failIfExists = false; - auto config = SystemConfig::instance()->spillerDirectoryCreateConfig(); - if (!config.empty()) { - options.values.emplace( - filesystems::DirectoryOptions::kMakeDirectoryConfig.toString(), - config); - } - fs->mkdir(dateStrDir, options); - - // After the parent directory is created, - // then create the spill directory for the actual task. - fs->mkdir(spillDir); - return spillDir; - }); + spillDiskOpts.spillDirPath = taskSpillDirPath; + spillDiskOpts.spillDirCreated = false; + spillDiskOpts.spillDirCreateCb = [spillDir = taskSpillDirPath, + dateDir = dateSpillDirPath]() { + auto fs = filesystems::getFileSystem(dateDir, nullptr); + // First create the top level directory (date string of the query) with + // TTL or other configs if set. + filesystems::DirectoryOptions options; + // Do not fail if the directory already exist because another process + // may have already created the dateStrDir. + options.failIfExists = false; + auto config = SystemConfig::instance()->spillerDirectoryCreateConfig(); + if (!config.empty()) { + options.values.emplace( + filesystems::DirectoryOptions::kMakeDirectoryConfig.toString(), + config); + } + fs->mkdir(dateDir, options); + // After the parent directory is created, + // then create the spill directory for the actual task. + fs->mkdir(spillDir); + return spillDir; + }; + return spillDiskOpts; } // Keep outstanding Promises in RequestHandler's state itself. @@ -182,8 +182,11 @@ void getData( } } - VLOG(1) << "Task " << taskId << ", buffer " << bufferId << ", sequence " - << sequence << " Results size: " << bytes + int64_t waitTimeMs = getCurrentTimeMs() - startMs; + VLOG(1) << "Task " << taskId << " waited " << waitTimeMs + << "ms for data: " + << "buffer " << bufferId << ", sequence " << sequence + << " Results size: " << bytes << ", page count: " << pages.size() << ", remaining: " << folly::join(',', remainingBytes) << ", complete: " << std::boolalpha << complete; @@ -194,6 +197,7 @@ void getData( result->complete = complete; result->data = std::move(iobuf); result->remainingBytes = std::move(remainingBytes); + result->waitTimeMs = waitTimeMs; promiseHolder->promise.setValue(std::move(result)); @@ -323,17 +327,45 @@ struct ZombieTaskStatsSet { } } }; + +// Add task to the task queue. +void enqueueTask( + TaskQueue& taskQueue, + std::shared_ptr& prestoTask) { + auto execTask = prestoTask->task; + if (execTask == nullptr) { + return; + } + + // If an entry exists with tasks for the same query, then add the task to it. + for (auto& entry : taskQueue) { + if (!entry.empty()) { + if (auto queuedTask = entry[0].lock()) { + auto queuedExecTask = queuedTask->task; + if (queuedExecTask && + (queuedExecTask->queryCtx() == execTask->queryCtx())) { + entry.emplace_back(prestoTask); + return; + } + } + } + } + // Otherwise create a new entry. + taskQueue.push_back({prestoTask}); +} } // namespace TaskManager::TaskManager( folly::Executor* driverExecutor, folly::Executor* httpSrvCpuExecutor, folly::Executor* spillerExecutor) - : bufferManager_(velox::exec::OutputBufferManager::getInstanceRef()), - queryContextManager_(std::make_unique( - driverExecutor, - spillerExecutor)), - httpSrvCpuExecutor_(httpSrvCpuExecutor) { + : queryContextManager_( + std::make_unique( + driverExecutor, + spillerExecutor)), + bufferManager_(velox::exec::OutputBufferManager::getInstanceRef()), + httpSrvCpuExecutor_(httpSrvCpuExecutor), + lastNotOverloadedTimeInSecs_(velox::getCurrentTimeSec()) { VELOX_CHECK_NOT_NULL(bufferManager_, "invalid OutputBufferManager"); } @@ -444,6 +476,13 @@ TaskManager::buildTaskSpillDirectoryPath( std::move(taskSpillDirPath), std::move(dateSpillDirPath)); } +void TaskManager::setServerOverloaded(bool serverOverloaded) { + serverOverloaded_ = serverOverloaded; + if (!serverOverloaded) { + lastNotOverloadedTimeInSecs_ = velox::getCurrentTimeSec(); + } +} + void TaskManager::getDataForResultRequests( const std::unordered_map>& resultRequests) { @@ -510,13 +549,17 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( bool summarize, std::shared_ptr queryCtx, long startProcessCpuTime) { + auto receiveTaskUpdateMs = getCurrentTimeMs(); std::shared_ptr execTask; bool startTask = false; auto prestoTask = findOrCreateTask(taskId, startProcessCpuTime); + if (prestoTask->firstTimeReceiveTaskUpdateMs == 0) { + prestoTask->firstTimeReceiveTaskUpdateMs = receiveTaskUpdateMs; + } { std::lock_guard l(prestoTask->mutex); prestoTask->updateCoordinatorHeartbeatLocked(); - if (not prestoTask->task && planFragment.planNode) { + if ((prestoTask->task == nullptr) && (planFragment.planNode != nullptr)) { // If the task is aborted, no need to do anything else. // This takes care of DELETE task message coming before CREATE task. if (prestoTask->info.taskStatus.state == protocol::TaskState::ABORTED) { @@ -524,6 +567,10 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( prestoTask->updateInfoLocked(summarize)); } + const auto baseSpillDir = *(baseSpillDir_.rlock()); + auto spillDiskOpts = + getTaskSpillOptions(taskId, planFragment, queryCtx, baseSpillDir); + // Uses a temp variable to store the created velox task to destroy it // under presto task lock if spill directory setup fails. Otherwise, the // concurrent task creation retry from the coordinator might see the @@ -537,16 +584,13 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( std::move(queryCtx), exec::Task::ExecutionMode::kParallel, static_cast(nullptr), - prestoTask->id.stageId()); - // TODO: move spill directory creation inside velox task execution - // whenever spilling is triggered. It will reduce the unnecessary file - // operations on remote storage. - const auto baseSpillDir = *(baseSpillDir_.rlock()); - maybeSetupTaskSpillDirectory(planFragment, *newExecTask, baseSpillDir); + prestoTask->id.stageId(), + spillDiskOpts); prestoTask->task = std::move(newExecTask); prestoTask->info.needsPlan = false; startTask = true; + prestoTask->createFinishTimeMs = getCurrentTimeMs(); } execTask = prestoTask->task; } @@ -584,12 +628,38 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( VLOG(1) << "Failed to update output buffers for task: " << taskId; } + folly::F14FastMap sourcesMap; for (const auto& source : sources) { + auto it = sourcesMap.find(source.planNodeId); + if (it == sourcesMap.end()) { + // No existing source with same planNodeId, add as new + sourcesMap.emplace(source.planNodeId, source); + continue; + } + + // Merge with existing source that has the same planNodeId + auto& merged = it->second; + + // Merge splits + merged.splits.insert( + merged.splits.end(), source.splits.begin(), source.splits.end()); + + // Merge noMoreSplitsForLifespan + merged.noMoreSplitsForLifespan.insert( + merged.noMoreSplitsForLifespan.end(), + source.noMoreSplitsForLifespan.begin(), + source.noMoreSplitsForLifespan.end()); + + // Use OR logic for noMoreSplits flag + merged.noMoreSplits = merged.noMoreSplits || source.noMoreSplits; + } + + for (const auto& [_, source] : sourcesMap) { // Add all splits from the source to the task. VLOG(1) << "Adding " << source.splits.size() << " splits to " << taskId << " for node " << source.planNodeId; // Keep track of the max sequence for this batch of splits. - long maxSplitSequenceId{-1}; + int64_t maxSplitSequenceId{-1}; for (const auto& protocolSplit : source.splits) { auto split = toVeloxSplit(protocolSplit); if (split.hasConnectorSplit()) { @@ -613,7 +683,13 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( if (source.noMoreSplits) { LOG(INFO) << "No more splits for " << taskId << " for node " << source.planNodeId; - execTask->noMoreSplits(source.planNodeId); + // If the task has not been started yet, we collect the plan node to + // call 'no more splits' after the start. + if (prestoTask->taskStarted) { + execTask->noMoreSplits(source.planNodeId); + } else { + prestoTask->delayedNoMoreSplitsPlanNodes_.emplace(source.planNodeId); + } } } @@ -642,8 +718,10 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( void TaskManager::maybeStartTaskLocked( std::shared_ptr& prestoTask, bool& startNextQueuedTask) { - // Default behavior (no task queuing) is to start the new task immediately. - if (!SystemConfig::instance()->workerOverloadedTaskQueuingEnabled()) { + // Start the new task if the task queuing is disabled. + // Also start it if some tasks from this query have already started. + if (!SystemConfig::instance()->workerOverloadedTaskQueuingEnabled() || + getQueryContextManager()->queryHasStartedTasks(prestoTask->info.taskId)) { startTaskLocked(prestoTask); return; } @@ -652,10 +730,11 @@ void TaskManager::maybeStartTaskLocked( // If server is overloaded, we don't start anything, but queue the new task. LOG(INFO) << "TASK QUEUE: Server is overloaded. Queueing task " << prestoTask->info.taskId; - taskQueue_.wlock()->emplace(prestoTask); + auto lockedTaskQueue = taskQueue_.wlock(); + enqueueTask(*lockedTaskQueue, prestoTask); } else { // If server is not overloaded, then we start the new task if the task queue - // is empty, otherwise we queue the new task and start the next queued task + // is empty, otherwise we queue the new task and start the first queued task // instead. { auto lockedTaskQueue = taskQueue_.wlock(); @@ -663,9 +742,9 @@ void TaskManager::maybeStartTaskLocked( LOG(INFO) << "TASK QUEUE: " "Server is not overloaded, but " << lockedTaskQueue->size() - << "queued tasks detected. Queueing task " + << " queued queries detected. Queueing task " << prestoTask->info.taskId; - lockedTaskQueue->emplace(prestoTask); + enqueueTask(*lockedTaskQueue, prestoTask); startNextQueuedTask = true; } } @@ -681,6 +760,8 @@ void TaskManager::startTaskLocked(std::shared_ptr& prestoTask) { return; } + getQueryContextManager()->setQueryHasStartedTasks(prestoTask->info.taskId); + const uint32_t maxDrivers = execTask->queryCtx()->queryConfig().get( kMaxDriversPerTask.data(), SystemConfig::instance()->maxDriversPerTask()); uint32_t concurrentLifespans = @@ -706,6 +787,8 @@ void TaskManager::startTaskLocked(std::shared_ptr& prestoTask) { // Record the time we spent between task creation and start, which is the // planned (queued) time. + // Note task could be created at getTaskStatus/getTaskInfo endpoint and later + // receive taskUpdate to create and start task. const auto queuedTimeInMs = velox::getCurrentTimeMs() - prestoTask->createTimeMs; prestoTask->info.stats.queuedTimeInNanos = queuedTimeInMs * 1'000'000; @@ -717,48 +800,75 @@ void TaskManager::maybeStartNextQueuedTask() { return; } - std::shared_ptr taskToStart; - size_t numQueuedTasks{0}; + // We will start all queued tasks from a single query. + std::vector> tasksToStart; // We run the loop here because some tasks might have failed or were aborted // or cancelled. Despite that we want to start at least one task. { auto lockedTaskQueue = taskQueue_.wlock(); while (!lockedTaskQueue->empty()) { - taskToStart = lockedTaskQueue->front().lock(); - lockedTaskQueue->pop(); + // Get the next entry. + auto queuedTasks = std::move(lockedTaskQueue->front()); + lockedTaskQueue->pop_front(); + + // Get all the still valid tasks from the entry. + bool queryTasksAreGoodToStart{true}; + for (auto& queuedTask : queuedTasks) { + auto taskToStart = queuedTask.lock(); + + // Task is already gone or no Velox task (the latter will never happen). + if (taskToStart == nullptr || taskToStart->task == nullptr) { + LOG(WARNING) << "TASK QUEUE: Skipping null task in the queue."; + queryTasksAreGoodToStart = false; + break; + } - // Task is already gone or no Velox task (the latter will never happen). - if (taskToStart == nullptr || taskToStart->task == nullptr) { - LOG(WARNING) << "TASK QUEUE: Skipping null task in the queue."; - continue; - } + // Sanity check. + VELOX_CHECK( + !taskToStart->taskStarted, + "TASK QUEUE: " + "The queued task must not be started, but it is already started"); + + const auto taskState = taskToStart->taskState(); + // If the status is not 'planned' then the tasks were likely aborted. + if (taskState != PrestoTaskState::kPlanned) { + LOG(INFO) << "TASK QUEUE: Discarding (not starting) queued task " + << taskToStart->info.taskId << " because state is " + << prestoTaskStateString(taskState); + queryTasksAreGoodToStart = false; + break; + } - // Sanity check. - VELOX_CHECK( - !taskToStart->taskStarted, - "TASK QUEUE: " - "The queued task must not be started yet, but it is already started"); + tasksToStart.emplace_back(taskToStart); + } - const auto taskState = taskToStart->taskState(); - // If the status is 'planned' then we got a task to start, exit the loop. - if (taskState == PrestoTaskState::kPlanned) { + if (queryTasksAreGoodToStart) { break; } - - LOG(INFO) << "TASK QUEUE: Discarding (not starting) queued task " - << taskToStart->info.taskId << " because state is " - << prestoTaskStateString(taskState); + tasksToStart.clear(); } - numQueuedTasks = lockedTaskQueue->size(); } - if (taskToStart) { + for (auto& taskToStart : tasksToStart) { std::lock_guard l(taskToStart->mutex); LOG(INFO) << "TASK QUEUE: Picking task to start from the queue: " - << taskToStart->info.taskId << ". " << numQueuedTasks - << " queued tasks left"; + << taskToStart->info.taskId; startTaskLocked(taskToStart); + // Make sure we call 'no more splits' we might have received before the task + // started. + auto execTask = taskToStart->task; + if (execTask != nullptr) { + for (const auto& planNodeId : + taskToStart->delayedNoMoreSplitsPlanNodes_) { + execTask->noMoreSplits(planNodeId); + } + taskToStart->delayedNoMoreSplitsPlanNodes_.clear(); + } + } + const auto queuedTasksLeft = numQueuedTasks(); + if (queuedTasksLeft > 0) { + LOG(INFO) << "TASK QUEUE: " << numQueuedTasks() << " queued tasks left"; } } @@ -931,8 +1041,9 @@ folly::Future> TaskManager::getTaskInfo( auto prestoTask = findOrCreateTask(taskId); if (!currentState || !maxWait) { // Return current TaskInfo without waiting. - promise.setValue(std::make_unique( - prestoTask->updateInfo(summarize))); + promise.setValue( + std::make_unique( + prestoTask->updateInfo(summarize))); prestoTask->updateCoordinatorHeartbeat(); return std::move(future).via(httpSrvCpuExecutor_); } @@ -975,8 +1086,9 @@ folly::Future> TaskManager::getTaskInfo( prestoTask->task->stateChangeFuture(maxWaitMicros) .via(httpSrvCpuExecutor_) .thenValue([promiseHolder, prestoTask, summarize](auto&& /*done*/) { - promiseHolder->promise.setValue(std::make_unique( - prestoTask->updateInfo(summarize))); + promiseHolder->promise.setValue( + std::make_unique( + prestoTask->updateInfo(summarize))); }) .thenError( folly::tag_t{}, @@ -1187,7 +1299,6 @@ std::shared_ptr TaskManager::findOrCreateTask( std::make_shared(taskId, nodeId_, startProcessCpuTime); prestoTask->info.stats.createTimeInMillis = velox::getCurrentTimeMs(); prestoTask->info.needsPlan = true; - prestoTask->info.metadataUpdates.connectorId = "unused"; struct UuidSplit { int64_t lo; @@ -1236,7 +1347,7 @@ std::string TaskManager::toString() const { return out.str(); } -velox::exec::Task::DriverCounts TaskManager::getDriverCounts() const { +velox::exec::Task::DriverCounts TaskManager::getDriverCounts() { const auto taskMap = *taskMap_.rlock(); velox::exec::Task::DriverCounts ret; for (const auto& pair : taskMap) { @@ -1251,6 +1362,7 @@ velox::exec::Task::DriverCounts TaskManager::getDriverCounts() const { } } } + numQueuedDrivers_ = ret.numQueuedDrivers; return ret; } @@ -1337,7 +1449,12 @@ std::array TaskManager::getTaskNumbers(size_t& numTasks) const { } size_t TaskManager::numQueuedTasks() const { - return this->taskQueue_.rlock()->size(); + size_t num = 0; + auto lockedTaskQueue = taskQueue_.rlock(); + for (const auto& entry : *lockedTaskQueue) { + num += entry.size(); + } + return num; } int64_t TaskManager::getBytesProcessed() const { diff --git a/presto-native-execution/presto_cpp/main/TaskManager.h b/presto-native-execution/presto_cpp/main/TaskManager.h index c5a4b319759ec..6d0fac7af8396 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.h +++ b/presto-native-execution/presto_cpp/main/TaskManager.h @@ -14,8 +14,8 @@ #pragma once #include +#include #include -#include #include "presto_cpp/main/PrestoTask.h" #include "presto_cpp/main/QueryContextManager.h" #include "presto_cpp/main/http/HttpServer.h" @@ -24,7 +24,8 @@ namespace facebook::presto { -using TaskQueue = std::queue>; +// One entry can hold multiple queued tasks for the same query. +using TaskQueue = std::deque>>; class TaskManager { public: @@ -33,6 +34,8 @@ class TaskManager { folly::Executor* httpSrvExecutor, folly::Executor* spillerExecutor); + virtual ~TaskManager() = default; + /// Invoked by Presto server shutdown to wait for all the tasks to complete /// and cleanup the completed tasks. void shutdown(); @@ -145,7 +148,7 @@ class TaskManager { int64_t getBytesProcessed() const; /// Stores the number of drivers in various states of execution. - velox::exec::Task::DriverCounts getDriverCounts() const; + velox::exec::Task::DriverCounts getDriverCounts(); /// Returns array with number of tasks for each of six PrestoTaskState (enum /// defined in PrestoTask.h). @@ -176,8 +179,20 @@ class TaskManager { /// Presto Server can notify the Task Manager that the former is overloaded, /// so the Task Manager can optionally change Task admission algorithm. - void setServerOverloaded(bool serverOverloaded) { - serverOverloaded_ = serverOverloaded; + void setServerOverloaded(bool serverOverloaded); + + bool isServerOverloaded() const { + return serverOverloaded_; + } + + uint64_t lastNotOverloadedTimeInSecs() const { + return lastNotOverloadedTimeInSecs_; + } + + /// Returns last known number of queued drivers. Used in determining if the + /// server is CPU overloaded. + uint32_t numQueuedDrivers() const { + return numQueuedDrivers_; } /// Contains the logic on starting tasks if not overloaded. @@ -188,6 +203,9 @@ class TaskManager { /// See if we have any queued tasks that can be started. void maybeStartNextQueuedTask(); + protected: + std::unique_ptr queryContextManager_; + private: static constexpr folly::StringPiece kMaxDriversPerTask{ "max_drivers_per_task"}; @@ -222,9 +240,10 @@ class TaskManager { std::shared_ptr bufferManager_; folly::Synchronized taskMap_; folly::Synchronized taskQueue_; - std::unique_ptr queryContextManager_; folly::Executor* httpSrvCpuExecutor_; std::atomic_bool serverOverloaded_{false}; + std::atomic_uint64_t lastNotOverloadedTimeInSecs_; + std::atomic_uint32_t numQueuedDrivers_{0}; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/TaskResource.cpp b/presto-native-execution/presto_cpp/main/TaskResource.cpp index afee3d1fbba36..736fc489ea062 100644 --- a/presto-native-execution/presto_cpp/main/TaskResource.cpp +++ b/presto-native-execution/presto_cpp/main/TaskResource.cpp @@ -19,6 +19,7 @@ #include "presto_cpp/main/thrift/ThriftIO.h" #include "presto_cpp/main/thrift/gen-cpp2/PrestoThrift.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" +#include "velox/core/PlanConsistencyChecker.h" namespace facebook::presto { @@ -147,7 +148,9 @@ proxygen::RequestHandler* TaskResource::abortResults( taskManager_.abortResults(taskId, destination); return true; }) - .via(folly::EventBaseManager::get()->getEventBase()) + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) .thenValue([downstream, handlerState](auto&& /* unused */) { if (!handlerState->requestExpired()) { http::sendOkResponse(downstream); @@ -182,17 +185,14 @@ proxygen::RequestHandler* TaskResource::acknowledgeResults( taskManager_.acknowledgeResults(taskId, bufferId, token); return true; }) - .via(folly::EventBaseManager::get()->getEventBase()) + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) .thenValue([downstream, handlerState](auto&& /* unused */) { if (!handlerState->requestExpired()) { http::sendOkResponse(downstream); } }) - .thenError( - folly::tag_t{}, - [downstream](auto&& e) { - http::sendErrorResponse(downstream, e.what()); - }) .thenError( folly::tag_t{}, [downstream, handlerState](auto&& e) { @@ -215,31 +215,53 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl( protocol::TaskId taskId = pathMatch[1]; bool summarize = message->hasQueryParam("summarize"); - auto& headers = message->getHeaders(); - const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); + const auto& headers = message->getHeaders(); + const auto& acceptHeader = + headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); const auto sendThrift = acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; - const auto& contentHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_CONTENT_TYPE); + const auto& contentHeader = + headers.getSingleOrEmpty(proxygen::HTTP_HEADER_CONTENT_TYPE); const auto receiveThrift = contentHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + const auto contentEncoding = headers.getSingleOrEmpty("Content-Encoding"); + const auto isCompressed = + !contentEncoding.empty() && contentEncoding != "identity"; return new http::CallbackRequestHandler( - [this, taskId, summarize, createOrUpdateFunc, sendThrift, receiveThrift]( + [this, + taskId, + summarize, + createOrUpdateFunc, + sendThrift, + receiveThrift, + contentEncoding, + isCompressed]( proxygen::HTTPMessage* /*message*/, const std::vector>& body, proxygen::ResponseHandler* downstream, std::shared_ptr handlerState) { folly::via( httpSrvCpuExecutor_, - [this, &body, taskId, summarize, createOrUpdateFunc, receiveThrift]() { + [this, + requestBody = isCompressed + ? util::decompressMessageBody(body, contentEncoding) + : util::extractMessageBody(body), + taskId, + summarize, + createOrUpdateFunc, + receiveThrift]() { const auto startProcessCpuTimeNs = util::getProcessCpuTimeNs(); - std::string requestBody = util::extractMessageBody(body); std::unique_ptr taskInfo; try { taskInfo = createOrUpdateFunc( - taskId, requestBody, summarize, startProcessCpuTimeNs, receiveThrift); - } catch (const velox::VeloxException& e) { + taskId, + requestBody, + summarize, + startProcessCpuTimeNs, + receiveThrift); + } catch (const velox::VeloxException&) { // Creating an empty task, putting errors inside so that next // status fetch from coordinator will catch the error and well // categorize it. @@ -249,13 +271,15 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl( std::current_exception(), summarize, startProcessCpuTimeNs); - } catch (const velox::VeloxUserError& e) { + } catch (const velox::VeloxUserError&) { throw; } } return taskInfo; }) - .via(folly::EventBaseManager::get()->getEventBase()) + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) .thenValue([downstream, handlerState, sendThrift](auto taskInfo) { if (!handlerState->requestExpired()) { if (sendThrift) { @@ -268,13 +292,6 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl( } } }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](auto&& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](auto&& e) { @@ -316,7 +333,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateBatchTask( } auto queryCtx = - taskManager_.getQueryContextManager()->findOrCreateQueryCtx( + taskManager_.getQueryContextManager()->findOrCreateBatchQueryCtx( taskId, updateRequest); VeloxBatchQueryPlanConverter converter( @@ -327,6 +344,9 @@ proxygen::RequestHandler* TaskResource::createOrUpdateBatchTask( pool_); auto planFragment = converter.toVeloxQueryPlan( prestoPlan, updateRequest.tableWriteInfo, taskId); + if (SystemConfig::instance()->planConsistencyCheckEnabled()) { + velox::core::PlanConsistencyChecker::check(planFragment.planNode); + } return taskManager_.createOrUpdateBatchTask( taskId, @@ -351,7 +371,8 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTask( bool receiveThrift) { protocol::TaskUpdateRequest updateRequest; if (receiveThrift) { - auto thriftTaskUpdateRequest = std::make_shared(); + auto thriftTaskUpdateRequest = + std::make_shared(); thriftRead(requestBody, thriftTaskUpdateRequest); fromThrift(*thriftTaskUpdateRequest, updateRequest); } else { @@ -360,7 +381,10 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTask( velox::core::PlanFragment planFragment; std::shared_ptr queryCtx; if (updateRequest.fragment) { - protocol::PlanFragment prestoPlan = json::parse(receiveThrift ? *updateRequest.fragment : velox::encoding::Base64::decode(*updateRequest.fragment)); + protocol::PlanFragment prestoPlan = json::parse( + receiveThrift + ? *updateRequest.fragment + : velox::encoding::Base64::decode(*updateRequest.fragment)); queryCtx = taskManager_.getQueryContextManager()->findOrCreateQueryCtx( @@ -369,6 +393,9 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTask( VeloxInteractiveQueryPlanConverter converter(queryCtx.get(), pool_); planFragment = converter.toVeloxQueryPlan( prestoPlan, updateRequest.tableWriteInfo, taskId); + if (SystemConfig::instance()->planConsistencyCheckEnabled()) { + velox::core::PlanConsistencyChecker::check(planFragment.planNode); + } planValidator_->validatePlanFragment(planFragment); } @@ -392,9 +419,14 @@ proxygen::RequestHandler* TaskResource::deleteTask( message->getQueryParam(protocol::PRESTO_ABORT_TASK_URL_PARAM) == "true"; } bool summarize = message->hasQueryParam("summarize"); + const auto& headers = message->getHeaders(); + const auto& acceptHeader = + headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); + const auto sendThrift = + acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; return new http::CallbackRequestHandler( - [this, taskId, abort, summarize]( + [this, taskId, abort, summarize, sendThrift]( proxygen::HTTPMessage* /*message*/, const std::vector>& /*body*/, proxygen::ResponseHandler* downstream, @@ -406,22 +438,26 @@ proxygen::RequestHandler* TaskResource::deleteTask( taskInfo = taskManager_.deleteTask(taskId, abort, summarize); return std::move(taskInfo); }) - .via(folly::EventBaseManager::get()->getEventBase()) - .thenValue([taskId, downstream, handlerState](auto&& taskInfo) { + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) + .thenValue([taskId, downstream, handlerState, sendThrift]( + auto&& taskInfo) { if (!handlerState->requestExpired()) { if (taskInfo == nullptr) { sendTaskNotFound(downstream, taskId); + return; + } + if (sendThrift) { + thrift::TaskInfo thriftTaskInfo; + toThrift(*taskInfo, thriftTaskInfo); + http::sendOkThriftResponse( + downstream, thriftWrite(thriftTaskInfo)); + } else { + http::sendOkResponse(downstream, json(*taskInfo)); } - http::sendOkResponse(downstream, json(*taskInfo)); } }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](auto&& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](auto&& e) { @@ -459,11 +495,11 @@ proxygen::RequestHandler* TaskResource::getResults( const std::vector>& /*body*/, proxygen::ResponseHandler* downstream, std::shared_ptr handlerState) { - auto evb = folly::EventBaseManager::get()->getEventBase(); folly::via( httpSrvCpuExecutor_, [this, - evb, + evb = folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase()), taskId, bufferId, token, @@ -505,16 +541,13 @@ proxygen::RequestHandler* TaskResource::getResults( protocol::PRESTO_BUFFER_REMAINING_BYTES_HEADER, folly::join(',', result->remainingBytes)); } + if (result->waitTimeMs > 0) { + builder.header( + protocol::PRESTO_BUFFER_WAIT_TIME_MS_HEADER, + std::to_string(result->waitTimeMs)); + } builder.body(std::move(result->data)).sendWithEOM(); }) - .thenError( - folly::tag_t{}, - [downstream, - handlerState](const velox::VeloxException& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](const std::exception& e) { @@ -533,8 +566,9 @@ proxygen::RequestHandler* TaskResource::getTaskStatus( auto currentState = getCurrentState(message); auto maxWait = getMaxWait(message); - auto& headers = message->getHeaders(); - const auto& acceptHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); + const auto& headers = message->getHeaders(); + const auto& acceptHeader = + headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); const auto sendThrift = acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; @@ -544,11 +578,11 @@ proxygen::RequestHandler* TaskResource::getTaskStatus( const std::vector>& /*body*/, proxygen::ResponseHandler* downstream, std::shared_ptr handlerState) { - auto evb = folly::EventBaseManager::get()->getEventBase(); folly::via( httpSrvCpuExecutor_, [this, - evb, + evb = folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase()), sendThrift, taskId, currentState, @@ -573,14 +607,6 @@ proxygen::RequestHandler* TaskResource::getTaskStatus( } } }) - .thenError( - folly::tag_t{}, - [downstream, - handlerState](const velox::VeloxException& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](const std::exception& e) { @@ -604,8 +630,14 @@ proxygen::RequestHandler* TaskResource::getTaskInfo( auto maxWait = getMaxWait(message); bool summarize = message->hasQueryParam("summarize"); + const auto& headers = message->getHeaders(); + const auto& acceptHeader = + headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); + const auto sendThrift = + acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + return new http::CallbackRequestHandler( - [this, taskId, currentState, maxWait, summarize]( + [this, taskId, currentState, maxWait, summarize, sendThrift]( proxygen::HTTPMessage* /*message*/, const std::vector>& /*body*/, proxygen::ResponseHandler* downstream, @@ -613,32 +645,32 @@ proxygen::RequestHandler* TaskResource::getTaskInfo( folly::via( httpSrvCpuExecutor_, [this, - evb = folly::EventBaseManager::get()->getEventBase(), + evb = folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase()), taskId, currentState, maxWait, summarize, handlerState, - downstream]() { + downstream, + sendThrift]() { taskManager_ .getTaskInfo( taskId, summarize, currentState, maxWait, handlerState) .via(evb) - .thenValue([downstream, taskId, handlerState]( + .thenValue([downstream, taskId, handlerState, sendThrift]( std::unique_ptr taskInfo) { if (!handlerState->requestExpired()) { - json taskInfoJson = *taskInfo; - http::sendOkResponse(downstream, taskInfoJson); + if (sendThrift) { + thrift::TaskInfo thriftTaskInfo; + toThrift(*taskInfo, thriftTaskInfo); + http::sendOkThriftResponse( + downstream, thriftWrite(thriftTaskInfo)); + } else { + http::sendOkResponse(downstream, json(*taskInfo)); + } } }) - .thenError( - folly::tag_t{}, - [downstream, - handlerState](const velox::VeloxException& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](const std::exception& e) { @@ -670,19 +702,14 @@ proxygen::RequestHandler* TaskResource::removeRemoteSource( [this, taskId, remoteId, downstream]() { taskManager_.removeRemoteSource(taskId, remoteId); }) - .via(folly::EventBaseManager::get()->getEventBase()) + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) .thenValue([downstream, handlerState](auto&& /* unused */) { if (!handlerState->requestExpired()) { http::sendOkResponse(downstream); } }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](const velox::VeloxException& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }) .thenError( folly::tag_t{}, [downstream, handlerState](const std::exception& e) { diff --git a/presto-native-execution/presto_cpp/main/common/CMakeLists.txt b/presto-native-execution/presto_cpp/main/common/CMakeLists.txt index c19e2c0b7eb2a..0f98aeb721bfd 100644 --- a/presto-native-execution/presto_cpp/main/common/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/common/CMakeLists.txt @@ -14,11 +14,15 @@ add_library(presto_exception Exception.cpp) add_library(presto_common Counters.cpp Utils.cpp ConfigReader.cpp Configs.cpp) target_link_libraries(presto_exception velox_exception) -set_property(TARGET presto_exception PROPERTY JOB_POOL_LINK - presto_link_job_pool) +set_property(TARGET presto_exception PROPERTY JOB_POOL_LINK presto_link_job_pool) -target_link_libraries(presto_common velox_common_config velox_core - velox_exception) +target_link_libraries( + presto_common + velox_common_config + velox_core + velox_exception + velox_presto_serializer +) set_property(TARGET presto_common PROPERTY JOB_POOL_LINK presto_link_job_pool) if(PRESTO_ENABLE_TESTING) diff --git a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp index 027193a2f045a..b10d9962c0bb9 100644 --- a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp +++ b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp @@ -87,7 +87,7 @@ std::string requiredProperty( const velox::config::ConfigBase& properties, const std::string& name) { auto value = properties.get(name); - if (!value.hasValue()) { + if (!value.has_value()) { VELOX_USER_FAIL("Missing configuration property {}", name); } return value.value(); @@ -120,7 +120,7 @@ std::string getOptionalProperty( const std::string& name, const std::string& defaultValue) { auto value = properties.get(name); - if (!value.hasValue()) { + if (!value.has_value()) { return defaultValue; } return value.value(); diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 9e59436943d03..89f14085da3e2 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -13,6 +13,7 @@ */ #include "presto_cpp/main/common/Configs.h" +#include #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Utils.h" #include "velox/core/QueryConfig.h" @@ -20,6 +21,7 @@ #include #include #include +#include #if __has_include("filesystem") #include namespace fs = std::filesystem; @@ -37,23 +39,20 @@ std::string bool2String(bool value) { return value ? "true" : "false"; } -int getThreadCount() { - auto numThreads = std::thread::hardware_concurrency(); - // The spec says std::thread::hardware_concurrency() might return 0. - // But we depend on std::thread::hardware_concurrency() to create executors. +uint32_t hardwareConcurrency() { + const auto numLogicalCores = folly::hardware_concurrency(); + // The spec says folly::hardware_concurrency() might return 0. + // But we depend on folly::hardware_concurrency() to create executors. // Check to ensure numThreads is > 0. - VELOX_CHECK_GT(numThreads, 0); - return numThreads; + VELOX_CHECK_GT(numLogicalCores, 0); + return numLogicalCores; } -#define STR_PROP(_key_, _val_) \ - { std::string(_key_), std::string(_val_) } +#define STR_PROP(_key_, _val_) {std::string(_key_), std::string(_val_)} #define NUM_PROP(_key_, _val_) \ - { std::string(_key_), folly::to(_val_) } -#define BOOL_PROP(_key_, _val_) \ - { std::string(_key_), bool2String(_val_) } -#define NONE_PROP(_key_) \ - { std::string(_key_), folly::none } + {std::string(_key_), folly::to(_val_)} +#define BOOL_PROP(_key_, _val_) {std::string(_key_), bool2String(_val_)} +#define NONE_PROP(_key_) {std::string(_key_), folly::none} } // namespace void ConfigBase::initialize(const std::string& filePath, bool optionalConfig) { @@ -78,13 +77,6 @@ void ConfigBase::initialize(const std::string& filePath, bool optionalConfig) { std::move(values), mutableConfig); } -std::string ConfigBase::capacityPropertyAsBytesString( - std::string_view propertyName) const { - return folly::to(velox::config::toCapacity( - optionalProperty(propertyName).value(), - velox::config::CapacityUnit::BYTE)); -} - bool ConfigBase::registerProperty( const std::string& propertyName, const folly::Optional& defaultValue) { @@ -110,8 +102,8 @@ folly::Optional ConfigBase::setValue( propertyName); auto oldValue = config_->get(propertyName); config_->set(propertyName, value); - if (oldValue.hasValue()) { - return oldValue; + if (oldValue.has_value()) { + return oldValue.value(); } return registeredProps_[propertyName]; } @@ -147,21 +139,36 @@ SystemConfig::SystemConfig() { BOOL_PROP(kHttpServerReusePort, false), BOOL_PROP(kHttpServerBindToNodeInternalAddressOnlyEnabled, false), NONE_PROP(kDiscoveryUri), - NUM_PROP(kMaxDriversPerTask, getThreadCount()), + NUM_PROP(kMaxDriversPerTask, hardwareConcurrency()), NONE_PROP(kTaskWriterCount), NONE_PROP(kTaskPartitionedWriterCount), + NONE_PROP(kTaskMaxStorageBroadcastBytes), NUM_PROP(kConcurrentLifespansPerTask, 1), STR_PROP(kTaskMaxPartialAggregationMemory, "16MB"), + NUM_PROP(kDriverMaxSplitPreload, 2), NUM_PROP(kHttpServerNumIoThreadsHwMultiplier, 1.0), NUM_PROP(kHttpServerNumCpuThreadsHwMultiplier, 1.0), NONE_PROP(kHttpServerHttpsPort), BOOL_PROP(kHttpServerHttpsEnabled, false), + BOOL_PROP(kHttpServerHttp2Enabled, true), + NUM_PROP(kHttpServerIdleTimeoutMs, 60'000), + NUM_PROP(kHttpServerHttp2InitialReceiveWindow, 1 << 20), + NUM_PROP(kHttpServerHttp2ReceiveStreamWindowSize, 1 << 20), + NUM_PROP(kHttpServerHttp2ReceiveSessionWindowSize, 10 * (1 << 20)), + NUM_PROP(kHttpServerHttp2MaxConcurrentStreams, 100), + NUM_PROP(kHttpServerContentCompressionLevel, 4), + NUM_PROP(kHttpServerContentCompressionMinimumSize, 3584), + BOOL_PROP(kHttpServerEnableContentCompression, false), + BOOL_PROP(kHttpServerEnableZstdCompression, false), + NUM_PROP(kHttpServerZstdContentCompressionLevel, 8), + BOOL_PROP(kHttpServerEnableGzipCompression, false), STR_PROP( kHttpsSupportedCiphers, "ECDHE-ECDSA-AES256-GCM-SHA384,AES256-GCM-SHA384"), NONE_PROP(kHttpsCertPath), NONE_PROP(kHttpsKeyPath), NONE_PROP(kHttpsClientCertAndKeyPath), + NONE_PROP(kHttpsClientCaFile), NUM_PROP(kExchangeHttpClientNumIoThreadsHwMultiplier, 1.0), NUM_PROP(kExchangeHttpClientNumCpuThreadsHwMultiplier, 1.0), NUM_PROP(kConnectorNumCpuThreadsHwMultiplier, 0.0), @@ -171,7 +178,9 @@ SystemConfig::SystemConfig() { NUM_PROP(kDriverStuckOperatorThresholdMs, 30 * 60 * 1000), NUM_PROP( kDriverCancelTasksWithStuckOperatorsThresholdMs, 40 * 60 * 1000), - NUM_PROP(kDriverNumStuckOperatorsToDetachWorker, 8), + NUM_PROP( + kDriverNumStuckOperatorsToDetachWorker, + std::round(0.5 * hardwareConcurrency())), NUM_PROP(kSpillerNumCpuThreadsHwMultiplier, 1.0), STR_PROP(kSpillerFileCreateConfig, ""), STR_PROP(kSpillerDirectoryCreateConfig, ""), @@ -185,7 +194,9 @@ SystemConfig::SystemConfig() { BOOL_PROP(kSystemMemPushbackAbortEnabled, false), NUM_PROP(kWorkerOverloadedThresholdMemGb, 0), NUM_PROP(kWorkerOverloadedThresholdCpuPct, 0), + NUM_PROP(kWorkerOverloadedThresholdNumQueuedDriversHwMultiplier, 0.0), NUM_PROP(kWorkerOverloadedCooldownPeriodSec, 5), + NUM_PROP(kWorkerOverloadedSecondsToDetachWorker, 0), BOOL_PROP(kWorkerOverloadedTaskQueuingEnabled, false), NUM_PROP(kMallocHeapDumpThresholdGb, 20), NUM_PROP(kMallocMemMinHeapDumpInterval, 10), @@ -202,6 +213,7 @@ SystemConfig::SystemConfig() { BOOL_PROP(kAsyncCacheSsdDisableFileCow, false), BOOL_PROP(kSsdCacheChecksumEnabled, false), BOOL_PROP(kSsdCacheReadVerificationEnabled, false), + NUM_PROP(kSsdCacheMaxEntries, 10'000'000), BOOL_PROP(kEnableSerializedPageChecksum, true), BOOL_PROP(kUseMmapAllocator, true), STR_PROP(kMemoryArbitratorKind, ""), @@ -231,12 +243,19 @@ SystemConfig::SystemConfig() { NUM_PROP(kLogNumZombieTasks, 20), NUM_PROP(kAnnouncementMaxFrequencyMs, 30'000), // 30s NUM_PROP(kHeartbeatFrequencyMs, 0), + BOOL_PROP(kHttpClientHttp2Enabled, false), + NUM_PROP(kHttpClientHttp2MaxStreamsPerConnection, 8), + NUM_PROP(kHttpClientHttp2InitialStreamWindow, 1 << 23 /*8MB*/), + NUM_PROP(kHttpClientHttp2StreamWindow, 1 << 23 /*8MB*/), + NUM_PROP(kHttpClientHttp2SessionWindow, 1 << 26 /*64MB*/), + BOOL_PROP(kHttpClientConnectionReuseCounterEnabled, true), STR_PROP(kExchangeMaxErrorDuration, "3m"), STR_PROP(kExchangeRequestTimeout, "20s"), STR_PROP(kExchangeConnectTimeout, "20s"), BOOL_PROP(kExchangeEnableConnectionPool, true), BOOL_PROP(kExchangeEnableBufferCopy, true), BOOL_PROP(kExchangeImmediateBufferTransfer, true), + STR_PROP(kExchangeMaxBufferSize, "32MB"), NUM_PROP(kTaskRunTimeSliceMicros, 50'000), BOOL_PROP(kIncludeNodeInSpillPath, false), NUM_PROP(kOldTaskCleanUpMs, 60'000), @@ -258,7 +277,22 @@ SystemConfig::SystemConfig() { BOOL_PROP(kJoinSpillEnabled, true), BOOL_PROP(kAggregationSpillEnabled, true), BOOL_PROP(kOrderBySpillEnabled, true), + NUM_PROP(kMaxSpillBytes, 100UL << 30), // 100GB + BOOL_PROP(kBroadcastJoinTableCachingEnabled, false), + BOOL_PROP(kExchangeLazyFetchingEnabled, false), NUM_PROP(kRequestDataSizesMaxWaitSec, 10), + STR_PROP(kPluginDir, ""), + NUM_PROP(kExchangeIoEvbViolationThresholdMs, 1000), + NUM_PROP(kHttpSrvIoEvbViolationThresholdMs, 1000), + NUM_PROP(kMaxLocalExchangeBufferSize, 32UL << 20), // 32MB + NUM_PROP(kMaxLocalExchangePartitionBufferSize, 65536), // 64KB + BOOL_PROP(kParallelOutputJoinBuildRowsEnabled, false), + NUM_PROP(kHashProbeBloomFilterPushdownMaxSize, 0), + BOOL_PROP(kTextWriterEnabled, true), + BOOL_PROP(kTextReaderEnabled, true), + BOOL_PROP(kCharNToVarcharImplicitCast, false), + BOOL_PROP(kEnumTypesEnabled, true), + BOOL_PROP(kPlanConsistencyCheckEnabled, false), }; } @@ -289,6 +323,60 @@ bool SystemConfig::httpServerHttpsEnabled() const { return optionalProperty(kHttpServerHttpsEnabled).value(); } +bool SystemConfig::httpServerHttp2Enabled() const { + return optionalProperty(kHttpServerHttp2Enabled).value(); +} + +uint32_t SystemConfig::httpServerIdleTimeoutMs() const { + return optionalProperty(kHttpServerIdleTimeoutMs).value(); +} + +uint32_t SystemConfig::httpServerHttp2InitialReceiveWindow() const { + return optionalProperty(kHttpServerHttp2InitialReceiveWindow) + .value(); +} + +uint32_t SystemConfig::httpServerHttp2ReceiveStreamWindowSize() const { + return optionalProperty(kHttpServerHttp2ReceiveStreamWindowSize) + .value(); +} + +uint32_t SystemConfig::httpServerHttp2ReceiveSessionWindowSize() const { + return optionalProperty(kHttpServerHttp2ReceiveSessionWindowSize) + .value(); +} + +uint32_t SystemConfig::httpServerHttp2MaxConcurrentStreams() const { + return optionalProperty(kHttpServerHttp2MaxConcurrentStreams) + .value(); +} + +uint32_t SystemConfig::httpServerContentCompressionLevel() const { + return optionalProperty(kHttpServerContentCompressionLevel).value(); +} + +uint32_t SystemConfig::httpServerContentCompressionMinimumSize() const { + return optionalProperty(kHttpServerContentCompressionMinimumSize) + .value(); +} + +bool SystemConfig::httpServerEnableContentCompression() const { + return optionalProperty(kHttpServerEnableContentCompression).value(); +} + +bool SystemConfig::httpServerEnableZstdCompression() const { + return optionalProperty(kHttpServerEnableZstdCompression).value(); +} + +uint32_t SystemConfig::httpServerZstdContentCompressionLevel() const { + return optionalProperty(kHttpServerZstdContentCompressionLevel) + .value(); +} + +bool SystemConfig::httpServerEnableGzipCompression() const { + return optionalProperty(kHttpServerEnableGzipCompression).value(); +} + std::string SystemConfig::httpsSupportedCiphers() const { return optionalProperty(kHttpsSupportedCiphers).value(); } @@ -305,6 +393,10 @@ folly::Optional SystemConfig::httpsClientCertAndKeyPath() const { return optionalProperty(kHttpsClientCertAndKeyPath); } +folly::Optional SystemConfig::httpsClientCaFile() const { + return optionalProperty(kHttpsClientCaFile); +} + std::string SystemConfig::prestoVersion() const { return requiredProperty(std::string(kPrestoVersion)); } @@ -338,6 +430,18 @@ bool SystemConfig::orderBySpillEnabled() const { return optionalProperty(kOrderBySpillEnabled).value(); } +bool SystemConfig::broadcastJoinTableCachingEnabled() const { + return optionalProperty(kBroadcastJoinTableCachingEnabled).value(); +} + +bool SystemConfig::exchangeLazyFetchingEnabled() const { + return optionalProperty(kExchangeLazyFetchingEnabled).value(); +} + +uint64_t SystemConfig::maxSpillBytes() const { + return optionalProperty(kMaxSpillBytes).value(); +} + int SystemConfig::requestDataSizesMaxWaitSec() const { return optionalProperty(kRequestDataSizesMaxWaitSec).value(); } @@ -355,7 +459,7 @@ SystemConfig::remoteFunctionServerLocation() const { // First check if there is a UDS path registered. If there's one, use it. auto remoteServerUdsPath = optionalProperty(kRemoteFunctionServerThriftUdsPath); - if (remoteServerUdsPath.hasValue()) { + if (remoteServerUdsPath.has_value()) { return folly::SocketAddress::makeFromPath(remoteServerUdsPath.value()); } @@ -365,13 +469,13 @@ SystemConfig::remoteFunctionServerLocation() const { auto remoteServerPort = optionalProperty(kRemoteFunctionServerThriftPort); - if (remoteServerPort.hasValue()) { + if (remoteServerPort.has_value()) { // Fallback to localhost if address is not specified. - return remoteServerAddress.hasValue() + return remoteServerAddress.has_value() ? folly:: SocketAddress{remoteServerAddress.value(), remoteServerPort.value()} : folly::SocketAddress{"::1", remoteServerPort.value()}; - } else if (remoteServerAddress.hasValue()) { + } else if (remoteServerAddress.has_value()) { VELOX_FAIL( "Remote function server port not provided using '{}'.", kRemoteFunctionServerThriftPort); @@ -394,10 +498,18 @@ std::string SystemConfig::remoteFunctionServerSerde() const { return optionalProperty(kRemoteFunctionServerSerde).value(); } +std::string SystemConfig::remoteFunctionServerRestURL() const { + return optionalProperty(kRemoteFunctionServerRestURL).value(); +} + int32_t SystemConfig::maxDriversPerTask() const { return optionalProperty(kMaxDriversPerTask).value(); } +int32_t SystemConfig::driverMaxSplitPreload() const { + return optionalProperty(kDriverMaxSplitPreload).value(); +} + folly::Optional SystemConfig::taskWriterCount() const { return optionalProperty(kTaskWriterCount); } @@ -406,6 +518,10 @@ folly::Optional SystemConfig::taskPartitionedWriterCount() const { return optionalProperty(kTaskPartitionedWriterCount); } +folly::Optional SystemConfig::taskMaxStorageBroadcastBytes() const { + return optionalProperty(kTaskMaxStorageBroadcastBytes); +} + int32_t SystemConfig::concurrentLifespansPerTask() const { return optionalProperty(kConcurrentLifespansPerTask).value(); } @@ -511,10 +627,22 @@ uint32_t SystemConfig::workerOverloadedThresholdCpuPct() const { return optionalProperty(kWorkerOverloadedThresholdCpuPct).value(); } +double SystemConfig::workerOverloadedThresholdNumQueuedDriversHwMultiplier() + const { + return optionalProperty( + kWorkerOverloadedThresholdNumQueuedDriversHwMultiplier) + .value(); +} + uint32_t SystemConfig::workerOverloadedCooldownPeriodSec() const { return optionalProperty(kWorkerOverloadedCooldownPeriodSec).value(); } +uint64_t SystemConfig::workerOverloadedSecondsToDetachWorker() const { + return optionalProperty(kWorkerOverloadedSecondsToDetachWorker) + .value(); +} + bool SystemConfig::workerOverloadedTaskQueuingEnabled() const { return optionalProperty(kWorkerOverloadedTaskQueuingEnabled).value(); } @@ -585,6 +713,10 @@ bool SystemConfig::ssdCacheReadVerificationEnabled() const { return optionalProperty(kSsdCacheReadVerificationEnabled).value(); } +uint64_t SystemConfig::ssdCacheMaxEntries() const { + return optionalProperty(kSsdCacheMaxEntries).value(); +} + std::string SystemConfig::shuffleName() const { return optionalProperty(kShuffleName).value(); } @@ -655,8 +787,9 @@ std::string SystemConfig::sharedArbitratorFastExponentialGrowthCapacityLimit() kSharedArbitratorFastExponentialGrowthCapacityLimitDefault = "512MB"; return optionalProperty( kSharedArbitratorFastExponentialGrowthCapacityLimit) - .value_or(std::string( - kSharedArbitratorFastExponentialGrowthCapacityLimitDefault)); + .value_or( + std::string( + kSharedArbitratorFastExponentialGrowthCapacityLimitDefault)); } std::string SystemConfig::sharedArbitratorSlowCapacityGrowPct() const { @@ -706,8 +839,9 @@ std::string SystemConfig::sharedArbitratorMemoryReclaimThreadsHwMultiplier() kSharedArbitratorMemoryReclaimThreadsHwMultiplierDefault = "0.5"; return optionalProperty( kSharedArbitratorMemoryReclaimThreadsHwMultiplier) - .value_or(std::string( - kSharedArbitratorMemoryReclaimThreadsHwMultiplierDefault)); + .value_or( + std::string( + kSharedArbitratorMemoryReclaimThreadsHwMultiplierDefault)); } std::string SystemConfig::sharedArbitratorGlobalArbitrationMemoryReclaimPct() @@ -716,8 +850,9 @@ std::string SystemConfig::sharedArbitratorGlobalArbitrationMemoryReclaimPct() kSharedArbitratorGlobalArbitrationMemoryReclaimPctDefault = "10"; return optionalProperty( kSharedArbitratorGlobalArbitrationMemoryReclaimPct) - .value_or(std::string( - kSharedArbitratorGlobalArbitrationMemoryReclaimPctDefault)); + .value_or( + std::string( + kSharedArbitratorGlobalArbitrationMemoryReclaimPctDefault)); } std::string SystemConfig::sharedArbitratorGlobalArbitrationAbortTimeRatio() @@ -796,6 +931,33 @@ uint64_t SystemConfig::heartbeatFrequencyMs() const { return optionalProperty(kHeartbeatFrequencyMs).value(); } +bool SystemConfig::httpClientHttp2Enabled() const { + return optionalProperty(kHttpClientHttp2Enabled).value(); +} + +uint32_t SystemConfig::httpClientHttp2MaxStreamsPerConnection() const { + return optionalProperty(kHttpClientHttp2MaxStreamsPerConnection) + .value(); +} + +uint32_t SystemConfig::httpClientHttp2InitialStreamWindow() const { + return optionalProperty(kHttpClientHttp2InitialStreamWindow) + .value(); +} + +uint32_t SystemConfig::httpClientHttp2StreamWindow() const { + return optionalProperty(kHttpClientHttp2StreamWindow).value(); +} + +uint32_t SystemConfig::httpClientHttp2SessionWindow() const { + return optionalProperty(kHttpClientHttp2SessionWindow).value(); +} + +bool SystemConfig::httpClientConnectionReuseCounterEnabled() const { + return optionalProperty(kHttpClientConnectionReuseCounterEnabled) + .value(); +} + std::chrono::duration SystemConfig::exchangeMaxErrorDuration() const { return velox::config::toDuration( optionalProperty(kExchangeMaxErrorDuration).value()); @@ -823,6 +985,12 @@ bool SystemConfig::exchangeImmediateBufferTransfer() const { return optionalProperty(kExchangeImmediateBufferTransfer).value(); } +uint64_t SystemConfig::exchangeMaxBufferSize() const { + return velox::config::toCapacity( + optionalProperty(kExchangeMaxBufferSize).value(), + velox::config::CapacityUnit::BYTE); +} + int32_t SystemConfig::taskRunTimeSliceMicros() const { return optionalProperty(kTaskRunTimeSliceMicros).value(); } @@ -884,6 +1052,56 @@ std::string SystemConfig::prestoDefaultNamespacePrefix() const { return optionalProperty(kPrestoDefaultNamespacePrefix).value().append("."); } +std::string SystemConfig::pluginDir() const { + return optionalProperty(kPluginDir).value(); +} + +int32_t SystemConfig::exchangeIoEvbViolationThresholdMs() const { + return optionalProperty(kExchangeIoEvbViolationThresholdMs).value(); +} + +int32_t SystemConfig::httpSrvIoEvbViolationThresholdMs() const { + return optionalProperty(kHttpSrvIoEvbViolationThresholdMs).value(); +} + +uint64_t SystemConfig::maxLocalExchangeBufferSize() const { + return optionalProperty(kMaxLocalExchangeBufferSize).value(); +} + +uint64_t SystemConfig::maxLocalExchangePartitionBufferSize() const { + return optionalProperty(kMaxLocalExchangePartitionBufferSize) + .value(); +} + +bool SystemConfig::parallelOutputJoinBuildRowsEnabled() const { + return optionalProperty(kParallelOutputJoinBuildRowsEnabled).value(); +} + +uint64_t SystemConfig::hashProbeBloomFilterPushdownMaxSize() const { + return optionalProperty(kHashProbeBloomFilterPushdownMaxSize) + .value(); +} + +bool SystemConfig::textWriterEnabled() const { + return optionalProperty(kTextWriterEnabled).value(); +} + +bool SystemConfig::textReaderEnabled() const { + return optionalProperty(kTextReaderEnabled).value(); +} + +bool SystemConfig::charNToVarcharImplicitCast() const { + return optionalProperty(kCharNToVarcharImplicitCast).value(); +} + +bool SystemConfig::enumTypesEnabled() const { + return optionalProperty(kEnumTypesEnabled).value(); +} + +bool SystemConfig::planConsistencyCheckEnabled() const { + return optionalProperty(kPlanConsistencyCheckEnabled).value(); +} + NodeConfig::NodeConfig() { registeredProps_ = std::unordered_map>{ @@ -892,6 +1110,7 @@ NodeConfig::NodeConfig() { NONE_PROP(kNodeIp), NONE_PROP(kNodeInternalAddress), NONE_PROP(kNodeLocation), + NONE_PROP(kNodePrometheusExecutorThreads), }; } @@ -904,9 +1123,18 @@ std::string NodeConfig::nodeEnvironment() const { return requiredProperty(kNodeEnvironment); } +int NodeConfig::prometheusExecutorThreads() const { + static constexpr int kNodePrometheusExecutorThreadsDefault = 2; + auto resultOpt = optionalProperty(kNodePrometheusExecutorThreads); + if (resultOpt.has_value()) { + return resultOpt.value(); + } + return kNodePrometheusExecutorThreadsDefault; +} + std::string NodeConfig::nodeId() const { auto resultOpt = optionalProperty(kNodeId); - if (resultOpt.hasValue()) { + if (resultOpt.has_value()) { return resultOpt.value(); } // Generate the nodeId which must be a UUID. nodeId must be a singleton. @@ -924,7 +1152,7 @@ std::string NodeConfig::nodeInternalAddress( auto resultOpt = optionalProperty(kNodeInternalAddress); /// node.ip(kNodeIp) is legacy config replaced with node.internal-address, but /// still valid config in Presto, so handling both. - if (!resultOpt.hasValue()) { + if (!resultOpt.has_value()) { resultOpt = optionalProperty(kNodeIp); } if (resultOpt.has_value()) { @@ -938,126 +1166,4 @@ std::string NodeConfig::nodeInternalAddress( } } -BaseVeloxQueryConfig::BaseVeloxQueryConfig() { - // Use empty instance to get default property values. - velox::core::QueryConfig c{{}}; - using namespace velox::core; - registeredProps_ = - std::unordered_map>{ - BOOL_PROP(kMutableConfig, false), - STR_PROP(QueryConfig::kSessionTimezone, c.sessionTimezone()), - BOOL_PROP( - QueryConfig::kAdjustTimestampToTimezone, - c.adjustTimestampToTimezone()), - BOOL_PROP(QueryConfig::kExprEvalSimplified, c.exprEvalSimplified()), - BOOL_PROP(QueryConfig::kExprTrackCpuUsage, c.exprTrackCpuUsage()), - BOOL_PROP( - QueryConfig::kOperatorTrackCpuUsage, c.operatorTrackCpuUsage()), - BOOL_PROP( - QueryConfig::kCastMatchStructByName, c.isMatchStructByName()), - NUM_PROP( - QueryConfig::kMaxLocalExchangeBufferSize, - c.maxLocalExchangeBufferSize()), - NUM_PROP( - QueryConfig::kMaxPartialAggregationMemory, - c.maxPartialAggregationMemoryUsage()), - NUM_PROP( - QueryConfig::kMaxExtendedPartialAggregationMemory, - c.maxExtendedPartialAggregationMemoryUsage()), - NUM_PROP( - QueryConfig::kAbandonPartialAggregationMinRows, - c.abandonPartialAggregationMinRows()), - NUM_PROP( - QueryConfig::kAbandonPartialAggregationMinPct, - c.abandonPartialAggregationMinPct()), - NUM_PROP( - QueryConfig::kMaxPartitionedOutputBufferSize, - c.maxPartitionedOutputBufferSize()), - NUM_PROP( - QueryConfig::kPreferredOutputBatchBytes, - c.preferredOutputBatchBytes()), - NUM_PROP( - QueryConfig::kPreferredOutputBatchRows, - c.preferredOutputBatchRows()), - NUM_PROP(QueryConfig::kMaxOutputBatchRows, c.maxOutputBatchRows()), - BOOL_PROP( - QueryConfig::kHashAdaptivityEnabled, c.hashAdaptivityEnabled()), - BOOL_PROP( - QueryConfig::kAdaptiveFilterReorderingEnabled, - c.adaptiveFilterReorderingEnabled()), - BOOL_PROP(QueryConfig::kSpillEnabled, c.spillEnabled()), - BOOL_PROP( - QueryConfig::kAggregationSpillEnabled, - c.aggregationSpillEnabled()), - BOOL_PROP(QueryConfig::kJoinSpillEnabled, c.joinSpillEnabled()), - BOOL_PROP(QueryConfig::kOrderBySpillEnabled, c.orderBySpillEnabled()), - NUM_PROP(QueryConfig::kMaxSpillLevel, c.maxSpillLevel()), - NUM_PROP(QueryConfig::kMaxSpillFileSize, c.maxSpillFileSize()), - NUM_PROP( - QueryConfig::kSpillStartPartitionBit, c.spillStartPartitionBit()), - NUM_PROP( - QueryConfig::kSpillNumPartitionBits, c.spillNumPartitionBits()), - NUM_PROP( - QueryConfig::kSpillableReservationGrowthPct, - c.spillableReservationGrowthPct()), - BOOL_PROP( - QueryConfig::kPrestoArrayAggIgnoreNulls, - c.prestoArrayAggIgnoreNulls()), - BOOL_PROP( - QueryConfig::kSelectiveNimbleReaderEnabled, - c.selectiveNimbleReaderEnabled()), - NUM_PROP(QueryConfig::kMaxOutputBufferSize, c.maxOutputBufferSize()), - }; -} - -BaseVeloxQueryConfig* BaseVeloxQueryConfig::instance() { - static std::unique_ptr instance = - std::make_unique(); - return instance.get(); -} - -void BaseVeloxQueryConfig::updateLoadedValues( - std::unordered_map& values) const { - // Update velox config with values from presto system config. - auto systemConfig = SystemConfig::instance(); - - using namespace velox::core; - std::unordered_map updatedValues{ - {QueryConfig::kPrestoArrayAggIgnoreNulls, - bool2String(systemConfig->useLegacyArrayAgg())}, - {QueryConfig::kMaxOutputBufferSize, - systemConfig->capacityPropertyAsBytesString( - SystemConfig::kSinkMaxBufferSize)}, - {QueryConfig::kMaxPartitionedOutputBufferSize, - systemConfig->capacityPropertyAsBytesString( - SystemConfig::kDriverMaxPagePartitioningBufferSize)}, - {QueryConfig::kMaxPartialAggregationMemory, - systemConfig->capacityPropertyAsBytesString( - SystemConfig::kTaskMaxPartialAggregationMemory)}, - }; - - auto taskWriterCount = systemConfig->taskWriterCount(); - if (taskWriterCount.has_value()) { - updatedValues[QueryConfig::kTaskWriterCount] = - std::to_string(taskWriterCount.value()); - } - auto taskPartitionedWriterCount = systemConfig->taskPartitionedWriterCount(); - if (taskPartitionedWriterCount.has_value()) { - updatedValues[QueryConfig::kTaskPartitionedWriterCount] = - std::to_string(taskPartitionedWriterCount.value()); - } - - std::stringstream updated; - for (const auto& pair : updatedValues) { - updated << " " << pair.first << "=" << pair.second << "\n"; - values[pair.first] = pair.second; - } - auto str = updated.str(); - if (!str.empty()) { - PRESTO_STARTUP_LOG(INFO) - << "Updated in '" << filePath_ << "' from SystemProperties:\n" - << str; - } -} - } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 5e7a6fd2e815f..2c44c3e845427 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -43,12 +43,12 @@ class ConfigBase { } /// DO NOT DELETE THIS METHOD! - /// The method is used to register new properties after the config class is created. - /// Returns true if succeeded, false if failed (due to the property already - /// registered). + /// The method is used to register new properties after the config class is + /// created. Returns true if succeeded, false if failed (due to the property + /// already registered). bool registerProperty( - const std::string& propertyName, - const folly::Optional& defaultValue = {}); + const std::string& propertyName, + const folly::Optional& defaultValue = {}); /// Adds or replaces value at the given key. Can be used by debugging or /// testing code. @@ -95,7 +95,7 @@ class ConfigBase { template folly::Optional optionalProperty(const std::string& propertyName) const { auto valOpt = config_->get(propertyName); - if (valOpt.hasValue()) { + if (valOpt.has_value()) { return valOpt.value(); } const auto it = registeredProps_.find(propertyName); @@ -115,8 +115,8 @@ class ConfigBase { folly::Optional optionalProperty( const std::string& propertyName) const { auto val = config_->get(propertyName); - if (val.hasValue()) { - return val; + if (val.has_value()) { + return val.value(); } const auto it = registeredProps_.find(propertyName); if (it != registeredProps_.end()) { @@ -131,10 +131,6 @@ class ConfigBase { return optionalProperty(std::string{propertyName}); } - /// Returns "N" as string containing capacity in bytes. - std::string capacityPropertyAsBytesString( - std::string_view propertyName) const; - /// Returns copy of the config values map. std::unordered_map values() const { return config_->rawConfigsCopy(); @@ -144,8 +140,9 @@ class ConfigBase { protected: ConfigBase() - : config_(std::make_unique( - std::unordered_map())){}; + : config_( + std::make_unique( + std::unordered_map())) {}; // Check if all properties are registered. void checkRegisteredProperties( @@ -183,6 +180,14 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kTaskWriterCount{"task.writer-count"}; static constexpr std::string_view kTaskPartitionedWriterCount{ "task.partitioned-writer-count"}; + + /// Maximum number of bytes per task that can be broadcast to storage for + /// storage-based broadcast joins. This property is only applicable to + /// storage-based broadcast join operations, currently used in the Presto on + /// Spark native stack. When the broadcast data size exceeds this limit, the + /// query fails. + static constexpr std::string_view kTaskMaxStorageBroadcastBytes{ + "task.max-storage-broadcast-bytes"}; static constexpr std::string_view kConcurrentLifespansPerTask{ "task.concurrent-lifespans-per-task"}; static constexpr std::string_view kTaskMaxPartialAggregationMemory{ @@ -202,6 +207,41 @@ class SystemConfig : public ConfigBase { "http-server.https.port"}; static constexpr std::string_view kHttpServerHttpsEnabled{ "http-server.https.enabled"}; + static constexpr std::string_view kHttpServerHttp2Enabled{ + "http-server.http2.enabled"}; + /// HTTP/2 server idle timeout in milliseconds (default 60000ms). + static constexpr std::string_view kHttpServerIdleTimeoutMs{ + "http-server.http2.idle-timeout-ms"}; + /// HTTP/2 initial receive window size in bytes (default 1MB). + static constexpr std::string_view kHttpServerHttp2InitialReceiveWindow{ + "http-server.http2.initial-receive-window"}; + /// HTTP/2 receive stream window size in bytes (default 1MB). + static constexpr std::string_view kHttpServerHttp2ReceiveStreamWindowSize{ + "http-server.http2.receive-stream-window-size"}; + /// HTTP/2 receive session window size in bytes (default 10MB). + static constexpr std::string_view kHttpServerHttp2ReceiveSessionWindowSize{ + "http-server.http2.receive-session-window-size"}; + /// HTTP/2 maximum concurrent streams per connection (default 100). + static constexpr std::string_view kHttpServerHttp2MaxConcurrentStreams{ + "http-server.http2.max-concurrent-streams"}; + /// HTTP/2 content compression level (1-9, default 4 for speed). + static constexpr std::string_view kHttpServerContentCompressionLevel{ + "http-server.http2.content-compression-level"}; + /// HTTP/2 content compression minimum size in bytes (default 3584). + static constexpr std::string_view kHttpServerContentCompressionMinimumSize{ + "http-server.http2.content-compression-minimum-size"}; + /// Enable content compression (master switch, default true). + static constexpr std::string_view kHttpServerEnableContentCompression{ + "http-server.http2.enable-content-compression"}; + /// Enable zstd compression (default false). + static constexpr std::string_view kHttpServerEnableZstdCompression{ + "http-server.http2.enable-zstd-compression"}; + /// Zstd compression level (-5 to 22, default 8). + static constexpr std::string_view kHttpServerZstdContentCompressionLevel{ + "http-server.http2.zstd-content-compression-level"}; + /// Enable gzip compression (default true). + static constexpr std::string_view kHttpServerEnableGzipCompression{ + "http-server.http2.enable-gzip-compression"}; /// List of comma separated ciphers the client can use. /// /// NOTE: the client needs to have at least one cipher shared with server @@ -213,6 +253,8 @@ class SystemConfig : public ConfigBase { /// Path to a .PEM file with certificate and key concatenated together. static constexpr std::string_view kHttpsClientCertAndKeyPath{ "https-client-cert-key-path"}; + /// Path to client CA file for SSL client certificate verification. + static constexpr std::string_view kHttpsClientCaFile{"https-client-ca-file"}; /// Floating point number used in calculating how many threads we would use /// for CPU executor for connectors mainly for async operators: @@ -230,6 +272,11 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kConnectorNumIoThreadsHwMultiplier{ "connector.num-io-threads-hw-multiplier"}; + /// Maximum number of splits to preload per driver. + /// Set to 0 to disable preloading. + static constexpr std::string_view kDriverMaxSplitPreload{ + "driver.max-split-preload"}; + /// Floating point number used in calculating how many threads we would use /// for Driver CPU executor: hw_concurrency x multiplier. 4.0 is default. static constexpr std::string_view kDriverNumCpuThreadsHwMultiplier{ @@ -257,6 +304,7 @@ class SystemConfig : public ConfigBase { /// The number of stuck operators (effectively stuck driver threads) when we /// detach the worker from the cluster in an attempt to keep the cluster /// operational. + /// 1/2 of the hardware concurrency is the default. static constexpr std::string_view kDriverNumStuckOperatorsToDetachWorker{ "driver.num-stuck-operators-to-detach-worker"}; @@ -313,11 +361,22 @@ class SystemConfig : public ConfigBase { /// Ignored if zero. Default is zero. static constexpr std::string_view kWorkerOverloadedThresholdCpuPct{ "worker-overloaded-threshold-cpu-pct"}; + /// Floating point number used in calculating how many drivers must be queued + /// for the worker to be considered overloaded. + /// Ignored if zero. Default is zero. + static constexpr std::string_view + kWorkerOverloadedThresholdNumQueuedDriversHwMultiplier{ + "worker-overloaded-threshold-num-queued-drivers-hw-multiplier"}; /// Specifies how many seconds worker has to be not overloaded (in terms of /// memory and CPU) before its status changes to not overloaded. /// This is to prevent spiky fluctuation of the overloaded status. static constexpr std::string_view kWorkerOverloadedCooldownPeriodSec{ "worker-overloaded-cooldown-period-sec"}; + /// The number of seconds the worker needs to be continuously overloaded for + /// us to detach the worker from the cluster in an attempt to keep the + /// cluster operational. Ignored if set to zero. Default is zero. + static constexpr std::string_view kWorkerOverloadedSecondsToDetachWorker{ + "worker-overloaded-seconds-to-detach-worker"}; /// If true, the worker starts queuing new tasks when overloaded, and /// starts them gradually when it stops being overloaded. static constexpr std::string_view kWorkerOverloadedTaskQueuingEnabled{ @@ -388,6 +447,12 @@ class SystemConfig : public ConfigBase { /// value when cache data is loaded from the SSD. static constexpr std::string_view kSsdCacheReadVerificationEnabled{ "ssd-cache-read-verification-enabled"}; + /// Maximum number of entries allowed in the SSD cache. A value of 0 means no + /// limit. When the limit is reached, new entry writes will be skipped. + /// Default is 10 million entries, which keeps metadata memory usage around + /// 500MB (each entry uses ~50-60 bytes for key, value, and hash overhead). + static constexpr std::string_view kSsdCacheMaxEntries{ + "ssd-cache-max-entries"}; static constexpr std::string_view kEnableSerializedPageChecksum{ "enable-serialized-page-checksum"}; @@ -618,6 +683,31 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kHeartbeatFrequencyMs{ "heartbeat-frequency-ms"}; + /// Whether HTTP/2 is enabled for HTTP client connections. + static constexpr std::string_view kHttpClientHttp2Enabled{ + "http-client.http2-enabled"}; + + /// Maximum concurrent streams per HTTP/2 connection + static constexpr std::string_view kHttpClientHttp2MaxStreamsPerConnection{ + "http-client.http2.max-streams-per-connection"}; + + /// HTTP/2 initial stream window size in bytes. + static constexpr std::string_view kHttpClientHttp2InitialStreamWindow{ + "http-client.http2.initial-stream-window"}; + + /// HTTP/2 stream window size in bytes. + static constexpr std::string_view kHttpClientHttp2StreamWindow{ + "http-client.http2.stream-window"}; + + /// HTTP/2 session window size in bytes. + static constexpr std::string_view kHttpClientHttp2SessionWindow{ + "http-client.http2.session-window"}; + + /// Whether to enable HTTP client connection reuse counter reporting. + /// When enabled, tracks connection first use and reuse metrics. + static constexpr std::string_view kHttpClientConnectionReuseCounterEnabled{ + "http-client.connection-reuse-counter-enabled"}; + static constexpr std::string_view kExchangeMaxErrorDuration{ "exchange.max-error-duration"}; @@ -631,7 +721,7 @@ class SystemConfig : public ConfigBase { /// as soon as exchange gets its response back. Otherwise the memory transfer /// will happen later in driver thread pool. /// - /// NOTE: this only applies if 'exchange.no-buffer-copy' is false. + /// NOTE: this only applies if 'exchange.enable-buffer-copy' is true. static constexpr std::string_view kExchangeImmediateBufferTransfer{ "exchange.immediate-buffer-transfer"}; @@ -662,6 +752,11 @@ class SystemConfig : public ConfigBase { kExchangeHttpClientNumCpuThreadsHwMultiplier{ "exchange.http-client.num-cpu-threads-hw-multiplier"}; + /// Maximum size in bytes to accumulate in ExchangeQueue. Enforced + /// approximately, not strictly. + static constexpr std::string_view kExchangeMaxBufferSize{ + "exchange.max-buffer-size"}; + /// The maximum timeslice for a task on thread if there are threads queued. static constexpr std::string_view kTaskRunTimeSliceMicros{ "task-run-timeslice-micros"}; @@ -684,6 +779,10 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{ "remote-function-server.thrift.uds-path"}; + /// HTTP URL used by the remote function rest server. + static constexpr std::string_view kRemoteFunctionServerRestURL{ + "remote-function-server.rest.url"}; + /// Path where json files containing signatures for remote functions can be /// found. static constexpr std::string_view @@ -709,9 +808,8 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kInternalCommunicationJwtExpirationSeconds{ "internal-communication.jwt.expiration-seconds"}; - /// Below are the Presto properties from config.properties that get converted - /// to their velox counterparts in BaseVeloxQueryConfig and used solely from - /// BaseVeloxQueryConfig. + /// Optional string containing the path to the plugin directory + static constexpr std::string_view kPluginDir{"plugin.dir"}; /// Uses legacy version of array_agg which ignores nulls. static constexpr std::string_view kUseLegacyArrayAgg{ @@ -736,10 +834,61 @@ class SystemConfig : public ConfigBase { "aggregation-spill-enabled"}; static constexpr std::string_view kOrderBySpillEnabled{ "order-by-spill-enabled"}; + static constexpr std::string_view kMaxSpillBytes{"max-spill-bytes"}; + + /// When enabled, hash tables built for broadcast joins are cached and reused + /// across tasks within the same query and stage. + static constexpr std::string_view kBroadcastJoinTableCachingEnabled{ + "broadcast-join-table-caching-enabled"}; + + /// If true, data fetching is deferred until next() is called on the exchange + /// client. If false (default), exchange clients will start fetching data + /// immediately when remote tasks are added. + static constexpr std::string_view kExchangeLazyFetchingEnabled{ + "exchange-lazy-fetching-enabled"}; // Max wait time for exchange request in seconds. static constexpr std::string_view kRequestDataSizesMaxWaitSec{ - "exchange.http-client.request-data-sizes-max-wait-sec"}; + "exchange.http-client.request-data-sizes-max-wait-sec"}; + + static constexpr std::string_view kExchangeIoEvbViolationThresholdMs{ + "exchange.io-evb-violation-threshold-ms"}; + static constexpr std::string_view kHttpSrvIoEvbViolationThresholdMs{ + "http-server.io-evb-violation-threshold-ms"}; + + static constexpr std::string_view kMaxLocalExchangeBufferSize{ + "local-exchange.max-buffer-size"}; + + static constexpr std::string_view kMaxLocalExchangePartitionBufferSize{ + "local-exchange.max-partition-buffer-size"}; + + static constexpr std::string_view kParallelOutputJoinBuildRowsEnabled{ + "join.parallel-output-build-rows-enabled"}; + + static constexpr std::string_view kHashProbeBloomFilterPushdownMaxSize{ + "join.hash-probe-bloom-filter-pushdown-max-size"}; + + // Add to temporarily help with gradual rollout for text writer + // TODO: remove once text writer is fully rolled out + static constexpr std::string_view kTextWriterEnabled{"text-writer-enabled"}; + + // Add to temporarily help with gradual rollout for text reader + // TODO: remove once text reader is fully rolled out + static constexpr std::string_view kTextReaderEnabled{"text-reader-enabled"}; + + /// Enable the type char(n) with the same behavior as unbounded varchar. + /// char(n) type is not supported by parser when set to false. + static constexpr std::string_view kCharNToVarcharImplicitCast{ + "char-n-to-varchar-implicit-cast"}; + + /// Enable BigintEnum and VarcharEnum types to be parsed and used in Velox. + /// When set to false, BigintEnum or VarcharEnum types will throw an + /// unsupported error during type parsing. + static constexpr std::string_view kEnumTypesEnabled{"enum-types-enabled"}; + + /// Enable velox plan consistency check. + static constexpr std::string_view kPlanConsistencyCheckEnabled{ + "plan-consistency-check-enabled"}; SystemConfig(); @@ -757,6 +906,30 @@ class SystemConfig : public ConfigBase { int httpServerHttpsPort() const; + bool httpServerHttp2Enabled() const; + + uint32_t httpServerIdleTimeoutMs() const; + + uint32_t httpServerHttp2InitialReceiveWindow() const; + + uint32_t httpServerHttp2ReceiveStreamWindowSize() const; + + uint32_t httpServerHttp2ReceiveSessionWindowSize() const; + + uint32_t httpServerHttp2MaxConcurrentStreams() const; + + uint32_t httpServerContentCompressionLevel() const; + + uint32_t httpServerContentCompressionMinimumSize() const; + + bool httpServerEnableContentCompression() const; + + bool httpServerEnableZstdCompression() const; + + uint32_t httpServerZstdContentCompressionLevel() const; + + bool httpServerEnableGzipCompression() const; + /// A list of ciphers (comma separated) that are supported by /// server and client. Note Java and folly::SSLContext use different names to /// refer to the same cipher. For e.g. TLS_RSA_WITH_AES_256_GCM_SHA384 in Java @@ -781,6 +954,9 @@ class SystemConfig : public ConfigBase { /// later. folly::Optional httpsClientCertAndKeyPath() const; + /// Path to client CA file for SSL client certificate verification. + folly::Optional httpsClientCaFile() const; + bool mutableConfig() const; std::string prestoVersion() const; @@ -796,12 +972,18 @@ class SystemConfig : public ConfigBase { std::string remoteFunctionServerSerde() const; + std::string remoteFunctionServerRestURL() const; + int32_t maxDriversPerTask() const; + int32_t driverMaxSplitPreload() const; + folly::Optional taskWriterCount() const; folly::Optional taskPartitionedWriterCount() const; + folly::Optional taskMaxStorageBroadcastBytes() const; + int32_t concurrentLifespansPerTask() const; double httpServerNumIoThreadsHwMultiplier() const; @@ -850,8 +1032,12 @@ class SystemConfig : public ConfigBase { uint32_t workerOverloadedThresholdCpuPct() const; + double workerOverloadedThresholdNumQueuedDriversHwMultiplier() const; + uint32_t workerOverloadedCooldownPeriodSec() const; + uint64_t workerOverloadedSecondsToDetachWorker() const; + bool workerOverloadedTaskQueuingEnabled() const; bool mallocMemHeapDumpEnabled() const; @@ -888,6 +1074,8 @@ class SystemConfig : public ConfigBase { bool ssdCacheReadVerificationEnabled() const; + uint64_t ssdCacheMaxEntries() const; + std::string shuffleName() const; bool enableSerializedPageChecksum() const; @@ -960,6 +1148,18 @@ class SystemConfig : public ConfigBase { uint64_t heartbeatFrequencyMs() const; + bool httpClientHttp2Enabled() const; + + uint32_t httpClientHttp2MaxStreamsPerConnection() const; + + uint32_t httpClientHttp2InitialStreamWindow() const; + + uint32_t httpClientHttp2StreamWindow() const; + + uint32_t httpClientHttp2SessionWindow() const; + + bool httpClientConnectionReuseCounterEnabled() const; + std::chrono::duration exchangeMaxErrorDuration() const; std::chrono::duration exchangeRequestTimeoutMs() const; @@ -972,6 +1172,8 @@ class SystemConfig : public ConfigBase { bool exchangeImmediateBufferTransfer() const; + uint64_t exchangeMaxBufferSize() const; + int32_t taskRunTimeSliceMicros() const; bool includeNodeInSpillPath() const; @@ -1012,7 +1214,37 @@ class SystemConfig : public ConfigBase { bool orderBySpillEnabled() const; + bool broadcastJoinTableCachingEnabled() const; + + bool exchangeLazyFetchingEnabled() const; + + uint64_t maxSpillBytes() const; + int requestDataSizesMaxWaitSec() const; + + std::string pluginDir() const; + + int32_t exchangeIoEvbViolationThresholdMs() const; + + int32_t httpSrvIoEvbViolationThresholdMs() const; + + uint64_t maxLocalExchangeBufferSize() const; + + uint64_t maxLocalExchangePartitionBufferSize() const; + + bool parallelOutputJoinBuildRowsEnabled() const; + + uint64_t hashProbeBloomFilterPushdownMaxSize() const; + + bool textWriterEnabled() const; + + bool textReaderEnabled() const; + + bool charNToVarcharImplicitCast() const; + + bool enumTypesEnabled() const; + + bool planConsistencyCheckEnabled() const; }; /// Provides access to node properties defined in node.properties file. @@ -1025,6 +1257,8 @@ class NodeConfig : public ConfigBase { static constexpr std::string_view kNodeInternalAddress{ "node.internal-address"}; static constexpr std::string_view kNodeLocation{"node.location"}; + static constexpr std::string_view kNodePrometheusExecutorThreads{ + "node.prometheus.num-executor-threads"}; NodeConfig(); @@ -1034,6 +1268,8 @@ class NodeConfig : public ConfigBase { std::string nodeEnvironment() const; + int prometheusExecutorThreads() const; + std::string nodeId() const; std::string nodeInternalAddress( @@ -1042,19 +1278,4 @@ class NodeConfig : public ConfigBase { std::string nodeLocation() const; }; -/// Used only in the single instance as the source of the initial properties for -/// velox::QueryConfig. Not designed for actual property access during a query -/// run. -class BaseVeloxQueryConfig : public ConfigBase { - public: - BaseVeloxQueryConfig(); - - virtual ~BaseVeloxQueryConfig() = default; - - void updateLoadedValues( - std::unordered_map& values) const override; - - static BaseVeloxQueryConfig* instance(); -}; - } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/common/Counters.cpp b/presto-native-execution/presto_cpp/main/common/Counters.cpp index 7bf9e3a8600d2..2704568c28c19 100644 --- a/presto-native-execution/presto_cpp/main/common/Counters.cpp +++ b/presto-native-execution/presto_cpp/main/common/Counters.cpp @@ -30,8 +30,21 @@ void registerPrestoMetrics() { DEFINE_METRIC(kCounterNumHTTPRequest, facebook::velox::StatType::COUNT); DEFINE_METRIC(kCounterNumHTTPRequestError, facebook::velox::StatType::COUNT); DEFINE_METRIC(kCounterHTTPRequestLatencyMs, facebook::velox::StatType::AVG); + DEFINE_HISTOGRAM_METRIC( + kCounterHTTPRequestSizeBytes, + 1 * 1024, // 1KB bucket size + 0, + 5 * 1024 * 1024, // 5MB max + 50, + 90, + 99, + 100); DEFINE_METRIC( kCounterHttpClientNumConnectionsCreated, facebook::velox::StatType::SUM); + DEFINE_METRIC( + kCounterHttpClientConnectionFirstUse, facebook::velox::StatType::SUM); + DEFINE_METRIC( + kCounterHttpClientConnectionReuse, facebook::velox::StatType::SUM); // Tracks http client transaction create delay in range of [0, 30s] with // 30 buckets and reports P50, P90, P99, and P100. DEFINE_HISTOGRAM_METRIC( @@ -93,11 +106,14 @@ void registerPrestoMetrics() { DEFINE_METRIC(kCounterOverloaded, facebook::velox::StatType::AVG); DEFINE_METRIC(kCounterNumStuckDrivers, facebook::velox::StatType::AVG); DEFINE_METRIC(kCounterTaskPlannedTimeMs, facebook::velox::StatType::AVG); + DEFINE_METRIC(kCounterOverloadedDurationSec, facebook::velox::StatType::AVG); DEFINE_METRIC( kCounterTotalPartitionedOutputBuffer, facebook::velox::StatType::AVG); DEFINE_METRIC( kCounterPartitionedOutputBufferGetDataLatencyMs, facebook::velox::StatType::AVG); + DEFINE_METRIC( + kCounterWorkerRuntimeUptimeSecs, facebook::velox::StatType::AVG); DEFINE_METRIC(kCounterOsUserCpuTimeMicros, facebook::velox::StatType::AVG); DEFINE_METRIC(kCounterOsSystemCpuTimeMicros, facebook::velox::StatType::AVG); DEFINE_METRIC(kCounterOsNumSoftPageFaults, facebook::velox::StatType::AVG); @@ -117,28 +133,42 @@ void registerPrestoMetrics() { 99, 100); - // Tracks exchange request duration in range of [0, 300s] with - // 300 buckets and reports P50, P90, P99, and P100. + // Tracks exchange request duration in range of [0, 10s] with + // 500 buckets and reports P50, P90, P99, and P100. DEFINE_HISTOGRAM_METRIC( kCounterExchangeRequestDuration, - 1'000, + 20, // 20ms bucket size 0, - 300'000, + 10'000, // 10s max 50, 90, 99, 100); - // Tracks exchange request num of retris in range of [0, 20] with + // Tracks exchange request num of tries in range of [0, 20] with // 20 buckets and reports P50, P90, P99, and P100. DEFINE_HISTOGRAM_METRIC( - kCounterExchangeRequestNumTries, - 1, + kCounterExchangeRequestNumTries, 1, 0, 20, 50, 90, 99, 100); + // Tracks exchange request page size in range of [0, 20MB] with + // 20K buckets and reports P50, P90, P99, and P100. + DEFINE_HISTOGRAM_METRIC( + kCounterExchangeRequestPageSize, + 10 * 1024, // 10KB bucket size 0, - 20, + 20 * 1024 * 1024, // 20MB max 50, 90, 99, 100); + + // Tracks exchange get-data-size request duration in range of [0, 300s] with + // 300 buckets and reports P50, P90, P99, and P100. + DEFINE_HISTOGRAM_METRIC( + kCounterExchangeGetDataSizeDuration, 1'000, 0, 300'000, 50, 90, 99, 100); + // Tracks exchange get-data-size request num of tries in range of [0, 20] with + // 20 buckets and reports P50, P90, P99, and P100. + DEFINE_HISTOGRAM_METRIC( + kCounterExchangeGetDataSizeNumTries, 1, 0, 20, 50, 90, 99, 100); + DEFINE_METRIC(kCounterMemoryPushbackCount, facebook::velox::StatType::COUNT); DEFINE_HISTOGRAM_METRIC( kCounterMemoryPushbackLatencyMs, 10'000, 0, 100'000, 50, 90, 99, 100); @@ -161,6 +191,11 @@ void registerPrestoMetrics() { 99, 100); + DEFINE_METRIC( + kCounterExchangeIoEvbViolation, facebook::velox::StatType::COUNT); + DEFINE_METRIC( + kCounterHttpServerIoEvbViolation, facebook::velox::StatType::COUNT); + // NOTE: Metrics type exporting for thread pool executor counters are in // PeriodicTaskManager because they have dynamic names and report configs. The // following counters have their type exported there: diff --git a/presto-native-execution/presto_cpp/main/common/Counters.h b/presto-native-execution/presto_cpp/main/common/Counters.h index 74d86f588d46a..c6379c575f011 100644 --- a/presto-native-execution/presto_cpp/main/common/Counters.h +++ b/presto-native-execution/presto_cpp/main/common/Counters.h @@ -13,7 +13,7 @@ */ #pragma once -#include +#include // Here we have all the counters presto cpp worker would export. namespace facebook::presto { @@ -22,145 +22,168 @@ namespace facebook::presto { // See velox/common/base/StatsReporter.h for the interface. void registerPrestoMetrics(); -constexpr folly::StringPiece kCounterDriverCPUExecutorQueueSize{ +constexpr std::string_view kCounterDriverCPUExecutorQueueSize{ "presto_cpp.driver_cpu_executor_queue_size"}; -constexpr folly::StringPiece kCounterDriverCPUExecutorLatencyMs{ +constexpr std::string_view kCounterDriverCPUExecutorLatencyMs{ "presto_cpp.driver_cpu_executor_latency_ms"}; -constexpr folly::StringPiece kCounterSpillerExecutorQueueSize{ +constexpr std::string_view kCounterSpillerExecutorQueueSize{ "presto_cpp.spiller_executor_queue_size"}; -constexpr folly::StringPiece kCounterSpillerExecutorLatencyMs{ +constexpr std::string_view kCounterSpillerExecutorLatencyMs{ "presto_cpp.spiller_executor_latency_ms"}; -constexpr folly::StringPiece kCounterHTTPExecutorLatencyMs{ +constexpr std::string_view kCounterHTTPExecutorLatencyMs{ "presto_cpp.http_executor_latency_ms"}; -constexpr folly::StringPiece kCounterNumHTTPRequest{ +constexpr std::string_view kCounterNumHTTPRequest{ "presto_cpp.num_http_request"}; -constexpr folly::StringPiece kCounterNumHTTPRequestError{ +constexpr std::string_view kCounterNumHTTPRequestError{ "presto_cpp.num_http_request_error"}; -constexpr folly::StringPiece kCounterHTTPRequestLatencyMs{ +constexpr std::string_view kCounterHTTPRequestLatencyMs{ "presto_cpp.http_request_latency_ms"}; +constexpr std::string_view kCounterHTTPRequestSizeBytes{ + "presto_cpp.http_request_size_bytes"}; -constexpr folly::StringPiece kCounterHttpClientNumConnectionsCreated{ +constexpr std::string_view kCounterHttpClientNumConnectionsCreated{ "presto_cpp.http.client.num_connections_created"}; -constexpr folly::StringPiece kCounterHTTPClientTransactionCreateDelay{ +/// Number of HTTP requests that are the first request on a connection +// (seqNo == 0). +constexpr std::string_view kCounterHttpClientConnectionFirstUse{ + "presto_cpp.http.client.connection_first_use"}; +/// Number of HTTP requests sent on reused connections (seqNo > 0). +constexpr std::string_view kCounterHttpClientConnectionReuse{ + "presto_cpp.http.client.connection_reuse"}; +constexpr std::string_view kCounterHTTPClientTransactionCreateDelay{ "presto_cpp.http.client.transaction_create_delay_ms"}; /// Peak number of bytes queued in PrestoExchangeSource waiting for consume. -constexpr folly::StringPiece kCounterExchangeSourcePeakQueuedBytes{ +constexpr std::string_view kCounterExchangeSourcePeakQueuedBytes{ "presto_cpp.exchange_source_peak_queued_bytes"}; -constexpr folly::StringPiece kCounterExchangeRequestDuration{ +constexpr std::string_view kCounterExchangeRequestDuration{ "presto_cpp.exchange.request.duration"}; -constexpr folly::StringPiece kCounterExchangeRequestNumTries{ +constexpr std::string_view kCounterExchangeRequestNumTries{ "presto_cpp.exchange.request.num_tries"}; +constexpr std::string_view kCounterExchangeRequestPageSize{ + "presto_cpp.exchange.request.page_size"}; -constexpr folly::StringPiece kCounterNumQueryContexts{ +constexpr std::string_view kCounterExchangeGetDataSizeDuration{ + "presto_cpp.exchange.get_data_size.duration"}; +constexpr std::string_view kCounterExchangeGetDataSizeNumTries{ + "presto_cpp.exchange.get_data_size.num_tries"}; + +constexpr std::string_view kCounterNumQueryContexts{ "presto_cpp.num_query_contexts"}; /// Export total bytes used by memory manager (in queries' memory pools). -constexpr folly::StringPiece kCounterMemoryManagerTotalBytes{ +constexpr std::string_view kCounterMemoryManagerTotalBytes{ "presto_cpp.memory_manager_total_bytes"}; -constexpr folly::StringPiece kCounterNumTasks{"presto_cpp.num_tasks"}; -constexpr folly::StringPiece kCounterNumTasksBytesProcessed{ +constexpr std::string_view kCounterNumTasks{"presto_cpp.num_tasks"}; +constexpr std::string_view kCounterNumTasksBytesProcessed{ "presto_cpp.num_tasks_bytes_processed"}; -constexpr folly::StringPiece kCounterNumTasksRunning{ +constexpr std::string_view kCounterNumTasksRunning{ "presto_cpp.num_tasks_running"}; -constexpr folly::StringPiece kCounterNumTasksFinished{ +constexpr std::string_view kCounterNumTasksFinished{ "presto_cpp.num_tasks_finished"}; -constexpr folly::StringPiece kCounterNumTasksCancelled{ +constexpr std::string_view kCounterNumTasksCancelled{ "presto_cpp.num_tasks_cancelled"}; -constexpr folly::StringPiece kCounterNumTasksAborted{ +constexpr std::string_view kCounterNumTasksAborted{ "presto_cpp.num_tasks_aborted"}; -constexpr folly::StringPiece kCounterNumTasksFailed{ +constexpr std::string_view kCounterNumTasksFailed{ "presto_cpp.num_tasks_failed"}; /// Number of the created but not yet started tasks, including queued tasks. -constexpr folly::StringPiece kCounterNumTasksPlanned{ +constexpr std::string_view kCounterNumTasksPlanned{ "presto_cpp.num_tasks_planned"}; /// Number of the created tasks in the task queue. -constexpr folly::StringPiece kCounterNumTasksQueued{ +constexpr std::string_view kCounterNumTasksQueued{ "presto_cpp.num_tasks_queued"}; -constexpr folly::StringPiece kCounterNumZombieVeloxTasks{ +constexpr std::string_view kCounterNumZombieVeloxTasks{ "presto_cpp.num_zombie_velox_tasks"}; -constexpr folly::StringPiece kCounterNumZombiePrestoTasks{ +constexpr std::string_view kCounterNumZombiePrestoTasks{ "presto_cpp.num_zombie_presto_tasks"}; -constexpr folly::StringPiece kCounterNumTasksWithStuckOperator{ +constexpr std::string_view kCounterNumTasksWithStuckOperator{ "presto_cpp.num_tasks_with_stuck_operator"}; -constexpr folly::StringPiece kCounterNumCancelledTasksByStuckDriver{ +constexpr std::string_view kCounterNumCancelledTasksByStuckDriver{ "presto_cpp.num_cancelled_tasks_by_stuck_driver"}; -constexpr folly::StringPiece kCounterNumTasksDeadlock{ +constexpr std::string_view kCounterNumTasksDeadlock{ "presto_cpp.num_tasks_deadlock"}; -constexpr folly::StringPiece kCounterNumTaskManagerLockTimeOut{ +constexpr std::string_view kCounterNumTaskManagerLockTimeOut{ "presto_cpp.num_tasks_manager_lock_timeout"}; -constexpr folly::StringPiece kCounterNumQueuedDrivers{ +constexpr std::string_view kCounterNumQueuedDrivers{ "presto_cpp.num_queued_drivers"}; -constexpr folly::StringPiece kCounterNumOnThreadDrivers{ +constexpr std::string_view kCounterNumOnThreadDrivers{ "presto_cpp.num_on_thread_drivers"}; -constexpr folly::StringPiece kCounterNumSuspendedDrivers{ +constexpr std::string_view kCounterNumSuspendedDrivers{ "presto_cpp.num_suspended_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForConsumerDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForConsumerDrivers{ "presto_cpp.num_blocked_wait_for_consumer_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForSplitDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForSplitDrivers{ "presto_cpp.num_blocked_wait_for_split_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForProducerDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForProducerDrivers{ "presto_cpp.num_blocked_wait_for_producer_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForJoinBuildDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForJoinBuildDrivers{ "presto_cpp.num_blocked_wait_for_join_build_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForJoinProbeDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForJoinProbeDrivers{ "presto_cpp.num_blocked_wait_for_join_probe_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForMergeJoinRightSideDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForMergeJoinRightSideDrivers{ "presto_cpp.num_blocked_wait_for_merge_join_right_side_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForMemoryDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForMemoryDrivers{ "presto_cpp.num_blocked_wait_for_memory_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedWaitForConnectorDrivers{ +constexpr std::string_view kCounterNumBlockedWaitForConnectorDrivers{ "presto_cpp.num_blocked_wait_for_connector_drivers"}; -constexpr folly::StringPiece kCounterNumBlockedYieldDrivers{ +constexpr std::string_view kCounterNumBlockedYieldDrivers{ "presto_cpp.num_blocked_yield_drivers"}; -constexpr folly::StringPiece kCounterNumStuckDrivers{ +constexpr std::string_view kCounterNumStuckDrivers{ "presto_cpp.num_stuck_drivers"}; /// Export 100 if worker is overloaded in terms of memory, 0 otherwise. -constexpr folly::StringPiece kCounterOverloadedMem{"presto_cpp.overloaded_mem"}; +constexpr std::string_view kCounterOverloadedMem{"presto_cpp.overloaded_mem"}; /// Export 100 if worker is overloaded in terms of CPU, 0 otherwise. -constexpr folly::StringPiece kCounterOverloadedCpu{"presto_cpp.overloaded_cpu"}; +constexpr std::string_view kCounterOverloadedCpu{"presto_cpp.overloaded_cpu"}; /// Export 100 if worker is overloaded in terms of memory or CPU, 0 otherwise. -constexpr folly::StringPiece kCounterOverloaded{"presto_cpp.overloaded"}; +constexpr std::string_view kCounterOverloaded{"presto_cpp.overloaded"}; /// Worker exports the average time tasks spend in the queue (considered /// planned) in milliseconds. -constexpr folly::StringPiece kCounterTaskPlannedTimeMs{ +constexpr std::string_view kCounterTaskPlannedTimeMs{ "presto_cpp.task_planned_time_ms"}; +/// Exports the current overloaded duration in seconds or 0 if not currently +/// overloaded. +constexpr std::string_view kCounterOverloadedDurationSec{ + "presto_cpp.overloaded_duration_sec"}; /// Number of total OutputBuffer managed by all /// OutputBufferManager -constexpr folly::StringPiece kCounterTotalPartitionedOutputBuffer{ +constexpr std::string_view kCounterTotalPartitionedOutputBuffer{ "presto_cpp.num_partitioned_output_buffer"}; /// Latency in millisecond of the get data call of a /// OutputBufferManager. -constexpr folly::StringPiece kCounterPartitionedOutputBufferGetDataLatencyMs{ +constexpr std::string_view kCounterPartitionedOutputBufferGetDataLatencyMs{ "presto_cpp.partitioned_output_buffer_get_data_latency_ms"}; +/// Worker runtime uptime in seconds after the worker process started. +constexpr std::string_view kCounterWorkerRuntimeUptimeSecs{ + "presto_cpp.worker_runtime_uptime_secs"}; /// ================== OS Counters ================= /// User CPU time of the presto_server process in microsecond since the process /// start. -constexpr folly::StringPiece kCounterOsUserCpuTimeMicros{ +constexpr std::string_view kCounterOsUserCpuTimeMicros{ "presto_cpp.os_user_cpu_time_micros"}; /// System CPU time of the presto_server process in microsecond since the /// process start. -constexpr folly::StringPiece kCounterOsSystemCpuTimeMicros{ +constexpr std::string_view kCounterOsSystemCpuTimeMicros{ "presto_cpp.os_system_cpu_time_micros"}; /// Total number of soft page faults of the presto_server process in microsecond /// since the process start. -constexpr folly::StringPiece kCounterOsNumSoftPageFaults{ +constexpr std::string_view kCounterOsNumSoftPageFaults{ "presto_cpp.os_num_soft_page_faults"}; /// Total number of hard page faults of the presto_server process in microsecond /// since the process start. -constexpr folly::StringPiece kCounterOsNumHardPageFaults{ +constexpr std::string_view kCounterOsNumHardPageFaults{ "presto_cpp.os_num_hard_page_faults"}; /// Total number of voluntary context switches in the presto_server process. -constexpr folly::StringPiece kCounterOsNumVoluntaryContextSwitches{ +constexpr std::string_view kCounterOsNumVoluntaryContextSwitches{ "presto_cpp.os_num_voluntary_context_switches"}; /// Total number of involuntary context switches in the presto_server process. -constexpr folly::StringPiece kCounterOsNumForcedContextSwitches{ +constexpr std::string_view kCounterOsNumForcedContextSwitches{ "presto_cpp.os_num_forced_context_switches"}; /// ================== HiveConnector Counters ================== @@ -196,25 +219,31 @@ constexpr std::string_view kCounterThreadPoolNumTotalTasksFormat{ constexpr std::string_view kCounterThreadPoolMaxIdleTimeNsFormat{ "presto_cpp.{}.max_idle_time_ns"}; +/// ================== EVB Counters ==================== +constexpr std::string_view kCounterExchangeIoEvbViolation{ + "presto_cpp.exchange_io_evb_violation_count"}; +constexpr std::string_view kCounterHttpServerIoEvbViolation{ + "presto_cpp.http_server_io_evb_violation_count"}; + /// ================== Memory Pushback Counters ================= /// Number of times memory pushback mechanism is triggered. -constexpr folly::StringPiece kCounterMemoryPushbackCount{ +constexpr std::string_view kCounterMemoryPushbackCount{ "presto_cpp.memory_pushback_count"}; /// Latency distribution of each memory pushback run in range of [0, 100s] and /// reports P50, P90, P99, and P100. -constexpr folly::StringPiece kCounterMemoryPushbackLatencyMs{ +constexpr std::string_view kCounterMemoryPushbackLatencyMs{ "presto_cpp.memory_pushback_latency_ms"}; /// Distribution of actual reduction in memory usage achieved by each memory /// pushback attempt. This is to gauge its effectiveness. In range of [0, 15GB] /// with 150 buckets and reports P50, P90, P99, and P100. -constexpr folly::StringPiece kCounterMemoryPushbackReductionBytes{ +constexpr std::string_view kCounterMemoryPushbackReductionBytes{ "presto_cpp.memory_pushback_reduction_bytes"}; /// Distribution of expected reduction in memory usage achieved by each memory /// pushback attempt. This is to gauge its effectiveness. In range of [0, 15GB] /// with 150 buckets and reports P50, P90, P99, and P100. The expected reduction /// can be different as other threads might have allocated memory in the /// meantime. -constexpr folly::StringPiece kCounterMemoryPushbackExpectedReductionBytes{ +constexpr std::string_view kCounterMemoryPushbackExpectedReductionBytes{ "presto_cpp.memory_pushback_expected_reduction_bytes"}; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/common/Exception.cpp b/presto-native-execution/presto_cpp/main/common/Exception.cpp index 0005b3c479161..03aa3ddad9232 100644 --- a/presto-native-execution/presto_cpp/main/common/Exception.cpp +++ b/presto-native-execution/presto_cpp/main/common/Exception.cpp @@ -15,10 +15,150 @@ #include "presto_cpp/main/common/Exception.h" namespace facebook::presto { + +VeloxToPrestoExceptionTranslator::VeloxToPrestoExceptionTranslator() { + // Register runtime errors + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kMemCapExceeded, + {.code = 0x00020007, + .name = "EXCEEDED_LOCAL_MEMORY_LIMIT", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kMemAborted, + {.code = 0x00020000, + .name = "GENERIC_INSUFFICIENT_RESOURCES", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kSpillLimitExceeded, + {.code = 0x00020006, + .name = "EXCEEDED_SPILL_LIMIT", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kMemArbitrationFailure, + {.code = 0x00020000, + .name = "MEMORY_ARBITRATION_FAILURE", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kMemArbitrationTimeout, + {.code = 0x00020000, + .name = "GENERIC_INSUFFICIENT_RESOURCES", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kMemAllocError, + {.code = 0x00020000, + .name = "GENERIC_INSUFFICIENT_RESOURCES", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kInvalidState, + {.code = 0x00010000, + .name = "GENERIC_INTERNAL_ERROR", + .type = protocol::ErrorType::INTERNAL_ERROR}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kGenericSpillFailure, + {.code = 0x00010023, + .name = "GENERIC_SPILL_FAILURE", + .type = protocol::ErrorType::INTERNAL_ERROR}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kUnreachableCode, + {.code = 0x00010000, + .name = "GENERIC_INTERNAL_ERROR", + .type = protocol::ErrorType::INTERNAL_ERROR}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kNotImplemented, + {.code = 0x00010000, + .name = "GENERIC_INTERNAL_ERROR", + .type = protocol::ErrorType::INTERNAL_ERROR}); + + registerError( + velox::error_source::kErrorSourceRuntime, + velox::error_code::kUnknown, + {.code = 0x00010000, + .name = "GENERIC_INTERNAL_ERROR", + .type = protocol::ErrorType::INTERNAL_ERROR}); + + registerError( + velox::error_source::kErrorSourceRuntime, + presto::error_code::kExceededLocalBroadcastJoinMemoryLimit, + {.code = 0x0002000C, + .name = "EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT", + .type = protocol::ErrorType::INSUFFICIENT_RESOURCES}); + + // Register user errors + registerError( + velox::error_source::kErrorSourceUser, + velox::error_code::kInvalidArgument, + {.code = 0x00000000, + .name = "GENERIC_USER_ERROR", + .type = protocol::ErrorType::USER_ERROR}); + + registerError( + velox::error_source::kErrorSourceUser, + velox::error_code::kUnsupported, + {.code = 0x0000000D, + .name = "NOT_SUPPORTED", + .type = protocol::ErrorType::USER_ERROR}); + + registerError( + velox::error_source::kErrorSourceUser, + velox::error_code::kUnsupportedInputUncatchable, + {.code = 0x0000000D, + .name = "NOT_SUPPORTED", + .type = protocol::ErrorType::USER_ERROR}); + + registerError( + velox::error_source::kErrorSourceUser, + velox::error_code::kArithmeticError, + {.code = 0x00000000, + .name = "GENERIC_USER_ERROR", + .type = protocol::ErrorType::USER_ERROR}); + + registerError( + velox::error_source::kErrorSourceUser, + velox::error_code::kSchemaMismatch, + {.code = 0x00000000, + .name = "GENERIC_USER_ERROR", + .type = protocol::ErrorType::USER_ERROR}); +} + +void VeloxToPrestoExceptionTranslator::registerError( + const std::string& errorSource, + const std::string& errorCode, + const protocol::ErrorCode& prestoErrorCode) { + auto& innerMap = errorMap_[errorSource]; + auto [it, inserted] = innerMap.emplace(errorCode, prestoErrorCode); + VELOX_CHECK( + inserted, + "Duplicate errorCode '{}' for errorSource '{}' is not allowed. " + "Existing mapping: [code={}, name={}, type={}]", + errorCode, + errorSource, + it->second.code, + it->second.name, + static_cast(it->second.type)); +} + protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( - const velox::VeloxException& e) { + const velox::VeloxException& e) const { protocol::ExecutionFailureInfo error; - // Line number must be >= 1 error.errorLocation.lineNumber = e.line() >= 1 ? e.line() : 1; error.errorLocation.columnNumber = 1; error.type = e.exceptionName(); @@ -31,8 +171,6 @@ protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( msg << " " << e.additionalContext(); } error.message = msg.str(); - // Stack trace may not be available if stack trace capturing is disabled or - // rate limited. if (e.stackTrace()) { error.stack = e.stackTrace()->toStrVector(); } @@ -40,8 +178,8 @@ protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( const auto& errorSource = e.errorSource(); const auto& errorCode = e.errorCode(); - auto itrErrorCodesMap = translateMap().find(errorSource); - if (itrErrorCodesMap != translateMap().end()) { + auto itrErrorCodesMap = errorMap_.find(errorSource); + if (itrErrorCodesMap != errorMap_.end()) { auto itrErrorCode = itrErrorCodesMap->second.find(errorCode); if (itrErrorCode != itrErrorCodesMap->second.end()) { error.errorCode = itrErrorCode->second; @@ -55,7 +193,7 @@ protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( } protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( - const std::exception& e) { + const std::exception& e) const { protocol::ExecutionFailureInfo error; error.errorLocation.lineNumber = 1; error.errorLocation.columnNumber = 1; @@ -66,4 +204,14 @@ protocol::ExecutionFailureInfo VeloxToPrestoExceptionTranslator::translate( error.message = e.what(); return error; } + +protocol::NativeSidecarFailureInfo toNativeSidecarFailureInfo( + const protocol::ExecutionFailureInfo& failure) { + facebook::presto::protocol::NativeSidecarFailureInfo nativeSidecarFailureInfo; + nativeSidecarFailureInfo.type = failure.type; + nativeSidecarFailureInfo.message = failure.message; + nativeSidecarFailureInfo.stack = failure.stack; + nativeSidecarFailureInfo.errorCode = failure.errorCode; + return nativeSidecarFailureInfo; +} } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/common/Exception.h b/presto-native-execution/presto_cpp/main/common/Exception.h index 0dd52460cbfcb..4ff27e1acb1dd 100644 --- a/presto-native-execution/presto_cpp/main/common/Exception.h +++ b/presto-native-execution/presto_cpp/main/common/Exception.h @@ -13,6 +13,7 @@ */ #pragma once +#include #include #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" #include "velox/common/base/VeloxException.h" @@ -27,96 +28,74 @@ struct ExecutionFailureInfo; struct ErrorCode; } // namespace protocol +namespace error_code { +using namespace folly::string_literals; + +/// An error raised when Presto broadcast join exceeds the broadcast size limit. +inline constexpr auto kExceededLocalBroadcastJoinMemoryLimit = + "EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT"_fs; +} // namespace error_code + +// Exception translator singleton for converting Velox exceptions to Presto +// errors. This follows the same pattern as velox/common/base/StatsReporter.h. +// +// IMPORTANT: folly::Singleton enforces single registration per type. +// - Only ONE registration of VeloxToPrestoExceptionTranslator can exist +// - Duplicate registrations will cause program to fail during static init +// - Extended servers must register a derived class class VeloxToPrestoExceptionTranslator { public: - // Translates to Presto error from Velox exceptions - static protocol::ExecutionFailureInfo translate( - const velox::VeloxException& e); + using ErrorCodeMap = std::unordered_map< + std::string, + std::unordered_map>; - // Translates to Presto error from std::exceptions - static protocol::ExecutionFailureInfo translate(const std::exception& e); + VeloxToPrestoExceptionTranslator(); - private: - static const std::unordered_map< - std::string, - std::unordered_map>& - translateMap() { - static const std::unordered_map< - std::string, - std::unordered_map> - kTranslateMap = { - {velox::error_source::kErrorSourceRuntime, - {{velox::error_code::kMemCapExceeded, - {0x00020007, - "EXCEEDED_LOCAL_MEMORY_LIMIT", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kMemAborted, - {0x00020000, - "GENERIC_INSUFFICIENT_RESOURCES", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kSpillLimitExceeded, - {0x00020006, - "EXCEEDED_SPILL_LIMIT", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kMemArbitrationFailure, - {0x00020000, - "MEMORY_ARBITRATION_FAILURE", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kMemArbitrationTimeout, - {0x00020000, - "GENERIC_INSUFFICIENT_RESOURCES", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kMemAllocError, - {0x00020000, - "GENERIC_INSUFFICIENT_RESOURCES", - protocol::ErrorType::INSUFFICIENT_RESOURCES}}, - - {velox::error_code::kInvalidState, - {0x00010000, - "GENERIC_INTERNAL_ERROR", - protocol::ErrorType::INTERNAL_ERROR}}, - - {velox::error_code::kGenericSpillFailure, - {0x00010023, - "GENERIC_SPILL_FAILURE", - protocol::ErrorType::INTERNAL_ERROR}}, - - {velox::error_code::kUnreachableCode, - {0x00010000, - "GENERIC_INTERNAL_ERROR", - protocol::ErrorType::INTERNAL_ERROR}}, - - {velox::error_code::kNotImplemented, - {0x00010000, - "GENERIC_INTERNAL_ERROR", - protocol::ErrorType::INTERNAL_ERROR}}, - - {velox::error_code::kUnknown, - {0x00010000, - "GENERIC_INTERNAL_ERROR", - protocol::ErrorType::INTERNAL_ERROR}}}}, - - {velox::error_source::kErrorSourceUser, - {{velox::error_code::kInvalidArgument, - {0x00000000, - "GENERIC_USER_ERROR", - protocol::ErrorType::USER_ERROR}}, - {velox::error_code::kUnsupported, - {0x0000000D, "NOT_SUPPORTED", protocol::ErrorType::USER_ERROR}}, - {velox::error_code::kUnsupportedInputUncatchable, - {0x0000000D, "NOT_SUPPORTED", protocol::ErrorType::USER_ERROR}}, - {velox::error_code::kArithmeticError, - {0x00000000, - "GENERIC_USER_ERROR", - protocol::ErrorType::USER_ERROR}}}}, - - {velox::error_source::kErrorSourceSystem, {}}}; - return kTranslateMap; + virtual ~VeloxToPrestoExceptionTranslator() = default; + + virtual protocol::ExecutionFailureInfo translate( + const velox::VeloxException& e) const; + + virtual protocol::ExecutionFailureInfo translate( + const std::exception& e) const; + + // For testing purposes only - provides access to the error map + const ErrorCodeMap& testingErrorMap() const { + return errorMap_; } + + protected: + void registerError( + const std::string& errorSource, + const std::string& errorCode, + const protocol::ErrorCode& prestoErrorCode); + + ErrorCodeMap errorMap_; }; + +// Global inline function APIs to translate exceptions (returns +// ExecutionFailureInfo) Similar pattern to StatsReporter, but returns a value +// instead of recording +inline protocol::ExecutionFailureInfo translateToPrestoException( + const velox::VeloxException& e) { + const auto translator = + folly::Singleton::try_get_fast(); + VELOX_CHECK_NOT_NULL( + translator, + "VeloxToPrestoExceptionTranslator singleton must be registered"); + return translator->translate(e); +} + +inline protocol::ExecutionFailureInfo translateToPrestoException( + const std::exception& e) { + const auto translator = + folly::Singleton::try_get_fast(); + VELOX_CHECK_NOT_NULL( + translator, + "VeloxToPrestoExceptionTranslator singleton must be registered"); + return translator->translate(e); +} + +protocol::NativeSidecarFailureInfo toNativeSidecarFailureInfo( + const protocol::ExecutionFailureInfo& failure); } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/common/Utils.cpp b/presto-native-execution/presto_cpp/main/common/Utils.cpp index 715eb1874d0d9..6befe209fc0b8 100644 --- a/presto-native-execution/presto_cpp/main/common/Utils.cpp +++ b/presto-native-execution/presto_cpp/main/common/Utils.cpp @@ -14,8 +14,11 @@ #include "presto_cpp/main/common/Utils.h" #include +#include #include +#include #include +#include "velox/common/base/Exceptions.h" #include "velox/common/process/ThreadDebugInfo.h" namespace facebook::presto::util { @@ -31,12 +34,18 @@ DateTime toISOTimestamp(uint64_t timeMilli) { std::shared_ptr createSSLContext( const std::string& clientCertAndKeyPath, - const std::string& ciphers) { + const std::string& ciphers, + bool http2Enabled) { try { auto sslContext = std::make_shared(); sslContext->loadCertKeyPairFromFiles( clientCertAndKeyPath.c_str(), clientCertAndKeyPath.c_str()); sslContext->setCiphersOrThrow(ciphers); + if (http2Enabled) { + sslContext->setAdvertisedNextProtocols({"h2", "http/1.1"}); + } else { + sslContext->setAdvertisedNextProtocols({"http/1.1"}); + } return sslContext; } catch (const std::exception& ex) { LOG(FATAL) << fmt::format( @@ -83,4 +92,58 @@ std::string extractMessageBody( } return ret; } + +std::string decompressMessageBody( + const std::vector>& body, + const std::string& contentEncoding) { + try { + // Combine all IOBufs into a single chain + std::unique_ptr combined; + for (const auto& buf : body) { + if (!combined) { + combined = buf->clone(); + } else { + combined->appendToChain(buf->clone()); + } + } + + // Determine compression codec type; Support only ZSTD for now + folly::compression::CodecType codecType; + if (contentEncoding == "zstd") { + codecType = folly::compression::CodecType::ZSTD; + } else { + VELOX_USER_FAIL("Unsupported Content-Encoding: {}", contentEncoding); + } + + // Decompress the data + auto codec = folly::compression::getCodec( + codecType); // getCodec never return nullptr + auto decompressed = codec->uncompress(combined.get()); + + size_t decompressedSize = decompressed->computeChainDataLength(); + + // Convert decompressed IOBuf to string + std::string ret; + ret.resize(decompressedSize); + folly::io::Cursor cursor(decompressed.get()); + cursor.pull(ret.data(), decompressedSize); + + return ret; + } catch (const std::exception& e) { + VELOX_USER_FAIL( + "Failed to decompress request body with {}: {}", + contentEncoding, + e.what()); + } +} + +const std::vector getFunctionNameParts( + const std::string& registeredFunction) { + std::vector parts; + folly::split('.', registeredFunction, parts, true); + VELOX_USER_CHECK( + parts.size() == 3, + fmt::format("Prefix missing for function {}", registeredFunction)); + return parts; +} } // namespace facebook::presto::util diff --git a/presto-native-execution/presto_cpp/main/common/Utils.h b/presto-native-execution/presto_cpp/main/common/Utils.h index 48d9a97d691ec..60e0a1a4a32d7 100644 --- a/presto-native-execution/presto_cpp/main/common/Utils.h +++ b/presto-native-execution/presto_cpp/main/common/Utils.h @@ -12,9 +12,9 @@ * limitations under the License. */ #pragma once +#include #include #include -#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" namespace facebook::presto::util { @@ -30,7 +30,8 @@ DateTime toISOTimestamp(uint64_t timeMilli); std::shared_ptr createSSLContext( const std::string& clientCertAndKeyPath, - const std::string& ciphers); + const std::string& ciphers, + bool http2Enabled); /// Returns current process-wide CPU time in nanoseconds. long getProcessCpuTimeNs(); @@ -48,9 +49,21 @@ void installSignalHandler(); std::string extractMessageBody( const std::vector>& body); +/// Decompress message body based on Content-Encoding +/// Throws exception if decompression fails +std::string decompressMessageBody( + const std::vector>& body, + const std::string& contentEncoding); + inline std::string addDefaultNamespacePrefix( const std::string& prestoDefaultNamespacePrefix, const std::string& functionName) { return fmt::format("{}{}", prestoDefaultNamespacePrefix, functionName); } + +/// The keys in velox function maps are of the format +/// `catalog.schema.function_name`. This utility function extracts the +/// three parts, {catalog, schema, function_name}, from the registered function. +const std::vector getFunctionNameParts( + const std::string& registeredFunction); } // namespace facebook::presto::util diff --git a/presto-native-execution/presto_cpp/main/common/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/common/tests/CMakeLists.txt index 1761f70856c82..5fcc6b0ca4681 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/common/tests/CMakeLists.txt @@ -9,8 +9,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_executable(presto_common_test CommonTest.cpp ConfigTest.cpp - BaseVeloxQueryConfigTest.cpp) +add_library(presto_mutable_configs MutableConfigs.cpp) + +target_link_libraries(presto_mutable_configs presto_common velox_file) + +add_executable(presto_common_test CommonTest.cpp ConfigTest.cpp) add_test(presto_common_test presto_common_test) @@ -23,11 +26,11 @@ target_link_libraries( velox_file velox_functions_prestosql velox_function_registry + velox_presto_serializer velox_presto_types velox_window ${RE2} GTest::gtest - GTest::gtest_main) +) -set_property(TARGET presto_common_test PROPERTY JOB_POOL_LINK - presto_link_job_pool) +set_property(TARGET presto_common_test PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/common/tests/CommonTest.cpp b/presto-native-execution/presto_cpp/main/common/tests/CommonTest.cpp index 7da469c3b113b..401dc89de9fc5 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/CommonTest.cpp +++ b/presto-native-execution/presto_cpp/main/common/tests/CommonTest.cpp @@ -16,9 +16,17 @@ #include "presto_cpp/main/common/Utils.h" #include "velox/common/base/Exceptions.h" +using namespace facebook; using namespace facebook::velox; using namespace facebook::presto; +namespace { +folly::Singleton + exceptionTranslatorSingleton([]() { + return new facebook::presto::VeloxToPrestoExceptionTranslator(); + }); +} // namespace + TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { FLAGS_velox_exception_user_stacktrace_enabled = true; for (const bool withContext : {false, true}) { @@ -46,7 +54,7 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { "operator()", "test message", "", - error_code::kArithmeticError, + velox::error_code::kArithmeticError, false); EXPECT_THROW({ throw userException; }, VeloxException); @@ -55,9 +63,12 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { } catch (const VeloxException& e) { EXPECT_EQ(e.exceptionName(), "VeloxUserError"); EXPECT_EQ(e.errorSource(), error_source::kErrorSourceUser); - EXPECT_EQ(e.errorCode(), error_code::kArithmeticError); + EXPECT_EQ(e.errorCode(), velox::error_code::kArithmeticError); - auto failureInfo = VeloxToPrestoExceptionTranslator::translate(e); + auto translator = + folly::Singleton::try_get(); + ASSERT_NE(translator, nullptr); + auto failureInfo = translateToPrestoException(e); EXPECT_EQ(failureInfo.type, e.exceptionName()); EXPECT_EQ(failureInfo.errorLocation.lineNumber, e.line()); EXPECT_EQ(failureInfo.errorCode.name, "GENERIC_USER_ERROR"); @@ -82,7 +93,7 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { "operator()", "test message", "", - error_code::kInvalidState, + velox::error_code::kInvalidState, false); EXPECT_THROW({ throw runtimeException; }, VeloxException); @@ -91,9 +102,12 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { } catch (const VeloxException& e) { EXPECT_EQ(e.exceptionName(), "VeloxRuntimeError"); EXPECT_EQ(e.errorSource(), error_source::kErrorSourceRuntime); - EXPECT_EQ(e.errorCode(), error_code::kInvalidState); + EXPECT_EQ(e.errorCode(), velox::error_code::kInvalidState); - auto failureInfo = VeloxToPrestoExceptionTranslator::translate(e); + auto translator = + folly::Singleton::try_get(); + ASSERT_NE(translator, nullptr); + auto failureInfo = translateToPrestoException(e); EXPECT_EQ(failureInfo.type, e.exceptionName()); EXPECT_EQ(failureInfo.errorLocation.lineNumber, e.line()); EXPECT_EQ(failureInfo.errorCode.name, "GENERIC_INTERNAL_ERROR"); @@ -107,9 +121,12 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { } catch (const VeloxException& e) { EXPECT_EQ(e.exceptionName(), "VeloxUserError"); EXPECT_EQ(e.errorSource(), error_source::kErrorSourceUser); - EXPECT_EQ(e.errorCode(), error_code::kInvalidArgument); + EXPECT_EQ(e.errorCode(), velox::error_code::kInvalidArgument); - auto failureInfo = VeloxToPrestoExceptionTranslator::translate(e); + auto translator = + folly::Singleton::try_get(); + ASSERT_NE(translator, nullptr); + auto failureInfo = translateToPrestoException(e); EXPECT_EQ(failureInfo.type, e.exceptionName()); EXPECT_EQ(failureInfo.errorLocation.lineNumber, e.line()); EXPECT_EQ(failureInfo.errorCode.name, "GENERIC_USER_ERROR"); @@ -120,8 +137,10 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { } std::runtime_error stdRuntimeError("Test error message"); - auto failureInfo = - VeloxToPrestoExceptionTranslator::translate((stdRuntimeError)); + auto translator = + folly::Singleton::try_get(); + ASSERT_NE(translator, nullptr); + auto failureInfo = translateToPrestoException((stdRuntimeError)); EXPECT_EQ(failureInfo.type, "std::exception"); EXPECT_EQ(failureInfo.errorLocation.lineNumber, 1); EXPECT_EQ(failureInfo.errorCode.name, "GENERIC_INTERNAL_ERROR"); @@ -129,6 +148,72 @@ TEST(VeloxToPrestoExceptionTranslatorTest, exceptionTranslation) { EXPECT_EQ(failureInfo.errorCode.type, protocol::ErrorType::INTERNAL_ERROR); } +TEST(VeloxToPrestoExceptionTranslatorTest, allErrorCodeTranslations) { + // Test all error codes in the translation map to ensure they translate + // correctly + auto translator = folly::Singleton< + facebook::presto::VeloxToPrestoExceptionTranslator>::try_get(); + ASSERT_NE(translator, nullptr); + const auto& translateMap = translator->testingErrorMap(); + + for (const auto& [errorSource, errorCodeMap] : translateMap) { + for (const auto& [errorCode, expectedErrorCode] : errorCodeMap) { + SCOPED_TRACE( + fmt::format( + "errorSource: {}, errorCode: {}", errorSource, errorCode)); + + // Determine the exception type based on error source + if (errorSource == velox::error_source::kErrorSourceRuntime) { + VeloxRuntimeError runtimeException( + "test_file.cpp", + 42, + "testFunction()", + "testExpression", + "test error message", + "", + errorCode, + false); + + auto failureInfo = translator->translate(runtimeException); + + EXPECT_EQ(failureInfo.errorCode.code, expectedErrorCode.code) + << "Error code mismatch for " << errorCode; + EXPECT_EQ(failureInfo.errorCode.name, expectedErrorCode.name) + << "Error name mismatch for " << errorCode; + EXPECT_EQ(failureInfo.errorCode.type, expectedErrorCode.type) + << "Error type mismatch for " << errorCode; + EXPECT_EQ(failureInfo.type, "VeloxRuntimeError"); + EXPECT_EQ(failureInfo.errorLocation.lineNumber, 42); + + } else if (errorSource == velox::error_source::kErrorSourceUser) { + VeloxUserError userException( + "test_file.cpp", + 42, + "testFunction()", + "testExpression", + "test error message", + "", + errorCode, + false); + + auto failureInfo = translator->translate(userException); + + EXPECT_EQ(failureInfo.errorCode.code, expectedErrorCode.code) + << "Error code mismatch for " << errorCode; + EXPECT_EQ(failureInfo.errorCode.name, expectedErrorCode.name) + << "Error name mismatch for " << errorCode; + EXPECT_EQ(failureInfo.errorCode.type, expectedErrorCode.type) + << "Error type mismatch for " << errorCode; + EXPECT_EQ(failureInfo.type, "VeloxUserError"); + EXPECT_EQ(failureInfo.errorLocation.lineNumber, 42); + + } else if (errorSource == velox::error_source::kErrorSourceSystem) { + FAIL(); + } + } + } +} + TEST(UtilsTest, general) { EXPECT_EQ("2021-05-20T19:18:27.001Z", util::toISOTimestamp(1621538307001l)); EXPECT_EQ("2021-05-20T19:18:27.000Z", util::toISOTimestamp(1621538307000l)); @@ -144,4 +229,36 @@ TEST(UtilsTest, extractMessageBody) { body.push_back(std::move(iobuf)); auto messageBody = util::extractMessageBody(body); EXPECT_EQ(messageBody, "body1body2body3body4body5"); -} \ No newline at end of file +} + +TEST(UtilsTest, getFunctionNameParts) { + { + auto parts = util::getFunctionNameParts("presto.default.my_function"); + ASSERT_EQ(parts.size(), 3); + EXPECT_EQ(parts[0], "presto"); + EXPECT_EQ(parts[1], "default"); + EXPECT_EQ(parts[2], "my_function"); + } + + { + auto parts = util::getFunctionNameParts("remote.catalog.sum"); + ASSERT_EQ(parts.size(), 3); + EXPECT_EQ(parts[0], "remote"); + EXPECT_EQ(parts[1], "catalog"); + EXPECT_EQ(parts[2], "sum"); + } + + EXPECT_THROW(util::getFunctionNameParts("catalog.function"), VeloxException); + EXPECT_THROW(util::getFunctionNameParts("function"), VeloxException); + EXPECT_THROW( + util::getFunctionNameParts("prefix.catalog.schema.function"), + VeloxException); + EXPECT_THROW(util::getFunctionNameParts(""), VeloxException); + EXPECT_THROW(util::getFunctionNameParts(".."), VeloxException); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::SingletonVault::singleton()->registrationComplete(); + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp index c119480ae4654..76096dad1145a 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp +++ b/presto-native-execution/presto_cpp/main/common/tests/ConfigTest.cpp @@ -11,9 +11,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include + #include #include + +#include +#include + #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Configs.h" #include "velox/common/base/Exceptions.h" @@ -212,12 +216,16 @@ TEST_F(ConfigTest, optionalNodeConfigs) { init(config, {{std::string(NodeConfig::kNodeIp), "127.0.0.1"}}); ASSERT_EQ( config.nodeInternalAddress([]() { return "0.0.0.0"; }), "127.0.0.1"); + + init( + config, {{std::string(NodeConfig::kNodePrometheusExecutorThreads), "4"}}); + ASSERT_EQ(config.prometheusExecutorThreads(), 4); } TEST_F(ConfigTest, optionalSystemConfigsWithDefault) { SystemConfig config; init(config, {}); - ASSERT_EQ(config.maxDriversPerTask(), std::thread::hardware_concurrency()); + ASSERT_EQ(config.maxDriversPerTask(), folly::hardware_concurrency()); init(config, {{std::string(SystemConfig::kMaxDriversPerTask), "1024"}}); ASSERT_EQ(config.maxDriversPerTask(), 1024); } diff --git a/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.cpp b/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.cpp new file mode 100644 index 0000000000000..cc729e58186ef --- /dev/null +++ b/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.cpp @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/common/tests/MutableConfigs.h" +#include "presto_cpp/main/common/Configs.h" +#include "velox/common/file/File.h" +#include "velox/common/file/FileSystems.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" + +using namespace facebook::velox; + +namespace facebook::presto::test { + +void setupMutableSystemConfig() { + auto dir = exec::test::TempDirectoryPath::create(); + auto sysConfigFilePath = fmt::format("{}/config.properties", dir->getPath()); + auto fileSystem = filesystems::getFileSystem(sysConfigFilePath, nullptr); + auto sysConfigFile = fileSystem->openFileForWrite(sysConfigFilePath); + sysConfigFile->append(fmt::format("{}=true\n", ConfigBase::kMutableConfig)); + sysConfigFile->append( + fmt::format("{}=4GB\n", SystemConfig::kQueryMaxMemoryPerNode)); + sysConfigFile->close(); + SystemConfig::instance()->initialize(sysConfigFilePath); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.h b/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.h new file mode 100644 index 0000000000000..16c3f740a4236 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/common/tests/MutableConfigs.h @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto::test { + +void setupMutableSystemConfig(); + +} diff --git a/presto-native-execution/presto_cpp/main/common/tests/test_json.h b/presto-native-execution/presto_cpp/main/common/tests/test_json.h index 169a48534a2be..e2aaafbb6a53e 100644 --- a/presto-native-execution/presto_cpp/main/common/tests/test_json.h +++ b/presto-native-execution/presto_cpp/main/common/tests/test_json.h @@ -13,10 +13,10 @@ */ #pragma once +#include +#include #include #include -#include -#include #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" @@ -52,7 +52,9 @@ inline std::string slurp(const std::string& path) { return buf.str(); } -inline std::string getDataPath(const std::string& dirUnderFbcode, const std::string& fileName) { +inline std::string getDataPath( + const std::string& dirUnderFbcode, + const std::string& fileName) { std::string currentPath = fs::current_path().c_str(); if (boost::algorithm::ends_with(currentPath, "fbcode")) { return currentPath + dirUnderFbcode + fileName; diff --git a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt index 6bcd3aaccadb2..29d56535a28ea 100644 --- a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt @@ -9,13 +9,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_library(presto_connectors Registration.cpp PrestoToVeloxConnector.cpp - SystemConnector.cpp) +add_library( + presto_connectors + ClpPrestoToVeloxConnector.cpp + IcebergPrestoToVeloxConnector.cpp + PrestoToVeloxConnectorUtils.cpp + HivePrestoToVeloxConnector.cpp + Registration.cpp + PrestoToVeloxConnector.cpp + SystemConnector.cpp +) if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) add_subdirectory(arrow_flight) + target_compile_definitions(presto_connectors PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) target_link_libraries(presto_connectors presto_flight_connector) endif() -target_link_libraries(presto_connectors presto_velox_expr_conversion - velox_clp_connector velox_type_fbhive) +if(PRESTO_ENABLE_CUDF) + target_link_libraries(presto_connectors velox_cudf_hive_connector cudf::cudf) +endif() + +target_link_libraries( + presto_connectors + presto_velox_expr_conversion + velox_clp_connector + velox_type_fbhive + velox_tpcds_connector +) + +add_subdirectory(hive) diff --git a/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..26d6f83f1d950 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.cpp @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/ClpPrestoToVeloxConnector.h" +#include "presto_cpp/main/types/TypeParser.h" +#include "velox/connectors/clp/ClpColumnHandle.h" +#include "velox/connectors/clp/ClpConnectorSplit.h" +#include "velox/connectors/clp/ClpTableHandle.h" + +namespace facebook::presto { + +using namespace velox; + +std::unique_ptr +ClpPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* /*splitContext*/) const { + auto clpSplit = dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + clpSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + catalogId, + clpSplit->path, + static_cast(clpSplit->type), + clpSplit->kqlQuery); +} + +std::unique_ptr +ClpPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto clpColumn = dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + clpColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + clpColumn->columnName, + clpColumn->originalColumnName, + typeParser.parse(clpColumn->columnType)); +} + +std::unique_ptr +ClpPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& /*exprConverter*/, + const TypeParser& /*typeParser*/) const { + auto clpLayout = + std::dynamic_pointer_cast( + tableHandle.connectorTableLayout); + VELOX_CHECK_NOT_NULL( + clpLayout, + "Unexpected layout type {}", + tableHandle.connectorTableLayout->_type); + return std::make_unique( + tableHandle.connectorId, clpLayout->table.schemaTableName.table); +} + +std::unique_ptr +ClpPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..0e553a4a365c7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ClpPrestoToVeloxConnector.h @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/presto_protocol/connector/clp/ClpConnectorProtocol.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" + +namespace facebook::presto { + +class ClpPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit ClpPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..e464614703fcf --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.cpp @@ -0,0 +1,498 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/connectors/HivePrestoToVeloxConnector.h" + +#include "presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" +#include "presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h" + +#include +#include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/HiveConnectorSplit.h" +#include "velox/connectors/hive/HiveDataSink.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/type/Filter.h" + +namespace facebook::presto { +using namespace velox; + +namespace { + +connector::hive::LocationHandle::TableType toTableType( + protocol::hive::TableType tableType) { + switch (tableType) { + case protocol::hive::TableType::NEW: + // Temporary tables are written and read by the SPI in a single pipeline. + // So they can be treated as New. They do not require Append or Overwrite + // semantics as applicable for regular tables. + case protocol::hive::TableType::TEMPORARY: + return connector::hive::LocationHandle::TableType::kNew; + case protocol::hive::TableType::EXISTING: + return connector::hive::LocationHandle::TableType::kExisting; + default: + VELOX_UNSUPPORTED("Unsupported table type: {}.", toJsonString(tableType)); + } +} + +std::shared_ptr toLocationHandle( + const protocol::hive::LocationHandle& locationHandle) { + return std::make_shared( + locationHandle.targetPath, + locationHandle.writePath, + toTableType(locationHandle.tableType)); +} + +velox::connector::hive::HiveBucketProperty::Kind toHiveBucketPropertyKind( + protocol::hive::BucketFunctionType bucketFuncType) { + switch (bucketFuncType) { + case protocol::hive::BucketFunctionType::PRESTO_NATIVE: + return velox::connector::hive::HiveBucketProperty::Kind::kPrestoNative; + case protocol::hive::BucketFunctionType::HIVE_COMPATIBLE: + return velox::connector::hive::HiveBucketProperty::Kind::kHiveCompatible; + default: + VELOX_USER_FAIL( + "Unknown hive bucket function: {}", toJsonString(bucketFuncType)); + } +} + +dwio::common::FileFormat toFileFormat( + const protocol::hive::HiveStorageFormat storageFormat, + const char* usage) { + switch (storageFormat) { + case protocol::hive::HiveStorageFormat::DWRF: + return dwio::common::FileFormat::DWRF; + case protocol::hive::HiveStorageFormat::PARQUET: + return dwio::common::FileFormat::PARQUET; + case protocol::hive::HiveStorageFormat::ALPHA: + // This has been renamed in Velox from ALPHA to NIMBLE. + return dwio::common::FileFormat::NIMBLE; + case protocol::hive::HiveStorageFormat::TEXTFILE: + return dwio::common::FileFormat::TEXT; + default: + VELOX_UNSUPPORTED( + "Unsupported file format in {}: {}.", + usage, + toJsonString(storageFormat)); + } +} + +std::vector stringToTypes( + const std::shared_ptr>& typeStrings, + const TypeParser& typeParser) { + std::vector types; + types.reserve(typeStrings->size()); + for (const auto& typeString : *typeStrings) { + types.push_back(stringToType(typeString, typeParser)); + } + return types; +} + +core::SortOrder toSortOrder(protocol::hive::Order order) { + switch (order) { + case protocol::hive::Order::ASCENDING: + return core::SortOrder(true, true); + case protocol::hive::Order::DESCENDING: + return core::SortOrder(false, false); + default: + VELOX_USER_FAIL("Unknown sort order: {}", toJsonString(order)); + } +} + +std::shared_ptr toHiveSortingColumn( + const protocol::hive::SortingColumn& sortingColumn) { + return std::make_shared( + sortingColumn.columnName, toSortOrder(sortingColumn.order)); +} + +std::vector> +toHiveSortingColumns( + const protocol::List& sortedBy) { + std::vector> + sortingColumns; + sortingColumns.reserve(sortedBy.size()); + for (const auto& sortingColumn : sortedBy) { + sortingColumns.push_back(toHiveSortingColumn(sortingColumn)); + } + return sortingColumns; +} + +std::shared_ptr +toHiveBucketProperty( + const std::vector>& + inputColumns, + const std::shared_ptr& bucketProperty, + const TypeParser& typeParser) { + if (bucketProperty == nullptr) { + return nullptr; + } + + VELOX_USER_CHECK_GT( + bucketProperty->bucketCount, 0, "Bucket count must be a positive value"); + + VELOX_USER_CHECK( + !bucketProperty->bucketedBy.empty(), + "Bucketed columns must be set: {}", + toJsonString(*bucketProperty)); + + const velox::connector::hive::HiveBucketProperty::Kind kind = + toHiveBucketPropertyKind(bucketProperty->bucketFunctionType); + std::vector bucketedTypes; + if (kind == + velox::connector::hive::HiveBucketProperty::Kind::kHiveCompatible) { + VELOX_USER_CHECK_NULL( + bucketProperty->types, + "Unexpected bucketed types set for hive compatible bucket function: {}", + toJsonString(*bucketProperty)); + bucketedTypes.reserve(bucketProperty->bucketedBy.size()); + for (const auto& bucketedColumn : bucketProperty->bucketedBy) { + TypePtr bucketedType{nullptr}; + for (const auto& inputColumn : inputColumns) { + if (inputColumn->name() != bucketedColumn) { + continue; + } + VELOX_USER_CHECK_NOT_NULL(inputColumn->hiveType()); + bucketedType = inputColumn->hiveType(); + break; + } + VELOX_USER_CHECK_NOT_NULL( + bucketedType, "Bucketed column {} not found", bucketedColumn); + bucketedTypes.push_back(std::move(bucketedType)); + } + } else { + VELOX_USER_CHECK_EQ( + bucketProperty->types->size(), + bucketProperty->bucketedBy.size(), + "Bucketed types is not set properly for presto native bucket function: {}", + toJsonString(*bucketProperty)); + bucketedTypes = stringToTypes(bucketProperty->types, typeParser); + } + + const auto sortedBy = toHiveSortingColumns(bucketProperty->sortedBy); + + return std::make_shared( + toHiveBucketPropertyKind(bucketProperty->bucketFunctionType), + bucketProperty->bucketCount, + bucketProperty->bucketedBy, + bucketedTypes, + sortedBy); +} + +std::unique_ptr +toVeloxHiveColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) { + auto* hiveColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + hiveColumn, "Unexpected column handle type {}", column->_type); + velox::type::fbhive::HiveTypeParser hiveTypeParser; + // TODO(spershin): Should we pass something different than 'typeSignature' + // to 'hiveType' argument of the 'HiveColumnHandle' constructor? + return std::make_unique( + hiveColumn->name, + toHiveColumnType(hiveColumn->columnType), + stringToType(hiveColumn->typeSignature, typeParser), + hiveTypeParser.parse(hiveColumn->hiveType), + toRequiredSubfields(hiveColumn->requiredSubfields)); +} + +velox::connector::hive::HiveBucketConversion toVeloxBucketConversion( + const protocol::hive::BucketConversion& bucketConversion) { + velox::connector::hive::HiveBucketConversion veloxBucketConversion; + // Current table bucket count (new). + veloxBucketConversion.tableBucketCount = bucketConversion.tableBucketCount; + // Partition bucket count (old). + veloxBucketConversion.partitionBucketCount = + bucketConversion.partitionBucketCount; + TypeParser typeParser; + for (const auto& column : bucketConversion.bucketColumnHandles) { + // Columns used as bucket input. + veloxBucketConversion.bucketColumnHandles.push_back( + toVeloxHiveColumnHandle(&column, typeParser)); + } + return veloxBucketConversion; +} + +dwio::common::FileFormat toVeloxFileFormat( + const presto::protocol::hive::StorageFormat& format) { + if (format.inputFormat == "com.facebook.hive.orc.OrcInputFormat") { + return dwio::common::FileFormat::DWRF; + } else if ( + format.inputFormat == "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat") { + return dwio::common::FileFormat::ORC; + } else if ( + format.inputFormat == + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat") { + return dwio::common::FileFormat::PARQUET; + } else if (format.inputFormat == "org.apache.hadoop.mapred.TextInputFormat") { + if (format.serDe == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { + return dwio::common::FileFormat::TEXT; + } else if (format.serDe == "org.apache.hive.hcatalog.data.JsonSerDe") { + return dwio::common::FileFormat::JSON; + } + } else if ( + format.inputFormat == + "org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat") { + if (format.serDe == + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe") { + return dwio::common::FileFormat::PARQUET; + } + } else if (format.inputFormat == "com.facebook.alpha.AlphaInputFormat") { + // ALPHA has been renamed in Velox to NIMBLE. + return dwio::common::FileFormat::NIMBLE; + } + VELOX_UNSUPPORTED( + "Unsupported file format: {} {}", format.inputFormat, format.serDe); +} + +} // namespace + +std::unique_ptr +HivePrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const { + auto hiveSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + hiveSplit, "Unexpected split type {}", connectorSplit->_type); + std::unordered_map> partitionKeys; + for (const auto& entry : hiveSplit->partitionKeys) { + partitionKeys.emplace( + entry.name, + entry.value == nullptr ? std::nullopt + : std::optional{*entry.value}); + } + std::unordered_map customSplitInfo; + for (const auto& [key, value] : hiveSplit->fileSplit.customSplitInfo) { + customSplitInfo[key] = value; + } + std::shared_ptr extraFileInfo; + if (hiveSplit->fileSplit.extraFileInfo) { + extraFileInfo = std::make_shared( + velox::encoding::Base64::decode(*hiveSplit->fileSplit.extraFileInfo)); + } + std::unordered_map serdeParameters; + serdeParameters.reserve(hiveSplit->storage.serdeParameters.size()); + for (const auto& [key, value] : hiveSplit->storage.serdeParameters) { + serdeParameters[key] = value; + } + std::unordered_map infoColumns = { + {"$path", hiveSplit->fileSplit.path}, + {"$file_size", std::to_string(hiveSplit->fileSplit.fileSize)}, + {"$file_modified_time", + std::to_string(hiveSplit->fileSplit.fileModifiedTime)}, + }; + if (hiveSplit->tableBucketNumber) { + infoColumns["$bucket"] = std::to_string(*hiveSplit->tableBucketNumber); + } + auto veloxSplit = + std::make_unique( + catalogId, + hiveSplit->fileSplit.path, + toVeloxFileFormat(hiveSplit->storage.storageFormat), + hiveSplit->fileSplit.start, + hiveSplit->fileSplit.length, + partitionKeys, + hiveSplit->tableBucketNumber + ? std::optional(*hiveSplit->tableBucketNumber) + : std::nullopt, + customSplitInfo, + extraFileInfo, + serdeParameters, + hiveSplit->splitWeight, + splitContext->cacheable, + infoColumns); + if (hiveSplit->bucketConversion) { + VELOX_CHECK_NOT_NULL(hiveSplit->tableBucketNumber); + veloxSplit->bucketConversion = + toVeloxBucketConversion(*hiveSplit->bucketConversion); + } + return veloxSplit; +} + +std::unique_ptr +HivePrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + return toVeloxHiveColumnHandle(column, typeParser); +} + +std::unique_ptr +HivePrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const { + auto hiveLayout = + std::dynamic_pointer_cast( + tableHandle.connectorTableLayout); + VELOX_CHECK_NOT_NULL( + hiveLayout, + "Unexpected layout type {}", + tableHandle.connectorTableLayout->_type); + + std::unordered_set columnNames; + std::vector columnHandles; + for (const auto& entry : hiveLayout->partitionColumns) { + if (columnNames.emplace(entry.name).second) { + columnHandles.emplace_back( + std::dynamic_pointer_cast( + std::shared_ptr(toVeloxColumnHandle(&entry, typeParser)))); + } + } + + // Add synthesized columns to the TableScanNode columnHandles as well. + for (const auto& entry : hiveLayout->predicateColumns) { + if (columnNames.emplace(entry.second.name).second) { + columnHandles.emplace_back( + std::dynamic_pointer_cast( + std::shared_ptr(toVeloxColumnHandle(&entry.second, typeParser)))); + } + } + + auto hiveTableHandle = + std::dynamic_pointer_cast( + tableHandle.connectorHandle); + VELOX_CHECK_NOT_NULL( + hiveTableHandle, + "Unexpected table handle type {}", + tableHandle.connectorHandle->_type); + + // Use fully qualified name if available. + std::string tableName = hiveTableHandle->schemaName.empty() + ? hiveTableHandle->tableName + : fmt::format( + "{}.{}", hiveTableHandle->schemaName, hiveTableHandle->tableName); + + return toHiveTableHandle( + hiveLayout->domainPredicate, + hiveLayout->remainingPredicate, + hiveLayout->pushdownFilterEnabled, + tableName, + hiveLayout->dataColumns, + tableHandle, + columnHandles, + hiveLayout->tableParameters, + exprConverter, + typeParser); +} + +std::unique_ptr +HivePrestoToVeloxConnector::toVeloxInsertTableHandle( + const protocol::CreateHandle* createHandle, + const TypeParser& typeParser) const { + auto hiveOutputTableHandle = + std::dynamic_pointer_cast( + createHandle->handle.connectorHandle); + VELOX_CHECK_NOT_NULL( + hiveOutputTableHandle, + "Unexpected output table handle type {}", + createHandle->handle.connectorHandle->_type); + bool isPartitioned{false}; + const auto inputColumns = toHiveColumns( + hiveOutputTableHandle->inputColumns, typeParser, isPartitioned); + return std::make_unique( + inputColumns, + toLocationHandle(hiveOutputTableHandle->locationHandle), + toFileFormat(hiveOutputTableHandle->actualStorageFormat, "TableWrite"), + toHiveBucketProperty( + inputColumns, hiveOutputTableHandle->bucketProperty, typeParser), + std::optional( + toFileCompressionKind(hiveOutputTableHandle->compressionCodec))); +} + +std::unique_ptr +HivePrestoToVeloxConnector::toVeloxInsertTableHandle( + const protocol::InsertHandle* insertHandle, + const TypeParser& typeParser) const { + auto hiveInsertTableHandle = + std::dynamic_pointer_cast( + insertHandle->handle.connectorHandle); + VELOX_CHECK_NOT_NULL( + hiveInsertTableHandle, + "Unexpected insert table handle type {}", + insertHandle->handle.connectorHandle->_type); + bool isPartitioned{false}; + const auto inputColumns = toHiveColumns( + hiveInsertTableHandle->inputColumns, typeParser, isPartitioned); + + const auto table = hiveInsertTableHandle->pageSinkMetadata.table; + VELOX_USER_CHECK_NOT_NULL(table, "Table must not be null for insert query"); + return std::make_unique( + inputColumns, + toLocationHandle(hiveInsertTableHandle->locationHandle), + toFileFormat(hiveInsertTableHandle->actualStorageFormat, "TableWrite"), + toHiveBucketProperty( + inputColumns, hiveInsertTableHandle->bucketProperty, typeParser), + std::optional( + toFileCompressionKind(hiveInsertTableHandle->compressionCodec)), + std::unordered_map( + table->storage.serdeParameters.begin(), + table->storage.serdeParameters.end())); +} + +std::vector> +HivePrestoToVeloxConnector::toHiveColumns( + const protocol::List& inputColumns, + const TypeParser& typeParser, + bool& hasPartitionColumn) const { + hasPartitionColumn = false; + std::vector> + hiveColumns; + hiveColumns.reserve(inputColumns.size()); + for (const auto& columnHandle : inputColumns) { + hasPartitionColumn |= + columnHandle.columnType == protocol::hive::ColumnType::PARTITION_KEY; + hiveColumns.emplace_back( + std::dynamic_pointer_cast( + std::shared_ptr(toVeloxColumnHandle(&columnHandle, typeParser)))); + } + return hiveColumns; +} + +std::unique_ptr +HivePrestoToVeloxConnector::createVeloxPartitionFunctionSpec( + const protocol::ConnectorPartitioningHandle* partitioningHandle, + const std::vector& bucketToPartition, + const std::vector& channels, + const std::vector& constValues, + bool& effectivelyGather) const { + auto hivePartitioningHandle = + dynamic_cast( + partitioningHandle); + VELOX_CHECK_NOT_NULL( + hivePartitioningHandle, + "Unexpected partitioning handle type {}", + partitioningHandle->_type); + VELOX_USER_CHECK( + hivePartitioningHandle->bucketFunctionType == + protocol::hive::BucketFunctionType::HIVE_COMPATIBLE, + "Unsupported Hive bucket function type: {}", + toJsonString(hivePartitioningHandle->bucketFunctionType)); + effectivelyGather = hivePartitioningHandle->bucketCount == 1; + return std::make_unique( + hivePartitioningHandle->bucketCount, + bucketToPartition, + channels, + constValues); +} + +std::unique_ptr +HivePrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.h new file mode 100644 index 0000000000000..fdbc4fc6d4f6c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/HivePrestoToVeloxConnector.h @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/core/PlanNode.h" +#include "velox/vector/ComplexVector.h" + +namespace facebook::presto { + +class HivePrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit HivePrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const final; + + std::unique_ptr + toVeloxInsertTableHandle( + const protocol::CreateHandle* createHandle, + const TypeParser& typeParser) const final; + + std::unique_ptr + toVeloxInsertTableHandle( + const protocol::InsertHandle* insertHandle, + const TypeParser& typeParser) const final; + + std::unique_ptr + createVeloxPartitionFunctionSpec( + const protocol::ConnectorPartitioningHandle* partitioningHandle, + const std::vector& bucketToPartition, + const std::vector& channels, + const std::vector& constValues, + bool& effectivelyGather) const final; + + std::unique_ptr createConnectorProtocol() + const final; + + private: + std::vector> + toHiveColumns( + const protocol::List& inputColumns, + const TypeParser& typeParser, + bool& hasPartitionColumn) const; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..b6ad329f0e5ae --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.cpp @@ -0,0 +1,387 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h" + +#include "presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h" +#include "velox/connectors/hive/iceberg/IcebergDataSink.h" +#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/type/fbhive/HiveTypeParser.h" + +namespace facebook::presto { + +namespace { + +velox::connector::hive::iceberg::FileContent toVeloxFileContent( + const presto::protocol::iceberg::FileContent content) { + if (content == protocol::iceberg::FileContent::DATA) { + return velox::connector::hive::iceberg::FileContent::kData; + } else if (content == protocol::iceberg::FileContent::POSITION_DELETES) { + return velox::connector::hive::iceberg::FileContent::kPositionalDeletes; + } + VELOX_UNSUPPORTED("Unsupported file content: {}", fmt::underlying(content)); +} + +velox::dwio::common::FileFormat toVeloxFileFormat( + const presto::protocol::iceberg::FileFormat format) { + if (format == protocol::iceberg::FileFormat::ORC) { + return velox::dwio::common::FileFormat::ORC; + } else if (format == protocol::iceberg::FileFormat::PARQUET) { + return velox::dwio::common::FileFormat::PARQUET; + } + VELOX_UNSUPPORTED("Unsupported file format: {}", fmt::underlying(format)); +} + +std::unique_ptr toIcebergTableHandle( + const protocol::TupleDomain& domainPredicate, + const std::shared_ptr& remainingPredicate, + bool isPushdownFilterEnabled, + const std::string& tableName, + const protocol::List& dataColumns, + const protocol::TableHandle& tableHandle, + const std::vector& + columnHandles, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) { + velox::common::SubfieldFilters subfieldFilters; + auto domains = domainPredicate.domains; + for (const auto& domain : *domains) { + auto filter = domain.second; + subfieldFilters[velox::common::Subfield(domain.first)] = + toFilter(domain.second, exprConverter, typeParser); + } + + auto remainingFilter = exprConverter.toVeloxExpr(remainingPredicate); + if (auto constant = + std::dynamic_pointer_cast( + remainingFilter)) { + bool value = constant->value().value(); + VELOX_CHECK(value, "Unexpected always-false remaining predicate"); + + // Use null for always-true filter. + remainingFilter = nullptr; + } + + velox::RowTypePtr finalDataColumns; + if (!dataColumns.empty()) { + std::vector names; + std::vector types; + velox::type::fbhive::HiveTypeParser hiveTypeParser; + names.reserve(dataColumns.size()); + types.reserve(dataColumns.size()); + for (auto& column : dataColumns) { + // For iceberg, the column name should be consistent with + // names in iceberg manifest file. The names in iceberg + // manifest file are consistent with the field names in + // parquet data file. + names.emplace_back(column.name); + auto parsedType = hiveTypeParser.parse(column.type); + // The type from the metastore may have upper case letters + // in field names, convert them all to lower case to be + // compatible with Presto. + types.push_back(VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, parsedType->kind(), parsedType)); + } + finalDataColumns = ROW(std::move(names), std::move(types)); + } + + return std::make_unique( + tableHandle.connectorId, + tableName, + isPushdownFilterEnabled, + std::move(subfieldFilters), + remainingFilter, + finalDataColumns, + std::unordered_map{}, + columnHandles); +} + +velox::connector::hive::iceberg::IcebergPartitionSpec::Field +toVeloxIcebergPartitionField( + const protocol::iceberg::IcebergPartitionField& field, + const TypeParser& typeParser, + const protocol::iceberg::PrestoIcebergSchema& schema) { + std::string type; + for (const auto& column : schema.columns) { + if (column.name == field.name) { + type = column.prestoType; + break; + } + } + + VELOX_USER_CHECK( + !type.empty(), + "Partition column not found in table schema: {}", + field.name); + + return velox::connector::hive::iceberg::IcebergPartitionSpec::Field{ + field.name, + stringToType(type, typeParser), + static_cast( + field.transform), + field.parameter ? *field.parameter : std::optional()}; +} + +std::unique_ptr +toVeloxIcebergPartitionSpec( + const protocol::iceberg::PrestoIcebergPartitionSpec& spec, + const TypeParser& typeParser) { + std::vector + fields; + fields.reserve(spec.fields.size()); + for (const auto& field : spec.fields) { + fields.emplace_back( + toVeloxIcebergPartitionField(field, typeParser, spec.schema)); + } + return std::make_unique< + velox::connector::hive::iceberg::IcebergPartitionSpec>( + spec.specId, fields); +} + +} // namespace + +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const { + auto icebergSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + icebergSplit, "Unexpected split type {}", connectorSplit->_type); + + std::unordered_map> partitionKeys; + for (const auto& entry : icebergSplit->partitionKeys) { + partitionKeys.emplace( + entry.second.name, + entry.second.value == nullptr + ? std::nullopt + : std::optional{*entry.second.value}); + } + + std::unordered_map customSplitInfo; + customSplitInfo["table_format"] = "hive-iceberg"; + + std::vector deletes; + deletes.reserve(icebergSplit->deletes.size()); + for (const auto& deleteFile : icebergSplit->deletes) { + std::unordered_map lowerBounds( + deleteFile.lowerBounds.begin(), deleteFile.lowerBounds.end()); + + std::unordered_map upperBounds( + deleteFile.upperBounds.begin(), deleteFile.upperBounds.end()); + + velox::connector::hive::iceberg::IcebergDeleteFile icebergDeleteFile( + toVeloxFileContent(deleteFile.content), + deleteFile.path, + toVeloxFileFormat(deleteFile.format), + deleteFile.recordCount, + deleteFile.fileSizeInBytes, + std::vector(deleteFile.equalityFieldIds), + lowerBounds, + upperBounds); + + deletes.emplace_back(icebergDeleteFile); + } + + std::unordered_map infoColumns = { + {"$data_sequence_number", + std::to_string(icebergSplit->dataSequenceNumber)}, + {"$path", icebergSplit->path}}; + + return std::make_unique( + catalogId, + icebergSplit->path, + toVeloxFileFormat(icebergSplit->fileFormat), + icebergSplit->start, + icebergSplit->length, + partitionKeys, + std::nullopt, + customSplitInfo, + nullptr, + splitContext->cacheable, + deletes, + infoColumns); +} + +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto icebergColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + icebergColumn, "Unexpected column handle type {}", column->_type); + // TODO(imjalpreet): Modify 'hiveType' argument of the 'HiveColumnHandle' + // constructor similar to how Hive Connector is handling for bucketing + velox::type::fbhive::HiveTypeParser hiveTypeParser; + auto type = stringToType(icebergColumn->type, typeParser); + velox::connector::hive::HiveColumnHandle::ColumnParseParameters + columnParseParameters; + if (type->isDate()) { + columnParseParameters.partitionDateValueFormat = velox::connector::hive:: + HiveColumnHandle::ColumnParseParameters::kDaysSinceEpoch; + } + return std::make_unique( + icebergColumn->columnIdentity.name, + toHiveColumnType(icebergColumn->columnType), + type, + type, + toRequiredSubfields(icebergColumn->requiredSubfields), + columnParseParameters); +} + +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const { + auto icebergLayout = std::dynamic_pointer_cast< + const protocol::iceberg::IcebergTableLayoutHandle>( + tableHandle.connectorTableLayout); + VELOX_CHECK_NOT_NULL( + icebergLayout, + "Unexpected layout type {}", + tableHandle.connectorTableLayout->_type); + + std::unordered_set columnNames; + std::vector columnHandles; + for (const auto& entry : icebergLayout->partitionColumns) { + if (columnNames.emplace(entry.columnIdentity.name).second) { + columnHandles.emplace_back( + std::dynamic_pointer_cast< + const velox::connector::hive::HiveColumnHandle>( + std::shared_ptr(toVeloxColumnHandle(&entry, typeParser)))); + } + } + + // Add synthesized columns to the TableScanNode columnHandles as well. + for (const auto& entry : icebergLayout->predicateColumns) { + if (columnNames.emplace(entry.second.columnIdentity.name).second) { + columnHandles.emplace_back( + std::dynamic_pointer_cast< + const velox::connector::hive::HiveColumnHandle>( + std::shared_ptr(toVeloxColumnHandle(&entry.second, typeParser)))); + } + } + + auto icebergTableHandle = + std::dynamic_pointer_cast( + tableHandle.connectorHandle); + VELOX_CHECK_NOT_NULL( + icebergTableHandle, + "Unexpected table handle type {}", + tableHandle.connectorHandle->_type); + + // Use fully qualified name if available. + std::string tableName = icebergTableHandle->schemaName.empty() + ? icebergTableHandle->icebergTableName.tableName + : fmt::format( + "{}.{}", + icebergTableHandle->schemaName, + icebergTableHandle->icebergTableName.tableName); + + return toIcebergTableHandle( + icebergLayout->domainPredicate, + icebergLayout->remainingPredicate, + icebergLayout->pushdownFilterEnabled, + tableName, + icebergLayout->dataColumns, + tableHandle, + columnHandles, + exprConverter, + typeParser); +} + +std::unique_ptr +IcebergPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxInsertTableHandle( + const protocol::CreateHandle* createHandle, + const TypeParser& typeParser) const { + auto icebergOutputTableHandle = + std::dynamic_pointer_cast( + createHandle->handle.connectorHandle); + + VELOX_CHECK_NOT_NULL( + icebergOutputTableHandle, + "Unexpected output table handle type {}", + createHandle->handle.connectorHandle->_type); + + const auto inputColumns = + toHiveColumns(icebergOutputTableHandle->inputColumns, typeParser); + + return std::make_unique< + velox::connector::hive::iceberg::IcebergInsertTableHandle>( + inputColumns, + std::make_shared( + fmt::format("{}/data", icebergOutputTableHandle->outputPath), + fmt::format("{}/data", icebergOutputTableHandle->outputPath), + velox::connector::hive::LocationHandle::TableType::kNew), + toVeloxFileFormat(icebergOutputTableHandle->fileFormat), + toVeloxIcebergPartitionSpec( + icebergOutputTableHandle->partitionSpec, typeParser), + std::optional( + toFileCompressionKind(icebergOutputTableHandle->compressionCodec))); +} + +std::unique_ptr +IcebergPrestoToVeloxConnector::toVeloxInsertTableHandle( + const protocol::InsertHandle* insertHandle, + const TypeParser& typeParser) const { + auto icebergInsertTableHandle = + std::dynamic_pointer_cast( + insertHandle->handle.connectorHandle); + + VELOX_CHECK_NOT_NULL( + icebergInsertTableHandle, + "Unexpected insert table handle type {}", + insertHandle->handle.connectorHandle->_type); + + const auto inputColumns = + toHiveColumns(icebergInsertTableHandle->inputColumns, typeParser); + + return std::make_unique< + velox::connector::hive::iceberg::IcebergInsertTableHandle>( + inputColumns, + std::make_shared( + fmt::format("{}/data", icebergInsertTableHandle->outputPath), + fmt::format("{}/data", icebergInsertTableHandle->outputPath), + velox::connector::hive::LocationHandle::TableType::kExisting), + toVeloxFileFormat(icebergInsertTableHandle->fileFormat), + toVeloxIcebergPartitionSpec( + icebergInsertTableHandle->partitionSpec, typeParser), + std::optional( + toFileCompressionKind(icebergInsertTableHandle->compressionCodec))); +} + +std::vector +IcebergPrestoToVeloxConnector::toHiveColumns( + const protocol::List& inputColumns, + const TypeParser& typeParser) const { + std::vector hiveColumns; + hiveColumns.reserve(inputColumns.size()); + for (const auto& columnHandle : inputColumns) { + hiveColumns.emplace_back( + std::dynamic_pointer_cast( + std::shared_ptr(toVeloxColumnHandle(&columnHandle, typeParser)))); + } + return hiveColumns; +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..c9336ba6c9bc4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h" + +namespace facebook::presto { + +class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit IcebergPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext) const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) const final; + + std::unique_ptr createConnectorProtocol() + const final; + + std::unique_ptr + toVeloxInsertTableHandle( + const protocol::CreateHandle* createHandle, + const TypeParser& typeParser) const final; + + std::unique_ptr + toVeloxInsertTableHandle( + const protocol::InsertHandle* insertHandle, + const TypeParser& typeParser) const final; + + private: + std::vector toHiveColumns( + const protocol::List& + inputColumns, + const TypeParser& typeParser) const; +}; + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp index 0784ad5ab1528..150029db6e761 100644 --- a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.cpp @@ -13,28 +13,26 @@ */ #include "presto_cpp/main/connectors/PrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h" #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include "presto_cpp/main/types/TypeParser.h" -#include "presto_cpp/presto_protocol/connector/clp/ClpConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/hive/HiveConnectorProtocol.h" -#include "presto_cpp/presto_protocol/connector/iceberg/IcebergConnectorProtocol.h" +#include "presto_cpp/presto_protocol/connector/tpcds/TpcdsConnectorProtocol.h" #include "presto_cpp/presto_protocol/connector/tpch/TpchConnectorProtocol.h" #include -#include "velox/connectors/clp/ClpColumnHandle.h" -#include "velox/connectors/clp/ClpConnectorSplit.h" -#include "velox/connectors/clp/ClpTableHandle.h" #include "velox/connectors/hive/HiveConnector.h" #include "velox/connectors/hive/HiveConnectorSplit.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/TableHandle.h" -#include "velox/connectors/hive/iceberg/IcebergDeleteFile.h" -#include "velox/connectors/hive/iceberg/IcebergSplit.h" +#include "velox/connectors/tpcds/TpcdsConnector.h" +#include "velox/connectors/tpcds/TpcdsConnectorSplit.h" #include "velox/connectors/tpch/TpchConnector.h" #include "velox/connectors/tpch/TpchConnectorSplit.h" #include "velox/type/Filter.h" namespace facebook::presto { +using namespace velox; namespace { std::unordered_map>& @@ -71,1436 +69,6 @@ const PrestoToVeloxConnector& getPrestoToVeloxConnector( return *(it->second); } -namespace { -using namespace velox; - -dwio::common::FileFormat toVeloxFileFormat( - const presto::protocol::hive::StorageFormat& format) { - if (format.inputFormat == "com.facebook.hive.orc.OrcInputFormat") { - return dwio::common::FileFormat::DWRF; - } else if ( - format.inputFormat == "org.apache.hadoop.hive.ql.io.orc.OrcInputFormat") { - return dwio::common::FileFormat::ORC; - } else if ( - format.inputFormat == - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat") { - return dwio::common::FileFormat::PARQUET; - } else if (format.inputFormat == "org.apache.hadoop.mapred.TextInputFormat") { - if (format.serDe == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { - return dwio::common::FileFormat::TEXT; - } else if (format.serDe == "org.apache.hive.hcatalog.data.JsonSerDe") { - return dwio::common::FileFormat::JSON; - } - } else if ( - format.inputFormat == - "org.apache.hadoop.hive.ql.io.SymlinkTextInputFormat") { - if (format.serDe == - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe") { - return dwio::common::FileFormat::PARQUET; - } - } else if (format.inputFormat == "com.facebook.alpha.AlphaInputFormat") { - // ALPHA has been renamed in Velox to NIMBLE. - return dwio::common::FileFormat::NIMBLE; - } - VELOX_UNSUPPORTED( - "Unsupported file format: {} {}", format.inputFormat, format.serDe); -} - -dwio::common::FileFormat toVeloxFileFormat( - const presto::protocol::iceberg::FileFormat format) { - if (format == protocol::iceberg::FileFormat::ORC) { - return dwio::common::FileFormat::ORC; - } else if (format == protocol::iceberg::FileFormat::PARQUET) { - return dwio::common::FileFormat::PARQUET; - } - VELOX_UNSUPPORTED("Unsupported file format: {}", fmt::underlying(format)); -} - -template -std::string toJsonString(const T& value) { - return ((json)value).dump(); -} - -TypePtr stringToType( - const std::string& typeString, - const TypeParser& typeParser) { - return typeParser.parse(typeString); -} - -connector::hive::HiveColumnHandle::ColumnType toHiveColumnType( - protocol::hive::ColumnType type) { - switch (type) { - case protocol::hive::ColumnType::PARTITION_KEY: - return connector::hive::HiveColumnHandle::ColumnType::kPartitionKey; - case protocol::hive::ColumnType::REGULAR: - return connector::hive::HiveColumnHandle::ColumnType::kRegular; - case protocol::hive::ColumnType::SYNTHESIZED: - return connector::hive::HiveColumnHandle::ColumnType::kSynthesized; - default: - VELOX_UNSUPPORTED( - "Unsupported Hive column type: {}.", toJsonString(type)); - } -} - -std::vector toRequiredSubfields( - const protocol::List& subfields) { - std::vector result; - result.reserve(subfields.size()); - for (auto& subfield : subfields) { - result.emplace_back(subfield); - } - return result; -} - -template -TypePtr fieldNamesToLowerCase(const TypePtr& type) { - return type; -} - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type); - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type); - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type); - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type) { - auto& elementType = type->childAt(0); - return std::make_shared(VELOX_DYNAMIC_TYPE_DISPATCH( - fieldNamesToLowerCase, elementType->kind(), elementType)); -} - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type) { - auto& keyType = type->childAt(0); - auto& valueType = type->childAt(1); - return std::make_shared( - VELOX_DYNAMIC_TYPE_DISPATCH( - fieldNamesToLowerCase, keyType->kind(), keyType), - VELOX_DYNAMIC_TYPE_DISPATCH( - fieldNamesToLowerCase, valueType->kind(), valueType)); -} - -template <> -TypePtr fieldNamesToLowerCase(const TypePtr& type) { - auto& rowType = type->asRow(); - std::vector names; - std::vector types; - names.reserve(type->size()); - types.reserve(type->size()); - for (int i = 0; i < rowType.size(); i++) { - std::string name = rowType.nameOf(i); - folly::toLowerAscii(name); - names.push_back(std::move(name)); - auto& childType = rowType.childAt(i); - types.push_back(VELOX_DYNAMIC_TYPE_DISPATCH( - fieldNamesToLowerCase, childType->kind(), childType)); - } - return std::make_shared(std::move(names), std::move(types)); -} - -int64_t toInt64( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto value = exprConverter.getConstantValue(type, *block); - return VariantConverter::convert(value) - .value(); -} - -int128_t toInt128( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto value = exprConverter.getConstantValue(type, *block); - return value.value(); -} - -Timestamp toTimestamp( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - const auto value = exprConverter.getConstantValue(type, *block); - return value.value(); -} - -int64_t dateToInt64( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto value = exprConverter.getConstantValue(type, *block); - return value.value(); -} - -template -T toFloatingPoint( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto variant = exprConverter.getConstantValue(type, *block); - return variant.value(); -} - -std::string toString( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto value = exprConverter.getConstantValue(type, *block); - if (type->isVarbinary()) { - return value.value(); - } - return value.value(); -} - -bool toBoolean( - const std::shared_ptr& block, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - auto variant = exprConverter.getConstantValue(type, *block); - return variant.value(); -} - -std::unique_ptr bigintRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowUnbounded = range.low.valueBlock == nullptr; - auto low = lowUnbounded ? std::numeric_limits::min() - : toInt64(range.low.valueBlock, exprConverter, type); - if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { - low++; - } - - bool highUnbounded = range.high.valueBlock == nullptr; - auto high = highUnbounded - ? std::numeric_limits::max() - : toInt64(range.high.valueBlock, exprConverter, type); - if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { - high--; - } - return std::make_unique(low, high, nullAllowed); -} - -std::unique_ptr hugeintRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowUnbounded = range.low.valueBlock == nullptr; - auto low = lowUnbounded ? std::numeric_limits::min() - : toInt128(range.low.valueBlock, exprConverter, type); - if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { - low++; - } - - bool highUnbounded = range.high.valueBlock == nullptr; - auto high = highUnbounded - ? std::numeric_limits::max() - : toInt128(range.high.valueBlock, exprConverter, type); - if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { - high--; - } - return std::make_unique(low, high, nullAllowed); -} - -std::unique_ptr timestampRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - const bool lowUnbounded = range.low.valueBlock == nullptr; - auto low = lowUnbounded - ? std::numeric_limits::min() - : toTimestamp(range.low.valueBlock, exprConverter, type); - if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { - ++low; - } - - const bool highUnbounded = range.high.valueBlock == nullptr; - auto high = highUnbounded - ? std::numeric_limits::max() - : toTimestamp(range.high.valueBlock, exprConverter, type); - if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { - --high; - } - return std::make_unique(low, high, nullAllowed); -} - -std::unique_ptr boolRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; - bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - bool highExclusive = range.high.bound == protocol::Bound::BELOW; - bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; - - if (!lowUnbounded && !highUnbounded) { - bool lowValue = toBoolean(range.low.valueBlock, exprConverter, type); - bool highValue = toBoolean(range.high.valueBlock, exprConverter, type); - VELOX_CHECK_EQ( - lowValue, - highValue, - "Boolean range should not be [FALSE, TRUE] after coordinator " - "optimization."); - return std::make_unique(lowValue, nullAllowed); - } - // Presto coordinator has made optimizations to the bool range already. For - // example, [FALSE, TRUE) will be optimized and shown here as (-infinity, - // TRUE). Plus (-infinity, +infinity) case has been guarded in toFilter() - // method, here it can only be one side bounded scenarios. - VELOX_CHECK_NE( - lowUnbounded, - highUnbounded, - "Passed in boolean range can only have one side bounded range scenario"); - if (!lowUnbounded) { - VELOX_CHECK( - highUnbounded, - "Boolean range should not be double side bounded after coordinator " - "optimization."); - bool lowValue = toBoolean(range.low.valueBlock, exprConverter, type); - - // (TRUE, +infinity) case, should resolve to filter all - if (lowExclusive && lowValue) { - if (nullAllowed) { - return std::make_unique(); - } - return std::make_unique(); - } - - // Both cases (FALSE, +infinity) or [TRUE, +infinity) should evaluate to - // true. Case [FALSE, +infinity) should not be expected - VELOX_CHECK( - !(!lowExclusive && !lowValue), - "Case [FALSE, +infinity) should " - "not be expected"); - return std::make_unique(true, nullAllowed); - } - if (!highUnbounded) { - VELOX_CHECK( - lowUnbounded, - "Boolean range should not be double side bounded after coordinator " - "optimization."); - bool highValue = toBoolean(range.high.valueBlock, exprConverter, type); - - // (-infinity, FALSE) case, should resolve to filter all - if (highExclusive && !highValue) { - if (nullAllowed) { - return std::make_unique(); - } - return std::make_unique(); - } - - // Both cases (-infinity, TRUE) or (-infinity, FALSE] should evaluate to - // false. Case (-infinity, TRUE] should not be expected - VELOX_CHECK( - !(!highExclusive && highValue), - "Case (-infinity, TRUE] should " - "not be expected"); - return std::make_unique(false, nullAllowed); - } - VELOX_UNREACHABLE(); -} - -template -std::unique_ptr floatingPointRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; - bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - auto low = lowUnbounded - ? (-1.0 * std::numeric_limits::infinity()) - : toFloatingPoint(range.low.valueBlock, exprConverter, type); - - bool highExclusive = range.high.bound == protocol::Bound::BELOW; - bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; - auto high = highUnbounded - ? std::numeric_limits::infinity() - : toFloatingPoint(range.high.valueBlock, exprConverter, type); - - // Handle NaN cases as NaN is not supported as a limit in Velox Filters - if (!lowUnbounded && std::isnan(low)) { - if (lowExclusive) { - // x > NaN is always false as NaN is considered the largest value. - return std::make_unique(); - } - // Equivalent to x > infinity as only NaN is greater than infinity - // Presto currently converts x >= NaN into the filter with domain - // [NaN, max), so ignoring the high value is fine. - low = std::numeric_limits::infinity(); - lowExclusive = true; - high = std::numeric_limits::infinity(); - highUnbounded = true; - highExclusive = false; - } else if (!highUnbounded && std::isnan(high)) { - high = std::numeric_limits::infinity(); - if (highExclusive) { - // equivalent to x in [low , infinity] or (low , infinity] - highExclusive = false; - } else { - if (lowUnbounded) { - // Anything <= NaN is true as NaN is the largest possible value. - return std::make_unique(); - } - // Equivalent to x > low or x >=low - highUnbounded = true; - } - } - - return std::make_unique>( - low, - lowUnbounded, - lowExclusive, - high, - highUnbounded, - highExclusive, - nullAllowed); -} - -std::unique_ptr varcharRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; - bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - auto low = - lowUnbounded ? "" : toString(range.low.valueBlock, exprConverter, type); - - bool highExclusive = range.high.bound == protocol::Bound::BELOW; - bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; - auto high = - highUnbounded ? "" : toString(range.high.valueBlock, exprConverter, type); - return std::make_unique( - low, - lowUnbounded, - lowExclusive, - high, - highUnbounded, - highExclusive, - nullAllowed); -} - -std::unique_ptr dateRangeToFilter( - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter, - const TypePtr& type) { - bool lowUnbounded = range.low.valueBlock == nullptr; - auto low = lowUnbounded - ? std::numeric_limits::min() - : dateToInt64(range.low.valueBlock, exprConverter, type); - if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { - low++; - } - - bool highUnbounded = range.high.valueBlock == nullptr; - auto high = highUnbounded - ? std::numeric_limits::max() - : dateToInt64(range.high.valueBlock, exprConverter, type); - if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { - high--; - } - - return std::make_unique(low, high, nullAllowed); -} - -std::unique_ptr combineIntegerRanges( - std::vector>& bigintFilters, - bool nullAllowed) { - bool allSingleValue = std::all_of( - bigintFilters.begin(), bigintFilters.end(), [](const auto& range) { - return range->isSingleValue(); - }); - - if (allSingleValue) { - std::vector values; - values.reserve(bigintFilters.size()); - for (const auto& filter : bigintFilters) { - values.emplace_back(filter->lower()); - } - return common::createBigintValues(values, nullAllowed); - } - - if (bigintFilters.size() == 2 && - bigintFilters[0]->lower() == std::numeric_limits::min() && - bigintFilters[1]->upper() == std::numeric_limits::max()) { - assert(bigintFilters[0]->upper() + 1 <= bigintFilters[1]->lower() - 1); - return std::make_unique( - bigintFilters[0]->upper() + 1, - bigintFilters[1]->lower() - 1, - nullAllowed); - } - - bool allNegatedValues = true; - bool foundMaximum = false; - assert(bigintFilters.size() > 1); // true by size checks on ranges - std::vector rejectedValues; - - // check if int64 min is a rejected value - if (bigintFilters[0]->lower() == std::numeric_limits::min() + 1) { - rejectedValues.emplace_back(std::numeric_limits::min()); - } - if (bigintFilters[0]->lower() > std::numeric_limits::min() + 1) { - // too many value at the lower end, bail out - return std::make_unique( - std::move(bigintFilters), nullAllowed); - } - rejectedValues.push_back(bigintFilters[0]->upper() + 1); - for (int i = 1; i < bigintFilters.size(); ++i) { - if (bigintFilters[i]->lower() != bigintFilters[i - 1]->upper() + 2) { - allNegatedValues = false; - break; - } - if (bigintFilters[i]->upper() == std::numeric_limits::max()) { - foundMaximum = true; - break; - } - rejectedValues.push_back(bigintFilters[i]->upper() + 1); - // make sure there is another range possible above this one - if (bigintFilters[i]->upper() == std::numeric_limits::max() - 1) { - foundMaximum = true; - break; - } - } - - if (allNegatedValues && foundMaximum) { - return common::createNegatedBigintValues(rejectedValues, nullAllowed); - } - - return std::make_unique( - std::move(bigintFilters), nullAllowed); -} - -std::unique_ptr combineBytesRanges( - std::vector>& bytesFilters, - bool nullAllowed) { - bool allSingleValue = std::all_of( - bytesFilters.begin(), bytesFilters.end(), [](const auto& range) { - return range->isSingleValue(); - }); - - if (allSingleValue) { - std::vector values; - values.reserve(bytesFilters.size()); - for (const auto& filter : bytesFilters) { - values.emplace_back(filter->lower()); - } - return std::make_unique(values, nullAllowed); - } - - int lowerUnbounded = 0, upperUnbounded = 0; - bool allExclusive = std::all_of( - bytesFilters.begin(), bytesFilters.end(), [](const auto& range) { - return range->lowerExclusive() && range->upperExclusive(); - }); - if (allExclusive) { - folly::F14FastSet unmatched; - std::vector rejectedValues; - rejectedValues.reserve(bytesFilters.size()); - for (int i = 0; i < bytesFilters.size(); ++i) { - if (bytesFilters[i]->isLowerUnbounded()) { - ++lowerUnbounded; - } else { - if (unmatched.contains(bytesFilters[i]->lower())) { - unmatched.erase(bytesFilters[i]->lower()); - rejectedValues.emplace_back(bytesFilters[i]->lower()); - } else { - unmatched.insert(bytesFilters[i]->lower()); - } - } - if (bytesFilters[i]->isUpperUnbounded()) { - ++upperUnbounded; - } else { - if (unmatched.contains(bytesFilters[i]->upper())) { - unmatched.erase(bytesFilters[i]->upper()); - rejectedValues.emplace_back(bytesFilters[i]->upper()); - } else { - unmatched.insert(bytesFilters[i]->upper()); - } - } - } - - if (lowerUnbounded == 1 && upperUnbounded == 1 && unmatched.size() == 0) { - return std::make_unique( - rejectedValues, nullAllowed); - } - } - - if (bytesFilters.size() == 2 && bytesFilters[0]->isLowerUnbounded() && - bytesFilters[1]->isUpperUnbounded()) { - // create a negated bytes range instead - return std::make_unique( - bytesFilters[0]->upper(), - false, - !bytesFilters[0]->upperExclusive(), - bytesFilters[1]->lower(), - false, - !bytesFilters[1]->lowerExclusive(), - nullAllowed); - } - - std::vector> bytesGeneric; - for (int i = 0; i < bytesFilters.size(); ++i) { - bytesGeneric.emplace_back(std::unique_ptr( - dynamic_cast(bytesFilters[i].release()))); - } - - return std::make_unique( - std::move(bytesGeneric), nullAllowed, false); -} - -std::unique_ptr toFilter( - const TypePtr& type, - const protocol::Range& range, - bool nullAllowed, - const VeloxExprConverter& exprConverter) { - if (type->isDate()) { - return dateRangeToFilter(range, nullAllowed, exprConverter, type); - } - switch (type->kind()) { - case TypeKind::TINYINT: - case TypeKind::SMALLINT: - case TypeKind::INTEGER: - case TypeKind::BIGINT: - return bigintRangeToFilter(range, nullAllowed, exprConverter, type); - case TypeKind::HUGEINT: - return hugeintRangeToFilter(range, nullAllowed, exprConverter, type); - case TypeKind::DOUBLE: - return floatingPointRangeToFilter( - range, nullAllowed, exprConverter, type); - case TypeKind::VARCHAR: - case TypeKind::VARBINARY: - return varcharRangeToFilter(range, nullAllowed, exprConverter, type); - case TypeKind::BOOLEAN: - return boolRangeToFilter(range, nullAllowed, exprConverter, type); - case TypeKind::REAL: - return floatingPointRangeToFilter( - range, nullAllowed, exprConverter, type); - case TypeKind::TIMESTAMP: - return timestampRangeToFilter(range, nullAllowed, exprConverter, type); - default: - VELOX_UNSUPPORTED("Unsupported range type: {}", type->toString()); - } -} - -std::unique_ptr toFilter( - const protocol::Domain& domain, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser) { - auto nullAllowed = domain.nullAllowed; - if (auto sortedRangeSet = - std::dynamic_pointer_cast(domain.values)) { - auto type = stringToType(sortedRangeSet->type, typeParser); - auto ranges = sortedRangeSet->ranges; - - if (ranges.empty()) { - VELOX_CHECK(nullAllowed, "Unexpected always-false filter"); - return std::make_unique(); - } - - if (ranges.size() == 1) { - // 'is not null' arrives as unbounded range with 'nulls not allowed'. - // We catch this case and create 'is not null' filter instead of the range - // filter. - const auto& range = ranges[0]; - bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; - bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; - bool highExclusive = range.high.bound == protocol::Bound::BELOW; - bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; - if (lowUnbounded && highUnbounded && !nullAllowed) { - return std::make_unique(); - } - - return toFilter(type, ranges[0], nullAllowed, exprConverter); - } - - if (type->isDate()) { - std::vector> dateFilters; - dateFilters.reserve(ranges.size()); - for (const auto& range : ranges) { - dateFilters.emplace_back( - dateRangeToFilter(range, nullAllowed, exprConverter, type)); - } - return std::make_unique( - std::move(dateFilters), nullAllowed); - } - - if (type->kind() == TypeKind::BIGINT || type->kind() == TypeKind::INTEGER || - type->kind() == TypeKind::SMALLINT || - type->kind() == TypeKind::TINYINT) { - std::vector> bigintFilters; - bigintFilters.reserve(ranges.size()); - for (const auto& range : ranges) { - bigintFilters.emplace_back( - bigintRangeToFilter(range, nullAllowed, exprConverter, type)); - } - return combineIntegerRanges(bigintFilters, nullAllowed); - } - - if (type->kind() == TypeKind::VARCHAR) { - std::vector> bytesFilters; - bytesFilters.reserve(ranges.size()); - for (const auto& range : ranges) { - bytesFilters.emplace_back( - varcharRangeToFilter(range, nullAllowed, exprConverter, type)); - } - return combineBytesRanges(bytesFilters, nullAllowed); - } - - if (type->kind() == TypeKind::BOOLEAN) { - VELOX_CHECK_EQ(ranges.size(), 2, "Multi bool ranges size can only be 2."); - std::unique_ptr boolFilter; - for (const auto& range : ranges) { - auto filter = - boolRangeToFilter(range, nullAllowed, exprConverter, type); - if (filter->kind() == common::FilterKind::kAlwaysFalse or - filter->kind() == common::FilterKind::kIsNull) { - continue; - } - VELOX_CHECK_NULL(boolFilter); - boolFilter = std::move(filter); - } - - VELOX_CHECK_NOT_NULL(boolFilter); - return boolFilter; - } - - std::vector> filters; - filters.reserve(ranges.size()); - for (const auto& range : ranges) { - filters.emplace_back(toFilter(type, range, nullAllowed, exprConverter)); - } - - return std::make_unique( - std::move(filters), nullAllowed, false); - } else if ( - auto equatableValueSet = - std::dynamic_pointer_cast( - domain.values)) { - if (equatableValueSet->entries.empty()) { - if (nullAllowed) { - return std::make_unique(); - } else { - return std::make_unique(); - } - } - VELOX_UNSUPPORTED( - "EquatableValueSet (with non-empty entries) to Velox filter conversion is not supported yet."); - } else if ( - auto allOrNoneValueSet = - std::dynamic_pointer_cast( - domain.values)) { - VELOX_UNSUPPORTED( - "AllOrNoneValueSet to Velox filter conversion is not supported yet."); - } - VELOX_UNSUPPORTED("Unsupported filter found."); -} - -std::unique_ptr toHiveTableHandle( - const protocol::TupleDomain& domainPredicate, - const std::shared_ptr& remainingPredicate, - bool isPushdownFilterEnabled, - const std::string& tableName, - const protocol::List& dataColumns, - const protocol::TableHandle& tableHandle, - const protocol::Map& tableParameters, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser) { - common::SubfieldFilters subfieldFilters; - auto domains = domainPredicate.domains; - for (const auto& domain : *domains) { - auto filter = domain.second; - subfieldFilters[common::Subfield(domain.first)] = - toFilter(domain.second, exprConverter, typeParser); - } - - auto remainingFilter = exprConverter.toVeloxExpr(remainingPredicate); - if (auto constant = std::dynamic_pointer_cast( - remainingFilter)) { - bool value = constant->value().value(); - VELOX_CHECK(value, "Unexpected always-false remaining predicate"); - - // Use null for always-true filter. - remainingFilter = nullptr; - } - - RowTypePtr finalDataColumns; - if (!dataColumns.empty()) { - std::vector names; - std::vector types; - velox::type::fbhive::HiveTypeParser hiveTypeParser; - names.reserve(dataColumns.size()); - types.reserve(dataColumns.size()); - for (auto& column : dataColumns) { - std::string name = column.name; - folly::toLowerAscii(name); - names.emplace_back(std::move(name)); - auto parsedType = hiveTypeParser.parse(column.type); - // The type from the metastore may have upper case letters - // in field names, convert them all to lower case to be - // compatible with Presto. - types.push_back(VELOX_DYNAMIC_TYPE_DISPATCH( - fieldNamesToLowerCase, parsedType->kind(), parsedType)); - } - finalDataColumns = ROW(std::move(names), std::move(types)); - } - - if (tableParameters.empty()) { - return std::make_unique( - tableHandle.connectorId, - tableName, - isPushdownFilterEnabled, - std::move(subfieldFilters), - remainingFilter, - finalDataColumns); - } - - std::unordered_map finalTableParameters = {}; - finalTableParameters.reserve(tableParameters.size()); - for (const auto& [key, value] : tableParameters) { - finalTableParameters[key] = value; - } - - return std::make_unique( - tableHandle.connectorId, - tableName, - isPushdownFilterEnabled, - std::move(subfieldFilters), - remainingFilter, - finalDataColumns, - finalTableParameters); -} - -connector::hive::LocationHandle::TableType toTableType( - protocol::hive::TableType tableType) { - switch (tableType) { - case protocol::hive::TableType::NEW: - // Temporary tables are written and read by the SPI in a single pipeline. - // So they can be treated as New. They do not require Append or Overwrite - // semantics as applicable for regular tables. - case protocol::hive::TableType::TEMPORARY: - return connector::hive::LocationHandle::TableType::kNew; - case protocol::hive::TableType::EXISTING: - return connector::hive::LocationHandle::TableType::kExisting; - default: - VELOX_UNSUPPORTED("Unsupported table type: {}.", toJsonString(tableType)); - } -} - -std::shared_ptr toLocationHandle( - const protocol::hive::LocationHandle& locationHandle) { - return std::make_shared( - locationHandle.targetPath, - locationHandle.writePath, - toTableType(locationHandle.tableType)); -} - -dwio::common::FileFormat toFileFormat( - const protocol::hive::HiveStorageFormat storageFormat, - const char* usage) { - switch (storageFormat) { - case protocol::hive::HiveStorageFormat::DWRF: - return dwio::common::FileFormat::DWRF; - case protocol::hive::HiveStorageFormat::PARQUET: - return dwio::common::FileFormat::PARQUET; - case protocol::hive::HiveStorageFormat::ALPHA: - // This has been renamed in Velox from ALPHA to NIMBLE. - return dwio::common::FileFormat::NIMBLE; - default: - VELOX_UNSUPPORTED( - "Unsupported file format in {}: {}.", - usage, - toJsonString(storageFormat)); - } -} - -velox::common::CompressionKind toFileCompressionKind( - const protocol::hive::HiveCompressionCodec& hiveCompressionCodec) { - switch (hiveCompressionCodec) { - case protocol::hive::HiveCompressionCodec::SNAPPY: - return velox::common::CompressionKind::CompressionKind_SNAPPY; - case protocol::hive::HiveCompressionCodec::GZIP: - return velox::common::CompressionKind::CompressionKind_GZIP; - case protocol::hive::HiveCompressionCodec::LZ4: - return velox::common::CompressionKind::CompressionKind_LZ4; - case protocol::hive::HiveCompressionCodec::ZSTD: - return velox::common::CompressionKind::CompressionKind_ZSTD; - case protocol::hive::HiveCompressionCodec::NONE: - return velox::common::CompressionKind::CompressionKind_NONE; - default: - VELOX_UNSUPPORTED( - "Unsupported file compression format: {}.", - toJsonString(hiveCompressionCodec)); - } -} - -velox::connector::hive::HiveBucketProperty::Kind toHiveBucketPropertyKind( - protocol::hive::BucketFunctionType bucketFuncType) { - switch (bucketFuncType) { - case protocol::hive::BucketFunctionType::PRESTO_NATIVE: - return velox::connector::hive::HiveBucketProperty::Kind::kPrestoNative; - case protocol::hive::BucketFunctionType::HIVE_COMPATIBLE: - return velox::connector::hive::HiveBucketProperty::Kind::kHiveCompatible; - default: - VELOX_USER_FAIL( - "Unknown hive bucket function: {}", toJsonString(bucketFuncType)); - } -} - -std::vector stringToTypes( - const std::shared_ptr>& typeStrings, - const TypeParser& typeParser) { - std::vector types; - types.reserve(typeStrings->size()); - for (const auto& typeString : *typeStrings) { - types.push_back(stringToType(typeString, typeParser)); - } - return types; -} - -core::SortOrder toSortOrder(protocol::hive::Order order) { - switch (order) { - case protocol::hive::Order::ASCENDING: - return core::SortOrder(true, true); - case protocol::hive::Order::DESCENDING: - return core::SortOrder(false, false); - default: - VELOX_USER_FAIL("Unknown sort order: {}", toJsonString(order)); - } -} - -std::shared_ptr toHiveSortingColumn( - const protocol::hive::SortingColumn& sortingColumn) { - return std::make_shared( - sortingColumn.columnName, toSortOrder(sortingColumn.order)); -} - -std::vector> -toHiveSortingColumns( - const protocol::List& sortedBy) { - std::vector> - sortingColumns; - sortingColumns.reserve(sortedBy.size()); - for (const auto& sortingColumn : sortedBy) { - sortingColumns.push_back(toHiveSortingColumn(sortingColumn)); - } - return sortingColumns; -} - -std::shared_ptr -toHiveBucketProperty( - const std::vector>& - inputColumns, - const std::shared_ptr& bucketProperty, - const TypeParser& typeParser) { - if (bucketProperty == nullptr) { - return nullptr; - } - - VELOX_USER_CHECK_GT( - bucketProperty->bucketCount, 0, "Bucket count must be a positive value"); - - VELOX_USER_CHECK( - !bucketProperty->bucketedBy.empty(), - "Bucketed columns must be set: {}", - toJsonString(*bucketProperty)); - - const velox::connector::hive::HiveBucketProperty::Kind kind = - toHiveBucketPropertyKind(bucketProperty->bucketFunctionType); - std::vector bucketedTypes; - if (kind == - velox::connector::hive::HiveBucketProperty::Kind::kHiveCompatible) { - VELOX_USER_CHECK_NULL( - bucketProperty->types, - "Unexpected bucketed types set for hive compatible bucket function: {}", - toJsonString(*bucketProperty)); - bucketedTypes.reserve(bucketProperty->bucketedBy.size()); - for (const auto& bucketedColumn : bucketProperty->bucketedBy) { - TypePtr bucketedType{nullptr}; - for (const auto& inputColumn : inputColumns) { - if (inputColumn->name() != bucketedColumn) { - continue; - } - VELOX_USER_CHECK_NOT_NULL(inputColumn->hiveType()); - bucketedType = inputColumn->hiveType(); - break; - } - VELOX_USER_CHECK_NOT_NULL( - bucketedType, "Bucketed column {} not found", bucketedColumn); - bucketedTypes.push_back(std::move(bucketedType)); - } - } else { - VELOX_USER_CHECK_EQ( - bucketProperty->types->size(), - bucketProperty->bucketedBy.size(), - "Bucketed types is not set properly for presto native bucket function: {}", - toJsonString(*bucketProperty)); - bucketedTypes = stringToTypes(bucketProperty->types, typeParser); - } - - const auto sortedBy = toHiveSortingColumns(bucketProperty->sortedBy); - - return std::make_shared( - toHiveBucketPropertyKind(bucketProperty->bucketFunctionType), - bucketProperty->bucketCount, - bucketProperty->bucketedBy, - bucketedTypes, - sortedBy); -} - -std::unique_ptr -toVeloxHiveColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) { - auto* hiveColumn = - dynamic_cast(column); - VELOX_CHECK_NOT_NULL( - hiveColumn, "Unexpected column handle type {}", column->_type); - velox::type::fbhive::HiveTypeParser hiveTypeParser; - // TODO(spershin): Should we pass something different than 'typeSignature' - // to 'hiveType' argument of the 'HiveColumnHandle' constructor? - return std::make_unique( - hiveColumn->name, - toHiveColumnType(hiveColumn->columnType), - stringToType(hiveColumn->typeSignature, typeParser), - hiveTypeParser.parse(hiveColumn->hiveType), - toRequiredSubfields(hiveColumn->requiredSubfields)); -} - -velox::connector::hive::HiveBucketConversion toVeloxBucketConversion( - const protocol::hive::BucketConversion& bucketConversion) { - velox::connector::hive::HiveBucketConversion veloxBucketConversion; - // Current table bucket count (new). - veloxBucketConversion.tableBucketCount = bucketConversion.tableBucketCount; - // Partition bucket count (old). - veloxBucketConversion.partitionBucketCount = - bucketConversion.partitionBucketCount; - TypeParser typeParser; - for (const auto& column : bucketConversion.bucketColumnHandles) { - // Columns used as bucket input. - veloxBucketConversion.bucketColumnHandles.push_back( - toVeloxHiveColumnHandle(&column, typeParser)); - } - return veloxBucketConversion; -} - -velox::connector::hive::iceberg::FileContent toVeloxFileContent( - const presto::protocol::iceberg::FileContent content) { - if (content == protocol::iceberg::FileContent::DATA) { - return velox::connector::hive::iceberg::FileContent::kData; - } else if (content == protocol::iceberg::FileContent::POSITION_DELETES) { - return velox::connector::hive::iceberg::FileContent::kPositionalDeletes; - } - VELOX_UNSUPPORTED("Unsupported file content: {}", fmt::underlying(content)); -} - -} // namespace - -std::unique_ptr -HivePrestoToVeloxConnector::toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { - auto hiveSplit = - dynamic_cast(connectorSplit); - VELOX_CHECK_NOT_NULL( - hiveSplit, "Unexpected split type {}", connectorSplit->_type); - std::unordered_map> partitionKeys; - for (const auto& entry : hiveSplit->partitionKeys) { - partitionKeys.emplace( - entry.name, - entry.value == nullptr ? std::nullopt - : std::optional{*entry.value}); - } - std::unordered_map customSplitInfo; - for (const auto& [key, value] : hiveSplit->fileSplit.customSplitInfo) { - customSplitInfo[key] = value; - } - std::shared_ptr extraFileInfo; - if (hiveSplit->fileSplit.extraFileInfo) { - extraFileInfo = std::make_shared( - velox::encoding::Base64::decode(*hiveSplit->fileSplit.extraFileInfo)); - } - std::unordered_map serdeParameters; - serdeParameters.reserve(hiveSplit->storage.serdeParameters.size()); - for (const auto& [key, value] : hiveSplit->storage.serdeParameters) { - serdeParameters[key] = value; - } - std::unordered_map infoColumns = { - {"$path", hiveSplit->fileSplit.path}, - {"$file_size", std::to_string(hiveSplit->fileSplit.fileSize)}, - {"$file_modified_time", - std::to_string(hiveSplit->fileSplit.fileModifiedTime)}, - }; - if (hiveSplit->tableBucketNumber) { - infoColumns["$bucket"] = std::to_string(*hiveSplit->tableBucketNumber); - } - auto veloxSplit = - std::make_unique( - catalogId, - hiveSplit->fileSplit.path, - toVeloxFileFormat(hiveSplit->storage.storageFormat), - hiveSplit->fileSplit.start, - hiveSplit->fileSplit.length, - partitionKeys, - hiveSplit->tableBucketNumber - ? std::optional(*hiveSplit->tableBucketNumber) - : std::nullopt, - customSplitInfo, - extraFileInfo, - serdeParameters, - hiveSplit->splitWeight, - splitContext->cacheable, - infoColumns); - if (hiveSplit->bucketConversion) { - VELOX_CHECK_NOT_NULL(hiveSplit->tableBucketNumber); - veloxSplit->bucketConversion = - toVeloxBucketConversion(*hiveSplit->bucketConversion); - } - return veloxSplit; -} - -std::unique_ptr -HivePrestoToVeloxConnector::toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const { - return toVeloxHiveColumnHandle(column, typeParser); -} - -std::unique_ptr -HivePrestoToVeloxConnector::toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { - auto addSynthesizedColumn = [&](const std::string& name, - protocol::hive::ColumnType columnType, - const protocol::ColumnHandle& column) { - if (toHiveColumnType(columnType) == - velox::connector::hive::HiveColumnHandle::ColumnType::kSynthesized) { - if (assignments.count(name) == 0) { - assignments.emplace(name, toVeloxColumnHandle(&column, typeParser)); - } - } - }; - auto hiveLayout = - std::dynamic_pointer_cast( - tableHandle.connectorTableLayout); - VELOX_CHECK_NOT_NULL( - hiveLayout, - "Unexpected layout type {}", - tableHandle.connectorTableLayout->_type); - for (const auto& entry : hiveLayout->partitionColumns) { - assignments.emplace(entry.name, toVeloxColumnHandle(&entry, typeParser)); - } - - // Add synthesized columns to the TableScanNode columnHandles as well. - for (const auto& entry : hiveLayout->predicateColumns) { - addSynthesizedColumn(entry.first, entry.second.columnType, entry.second); - } - - auto hiveTableHandle = - std::dynamic_pointer_cast( - tableHandle.connectorHandle); - VELOX_CHECK_NOT_NULL( - hiveTableHandle, - "Unexpected table handle type {}", - tableHandle.connectorHandle->_type); - - // Use fully qualified name if available. - std::string tableName = hiveTableHandle->schemaName.empty() - ? hiveTableHandle->tableName - : fmt::format( - "{}.{}", hiveTableHandle->schemaName, hiveTableHandle->tableName); - - return toHiveTableHandle( - hiveLayout->domainPredicate, - hiveLayout->remainingPredicate, - hiveLayout->pushdownFilterEnabled, - tableName, - hiveLayout->dataColumns, - tableHandle, - hiveLayout->tableParameters, - exprConverter, - typeParser); -} - -std::unique_ptr -HivePrestoToVeloxConnector::toVeloxInsertTableHandle( - const protocol::CreateHandle* createHandle, - const TypeParser& typeParser) const { - auto hiveOutputTableHandle = - std::dynamic_pointer_cast( - createHandle->handle.connectorHandle); - VELOX_CHECK_NOT_NULL( - hiveOutputTableHandle, - "Unexpected output table handle type {}", - createHandle->handle.connectorHandle->_type); - bool isPartitioned{false}; - const auto inputColumns = toHiveColumns( - hiveOutputTableHandle->inputColumns, typeParser, isPartitioned); - return std::make_unique( - inputColumns, - toLocationHandle(hiveOutputTableHandle->locationHandle), - toFileFormat(hiveOutputTableHandle->actualStorageFormat, "TableWrite"), - toHiveBucketProperty( - inputColumns, hiveOutputTableHandle->bucketProperty, typeParser), - std::optional( - toFileCompressionKind(hiveOutputTableHandle->compressionCodec))); -} - -std::unique_ptr -HivePrestoToVeloxConnector::toVeloxInsertTableHandle( - const protocol::InsertHandle* insertHandle, - const TypeParser& typeParser) const { - auto hiveInsertTableHandle = - std::dynamic_pointer_cast( - insertHandle->handle.connectorHandle); - VELOX_CHECK_NOT_NULL( - hiveInsertTableHandle, - "Unexpected insert table handle type {}", - insertHandle->handle.connectorHandle->_type); - bool isPartitioned{false}; - const auto inputColumns = toHiveColumns( - hiveInsertTableHandle->inputColumns, typeParser, isPartitioned); - - const auto table = hiveInsertTableHandle->pageSinkMetadata.table; - VELOX_USER_CHECK_NOT_NULL(table, "Table must not be null for insert query"); - return std::make_unique( - inputColumns, - toLocationHandle(hiveInsertTableHandle->locationHandle), - toFileFormat(hiveInsertTableHandle->actualStorageFormat, "TableWrite"), - toHiveBucketProperty( - inputColumns, hiveInsertTableHandle->bucketProperty, typeParser), - std::optional( - toFileCompressionKind(hiveInsertTableHandle->compressionCodec)), - std::unordered_map( - table->storage.serdeParameters.begin(), - table->storage.serdeParameters.end())); -} - -std::vector> -HivePrestoToVeloxConnector::toHiveColumns( - const protocol::List& inputColumns, - const TypeParser& typeParser, - bool& hasPartitionColumn) const { - hasPartitionColumn = false; - std::vector> - hiveColumns; - hiveColumns.reserve(inputColumns.size()); - for (const auto& columnHandle : inputColumns) { - hasPartitionColumn |= - columnHandle.columnType == protocol::hive::ColumnType::PARTITION_KEY; - hiveColumns.emplace_back( - std::dynamic_pointer_cast( - std::shared_ptr(toVeloxColumnHandle(&columnHandle, typeParser)))); - } - return hiveColumns; -} - -std::unique_ptr -HivePrestoToVeloxConnector::createVeloxPartitionFunctionSpec( - const protocol::ConnectorPartitioningHandle* partitioningHandle, - const std::vector& bucketToPartition, - const std::vector& channels, - const std::vector& constValues, - bool& effectivelyGather) const { - auto hivePartitioningHandle = - dynamic_cast( - partitioningHandle); - VELOX_CHECK_NOT_NULL( - hivePartitioningHandle, - "Unexpected partitioning handle type {}", - partitioningHandle->_type); - VELOX_USER_CHECK( - hivePartitioningHandle->bucketFunctionType == - protocol::hive::BucketFunctionType::HIVE_COMPATIBLE, - "Unsupported Hive bucket function type: {}", - toJsonString(hivePartitioningHandle->bucketFunctionType)); - effectivelyGather = hivePartitioningHandle->bucketCount == 1; - return std::make_unique( - hivePartitioningHandle->bucketCount, - bucketToPartition, - channels, - constValues); -} - -std::unique_ptr -HivePrestoToVeloxConnector::createConnectorProtocol() const { - return std::make_unique(); -} - -std::unique_ptr -IcebergPrestoToVeloxConnector::toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { - auto icebergSplit = - dynamic_cast(connectorSplit); - VELOX_CHECK_NOT_NULL( - icebergSplit, "Unexpected split type {}", connectorSplit->_type); - - std::unordered_map> partitionKeys; - for (const auto& entry : icebergSplit->partitionKeys) { - partitionKeys.emplace( - entry.second.name, - entry.second.value == nullptr - ? std::nullopt - : std::optional{*entry.second.value}); - } - - std::unordered_map customSplitInfo; - customSplitInfo["table_format"] = "hive-iceberg"; - - std::vector deletes; - deletes.reserve(icebergSplit->deletes.size()); - for (const auto& deleteFile : icebergSplit->deletes) { - std::unordered_map lowerBounds( - deleteFile.lowerBounds.begin(), deleteFile.lowerBounds.end()); - - std::unordered_map upperBounds( - deleteFile.upperBounds.begin(), deleteFile.upperBounds.end()); - - velox::connector::hive::iceberg::IcebergDeleteFile icebergDeleteFile( - toVeloxFileContent(deleteFile.content), - deleteFile.path, - toVeloxFileFormat(deleteFile.format), - deleteFile.recordCount, - deleteFile.fileSizeInBytes, - std::vector(deleteFile.equalityFieldIds), - lowerBounds, - upperBounds); - - deletes.emplace_back(icebergDeleteFile); - } - - std::unordered_map infoColumns = { - {"$data_sequence_number", - std::to_string(icebergSplit->dataSequenceNumber)}, - {"$path", icebergSplit->path}}; - - return std::make_unique( - catalogId, - icebergSplit->path, - toVeloxFileFormat(icebergSplit->fileFormat), - icebergSplit->start, - icebergSplit->length, - partitionKeys, - std::nullopt, - customSplitInfo, - nullptr, - splitContext->cacheable, - deletes, - infoColumns); -} - -std::unique_ptr -IcebergPrestoToVeloxConnector::toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const { - auto icebergColumn = - dynamic_cast(column); - VELOX_CHECK_NOT_NULL( - icebergColumn, "Unexpected column handle type {}", column->_type); - // TODO(imjalpreet): Modify 'hiveType' argument of the 'HiveColumnHandle' - // constructor similar to how Hive Connector is handling for bucketing - velox::type::fbhive::HiveTypeParser hiveTypeParser; - auto type = stringToType(icebergColumn->type, typeParser); - connector::hive::HiveColumnHandle::ColumnParseParameters - columnParseParameters; - if (type->isDate()) { - columnParseParameters.partitionDateValueFormat = connector::hive:: - HiveColumnHandle::ColumnParseParameters::kDaysSinceEpoch; - } - return std::make_unique( - icebergColumn->columnIdentity.name, - toHiveColumnType(icebergColumn->columnType), - type, - type, - toRequiredSubfields(icebergColumn->requiredSubfields), - columnParseParameters); -} - -std::unique_ptr -IcebergPrestoToVeloxConnector::toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { - auto addSynthesizedColumn = [&](const std::string& name, - protocol::hive::ColumnType columnType, - const protocol::ColumnHandle& column) { - if (toHiveColumnType(columnType) == - velox::connector::hive::HiveColumnHandle::ColumnType::kSynthesized) { - if (assignments.count(name) == 0) { - assignments.emplace(name, toVeloxColumnHandle(&column, typeParser)); - } - } - }; - - auto icebergLayout = std::dynamic_pointer_cast< - const protocol::iceberg::IcebergTableLayoutHandle>( - tableHandle.connectorTableLayout); - VELOX_CHECK_NOT_NULL( - icebergLayout, - "Unexpected layout type {}", - tableHandle.connectorTableLayout->_type); - - for (const auto& entry : icebergLayout->partitionColumns) { - assignments.emplace( - entry.columnIdentity.name, toVeloxColumnHandle(&entry, typeParser)); - } - - // Add synthesized columns to the TableScanNode columnHandles as well. - for (const auto& entry : icebergLayout->predicateColumns) { - addSynthesizedColumn(entry.first, entry.second.columnType, entry.second); - } - - auto icebergTableHandle = - std::dynamic_pointer_cast( - tableHandle.connectorHandle); - VELOX_CHECK_NOT_NULL( - icebergTableHandle, - "Unexpected table handle type {}", - tableHandle.connectorHandle->_type); - - // Use fully qualified name if available. - std::string tableName = icebergTableHandle->schemaName.empty() - ? icebergTableHandle->icebergTableName.tableName - : fmt::format( - "{}.{}", - icebergTableHandle->schemaName, - icebergTableHandle->icebergTableName.tableName); - - return toHiveTableHandle( - icebergLayout->domainPredicate, - icebergLayout->remainingPredicate, - icebergLayout->pushdownFilterEnabled, - tableName, - icebergLayout->dataColumns, - tableHandle, - {}, - exprConverter, - typeParser); -} - -std::unique_ptr -IcebergPrestoToVeloxConnector::createConnectorProtocol() const { - return std::make_unique(); -} - std::unique_ptr TpchPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, @@ -1533,10 +101,7 @@ std::unique_ptr TpchPrestoToVeloxConnector::toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { + const TypeParser& typeParser) const { auto tpchLayout = std::dynamic_pointer_cast( tableHandle.connectorTableLayout); @@ -1556,55 +121,53 @@ TpchPrestoToVeloxConnector::createConnectorProtocol() const { } std::unique_ptr -ClpPrestoToVeloxConnector::toVeloxSplit( +TpcdsPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, const protocol::SplitContext* splitContext) const { - auto clpSplit = dynamic_cast(connectorSplit); + auto tpcdsSplit = + dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( - clpSplit, "Unexpected split type {}", connectorSplit->_type); - return std::make_unique( + tpcdsSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( catalogId, - clpSplit->path, - static_cast(clpSplit->type), - clpSplit->kqlQuery); + splitContext->cacheable, + tpcdsSplit->totalParts, + tpcdsSplit->partNumber); } std::unique_ptr -ClpPrestoToVeloxConnector::toVeloxColumnHandle( +TpcdsPrestoToVeloxConnector::toVeloxColumnHandle( const protocol::ColumnHandle* column, const TypeParser& typeParser) const { - auto clpColumn = dynamic_cast(column); + auto tpcdsColumn = + dynamic_cast(column); VELOX_CHECK_NOT_NULL( - clpColumn, "Unexpected column handle type {}", column->_type); - return std::make_unique( - clpColumn->columnName, - clpColumn->originalColumnName, - typeParser.parse(clpColumn->columnType)); + tpcdsColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + tpcdsColumn->columnName); } std::unique_ptr -ClpPrestoToVeloxConnector::toVeloxTableHandle( +TpcdsPrestoToVeloxConnector::toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { - auto clpLayout = - std::dynamic_pointer_cast( + const TypeParser& typeParser) const { + auto tpcdsLayout = + std::dynamic_pointer_cast( tableHandle.connectorTableLayout); VELOX_CHECK_NOT_NULL( - clpLayout, + tpcdsLayout, "Unexpected layout type {}", tableHandle.connectorTableLayout->_type); - return std::make_unique( - tableHandle.connectorId, clpLayout->table.schemaTableName.table); + return std::make_unique( + tableHandle.connectorId, + tpcds::fromTableName(tpcdsLayout->table.tableName), + tpcdsLayout->table.scaleFactor); } std::unique_ptr -ClpPrestoToVeloxConnector::createConnectorProtocol() const { - return std::make_unique(); +TpcdsPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); } - } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h index c20fa3fd3a73e..910d9956eec2e 100644 --- a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnector.h @@ -58,11 +58,7 @@ class PrestoToVeloxConnector { toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const = 0; + const TypeParser& typeParser) const = 0; [[nodiscard]] virtual std::unique_ptr< velox::connector::ConnectorInsertTableHandle> @@ -118,85 +114,6 @@ class PrestoToVeloxConnector { const std::string connectorName_; }; -class HivePrestoToVeloxConnector final : public PrestoToVeloxConnector { - public: - explicit HivePrestoToVeloxConnector(std::string connectorName) - : PrestoToVeloxConnector(std::move(connectorName)) {} - - std::unique_ptr toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; - - std::unique_ptr toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const final; - - std::unique_ptr toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; - - std::unique_ptr - toVeloxInsertTableHandle( - const protocol::CreateHandle* createHandle, - const TypeParser& typeParser) const final; - - std::unique_ptr - toVeloxInsertTableHandle( - const protocol::InsertHandle* insertHandle, - const TypeParser& typeParser) const final; - - std::unique_ptr - createVeloxPartitionFunctionSpec( - const protocol::ConnectorPartitioningHandle* partitioningHandle, - const std::vector& bucketToPartition, - const std::vector& channels, - const std::vector& constValues, - bool& effectivelyGather) const final; - - std::unique_ptr createConnectorProtocol() - const final; - - private: - std::vector> - toHiveColumns( - const protocol::List& inputColumns, - const TypeParser& typeParser, - bool& hasPartitionColumn) const; -}; - -class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { - public: - explicit IcebergPrestoToVeloxConnector(std::string connectorName) - : PrestoToVeloxConnector(std::move(connectorName)) {} - - std::unique_ptr toVeloxSplit( - const protocol::ConnectorId& catalogId, - const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; - - std::unique_ptr toVeloxColumnHandle( - const protocol::ColumnHandle* column, - const TypeParser& typeParser) const final; - - std::unique_ptr toVeloxTableHandle( - const protocol::TableHandle& tableHandle, - const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; - - std::unique_ptr createConnectorProtocol() - const final; -}; - class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { public: explicit TpchPrestoToVeloxConnector(std::string connectorName) @@ -214,19 +131,15 @@ class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; + const TypeParser& typeParser) const final; std::unique_ptr createConnectorProtocol() const final; }; -class ClpPrestoToVeloxConnector final : public PrestoToVeloxConnector { +class TpcdsPrestoToVeloxConnector final : public PrestoToVeloxConnector { public: - explicit ClpPrestoToVeloxConnector(std::string connectorName) + explicit TpcdsPrestoToVeloxConnector(std::string connectorName) : PrestoToVeloxConnector(std::move(connectorName)) {} std::unique_ptr toVeloxSplit( @@ -241,11 +154,7 @@ class ClpPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; + const TypeParser& typeParser) const final; std::unique_ptr createConnectorProtocol() const final; diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.cpp b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.cpp new file mode 100644 index 0000000000000..f95b7c9f27a13 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.cpp @@ -0,0 +1,810 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h" + +#include +#include "presto_cpp/main/types/TypeParser.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/type/fbhive/HiveTypeParser.h" + +namespace facebook::presto { + +using namespace facebook::velox; + +namespace { + +int64_t toInt64( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto value = exprConverter.getConstantValue(type, *block); + return VariantConverter::convert(value) + .value(); +} + +int128_t toInt128( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto value = exprConverter.getConstantValue(type, *block); + return value.value(); +} + +Timestamp toTimestamp( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + const auto value = exprConverter.getConstantValue(type, *block); + return value.value(); +} + +int64_t dateToInt64( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto value = exprConverter.getConstantValue(type, *block); + return value.value(); +} + +template +T toFloatingPoint( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto variant = exprConverter.getConstantValue(type, *block); + return variant.value(); +} + +std::string toString( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto value = exprConverter.getConstantValue(type, *block); + if (type->isVarbinary()) { + return value.value(); + } + return value.value(); +} + +bool toBoolean( + const std::shared_ptr& block, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + auto variant = exprConverter.getConstantValue(type, *block); + return variant.value(); +} + +std::unique_ptr bigintRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowUnbounded = range.low.valueBlock == nullptr; + auto low = lowUnbounded ? std::numeric_limits::min() + : toInt64(range.low.valueBlock, exprConverter, type); + if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { + low++; + } + + bool highUnbounded = range.high.valueBlock == nullptr; + auto high = highUnbounded + ? std::numeric_limits::max() + : toInt64(range.high.valueBlock, exprConverter, type); + if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { + high--; + } + return std::make_unique(low, high, nullAllowed); +} + +std::unique_ptr hugeintRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowUnbounded = range.low.valueBlock == nullptr; + auto low = lowUnbounded ? std::numeric_limits::min() + : toInt128(range.low.valueBlock, exprConverter, type); + if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { + low++; + } + + bool highUnbounded = range.high.valueBlock == nullptr; + auto high = highUnbounded + ? std::numeric_limits::max() + : toInt128(range.high.valueBlock, exprConverter, type); + if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { + high--; + } + return std::make_unique(low, high, nullAllowed); +} + +std::unique_ptr timestampRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + const bool lowUnbounded = range.low.valueBlock == nullptr; + auto low = lowUnbounded + ? std::numeric_limits::min() + : toTimestamp(range.low.valueBlock, exprConverter, type); + if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { + ++low; + } + + const bool highUnbounded = range.high.valueBlock == nullptr; + auto high = highUnbounded + ? std::numeric_limits::max() + : toTimestamp(range.high.valueBlock, exprConverter, type); + if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { + --high; + } + return std::make_unique(low, high, nullAllowed); +} + +std::unique_ptr boolRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; + bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; + bool highExclusive = range.high.bound == protocol::Bound::BELOW; + bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; + + if (!lowUnbounded && !highUnbounded) { + bool lowValue = toBoolean(range.low.valueBlock, exprConverter, type); + bool highValue = toBoolean(range.high.valueBlock, exprConverter, type); + VELOX_CHECK_EQ( + lowValue, + highValue, + "Boolean range should not be [FALSE, TRUE] after coordinator " + "optimization."); + return std::make_unique(lowValue, nullAllowed); + } + // Presto coordinator has made optimizations to the bool range already. For + // example, [FALSE, TRUE) will be optimized and shown here as (-infinity, + // TRUE). Plus (-infinity, +infinity) case has been guarded in toFilter() + // method, here it can only be one side bounded scenarios. + VELOX_CHECK_NE( + lowUnbounded, + highUnbounded, + "Passed in boolean range can only have one side bounded range scenario"); + if (!lowUnbounded) { + VELOX_CHECK( + highUnbounded, + "Boolean range should not be double side bounded after coordinator " + "optimization."); + bool lowValue = toBoolean(range.low.valueBlock, exprConverter, type); + + // (TRUE, +infinity) case, should resolve to filter all + if (lowExclusive && lowValue) { + if (nullAllowed) { + return std::make_unique(); + } + return std::make_unique(); + } + + // Both cases (FALSE, +infinity) or [TRUE, +infinity) should evaluate to + // true. Case [FALSE, +infinity) should not be expected + VELOX_CHECK( + !(!lowExclusive && !lowValue), + "Case [FALSE, +infinity) should " + "not be expected"); + return std::make_unique(true, nullAllowed); + } + if (!highUnbounded) { + VELOX_CHECK( + lowUnbounded, + "Boolean range should not be double side bounded after coordinator " + "optimization."); + bool highValue = toBoolean(range.high.valueBlock, exprConverter, type); + + // (-infinity, FALSE) case, should resolve to filter all + if (highExclusive && !highValue) { + if (nullAllowed) { + return std::make_unique(); + } + return std::make_unique(); + } + + // Both cases (-infinity, TRUE) or (-infinity, FALSE] should evaluate to + // false. Case (-infinity, TRUE] should not be expected + VELOX_CHECK( + !(!highExclusive && highValue), + "Case (-infinity, TRUE] should " + "not be expected"); + return std::make_unique(false, nullAllowed); + } + VELOX_UNREACHABLE(); +} + +template +std::unique_ptr floatingPointRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; + bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; + auto low = lowUnbounded + ? (-1.0 * std::numeric_limits::infinity()) + : toFloatingPoint(range.low.valueBlock, exprConverter, type); + + bool highExclusive = range.high.bound == protocol::Bound::BELOW; + bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; + auto high = highUnbounded + ? std::numeric_limits::infinity() + : toFloatingPoint(range.high.valueBlock, exprConverter, type); + + // Handle NaN cases as NaN is not supported as a limit in Velox Filters + if (!lowUnbounded && std::isnan(low)) { + if (lowExclusive) { + // x > NaN is always false as NaN is considered the largest value. + return std::make_unique(); + } + // Equivalent to x > infinity as only NaN is greater than infinity + // Presto currently converts x >= NaN into the filter with domain + // [NaN, max), so ignoring the high value is fine. + low = std::numeric_limits::infinity(); + lowExclusive = true; + high = std::numeric_limits::infinity(); + highUnbounded = true; + highExclusive = false; + } else if (!highUnbounded && std::isnan(high)) { + high = std::numeric_limits::infinity(); + if (highExclusive) { + // equivalent to x in [low , infinity] or (low , infinity] + highExclusive = false; + } else { + if (lowUnbounded) { + // Anything <= NaN is true as NaN is the largest possible value. + return std::make_unique(); + } + // Equivalent to x > low or x >=low + highUnbounded = true; + } + } + + return std::make_unique>( + low, + lowUnbounded, + lowExclusive, + high, + highUnbounded, + highExclusive, + nullAllowed); +} + +std::unique_ptr varcharRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; + bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; + auto low = + lowUnbounded ? "" : toString(range.low.valueBlock, exprConverter, type); + + bool highExclusive = range.high.bound == protocol::Bound::BELOW; + bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; + auto high = + highUnbounded ? "" : toString(range.high.valueBlock, exprConverter, type); + return std::make_unique( + low, + lowUnbounded, + lowExclusive, + high, + highUnbounded, + highExclusive, + nullAllowed); +} + +std::unique_ptr dateRangeToFilter( + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter, + const TypePtr& type) { + bool lowUnbounded = range.low.valueBlock == nullptr; + auto low = lowUnbounded + ? std::numeric_limits::min() + : dateToInt64(range.low.valueBlock, exprConverter, type); + if (!lowUnbounded && range.low.bound == protocol::Bound::ABOVE) { + low++; + } + + bool highUnbounded = range.high.valueBlock == nullptr; + auto high = highUnbounded + ? std::numeric_limits::max() + : dateToInt64(range.high.valueBlock, exprConverter, type); + if (!highUnbounded && range.high.bound == protocol::Bound::BELOW) { + high--; + } + + return std::make_unique(low, high, nullAllowed); +} + +std::unique_ptr toFilter( + const TypePtr& type, + const protocol::Range& range, + bool nullAllowed, + const VeloxExprConverter& exprConverter) { + if (type->isDate()) { + return dateRangeToFilter(range, nullAllowed, exprConverter, type); + } + switch (type->kind()) { + case TypeKind::TINYINT: + case TypeKind::SMALLINT: + case TypeKind::INTEGER: + case TypeKind::BIGINT: + return bigintRangeToFilter(range, nullAllowed, exprConverter, type); + case TypeKind::HUGEINT: + return hugeintRangeToFilter(range, nullAllowed, exprConverter, type); + case TypeKind::DOUBLE: + return floatingPointRangeToFilter( + range, nullAllowed, exprConverter, type); + case TypeKind::VARCHAR: + case TypeKind::VARBINARY: + return varcharRangeToFilter(range, nullAllowed, exprConverter, type); + case TypeKind::BOOLEAN: + return boolRangeToFilter(range, nullAllowed, exprConverter, type); + case TypeKind::REAL: + return floatingPointRangeToFilter( + range, nullAllowed, exprConverter, type); + case TypeKind::TIMESTAMP: + return timestampRangeToFilter(range, nullAllowed, exprConverter, type); + default: + VELOX_UNSUPPORTED("Unsupported range type: {}", type->toString()); + } +} + +std::unique_ptr combineIntegerRanges( + std::vector>& bigintFilters, + bool nullAllowed) { + bool allSingleValue = std::all_of( + bigintFilters.begin(), bigintFilters.end(), [](const auto& range) { + return range->isSingleValue(); + }); + + if (allSingleValue) { + std::vector values; + values.reserve(bigintFilters.size()); + for (const auto& filter : bigintFilters) { + values.emplace_back(filter->lower()); + } + return common::createBigintValues(values, nullAllowed); + } + + if (bigintFilters.size() == 2 && + bigintFilters[0]->lower() == std::numeric_limits::min() && + bigintFilters[1]->upper() == std::numeric_limits::max()) { + assert(bigintFilters[0]->upper() + 1 <= bigintFilters[1]->lower() - 1); + return std::make_unique( + bigintFilters[0]->upper() + 1, + bigintFilters[1]->lower() - 1, + nullAllowed); + } + + bool allNegatedValues = true; + bool foundMaximum = false; + assert(bigintFilters.size() > 1); // true by size checks on ranges + std::vector rejectedValues; + + // check if int64 min is a rejected value + if (bigintFilters[0]->lower() == std::numeric_limits::min() + 1) { + rejectedValues.emplace_back(std::numeric_limits::min()); + } + if (bigintFilters[0]->lower() > std::numeric_limits::min() + 1) { + // too many value at the lower end, bail out + return std::make_unique( + std::move(bigintFilters), nullAllowed); + } + rejectedValues.push_back(bigintFilters[0]->upper() + 1); + for (int i = 1; i < bigintFilters.size(); ++i) { + if (bigintFilters[i]->lower() != bigintFilters[i - 1]->upper() + 2) { + allNegatedValues = false; + break; + } + if (bigintFilters[i]->upper() == std::numeric_limits::max()) { + foundMaximum = true; + break; + } + rejectedValues.push_back(bigintFilters[i]->upper() + 1); + // make sure there is another range possible above this one + if (bigintFilters[i]->upper() == std::numeric_limits::max() - 1) { + foundMaximum = true; + break; + } + } + + if (allNegatedValues && foundMaximum) { + return common::createNegatedBigintValues(rejectedValues, nullAllowed); + } + + return std::make_unique( + std::move(bigintFilters), nullAllowed); +} + +std::unique_ptr combineBytesRanges( + std::vector>& bytesFilters, + bool nullAllowed) { + bool allSingleValue = std::all_of( + bytesFilters.begin(), bytesFilters.end(), [](const auto& range) { + return range->isSingleValue(); + }); + + if (allSingleValue) { + std::vector values; + values.reserve(bytesFilters.size()); + for (const auto& filter : bytesFilters) { + values.emplace_back(filter->lower()); + } + return std::make_unique(values, nullAllowed); + } + + int lowerUnbounded = 0, upperUnbounded = 0; + bool allExclusive = std::all_of( + bytesFilters.begin(), bytesFilters.end(), [](const auto& range) { + return range->lowerExclusive() && range->upperExclusive(); + }); + if (allExclusive) { + folly::F14FastSet unmatched; + std::vector rejectedValues; + rejectedValues.reserve(bytesFilters.size()); + for (int i = 0; i < bytesFilters.size(); ++i) { + if (bytesFilters[i]->isLowerUnbounded()) { + ++lowerUnbounded; + } else { + if (unmatched.contains(bytesFilters[i]->lower())) { + unmatched.erase(bytesFilters[i]->lower()); + rejectedValues.emplace_back(bytesFilters[i]->lower()); + } else { + unmatched.insert(bytesFilters[i]->lower()); + } + } + if (bytesFilters[i]->isUpperUnbounded()) { + ++upperUnbounded; + } else { + if (unmatched.contains(bytesFilters[i]->upper())) { + unmatched.erase(bytesFilters[i]->upper()); + rejectedValues.emplace_back(bytesFilters[i]->upper()); + } else { + unmatched.insert(bytesFilters[i]->upper()); + } + } + } + + if (lowerUnbounded == 1 && upperUnbounded == 1 && unmatched.size() == 0) { + return std::make_unique( + rejectedValues, nullAllowed); + } + } + + if (bytesFilters.size() == 2 && bytesFilters[0]->isLowerUnbounded() && + bytesFilters[1]->isUpperUnbounded()) { + // create a negated bytes range instead + return std::make_unique( + bytesFilters[0]->upper(), + false, + !bytesFilters[0]->upperExclusive(), + bytesFilters[1]->lower(), + false, + !bytesFilters[1]->lowerExclusive(), + nullAllowed); + } + + std::vector> bytesGeneric; + for (int i = 0; i < bytesFilters.size(); ++i) { + bytesGeneric.emplace_back( + std::unique_ptr( + dynamic_cast(bytesFilters[i].release()))); + } + + return std::make_unique( + std::move(bytesGeneric), nullAllowed, false); +} + +} // namespace + +TypePtr stringToType( + const std::string& typeString, + const TypeParser& typeParser) { + return typeParser.parse(typeString); +} + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type); + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type); + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type); + +template +TypePtr fieldNamesToLowerCase(const TypePtr& type) { + return type; +} + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type) { + auto& rowType = type->asRow(); + std::vector names; + std::vector types; + names.reserve(type->size()); + types.reserve(type->size()); + for (int i = 0; i < rowType.size(); i++) { + std::string name = rowType.nameOf(i); + folly::toLowerAscii(name); + names.push_back(std::move(name)); + auto& childType = rowType.childAt(i); + types.push_back(VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, childType->kind(), childType)); + } + return std::make_shared(std::move(names), std::move(types)); +} + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type) { + auto& keyType = type->childAt(0); + auto& valueType = type->childAt(1); + return std::make_shared( + VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, keyType->kind(), keyType), + VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, valueType->kind(), valueType)); +} + +template <> +TypePtr fieldNamesToLowerCase(const TypePtr& type) { + auto& elementType = type->childAt(0); + return std::make_shared(VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, elementType->kind(), elementType)); +} + +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); +template TypePtr fieldNamesToLowerCase(const TypePtr&); + +std::unique_ptr toFilter( + const protocol::Domain& domain, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) { + auto nullAllowed = domain.nullAllowed; + if (auto sortedRangeSet = + std::dynamic_pointer_cast(domain.values)) { + auto type = stringToType(sortedRangeSet->type, typeParser); + auto ranges = sortedRangeSet->ranges; + + if (ranges.empty()) { + VELOX_CHECK(nullAllowed, "Unexpected always-false filter"); + return std::make_unique(); + } + + if (ranges.size() == 1) { + // 'is not null' arrives as unbounded range with 'nulls not allowed'. + // We catch this case and create 'is not null' filter instead of the range + // filter. + const auto& range = ranges[0]; + bool lowExclusive = range.low.bound == protocol::Bound::ABOVE; + bool lowUnbounded = range.low.valueBlock == nullptr && lowExclusive; + bool highExclusive = range.high.bound == protocol::Bound::BELOW; + bool highUnbounded = range.high.valueBlock == nullptr && highExclusive; + if (lowUnbounded && highUnbounded && !nullAllowed) { + return std::make_unique(); + } + + return toFilter(type, ranges[0], nullAllowed, exprConverter); + } + + if (type->isDate()) { + std::vector> dateFilters; + dateFilters.reserve(ranges.size()); + for (const auto& range : ranges) { + dateFilters.emplace_back( + dateRangeToFilter(range, nullAllowed, exprConverter, type)); + } + return std::make_unique( + std::move(dateFilters), nullAllowed); + } + + if (type->kind() == TypeKind::BIGINT || type->kind() == TypeKind::INTEGER || + type->kind() == TypeKind::SMALLINT || + type->kind() == TypeKind::TINYINT) { + std::vector> bigintFilters; + bigintFilters.reserve(ranges.size()); + for (const auto& range : ranges) { + bigintFilters.emplace_back( + bigintRangeToFilter(range, nullAllowed, exprConverter, type)); + } + return combineIntegerRanges(bigintFilters, nullAllowed); + } + + if (type->kind() == TypeKind::VARCHAR) { + std::vector> bytesFilters; + bytesFilters.reserve(ranges.size()); + for (const auto& range : ranges) { + bytesFilters.emplace_back( + varcharRangeToFilter(range, nullAllowed, exprConverter, type)); + } + return combineBytesRanges(bytesFilters, nullAllowed); + } + + if (type->kind() == TypeKind::BOOLEAN) { + VELOX_CHECK_EQ(ranges.size(), 2, "Multi bool ranges size can only be 2."); + std::unique_ptr boolFilter; + for (const auto& range : ranges) { + auto filter = + boolRangeToFilter(range, nullAllowed, exprConverter, type); + if (filter->kind() == common::FilterKind::kAlwaysFalse or + filter->kind() == common::FilterKind::kIsNull) { + continue; + } + VELOX_CHECK_NULL(boolFilter); + boolFilter = std::move(filter); + } + + VELOX_CHECK_NOT_NULL(boolFilter); + return boolFilter; + } + + std::vector> filters; + filters.reserve(ranges.size()); + for (const auto& range : ranges) { + filters.emplace_back(toFilter(type, range, nullAllowed, exprConverter)); + } + + return std::make_unique( + std::move(filters), nullAllowed, false); + } else if ( + auto equatableValueSet = + std::dynamic_pointer_cast( + domain.values)) { + if (equatableValueSet->entries.empty()) { + if (nullAllowed) { + return std::make_unique(); + } else { + return std::make_unique(); + } + } + VELOX_UNSUPPORTED( + "EquatableValueSet (with non-empty entries) to Velox filter conversion is not supported yet."); + } else if ( + auto allOrNoneValueSet = + std::dynamic_pointer_cast( + domain.values)) { + VELOX_UNSUPPORTED( + "AllOrNoneValueSet to Velox filter conversion is not supported yet."); + } + VELOX_UNSUPPORTED("Unsupported filter found."); +} + +std::vector toRequiredSubfields( + const protocol::List& subfields) { + std::vector result; + result.reserve(subfields.size()); + for (auto& subfield : subfields) { + result.emplace_back(subfield); + } + return result; +} + +velox::common::CompressionKind toFileCompressionKind( + const protocol::hive::HiveCompressionCodec& hiveCompressionCodec) { + switch (hiveCompressionCodec) { + case protocol::hive::HiveCompressionCodec::SNAPPY: + return velox::common::CompressionKind::CompressionKind_SNAPPY; + case protocol::hive::HiveCompressionCodec::GZIP: + return velox::common::CompressionKind::CompressionKind_GZIP; + case protocol::hive::HiveCompressionCodec::LZ4: + return velox::common::CompressionKind::CompressionKind_LZ4; + case protocol::hive::HiveCompressionCodec::ZSTD: + return velox::common::CompressionKind::CompressionKind_ZSTD; + case protocol::hive::HiveCompressionCodec::NONE: + return velox::common::CompressionKind::CompressionKind_NONE; + default: + VELOX_UNSUPPORTED( + "Unsupported file compression format: {}.", + toJsonString(hiveCompressionCodec)); + } +} + +connector::hive::HiveColumnHandle::ColumnType toHiveColumnType( + protocol::hive::ColumnType type) { + switch (type) { + case protocol::hive::ColumnType::PARTITION_KEY: + return connector::hive::HiveColumnHandle::ColumnType::kPartitionKey; + case protocol::hive::ColumnType::REGULAR: + return connector::hive::HiveColumnHandle::ColumnType::kRegular; + case protocol::hive::ColumnType::SYNTHESIZED: + return connector::hive::HiveColumnHandle::ColumnType::kSynthesized; + default: + VELOX_UNSUPPORTED( + "Unsupported Hive column type: {}.", toJsonString(type)); + } +} + +std::unique_ptr toHiveTableHandle( + const protocol::TupleDomain& domainPredicate, + const std::shared_ptr& remainingPredicate, + bool isPushdownFilterEnabled, + const std::string& tableName, + const protocol::List& dataColumns, + const protocol::TableHandle& tableHandle, + const std::vector& + columnHandles, + const protocol::Map& tableParameters, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser) { + common::SubfieldFilters subfieldFilters; + auto domains = domainPredicate.domains; + for (const auto& domain : *domains) { + auto filter = domain.second; + subfieldFilters[common::Subfield(domain.first)] = + toFilter(domain.second, exprConverter, typeParser); + } + + auto remainingFilter = exprConverter.toVeloxExpr(remainingPredicate); + if (auto constant = std::dynamic_pointer_cast( + remainingFilter)) { + bool value = constant->value().value(); + VELOX_CHECK(value, "Unexpected always-false remaining predicate"); + + remainingFilter = nullptr; + } + + RowTypePtr finalDataColumns; + if (!dataColumns.empty()) { + std::vector names; + std::vector types; + velox::type::fbhive::HiveTypeParser hiveTypeParser; + names.reserve(dataColumns.size()); + types.reserve(dataColumns.size()); + for (auto& column : dataColumns) { + std::string name = column.name; + folly::toLowerAscii(name); + names.emplace_back(std::move(name)); + auto parsedType = hiveTypeParser.parse(column.type); + types.push_back(VELOX_DYNAMIC_TYPE_DISPATCH( + fieldNamesToLowerCase, parsedType->kind(), parsedType)); + } + finalDataColumns = ROW(std::move(names), std::move(types)); + } + + std::unordered_map finalTableParameters = {}; + finalTableParameters.reserve(tableParameters.size()); + for (const auto& [key, value] : tableParameters) { + finalTableParameters[key] = value; + } + + return std::make_unique( + tableHandle.connectorId, + tableName, + isPushdownFilterEnabled, + std::move(subfieldFilters), + remainingFilter, + finalDataColumns, + finalTableParameters, + columnHandles); +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h new file mode 100644 index 0000000000000..aec1f169ffad1 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/PrestoToVeloxConnectorUtils.h @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" +#include "velox/connectors/Connector.h" +#include "velox/connectors/hive/TableHandle.h" +#include "velox/dwio/common/Options.h" +#include "velox/type/Filter.h" +#include "velox/type/Type.h" + +namespace facebook::presto { + +velox::TypePtr stringToType( + const std::string& typeString, + const TypeParser& typeParser); + +template +velox::TypePtr fieldNamesToLowerCase(const velox::TypePtr& type); + +std::unique_ptr toFilter( + const protocol::Domain& domain, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser); + +template +std::string toJsonString(const T& value) { + return ((json)value).dump(); +} + +std::vector toRequiredSubfields( + const protocol::List& subfields); + +velox::common::CompressionKind toFileCompressionKind( + const protocol::hive::HiveCompressionCodec& hiveCompressionCodec); + +velox::connector::hive::HiveColumnHandle::ColumnType toHiveColumnType( + protocol::hive::ColumnType type); + +std::unique_ptr toHiveTableHandle( + const protocol::TupleDomain& domainPredicate, + const std::shared_ptr& remainingPredicate, + bool isPushdownFilterEnabled, + const std::string& tableName, + const protocol::List& dataColumns, + const protocol::TableHandle& tableHandle, + const std::vector& + columnHandles, + const protocol::Map& tableParameters, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser); + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp index e5040912b6018..9b689175f7c01 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.cpp @@ -12,6 +12,9 @@ * limitations under the License. */ #include "presto_cpp/main/connectors/Registration.h" +#include "presto_cpp/main/connectors/ClpPrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/HivePrestoToVeloxConnector.h" +#include "presto_cpp/main/connectors/IcebergPrestoToVeloxConnector.h" #include "presto_cpp/main/connectors/SystemConnector.h" #ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR @@ -21,7 +24,13 @@ #include "velox/connectors/clp/ClpConnector.h" #include "velox/connectors/hive/HiveConnector.h" +#include "velox/connectors/hive/iceberg/IcebergConnector.h" +#include "velox/connectors/tpcds/TpcdsConnector.h" #include "velox/connectors/tpch/TpchConnector.h" +#ifdef PRESTO_ENABLE_CUDF +#include "velox/experimental/cudf/CudfConfig.h" +#include "velox/experimental/cudf/connectors/hive/CudfHiveConnector.h" +#endif namespace facebook::presto { namespace { @@ -29,61 +38,34 @@ namespace { constexpr char const* kHiveHadoop2ConnectorName = "hive-hadoop2"; constexpr char const* kIcebergConnectorName = "iceberg"; -void registerConnectorFactories() { - // These checks for connector factories can be removed after we remove the - // registrations from the Velox library. - if (!velox::connector::hasConnectorFactory( - velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - velox::connector::registerConnectorFactory( - std::make_shared( - kHiveHadoop2ConnectorName)); - } - if (!velox::connector::hasConnectorFactory( - velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } - - if (!velox::connector::hasConnectorFactory( - velox::connector::clp::ClpConnectorFactory::kClpConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } - - // Register Velox connector factory for iceberg. - // The iceberg catalog is handled by the hive connector factory. - if (!velox::connector::hasConnectorFactory(kIcebergConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared( - kIcebergConnectorName)); - } +} // namespace -#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR - if (!velox::connector::hasConnectorFactory( - ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); +std::vector listConnectorFactories() { + std::vector names; + const auto& factories = detail::connectorFactories(); + names.reserve(factories.size()); + for (const auto& [name, _] : factories) { + names.push_back(name); } -#endif + return names; } -} // namespace void registerConnectors() { registerConnectorFactories(); - registerPrestoToVeloxConnector(std::make_unique( - velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique( + velox::connector::hive::HiveConnectorFactory::kHiveConnectorName)); registerPrestoToVeloxConnector( std::make_unique(kHiveHadoop2ConnectorName)); registerPrestoToVeloxConnector( std::make_unique(kIcebergConnectorName)); - registerPrestoToVeloxConnector(std::make_unique( - velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)); registerPrestoToVeloxConnector( - std::make_unique( - velox::connector::clp::ClpConnectorFactory::kClpConnectorName)); + std::make_unique( + velox::connector::tpch::TpchConnectorFactory::kTpchConnectorName)); + registerPrestoToVeloxConnector( + std::make_unique( + velox::connector::tpcds::TpcdsConnectorFactory::kTpcdsConnectorName)); // Presto server uses system catalog or system schema in other catalogs // in different places in the code. All these resolve to the SystemConnector. @@ -97,9 +79,71 @@ void registerConnectors() { registerPrestoToVeloxConnector( std::make_unique("$system@system")); + registerPrestoToVeloxConnector(std::make_unique( + velox::connector::clp::ClpConnectorFactory::kClpConnectorName)); + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + registerPrestoToVeloxConnector( + std::make_unique( + ArrowFlightConnectorFactory::kArrowFlightConnectorName)); +#endif +} + +void registerConnectorFactories() { + // Register all connector factories using the facebook::presto namespace + // factory registry + + // Register Hive connector factory + facebook::presto::registerConnectorFactory( + std::make_shared< + facebook::velox::connector::hive::HiveConnectorFactory>()); + + // Register Hive Hadoop2 connector factory + facebook::presto::registerConnectorFactory( + std::make_shared( + kHiveHadoop2ConnectorName)); +#ifdef PRESTO_ENABLE_CUDF + facebook::presto::unregisterConnectorFactory( + facebook::velox::connector::hive::HiveConnectorFactory:: + kHiveConnectorName); + facebook::presto::unregisterConnectorFactory(kHiveHadoop2ConnectorName); + + // Register cuDF Hive connector factory + facebook::presto::registerConnectorFactory( + std::make_shared()); + + // Register cudf Hive connector factory + facebook::presto::registerConnectorFactory( + std::make_shared( + kHiveHadoop2ConnectorName)); +#endif + + // Register TPC-DS connector factory + facebook::presto::registerConnectorFactory( + std::make_shared< + facebook::velox::connector::tpcds::TpcdsConnectorFactory>()); + + // Register TPCH connector factory + facebook::presto::registerConnectorFactory( + std::make_shared< + facebook::velox::connector::tpch::TpchConnectorFactory>()); + + facebook::presto::registerConnectorFactory( + std::make_shared()); + + facebook::presto::registerConnectorFactory( + std::make_shared< + facebook::velox::connector::clp::ClpConnectorFactory>()); + #ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR - registerPrestoToVeloxConnector(std::make_unique( - ArrowFlightConnectorFactory::kArrowFlightConnectorName)); + // Note: ArrowFlightConnectorFactory would need to be implemented in Presto + // namespace For now, keep the Velox version + facebook::presto::registerConnectorFactory( + std::make_shared()); #endif } + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/Registration.h b/presto-native-execution/presto_cpp/main/connectors/Registration.h index c95aefaacfcaa..cbff3e83c2bbf 100644 --- a/presto-native-execution/presto_cpp/main/connectors/Registration.h +++ b/presto-native-execution/presto_cpp/main/connectors/Registration.h @@ -13,7 +13,79 @@ */ #pragma once +#include +#include +#include +#include + +#include "folly/Executor.h" +#include "velox/connectors/Connector.h" + +// Forward declaration for ConnectorFactory. +namespace facebook::velox::connector { +class ConnectorFactory; +} // namespace facebook::velox::connector + namespace facebook::presto { +using facebook::velox::connector::ConnectorFactory; + +namespace detail { +inline std::unordered_map>& +connectorFactories() { + static std::unordered_map> + factories; + return factories; +} +} // namespace detail + +/// Adds a factory for creating connectors to the registry using connector +/// name as the key. Throws if factory with the same name is already present. +/// Always returns true. The return value makes it easy to use with +/// FB_ANONYMOUS_VARIABLE. +inline bool registerConnectorFactory( + std::shared_ptr factory) { + const bool ok = detail::connectorFactories() + .insert({factory->connectorName(), factory}) + .second; + VELOX_CHECK( + ok, + "ConnectorFactory with name '{}' is already registered", + factory->connectorName()); + return true; +} + +/// Returns true if a connector with the specified name has been registered, +/// false otherwise. +inline bool hasConnectorFactory(const std::string& connectorName) { + return detail::connectorFactories().count(connectorName) == 1; +} + +/// Unregister a connector factory by name. +/// Returns true if a connector with the specified name has been +/// unregistered, false otherwise. +inline bool unregisterConnectorFactory(const std::string& connectorName) { + const auto count = detail::connectorFactories().erase(connectorName); + return count == 1; +} + +/// Returns a factory for creating connectors with the specified name. +/// Throws if factory doesn't exist. +inline std::shared_ptr getConnectorFactory( + const std::string& connectorName) { + auto it = detail::connectorFactories().find(connectorName); + VELOX_CHECK( + it != detail::connectorFactories().end(), + "ConnectorFactory with name '{}' not registered", + connectorName); + return it->second; +} + +/// Returns a list of all registered connector factory names. +std::vector listConnectorFactories(); + +/// Registers all connector factories using the facebook::presto namespace +/// factory registry. +void registerConnectorFactories(); void registerConnectors(); diff --git a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp index eb9fb48196e9d..9251c70f1a257 100644 --- a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.cpp @@ -28,7 +28,7 @@ static const std::string kTasksTable = "tasks"; } // namespace -const velox::RowTypePtr SystemTableHandle::taskSchema() { +const velox::RowTypePtr SystemTableHandle::taskSchema() const { static std::vector kTaskColumnNames = { "node_id", "task_id", @@ -74,6 +74,7 @@ SystemTableHandle::SystemTableHandle( std::string schemaName, std::string tableName) : ConnectorTableHandle(std::move(connectorId)), + name_(fmt::format("{}.{}", schemaName, tableName)), schemaName_(std::move(schemaName)), tableName_(std::move(tableName)) { VELOX_USER_CHECK_EQ( @@ -89,16 +90,14 @@ std::string SystemTableHandle::toString() const { } SystemDataSource::SystemDataSource( - const std::shared_ptr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const RowTypePtr& outputType, + const connector::ConnectorTableHandlePtr& tableHandle, + const connector::ColumnHandleMap& columnHandles, const TaskManager* taskManager, velox::memory::MemoryPool* FOLLY_NONNULL pool) : taskManager_(taskManager), pool_(pool) { auto systemTableHandle = - std::dynamic_pointer_cast(tableHandle); + std::dynamic_pointer_cast(tableHandle); VELOX_CHECK_NOT_NULL( systemTableHandle, "TableHandle must be an instance of SystemTableHandle"); @@ -112,7 +111,8 @@ SystemDataSource::SystemDataSource( "ColumnHandle is missing for output column '{}'", outputName); - auto handle = std::dynamic_pointer_cast(it->second); + auto handle = + std::dynamic_pointer_cast(it->second); VELOX_CHECK_NOT_NULL( handle, "ColumnHandle must be an instance of SystemColumnHandle " @@ -220,25 +220,25 @@ RowVectorPtr SystemDataSource::getTaskResults() { case TaskColumnEnum::kSplits: { auto flat = result->childAt(i)->as>(); - SET_TASK_COLUMN(taskInfo.stats.totalDrivers); + SET_TASK_COLUMN(taskInfo.stats.totalSplits); break; } case TaskColumnEnum::kQueuedSplits: { auto flat = result->childAt(i)->as>(); - SET_TASK_COLUMN(taskInfo.stats.queuedDrivers); + SET_TASK_COLUMN(taskInfo.stats.queuedSplits); break; } case TaskColumnEnum::kRunningSplits: { auto flat = result->childAt(i)->as>(); - SET_TASK_COLUMN(taskInfo.stats.runningDrivers); + SET_TASK_COLUMN(taskInfo.stats.runningSplits); break; } case TaskColumnEnum::kCompletedSplits: { auto flat = result->childAt(i)->as>(); - SET_TASK_COLUMN(taskInfo.stats.completedDrivers); + SET_TASK_COLUMN(taskInfo.stats.completedSplits); break; } @@ -376,10 +376,7 @@ std::unique_ptr SystemPrestoToVeloxConnector::toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { + const TypeParser& typeParser) const { auto systemLayout = std::dynamic_pointer_cast( tableHandle.connectorTableLayout); diff --git a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h index e7ffd7f2519b6..e36426bb481db 100644 --- a/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/SystemConnector.h @@ -42,17 +42,22 @@ class SystemTableHandle : public velox::connector::ConnectorTableHandle { std::string toString() const override; - const std::string& schemaName() { + const std::string& name() const override { + return name_; + } + + const std::string& schemaName() const { return schemaName_; } - const std::string& tableName() { + const std::string& tableName() const { return tableName_; } - const velox::RowTypePtr taskSchema(); + const velox::RowTypePtr taskSchema() const; private: + const std::string name_; const std::string schemaName_; const std::string tableName_; }; @@ -61,11 +66,8 @@ class SystemDataSource : public velox::connector::DataSource { public: SystemDataSource( const velox::RowTypePtr& outputType, - const std::shared_ptr& - tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const velox::connector::ConnectorTableHandlePtr& tableHandle, + const velox::connector::ColumnHandleMap& columnHandles, const TaskManager* taskManager, velox::memory::MemoryPool* pool); @@ -90,7 +92,7 @@ class SystemDataSource : public velox::connector::DataSource { return completedBytes_; } - std::unordered_map runtimeStats() + std::unordered_map getRuntimeStats() override { return {}; } @@ -147,11 +149,8 @@ class SystemConnector : public velox::connector::Connector { std::unique_ptr createDataSource( const velox::RowTypePtr& outputType, - const std::shared_ptr& - tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const velox::connector::ConnectorTableHandlePtr& tableHandle, + const velox::connector::ColumnHandleMap& columnHandles, velox::connector::ConnectorQueryCtx* connectorQueryCtx) override final { VELOX_CHECK(taskManager_); return std::make_unique( @@ -164,9 +163,8 @@ class SystemConnector : public velox::connector::Connector { std::unique_ptr createDataSink( velox::RowTypePtr /*inputType*/, - std::shared_ptr< - velox::connector:: - ConnectorInsertTableHandle> /*connectorInsertTableHandle*/, + velox::connector:: + ConnectorInsertTableHandlePtr /*connectorInsertTableHandle*/, velox::connector::ConnectorQueryCtx* /*connectorQueryCtx*/, velox::connector::CommitStrategy /*commitStrategy*/) override final { VELOX_NYI("SystemConnector does not support data sink."); @@ -193,11 +191,7 @@ class SystemPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; + const TypeParser& typeParser) const final; std::unique_ptr createConnectorProtocol() const final; diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp index c0900c6c41de4..18d637df02db8 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.cpp @@ -43,4 +43,14 @@ std::optional ArrowFlightConfig::serverSslCertificate() const { config_->get(kServerSslCertificate)); } +std::optional ArrowFlightConfig::clientSslCertificate() const { + return static_cast>( + config_->get(kClientSslCertificate)); +} + +std::optional ArrowFlightConfig::clientSslKey() const { + return static_cast>( + config_->get(kClientSslKey)); +} + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h index 77ad8e9379cf3..3ccfe06038d5c 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConfig.h @@ -38,6 +38,11 @@ class ArrowFlightConfig { static constexpr const char* kServerSslCertificate = "arrow-flight.server-ssl-certificate"; + static constexpr const char* kClientSslCertificate = + "arrow-flight.client-ssl-certificate"; + + static constexpr const char* kClientSslKey = "arrow-flight.client-ssl-key"; + std::string authenticatorName() const; std::optional defaultServerHostname() const; @@ -50,6 +55,10 @@ class ArrowFlightConfig { std::optional serverSslCertificate() const; + std::optional clientSslCertificate() const; + + std::optional clientSslKey() const; + private: const std::shared_ptr config_; }; diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp index 6aacbd339228f..69083e8035fb9 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp @@ -71,13 +71,36 @@ ArrowFlightConnector::initClientOpts( clientOpts->tls_root_certs = cert; } + auto clientCertPath = config->clientSslCertificate(); + if (clientCertPath.has_value()) { + std::ifstream certFile(clientCertPath.value()); + VELOX_CHECK( + certFile.is_open(), + "Could not open client certificate at {}", + clientCertPath.value()); + clientOpts->cert_chain.assign( + (std::istreambuf_iterator(certFile)), + (std::istreambuf_iterator())); + } + + auto clientKeyPath = config->clientSslKey(); + if (clientKeyPath.has_value()) { + std::ifstream keyFile(clientKeyPath.value()); + VELOX_CHECK( + keyFile.is_open(), + "Could not open client key at {}", + clientKeyPath.value()); + clientOpts->private_key.assign( + (std::istreambuf_iterator(keyFile)), + (std::istreambuf_iterator())); + } + return clientOpts; } ArrowFlightDataSource::ArrowFlightDataSource( const velox::RowTypePtr& outputType, - const std::unordered_map>& - columnHandles, + const velox::connector::ColumnHandleMap& columnHandles, std::shared_ptr authenticator, const ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& flightConfig, @@ -103,7 +126,7 @@ ArrowFlightDataSource::ArrowFlightDataSource( columnName); auto handle = - std::dynamic_pointer_cast(it->second); + std::dynamic_pointer_cast(it->second); VELOX_CHECK_NOT_NULL( handle, "handle for column '{}' is not an ArrowFlightColumnHandle", @@ -164,7 +187,7 @@ std::optional ArrowFlightDataSource::next( auto output = projectOutputColumns(chunk.data); completedRows_ += output->size(); - completedBytes_ += output->inMemoryBytes(); + completedBytes_ += output->estimateFlatSize(); return output; } @@ -195,10 +218,8 @@ velox::RowVectorPtr ArrowFlightDataSource::projectOutputColumns( std::unique_ptr ArrowFlightConnector::createDataSource( const velox::RowTypePtr& outputType, - const std::shared_ptr& tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const velox::connector::ConnectorTableHandlePtr& tableHandle, + const velox::connector::ColumnHandleMap& columnHandles, velox::connector::ConnectorQueryCtx* connectorQueryCtx) { return std::make_unique( outputType, diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h index 92d73a42420b4..3e893164a049a 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h @@ -31,7 +31,14 @@ namespace facebook::presto { class ArrowFlightTableHandle : public velox::connector::ConnectorTableHandle { public: explicit ArrowFlightTableHandle(const std::string& connectorId) - : ConnectorTableHandle(connectorId) {} + : ConnectorTableHandle(connectorId), name_("arrow_flight") {} + + const std::string& name() const override { + return name_; + } + + private: + const std::string name_; }; struct ArrowFlightSplit : public velox::connector::ConnectorSplit { @@ -51,7 +58,7 @@ class ArrowFlightColumnHandle : public velox::connector::ColumnHandle { explicit ArrowFlightColumnHandle(const std::string& columnName) : columnName_(columnName) {} - const std::string& name() { + const std::string& name() const { return columnName_; } @@ -63,9 +70,7 @@ class ArrowFlightDataSource : public velox::connector::DataSource { public: ArrowFlightDataSource( const velox::RowTypePtr& outputType, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const velox::connector::ColumnHandleMap& columnHandles, std::shared_ptr authenticator, const velox::connector::ConnectorQueryCtx* connectorQueryCtx, const std::shared_ptr& flightConfig, @@ -92,7 +97,7 @@ class ArrowFlightDataSource : public velox::connector::DataSource { return completedRows_; } - std::unordered_map runtimeStats() + std::unordered_map getRuntimeStats() override { return {}; } @@ -132,16 +137,13 @@ class ArrowFlightConnector : public velox::connector::Connector { std::unique_ptr createDataSource( const velox::RowTypePtr& outputType, - const std::shared_ptr& - tableHandle, - const std::unordered_map< - std::string, - std::shared_ptr>& columnHandles, + const velox::connector::ConnectorTableHandlePtr& tableHandle, + const velox::connector::ColumnHandleMap& columnHandles, velox::connector::ConnectorQueryCtx* connectorQueryCtx) override; std::unique_ptr createDataSink( velox::RowTypePtr inputType, - std::shared_ptr + velox::connector::ConnectorInsertTableHandlePtr connectorInsertTableHandle, velox::connector::ConnectorQueryCtx* connectorQueryCtx, velox::connector::CommitStrategy commitStrategy) override { diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp index 1ac5ab838f1db..bc324d3fc9119 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp @@ -47,10 +47,7 @@ std::unique_ptr ArrowPrestoToVeloxConnector::toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& /*exprConverter*/, - const TypeParser& /*typeParser*/, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) const { + const TypeParser& /*typeParser*/) const { return std::make_unique( tableHandle.connectorId); } diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h index fa7ab67b9c0b7..ba2486e094c2f 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h @@ -34,11 +34,7 @@ class ArrowPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxTableHandle( const protocol::TableHandle& tableHandle, const VeloxExprConverter& exprConverter, - const TypeParser& typeParser, - std::unordered_map< - std::string, - std::shared_ptr>& assignments) - const final; + const TypeParser& typeParser) const final; std::unique_ptr createConnectorProtocol() const final; diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt index ed3e60f6be707..b4dc9a1829df1 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt @@ -18,16 +18,22 @@ add_library(presto_flight_connector_utils INTERFACE Macros.h) target_link_libraries(presto_flight_connector_utils INTERFACE velox_exception) add_library( - presto_flight_connector OBJECT - ArrowFlightConnector.cpp ArrowPrestoToVeloxConnector.cpp - ArrowFlightConfig.cpp) + presto_flight_connector + OBJECT + ArrowFlightConnector.cpp + ArrowPrestoToVeloxConnector.cpp + ArrowFlightConfig.cpp +) -target_compile_definitions(presto_flight_connector - PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) +target_compile_definitions(presto_flight_connector PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) target_link_libraries( - presto_flight_connector velox_connector ArrowFlight::arrow_flight_shared - presto_flight_connector_utils presto_flight_connector_auth) + presto_flight_connector + velox_connector + ArrowFlight::arrow_flight_shared + presto_flight_connector_utils + presto_flight_connector_auth +) if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt index 1e7eba3154a0e..5e9564361ae47 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt @@ -11,5 +11,4 @@ # limitations under the License. add_library(presto_flight_connector_auth Authenticator.cpp) -target_link_libraries(presto_flight_connector_auth - presto_flight_connector_utils velox_exception) +target_link_libraries(presto_flight_connector_auth presto_flight_connector_utils velox_exception) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp index eb946f1fcae76..8ec912b417650 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConfigTest.cpp @@ -27,6 +27,8 @@ TEST(ArrowFlightConfigTest, defaultConfig) { ASSERT_EQ(config.defaultServerSslEnabled(), false); ASSERT_EQ(config.serverVerify(), true); ASSERT_EQ(config.serverSslCertificate(), std::nullopt); + ASSERT_EQ(config.clientSslCertificate(), std::nullopt); + ASSERT_EQ(config.clientSslKey(), std::nullopt); } TEST(ArrowFlightConfigTest, overrideConfig) { @@ -36,7 +38,9 @@ TEST(ArrowFlightConfigTest, overrideConfig) { {ArrowFlightConfig::kDefaultServerPort, "9000"}, {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, {ArrowFlightConfig::kServerVerify, "false"}, - {ArrowFlightConfig::kServerSslCertificate, "my-cert.crt"}}; + {ArrowFlightConfig::kServerSslCertificate, "my-cert.crt"}, + {ArrowFlightConfig::kClientSslCertificate, "/path/to/client.crt"}, + {ArrowFlightConfig::kClientSslKey, "/path/to/client.key"}}; auto config = ArrowFlightConfig( std::make_shared(std::move(configMap))); ASSERT_EQ(config.authenticatorName(), "my-authenticator"); @@ -45,4 +49,6 @@ TEST(ArrowFlightConfigTest, overrideConfig) { ASSERT_EQ(config.defaultServerSslEnabled(), true); ASSERT_EQ(config.serverVerify(), false); ASSERT_EQ(config.serverSslCertificate(), "my-cert.crt"); + ASSERT_EQ(config.clientSslCertificate(), "/path/to/client.crt"); + ASSERT_EQ(config.clientSslKey(), "/path/to/client.key"); } diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp index 257497caf224d..f1a489600b825 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp @@ -77,14 +77,16 @@ TEST_F(ArrowFlightConnectorDataTypeTest, integerTypes) { auto bigintVec = makeFlatVector(bigData); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, - {velox::TINYINT(), - velox::SMALLINT(), - velox::INTEGER(), - velox::BIGINT()})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {velox::TINYINT(), + velox::SMALLINT(), + velox::INTEGER(), + velox::BIGINT()})) + .planNode(); AssertQueryBuilder(plan) .splits(makeSplits({"sample-data"})) @@ -115,10 +117,12 @@ TEST_F(ArrowFlightConnectorDataTypeTest, realType) { auto doubleVec = makeFlatVector(doubleData); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW( + {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) + .planNode(); AssertQueryBuilder(plan) .splits(makeSplits({"sample-data"})) @@ -243,9 +247,12 @@ TEST_F(ArrowFlightConnectorDataTypeTest, timestampType) { core::PlanNodePtr plan; plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, - {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) + .flightTableScan( + velox::ROW( + {"timestampsec_col", + "timestampmilli_col", + "timestampmicro_col"}, + {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) .planNode(); AssertQueryBuilder(plan) @@ -301,11 +308,13 @@ TEST_F(ArrowFlightConnectorDataTypeTest, decimalType) { auto decimalVecBigInt = makeFlatVector(decimalValuesBigInt); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"decimal_col_bigint"}, - {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW( + {"decimal_col_bigint"}, + {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale + .planNode(); // Execute the query and assert the results AssertQueryBuilder(plan) @@ -338,10 +347,11 @@ TEST_F(ArrowFlightConnectorDataTypeTest, arrayType) { updateTable("sample-data", makeArrowTable({"int_array_col"}, {listArray})); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"int_array_col"}, {velox::ARRAY(velox::INTEGER())})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"int_array_col"}, {velox::ARRAY(velox::INTEGER())})) + .planNode(); auto expectedData = makeNullableArrayVector(data); AssertQueryBuilder(plan) @@ -396,10 +406,12 @@ TEST_F(ArrowFlightConnectorDataTypeTest, mapType) { updateTable("sample-data", makeArrowTable({"map_col"}, {mapArray})); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"map_col"}, {velox::MAP(velox::INTEGER(), velox::BIGINT())})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW( + {"map_col"}, {velox::MAP(velox::INTEGER(), velox::BIGINT())})) + .planNode(); auto expectedData = makeNullableMapVector(data); AssertQueryBuilder(plan) @@ -422,13 +434,15 @@ TEST_F(ArrowFlightConnectorDataTypeTest, rowType) { updateTable("sample-data", makeArrowTable({"row_col"}, {structArray})); core::PlanNodePtr plan; - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"row_col"}, - {velox::ROW( - {"int_col", "varchar_col", "double_col"}, - {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW( + {"row_col"}, + {velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})})) + .planNode(); auto expectedData = makeRowVector( {makeFlatVector(intData), @@ -482,19 +496,20 @@ TEST_F(ArrowFlightConnectorDataTypeTest, allTypes) { core::PlanNodePtr plan; plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"daydate_col", - "timestamp_col", - "varchar_col", - "real_col", - "int_col", - "bool_col"}, - {velox::DATE(), - velox::TIMESTAMP(), - velox::VARCHAR(), - velox::DOUBLE(), - velox::INTEGER(), - velox::BOOLEAN()})) + .flightTableScan( + velox::ROW( + {"daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {velox::DATE(), + velox::TIMESTAMP(), + velox::VARCHAR(), + velox::DOUBLE(), + velox::INTEGER(), + velox::BOOLEAN()})) .planNode(); AssertQueryBuilder(plan) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorMTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorMTlsTest.cpp new file mode 100644 index 0000000000000..9f17a2f1552cd --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorMTlsTest.cpp @@ -0,0 +1,168 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" + +using namespace arrow; +using namespace facebook::velox; +using namespace facebook::velox::exec::test; + +namespace facebook::presto::test { + +class ArrowFlightConnectorMtlsTestBase : public ArrowFlightConnectorTestBase { + protected: + explicit ArrowFlightConnectorMtlsTestBase( + std::shared_ptr config) + : ArrowFlightConnectorTestBase(std::move(config)) {} + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + flight::CertKeyPair tlsCertificate{ + .pem_cert = readFile("./data/certs/server.crt"), + .pem_key = readFile("./data/certs/server.key")}; + serverOptions->tls_certificates.push_back(tlsCertificate); + serverOptions->verify_client = true; + serverOptions->root_certificates = readFile("./data/certs/ca.crt"); + } + + void executeSuccessfulQuery() { + std::vector idData = { + 1, 12, 2, std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable({"id"}, {makeNumericArray(idData)})); + + auto idVec = makeFlatVector(idData); + + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(ROW({"id"}, {BIGINT()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec})); + } + + std::function createQueryFunction() { + std::vector idData = { + 1, 12, 2, std::numeric_limits::max()}; + + updateTable( + "sample-data", + makeArrowTable({"id"}, {makeNumericArray(idData)})); + + auto idVec = makeFlatVector(idData); + + auto plan = ArrowFlightPlanBuilder() + .flightTableScan(ROW({"id"}, {BIGINT()})) + .planNode(); + + return [this, plan, idVec]() { + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec})); + }; + } +}; + +class ArrowFlightConnectorMtlsTest : public ArrowFlightConnectorMtlsTestBase { + protected: + explicit ArrowFlightConnectorMtlsTest() + : ArrowFlightConnectorMtlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/certs/ca.crt"}, + {ArrowFlightConfig::kClientSslCertificate, + "./data/certs/client.crt"}, + {ArrowFlightConfig::kClientSslKey, + "./data/certs/client.key"}})) {} +}; + +TEST_F(ArrowFlightConnectorMtlsTest, successfulMtlsConnection) { + executeSuccessfulQuery(); +} + +class ArrowFlightMtlsNoClientCertTest + : public ArrowFlightConnectorMtlsTestBase { + protected: + ArrowFlightMtlsNoClientCertTest() + : ArrowFlightConnectorMtlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/certs/ca.crt"}})) {} +}; + +TEST_F(ArrowFlightMtlsNoClientCertTest, mtlsFailsWithoutClientCert) { + auto queryFunction = createQueryFunction(); + VELOX_ASSERT_THROW(queryFunction(), "failed to connect"); +} + +class ArrowFlightConnectorImplicitSslTest + : public ArrowFlightConnectorMtlsTestBase { + protected: + ArrowFlightConnectorImplicitSslTest() + : ArrowFlightConnectorMtlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/certs/ca.crt"}, + {ArrowFlightConfig::kClientSslCertificate, + "./data/certs/client.crt"}, + {ArrowFlightConfig::kClientSslKey, + "./data/certs/client.key"}})) {} +}; + +TEST_F(ArrowFlightConnectorImplicitSslTest, successfulImplicitSslConnection) { + executeSuccessfulQuery(); +} + +class ArrowFlightImplicitSslNoClientCertTest + : public ArrowFlightConnectorMtlsTestBase { + protected: + ArrowFlightImplicitSslNoClientCertTest() + : ArrowFlightConnectorMtlsTestBase( + std::make_shared( + std::unordered_map{ + {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, + {ArrowFlightConfig::kServerVerify, "true"}, + {ArrowFlightConfig::kServerSslCertificate, + "./data/certs/ca.crt"}})) {} +}; + +TEST_F( + ArrowFlightImplicitSslNoClientCertTest, + mtlsFailsWithoutClientCertOnImplicitSsl) { + auto queryFunction = createQueryFunction(); + VELOX_ASSERT_THROW(queryFunction(), "failed to connect"); +} + +} // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp index acb41c5087ea1..a8f9cb36a10a8 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp @@ -89,10 +89,11 @@ TEST_F(ArrowFlightConnectorTest, dataSource) { core::PlanNodePtr plan; // direct test - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); AssertQueryBuilder(plan) .splits(makeSplits({"sample-data"})) @@ -118,8 +119,9 @@ TEST_F(ArrowFlightConnectorTest, dataSource) { // invalid columnHandle test plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .flightTableScan( + velox::ROW( + {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) .planNode(); VELOX_ASSERT_THROW( @@ -156,9 +158,10 @@ TEST_F(ArrowFlightConnectorTest, multipleBatches) { core::PlanNodePtr plan; plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"int_col", "varchar_col", "double_col"}, - {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) + .flightTableScan( + velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) .planNode(); auto intVec = makeFlatVector(intData); @@ -204,9 +207,10 @@ TEST_F(ArrowFlightConnectorTest, multipleSplits) { core::PlanNodePtr plan; plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"int_col", "varchar_col", "double_col"}, - {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) + .flightTableScan( + velox::ROW( + {"int_col", "varchar_col", "double_col"}, + {velox::INTEGER(), velox::VARCHAR(), velox::DOUBLE()})) .planNode(); std::vector intDataAll(intData); @@ -259,10 +263,11 @@ TEST_F(ArrowFlightConnectorTestDefaultServer, dataSource) { core::PlanNodePtr plan; // direct test - plan = ArrowFlightPlanBuilder() - .flightTableScan(velox::ROW( - {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) - .planNode(); + plan = + ArrowFlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); AssertQueryBuilder(plan) .splits(makeSplits({"sample-data"})) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp index 4453183a39412..0be7df437e5e4 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp @@ -39,8 +39,8 @@ class ArrowFlightConnectorTlsTestBase : public ArrowFlightConnectorTestBase { void setFlightServerOptions( flight::FlightServerOptions* serverOptions) override { flight::CertKeyPair tlsCertificate{ - .pem_cert = readFile("./data/tls_certs/server.crt"), - .pem_key = readFile("./data/tls_certs/server.key")}; + .pem_cert = readFile("./data/certs/server.crt"), + .pem_key = readFile("./data/certs/server.key")}; serverOptions->tls_certificates.push_back(tlsCertificate); } @@ -83,7 +83,7 @@ class ArrowFlightConnectorTlsTest : public ArrowFlightConnectorTlsTestBase { {ArrowFlightConfig::kDefaultServerSslEnabled, "true"}, {ArrowFlightConfig::kServerVerify, "true"}, {ArrowFlightConfig::kServerSslCertificate, - "./data/tls_certs/ca.crt"}})) {} + "./data/certs/ca.crt"}})) {} }; TEST_F(ArrowFlightConnectorTlsTest, tlsEnabled) { @@ -116,7 +116,7 @@ class ArrowFlightTlsNoCertTest : public ArrowFlightConnectorTlsTestBase { }; TEST_F(ArrowFlightTlsNoCertTest, tlsNoCert) { - executeTest(false, "handshake failed"); + executeTest(false, "failed to connect"); } } // namespace facebook::presto::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt index 9af596a913973..264aa6438a455 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt @@ -11,28 +11,39 @@ # limitations under the License. add_subdirectory(utils) -add_executable(presto_flight_connector_infra_test - TestingArrowFlightServerTest.cpp) +add_executable(presto_flight_connector_infra_test TestingArrowFlightServerTest.cpp) add_test(presto_flight_connector_infra_test presto_flight_connector_infra_test) target_link_libraries( - presto_flight_connector_infra_test presto_protocol - presto_flight_connector_test_lib GTest::gtest GTest::gtest_main ${GLOG}) + presto_flight_connector_infra_test + presto_protocol + presto_flight_connector_test_lib + GTest::gtest + GTest::gtest_main + ${GLOG} +) add_executable( presto_flight_connector_test - ArrowFlightConnectorTest.cpp ArrowFlightConnectorAuthTest.cpp - ArrowFlightConnectorTlsTest.cpp ArrowFlightConnectorDataTypeTest.cpp - ArrowFlightConfigTest.cpp) - -set(DATA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/data/tls_certs") + ArrowFlightConnectorTest.cpp + ArrowFlightConnectorAuthTest.cpp + ArrowFlightConnectorMTlsTest.cpp + ArrowFlightConnectorTlsTest.cpp + ArrowFlightConnectorDataTypeTest.cpp + ArrowFlightConfigTest.cpp +) add_custom_target( - copy_flight_test_data ALL - COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATA_DIR} - $/data/tls_certs) - + copy_flight_test_data + ALL + COMMAND + ${CMAKE_COMMAND} -E copy_directory_if_different "${CMAKE_CURRENT_SOURCE_DIR}/data" + "${CMAKE_CURRENT_BINARY_DIR}/data" + COMMENT "Copying test data files..." +) + +add_dependencies(presto_flight_connector_test copy_flight_test_data) add_test(presto_flight_connector_test presto_flight_connector_test) target_link_libraries( @@ -42,4 +53,5 @@ target_link_libraries( gtest gtest_main presto_flight_connector_test_lib - presto_protocol) + presto_protocol +) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md index 3a5f2e5786c67..17e778a84750c 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md @@ -1,7 +1,7 @@ -### Placeholder TLS Certificates for Arrow Flight Connector Unit Testing -The `tls_certs` directory contains placeholder TLS certificates generated for unit testing the Arrow Flight Connector with TLS enabled. These certificates are not intended for production use and should only be used in the context of unit tests. +### Placeholder TLS & mTLS Certificates for Arrow Flight Connector Unit Testing +The `certs/` directory contains placeholder certificates used for unit and integration testing of the Arrow Flight Connector with **TLS** and **mutual TLS (mTLS)** enabled. These certificates are **not intended for production use** and should only be used for local development or automated testing scenarios. -### Generating TLS Certificates -To create the TLS certificates and keys inside the `tls_certs` folder, run the following command: +### Generating Certificates for TLS and mTLS Testing +To generate the necessary certificates and keys for both TLS and mTLS testing, run the following script: -`./generate_tls_certs.sh` +`./generate_certs.sh` diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/ca.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/ca.crt new file mode 100644 index 0000000000000..bc6036c52aa12 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoOgAwIBAgIUVB6MDVAXGLccJY5XrssVNLKbH/gwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI1MDczMDE2NTU1M1oYDzMwMjQxMTMwMTY1NTUzWjBcMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxDTALBgNVBAMMBE15Q0EwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCQ5D8+R5/HesO9BrgWF92wGOesBcOPQL7Nemzh1qYT +fhxNV7HGN1W5fByWsVelZ9326CTa7Yr4TlRXC8GHt9YbRvU68LBU1kAGUqGmPdHY +LgPZJkQqFhYdPdD++Gw2+ecfYp+Ls9pZm/pNOeCWBV0F3RaoPZ19m/C/lDdZz8OX +V6t2to+Yh0GXJyyJfO+w7qG6B/j8UiRYtnnMq3ywTcWqsUmYp4+uwmkEEN9/eSo9 +JanjpEiv6o9Yb9J5StXRPmbAoXOl45o87A6qo0vzgYdP6uKkPQxo6wSdltb5qctM +CKHm4bYFT6IHoVzUGVHiq2iRna87OiDPqnFaX3ktWJB9AgMBAAGjUzBRMB0GA1Ud +DgQWBBRGccIvJxNjWYY8hGZ3HIlut0p17jAfBgNVHSMEGDAWgBRGccIvJxNjWYY8 +hGZ3HIlut0p17jAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAI +sUZOfL8ZJEu0Oq7AHgOmxMYhdRQmFr45C54fiNb6QLs5bPJZ9T1/nbPQZvF1hkFh +wB3pDn8rNntiYXkCkrQ6PAiQfn0WHl4jLYCoGYxFbSP1QViZNid7dPmpaxccjMhL +Zk7htfCS1HtHWWBZPMDDA8hsUvBf4qusVonO71XGL22Z2ZKtgvDJYAyoxm7xwIo2 +mqSH9TfOnHYE0hUpo3u4PdmVAfCXzSDRccLALnVlzt50ColmAQgzj3MnwWfXJmdv +kjBhIZ9Obt7Hf6FcBhX9/qQN1t3u6mLIjf2akokRFblcW5Xgv4X1c7pEtdtM9V+o +MVqiDfm49mpMcOdHDeb1 +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.crt new file mode 100644 index 0000000000000..7313fbf6babc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmTCCAoGgAwIBAgIUHZmKZs2+ejbJGEezgoSUTPKF42YwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI1MDczMDE2NTU1M1oYDzMwMjQxMTMwMTY1NTUzWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME2NsaWVudC5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDESuInkBMIwJ02rNuO1DPN +3BfuNzF2SJq5/UgH9soWxVIYqLOwinBmWRsrRVOps9mJjN/GYbj27GT1LFXZHIwh +MYf5JLE/Y8CL/JwnMyxe3wkiBb1s2d+urOrIPMLQTVLGtClyPVT1PnmOd/h7vJ+k +I5ciK9G+8krAq5NoWM/OvhdTK2pOns1jl6Lq2c4IGDgQHt65uKRqPTPtc1xe/Vxd +4sTT8LJQEHu6zdl5nGFdGq4eIJpr+9LuM8ZCYkTAYpzHJ2XKP7jEj6ZRmsIgwwTf +VMON+yAF06Tmk9VCRkLUGjNMA4Xh24MejieE0JSq7mIAhSYWE+nalG8HsxtsQWDt +AgMBAAGjQjBAMB0GA1UdDgQWBBRPc2AWGMwuJEbnnMXvzgNAizik1jAfBgNVHSME +GDAWgBRGccIvJxNjWYY8hGZ3HIlut0p17jANBgkqhkiG9w0BAQsFAAOCAQEACJde +pXZvE04uw6tv+iGplmYNfasMQ9JXbvi0JMlnt9Y7ajf0F3g6yw9xWQfp8mCWVuof +aS/X8qw3loundeprxVq/2V6pFXStLFXJCXX+YL0Wl8AMv8VOxdZ8+hYlkfMoiKnx +ZKaWgrVtI/idFRUJLg8aHLRk5qVOwACBg9DAxMBC4V4MCQBfvDDdY6Y5qAM8o7PN +YGsPv5JIQI/3jsG2ZNQ/A8Ar+BNKWqnwRg2jjXysjPJPaU8TExvFXUypQaUn+a8X +/Y+CXhGabfbodEbEvBny6tiFQMb3YWhd0kHkjrYylY8GOpQ/ziEC8s7JYcaLSiu/ +ko1ErmRY/vhQIvrylw== +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.key new file mode 100644 index 0000000000000..2c285f776ba2d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/client.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDESuInkBMIwJ02 +rNuO1DPN3BfuNzF2SJq5/UgH9soWxVIYqLOwinBmWRsrRVOps9mJjN/GYbj27GT1 +LFXZHIwhMYf5JLE/Y8CL/JwnMyxe3wkiBb1s2d+urOrIPMLQTVLGtClyPVT1PnmO +d/h7vJ+kI5ciK9G+8krAq5NoWM/OvhdTK2pOns1jl6Lq2c4IGDgQHt65uKRqPTPt +c1xe/Vxd4sTT8LJQEHu6zdl5nGFdGq4eIJpr+9LuM8ZCYkTAYpzHJ2XKP7jEj6ZR +msIgwwTfVMON+yAF06Tmk9VCRkLUGjNMA4Xh24MejieE0JSq7mIAhSYWE+nalG8H +sxtsQWDtAgMBAAECggEAKjnqszSS56nZ2Bpw4+Wh3EnZywjMBuRBBrwmC/KK0EGz +8rKN7y8k1VubVOBlykayiBzKQciha9r4L+bQ8/LocTaQx+edCqQolmSp6ePgCmuj +8RH3iSxIanDv09IAXaOYqD63AMiRV22QZDXKOkIePIbcewEerps8Ofze6c5bK9/H +jWWScpphiJcTVSNKqTzVMJ27DcpTn/bQ+nPYQNdyMj51d1Vaebesmv1ieQfvdMTA +FFTCYUfNrWWjB9CEhZAg3sgr3dOVVTl8LC8VF6ug5EvNeuW2l8FX+fgVaZWKzGpf +veItiHIeMAEw6ZQPTumSvIO2fhg5Iue9fu4hR08twQKBgQDuqsRTJNG8DhR2CZrH +KWeW8ac7pOiHSHxqpERfwkxtpUEsLXib8foPRJYdEzvDDgL3SxaRmTnwtyZvhZgl +RE7ebRvAHkHy+nvim6wBgFYe2YwTTm9d3gCRxOHknEXfXSxzbwJznDlZrHlJ9XDQ +OFBgLbhheF694y1l605E/8s/QQKBgQDSjEztXmrzlQvXKhOp+eL6o+jRrfrv5hmC +ORy063ecVD/bA3OsgOKd5Q0v7dEyCC5IesQA0uY4GZ7zjsxcA4tUp20fZoTEeMpZ +DznTz5Lqkpa8DCeMLj6fPWKnb4zZeduvBE4BRzS4Mdpzp/OrLAGgE2hwWNIB56uq +pyloAkEirQKBgQDPZjg7JFjaOcYQGSKWheWOJyszSogC37u2lE8Sg+8UrTGoaU9Q +/QNXdzuXwpoBU9DCA092cRgHlbDh4s8nO2fqJBikZ+bZdlBnyO29VEACiPwP3u4q +PPxzsAq5NhAGHZq+KS6RNqYjxhyUZ6SEXRuDqNd8ZDS4gI137vZSQZLmwQKBgDcs +mQQjF/fY+Q9bcWe7miWASoSYCQhQziJ4APPQOLn4wfsMvoVYCQrDeV8z/PwVdLt9 +oFtu6PGOlT7SDu+V5i866Levz98EoFISUV8WKDPcUi/ZJ4vumm50UaP68XgUHOOS +RzbCiCg0uEBSpOIYWBywuU+nlvD02uGPiKQ+4v7JAoGAdyREhzfyz+WO5wdHLRIB +k9cz448Lp/Uh2pcgAkFvtBwCAGZn86YJZ7JFD/HkQbw7amVEn+nxoCeWHkbSnzt0 +8Gu0hdqo6SaxOdf1wKBel79r/GH4ZwYRzWJoO/Iy//JJJO7g9GylAo0rBdQR9n60 +ZPMIZJLfKmGQWahl2EqyMG8= +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.crt new file mode 100644 index 0000000000000..bdae53b6bd114 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIIDxDCCAqygAwIBAgIUHZmKZs2+ejbJGEezgoSUTPKF42UwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI1MDczMDE2NTU1M1oYDzMwMjQxMTMwMTY1NTUzWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME3NlcnZlci5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDS9+yttKuydwcc4te8CtJS +yldr163zuSemazYUDY07ZMc/VFfSISYBX3yrW8uveR4zt8H8XhdkPHV6t3K3GxcY +28CpUXr9TkCH2xMd4xn2DuewaTuf+yCcM6TLFh00nyqcYPhZGn3Bd5jGXKxI8PsH +butswFKM/t9VTRtTpBgMsw8SEy0vyvsJyPTHm8aWY7tCDSI+vI8bHcG9sO8cuin8 +0JcZ4rRTgZpmDDlcY0OniRwowB5ph3eU0uaHIX+EWuht3+1trjyrSFuQ6y82f/0N +MchcA3vTG0qF4ulcZn0yng1wuC1YorInYiDUahxDXHvmsoYz+0yOkR5fVReHLfdp +AgMBAAGjbTBrMCkGA1UdEQQiMCCCE3NlcnZlci5teWRvbWFpbi5jb22CCWxvY2Fs +aG9zdDAdBgNVHQ4EFgQUuxBCAwXgk1cEC3myPm4Q/b8+oWMwHwYDVR0jBBgwFoAU +RnHCLycTY1mGPIRmdxyJbrdKde4wDQYJKoZIhvcNAQELBQADggEBABlwupm4YIBd +dPRoX6S/Ta8oAvkz+gv/s9vIekG7fTzfGDmh9hQEtVU4OVT0ifVAX7x5bf0v7KKp +jXDctZqyttXAtu6e1nQBE+MC9MXi68YU2hSApJF+z7WTPXEbcuIOQPpvXfiD7s2j +455tMF90iKWjcFGdgB1usiQeNeNDjBwcvkGJhLhKO6UqsHLh2BBUTPhDnpyBkt1y +zsZk3YDKr2ipaIj3OvBwueJNL8I3oU04eZMAFeSYbnJ49GdSDLTxglACoOkZrbju +4qwV1bbIB97AtSdhcg3OxBZTAgVep8kOl3N/spNfgq4N/jTdsl+hoEKlChH/7kI3 +oYZfv+nc6s0= +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.key new file mode 100644 index 0000000000000..1906728090bb7 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDS9+yttKuydwcc +4te8CtJSyldr163zuSemazYUDY07ZMc/VFfSISYBX3yrW8uveR4zt8H8XhdkPHV6 +t3K3GxcY28CpUXr9TkCH2xMd4xn2DuewaTuf+yCcM6TLFh00nyqcYPhZGn3Bd5jG +XKxI8PsHbutswFKM/t9VTRtTpBgMsw8SEy0vyvsJyPTHm8aWY7tCDSI+vI8bHcG9 +sO8cuin80JcZ4rRTgZpmDDlcY0OniRwowB5ph3eU0uaHIX+EWuht3+1trjyrSFuQ +6y82f/0NMchcA3vTG0qF4ulcZn0yng1wuC1YorInYiDUahxDXHvmsoYz+0yOkR5f +VReHLfdpAgMBAAECggEAGOjA/zmH1EiNhHGcO02jy7asX8VVeqNv9QxPlEqNVGfv +xqB0xhC35g2aMLlj8VIBqOWXd+68IE+rJ1QlrUz7iynXM6a1ONdWczQAq9S2qgDU +hlXGfnsuPIM0f+4agK4SX+hrKkogcwll9nXWub4KRbRpA6wpkxA82luCUHvdgxIi +PW37HHYWUKoAu828PpG1wGUf6wDFnSEhEuYfgHGkcWyhpRw5QZP563QTMCKWPjcW +Nm+XJ10gpAB4Q+zcqxji7r+5uIbF9Zobkd9VuaWHavQfrgHpjfLnZjHr3r3BABi0 +U1Zz5x5R19r54aY93J0Fq4DOlJ1Gf80eBLJEs1b8CQKBgQD53IfUp4ScfIs1I3Vw +JOclRY3UY/uZDQ+edXLqT91UxRpmvq8pmL82Y+idWGrx1WtCk2uOFGAl37Y+8goW +bMHEZ8Wv8NtV29sqduuCE210miZN3q7EmTEj0AOOKd/Skwyte3rxvFw0didQvAhE ++uZEIZ+XaUNF102hT2BaTIm8fwKBgQDYJsgiqGRzGtP3sOeLKxYmGDE3fGB4JHYU +U9kZ0pTQWiYsl+F7lXdwqkUApgU1rFuA3oR7dV4a8zS7BbyLK2WYfWE4yAoRm52W +VnIGsdG4Z4wFGuNR7d+m+MouP4HYSJFtUoJFJxYXU4Kc3H88Ob5UmerVNFN7VSaU +W/jtek34FwKBgH+c4sL5zAEgmvjI43IjZuriW03ewuGoihGkaszBfYmOIa3YNh5I +pWBiJqw2PGjHV8DpCkXGolS1rZ74f650XYKyfYUevudbItTNZ/tHcN/c2zNqSFig +5TglRauWN3qVICR6rJBKY81niyzw3Ehe3Lxvb9MlL/a7wCpjIBL+hFqBAoGBAIax +jA+EzauosSPtWiwv+kpc0vaXi+nyFp7OLUBZKCC5vIYXUwxW9KoBgKRJ0H9E23Rv +tTDVz4GNwnM0vOwga9vdbaMbjKKyTT4sujuPvXdjFy7rNXKNf8wlxp+RNZGYjv8H +5mO/WpXIlWC4SpU2CnPfwiV/yPHW+waCVZlumH2bAoGAJscnMS4YQXiSoLQweiyQ ++mev9+i9h57/M0bk6WQHWzhAjxmZDtFUPO8P4vi3akAhCnt0yswh8e3G6h3F9TEz +PQ4AIFMajb0PWUEcGsZhBh7TCJwxPXo34SbXGkIB5RYogT/mDVMjEfduAMen/0sp +Te0BRozzkKHtIcGAqozt7ZU= +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_certs.sh new file mode 100755 index 0000000000000..11875d6fe462a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_certs.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Set directory for certificates and keys. +CERT_DIR="./certs" +mkdir -p $CERT_DIR + +# Dummy values for the certificates. +COUNTRY="US" +STATE="State" +LOCALITY="City" +ORGANIZATION="MyOrg" +ORG_UNIT="MyUnit" +COMMON_NAME="MyCA" +SERVER_CN="server.mydomain.com" +CLIENT_CN="client.mydomain.com" + +# Step 1: Generate CA private key and self-signed certificate. +openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key +openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" + +# Step 2: Generate server private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/server.key + +# Step 3: Generate server certificate signing request (CSR). +openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ + -addext "subjectAltName=DNS:$SERVER_CN,DNS:localhost" + +# Step 4: Sign server CSR with the CA certificate to generate the server certificate. +openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ + -extfile <(printf "subjectAltName=DNS:$SERVER_CN,DNS:localhost") + +# Step 5: Generate client private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/client.key + +# Step 6: Generate client certificate signing request (CSR). +openssl req -new -key $CERT_DIR/client.key -out $CERT_DIR/client.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$CLIENT_CN" + +# Step 7: Sign client CSR with the CA certificate to generate the client certificate. +openssl x509 -req -in $CERT_DIR/client.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/client.crt -days 365000 + +# Step 8: Output summary. +echo "Certificate Authority (CA) certificate : $CERT_DIR/ca.crt" +echo "Server certificate : $CERT_DIR/server.crt" +echo "Server private key : $CERT_DIR/server.key" +echo "Client certificate : $CERT_DIR/client.crt" +echo "Client private key : $CERT_DIR/client.key" + +# Step 9: Remove unused files. +rm -f $CERT_DIR/*.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key + diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh index 718f313c70a75..e69de29bb2d1d 100755 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh @@ -1,40 +0,0 @@ -#!/bin/bash - -# Set directory for certificates and keys. -CERT_DIR="./tls_certs" -mkdir -p $CERT_DIR - -# Dummy values for the certificates. -COUNTRY="US" -STATE="State" -LOCALITY="City" -ORGANIZATION="MyOrg" -ORG_UNIT="MyUnit" -COMMON_NAME="MyCA" -SERVER_CN="server.mydomain.com" - -# Step 1: Generate CA private key and self-signed certificate. -openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key -openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ - -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" - -# Step 2: Generate server private key. -openssl genpkey -algorithm RSA -out $CERT_DIR/server.key - -# Step 3: Generate server certificate signing request (CSR). -openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ - -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ - -addext "subjectAltName=DNS:$COMMON_NAME,DNS:localhost" \ - -# Step 4: Sign server CSR with the CA certificate to generate the server certificate. -openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ - -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ - -extfile <(printf "subjectAltName=DNS:$COMMON_NAME,DNS:localhost") - -# Step 5: Output the generated files. -echo "Certificate Authority (CA) certificate: $CERT_DIR/ca.crt" -echo "Server certificate: $CERT_DIR/server.crt" -echo "Server private key: $CERT_DIR/server.key" - -# Step 6: Remove unused files. -rm -rf $CERT_DIR/server.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp index 1fecf9e31977a..2dbbaca394bcf 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightConnectorTestBase.cpp @@ -25,16 +25,9 @@ namespace facebook::presto::test { void ArrowFlightConnectorTestBase::SetUp() { OperatorTestBase::SetUp(); - - if (!velox::connector::hasConnectorFactory( - presto::ArrowFlightConnectorFactory::kArrowFlightConnectorName)) { - velox::connector::registerConnectorFactory( - std::make_shared()); - } + presto::ArrowFlightConnectorFactory factory; velox::connector::registerConnector( - velox::connector::getConnectorFactory( - ArrowFlightConnectorFactory::kArrowFlightConnectorName) - ->newConnector(kFlightConnectorId, config_)); + factory.newConnector(kFlightConnectorId, config_)); ArrowFlightConfig config(config_); if (config.defaultServerPort().has_value()) { @@ -85,8 +78,9 @@ ArrowFlightConnectorTestBase::makeSplits( AFC_ASSIGN_OR_RAISE( auto flightEndpointStr, flightEndpoint.SerializeToString()); auto flightEndpointBytes = folly::base64Encode(flightEndpointStr); - splits.push_back(std::make_shared( - kFlightConnectorId, flightEndpointBytes)); + splits.push_back( + std::make_shared( + kFlightConnectorId, flightEndpointBytes)); } return splits; } diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp index 5280f9c56832d..8ea2c50734165 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.cpp @@ -21,9 +21,7 @@ const std::string kFlightConnectorId = "test-flight"; velox::exec::test::PlanBuilder& ArrowFlightPlanBuilder::flightTableScan( const velox::RowTypePtr& outputType, - std::unordered_map< - std::string, - std::shared_ptr> assignments, + velox::connector::ColumnHandleMap assignments, bool createDefaultColumnHandles) { if (createDefaultColumnHandles) { for (const auto& name : outputType->names()) { diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h index 5eda2c60aac16..4615e4babafc2 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/ArrowFlightPlanBuilder.h @@ -26,9 +26,7 @@ class ArrowFlightPlanBuilder : public velox::exec::test::PlanBuilder { /// for the columns which don't have an entry in assignments velox::exec::test::PlanBuilder& flightTableScan( const velox::RowTypePtr& outputType, - std::unordered_map< - std::string, - std::shared_ptr> assignments = {}, + velox::connector::ColumnHandleMap assignments = {}, bool createDefaultColumnHandles = true); }; diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt index b6d2337a2d301..00881474b5604 100644 --- a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt @@ -11,9 +11,17 @@ # limitations under the License. add_library( presto_flight_connector_test_lib - TestingArrowFlightServer.cpp ArrowFlightConnectorTestBase.cpp Utils.cpp - ArrowFlightPlanBuilder.cpp) + TestingArrowFlightServer.cpp + ArrowFlightConnectorTestBase.cpp + Utils.cpp + ArrowFlightPlanBuilder.cpp +) target_link_libraries( - presto_flight_connector_test_lib arrow presto_flight_connector - velox_exception presto_flight_connector_utils velox_exec_test_lib) + presto_flight_connector_test_lib + arrow + presto_flight_connector + velox_exception + presto_flight_connector_utils + velox_exec_test_lib +) diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/hive/CMakeLists.txt new file mode 100644 index 0000000000000..fb98894c302db --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/CMakeLists.txt @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(functions) diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/hive/functions/CMakeLists.txt new file mode 100644 index 0000000000000..2435a6fd73e5f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/CMakeLists.txt @@ -0,0 +1,22 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_hive_functions HiveFunctionRegistration.cpp) +target_link_libraries( + presto_hive_functions + presto_dynamic_function_registrar + velox_functions_string +) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.cpp b/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.cpp new file mode 100644 index 0000000000000..ffe09361605ca --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.cpp @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h" + +#include "presto_cpp/main/connectors/hive/functions/InitcapFunction.h" +#include "presto_cpp/main/functions/dynamic_registry/DynamicFunctionRegistrar.h" + +using namespace facebook::velox; +namespace facebook::presto::hive::functions { + +namespace { +void registerHiveFunctions() { + // Register functions under the 'hive.default' namespace. + facebook::presto::registerPrestoFunction( + "initcap", "hive.default"); +} +} // namespace + +void registerHiveNativeFunctions() { + static std::once_flag once; + std::call_once(once, []() { registerHiveFunctions(); }); +} + +} // namespace facebook::presto::hive::functions diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h b/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h new file mode 100644 index 0000000000000..338938f2bbb67 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto::hive::functions { + +// Registers Hive-specific native functions into the 'hive.default' namespace. +// This method is safe to call multiple times; it performs one-time registration +// guarded by an internal call_once. +void registerHiveNativeFunctions(); + +} // namespace facebook::presto::hive::functions diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/InitcapFunction.h b/presto-native-execution/presto_cpp/main/connectors/hive/functions/InitcapFunction.h new file mode 100644 index 0000000000000..c7bf162939b07 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/InitcapFunction.h @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/functions/Macros.h" +#include "velox/functions/lib/string/StringImpl.h" + +namespace facebook::presto::hive::functions { + +/// The InitCapFunction capitalizes the first character of each word in a +/// string, and lowercases the rest. +template +struct InitCapFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // ASCII input always produces ASCII result. This is required for ASCII fast + // path + static constexpr bool is_default_ascii_behavior = true; + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input) { + velox::functions::stringImpl::initcap< + /*strictSpace=*/false, + /*isAscii=*/false, + /*turkishCasing=*/true, + /*greekFinalSigma=*/true>(result, input); + } + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& input) { + velox::functions::stringImpl::initcap< + /*strictSpace=*/false, + /*isAscii=*/true, + /*turkishCasing=*/true, + /*greekFinalSigma=*/true>(result, input); + } +}; + +} // namespace facebook::presto::hive::functions diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/CMakeLists.txt new file mode 100644 index 0000000000000..2089503c68181 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/CMakeLists.txt @@ -0,0 +1,28 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(presto_hive_functions_test InitcapTest.cpp) + +add_test( + NAME presto_hive_functions_test + COMMAND presto_hive_functions_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries( + presto_hive_functions_test + presto_hive_functions + presto_common + velox_functions_test_lib + GTest::gtest + GTest::gtest_main +) diff --git a/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/InitcapTest.cpp b/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/InitcapTest.cpp new file mode 100644 index 0000000000000..263d03aeb735b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/hive/functions/tests/InitcapTest.cpp @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "presto_cpp/main/connectors/hive/functions/HiveFunctionRegistration.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +namespace facebook::presto::functions::test { +class InitcapTest : public velox::functions::test::FunctionBaseTest { + protected: + static void SetUpTestCase() { + velox::functions::test::FunctionBaseTest::SetUpTestCase(); + facebook::presto::hive::functions::registerHiveNativeFunctions(); + } +}; + +TEST_F(InitcapTest, initcap) { + const auto initcap = [&](const std::optional& value) { + return evaluateOnce("\"hive.default.initcap\"(c0)", value); + }; + + // Unicode only. + EXPECT_EQ( + initcap("àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþ"), + "Àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþ"); + EXPECT_EQ(initcap("αβγδεζηθικλμνξοπρςστυφχψ"), "Αβγδεζηθικλμνξοπρςστυφχψ"); + // Mix of ascii and unicode. + EXPECT_EQ(initcap("αβγδεζ world"), "Αβγδεζ World"); + EXPECT_EQ(initcap("αfoo wβ"), "Αfoo Wβ"); + // Ascii only. + EXPECT_EQ(initcap("hello world"), "Hello World"); + EXPECT_EQ(initcap("HELLO WORLD"), "Hello World"); + EXPECT_EQ(initcap("1234"), "1234"); + EXPECT_EQ(initcap("a b c d"), "A B C D"); + EXPECT_EQ(initcap("abcd"), "Abcd"); + // Numbers. + EXPECT_EQ(initcap("123"), "123"); + EXPECT_EQ(initcap("1abc"), "1abc"); + // Edge cases. + EXPECT_EQ(initcap(""), ""); + EXPECT_EQ(initcap(std::nullopt), std::nullopt); + + // Test with various whitespace characters + EXPECT_EQ(initcap("YQ\tY"), "Yq\tY"); + EXPECT_EQ(initcap("YQ\nY"), "Yq\nY"); + EXPECT_EQ(initcap("YQ\rY"), "Yq\rY"); + EXPECT_EQ(initcap("hello\tworld\ntest"), "Hello\tWorld\nTest"); + EXPECT_EQ(initcap("foo\r\nbar"), "Foo\r\nBar"); + + // Test with multiple consecutive whitespaces + EXPECT_EQ(initcap("hello world"), "Hello World"); + EXPECT_EQ(initcap("a b c"), "A B C"); + EXPECT_EQ(initcap("test\t\tvalue"), "Test\t\tValue"); + EXPECT_EQ(initcap("line\n\n\nbreak"), "Line\n\n\nBreak"); + + // Test with leading and trailing whitespaces + EXPECT_EQ(initcap(" hello"), " Hello"); + EXPECT_EQ(initcap("world "), "World "); + EXPECT_EQ(initcap(" spaces "), " Spaces "); + EXPECT_EQ(initcap("\thello"), "\tHello"); + EXPECT_EQ(initcap("\nworld"), "\nWorld"); + EXPECT_EQ(initcap("test\n"), "Test\n"); + + // Test with mixed whitespace types + EXPECT_EQ(initcap("hello \t\nworld"), "Hello \t\nWorld"); + EXPECT_EQ(initcap("a\tb\nc\rd"), "A\tB\nC\rD"); + EXPECT_EQ(initcap(" \t\n "), " \t\n "); +} +} // namespace facebook::presto::functions::test diff --git a/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt new file mode 100644 index 0000000000000..20020ea182e5d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/CMakeLists.txt @@ -0,0 +1,24 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_library(presto_function_metadata OBJECT FunctionMetadata.cpp) +target_link_libraries(presto_function_metadata presto_common velox_function_registry) + +add_subdirectory(dynamic_registry) + +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_subdirectory(remote) +endif() + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp new file mode 100644 index 0000000000000..6c7262a83b1b8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.cpp @@ -0,0 +1,323 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/functions/FunctionMetadata.h" +#include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" +#include "velox/exec/Aggregate.h" +#include "velox/exec/WindowFunction.h" +#include "velox/expression/SimpleFunctionRegistry.h" +#include "velox/functions/FunctionRegistry.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +namespace facebook::presto { +namespace { + +// Check if the Velox type is supported in Presto. +bool isValidPrestoType(const TypeSignature& typeSignature) { + if (typeSignature.parameters().empty()) { + // Hugeint type is not supported in Presto. + auto kindName = boost::algorithm::to_upper_copy(typeSignature.baseName()); + if (auto typeKind = TypeKindName::tryToTypeKind(kindName)) { + return typeKind.value() != TypeKind::HUGEINT; + } + } else { + for (const auto& paramType : typeSignature.parameters()) { + if (!isValidPrestoType(paramType)) { + return false; + } + } + } + return true; +} + +const protocol::AggregationFunctionMetadata getAggregationFunctionMetadata( + const std::string& name, + const AggregateFunctionSignature& signature) { + protocol::AggregationFunctionMetadata metadata; + metadata.intermediateType = + boost::algorithm::to_lower_copy(signature.intermediateType().toString()); + metadata.isOrderSensitive = + getAggregateFunctionEntry(name)->metadata.orderSensitive; + return metadata; +} + +const exec::VectorFunctionMetadata getScalarMetadata(const std::string& name) { + auto simpleFunctionMetadata = + exec::simpleFunctions().getFunctionSignaturesAndMetadata(name); + if (simpleFunctionMetadata.size()) { + // Functions like abs are registered as simple functions for primitive + // types, and as a vector function for complex types like DECIMAL. So do not + // throw an error if function metadata is not found in simple function + // signature map. + return simpleFunctionMetadata.back().first; + } + + auto vectorFunctionMetadata = exec::getVectorFunctionMetadata(name); + if (vectorFunctionMetadata.has_value()) { + return vectorFunctionMetadata.value(); + } + VELOX_UNREACHABLE("Metadata for function {} not found", name); +} + +const protocol::RoutineCharacteristics getRoutineCharacteristics( + const std::string& name, + const protocol::FunctionKind& kind) { + protocol::Determinism determinism; + protocol::NullCallClause nullCallClause; + if (kind == protocol::FunctionKind::SCALAR) { + auto metadata = getScalarMetadata(name); + determinism = metadata.deterministic + ? protocol::Determinism::DETERMINISTIC + : protocol::Determinism::NOT_DETERMINISTIC; + nullCallClause = metadata.defaultNullBehavior + ? protocol::NullCallClause::RETURNS_NULL_ON_NULL_INPUT + : protocol::NullCallClause::CALLED_ON_NULL_INPUT; + } else { + // Default metadata values of DETERMINISTIC and CALLED_ON_NULL_INPUT for + // non-scalar functions. + determinism = protocol::Determinism::DETERMINISTIC; + nullCallClause = protocol::NullCallClause::CALLED_ON_NULL_INPUT; + } + + protocol::RoutineCharacteristics routineCharacteristics; + routineCharacteristics.language = + std::make_shared(protocol::Language({"CPP"})); + routineCharacteristics.determinism = + std::make_shared(determinism); + routineCharacteristics.nullCallClause = + std::make_shared(nullCallClause); + return routineCharacteristics; +} + +const std::vector getTypeVariableConstraints( + const FunctionSignature& functionSignature) { + std::vector typeVariableConstraints; + const auto& functionVariables = functionSignature.variables(); + for (const auto& [name, signature] : functionVariables) { + if (signature.isTypeParameter()) { + protocol::TypeVariableConstraint typeVariableConstraint; + typeVariableConstraint.name = + boost::algorithm::to_lower_copy(signature.name()); + typeVariableConstraint.orderableRequired = signature.orderableTypesOnly(); + typeVariableConstraint.comparableRequired = + signature.comparableTypesOnly(); + typeVariableConstraints.emplace_back(typeVariableConstraint); + } + } + return typeVariableConstraints; +} + +const std::vector getLongVariableConstraints( + const FunctionSignature& functionSignature) { + std::vector longVariableConstraints; + const auto& functionVariables = functionSignature.variables(); + for (const auto& [name, signature] : functionVariables) { + if (signature.isIntegerParameter() && !signature.constraint().empty()) { + protocol::LongVariableConstraint longVariableConstraint; + longVariableConstraint.name = + boost::algorithm::to_lower_copy(signature.name()); + longVariableConstraint.expression = + boost::algorithm::to_lower_copy(signature.constraint()); + longVariableConstraints.emplace_back(longVariableConstraint); + } + } + return longVariableConstraints; +} + +std::optional buildFunctionMetadata( + const std::string& name, + const std::string& schema, + const protocol::FunctionKind& kind, + const FunctionSignature& signature, + const AggregateFunctionSignaturePtr& aggregateSignature = nullptr) { + protocol::JsonBasedUdfFunctionMetadata metadata; + metadata.docString = name; + metadata.functionKind = kind; + if (!isValidPrestoType(signature.returnType())) { + return std::nullopt; + } + metadata.outputType = + boost::algorithm::to_lower_copy(signature.returnType().toString()); + + const auto& argumentTypes = signature.argumentTypes(); + std::vector paramTypes(argumentTypes.size()); + for (auto i = 0; i < argumentTypes.size(); i++) { + if (!isValidPrestoType(argumentTypes.at(i))) { + return std::nullopt; + } + paramTypes[i] = + boost::algorithm::to_lower_copy(argumentTypes.at(i).toString()); + } + metadata.paramTypes = paramTypes; + metadata.schema = schema; + metadata.variableArity = signature.variableArity(); + metadata.routineCharacteristics = getRoutineCharacteristics(name, kind); + metadata.typeVariableConstraints = + std::make_shared>( + getTypeVariableConstraints(signature)); + metadata.longVariableConstraints = + std::make_shared>( + getLongVariableConstraints(signature)); + + if (aggregateSignature) { + metadata.aggregateMetadata = + std::make_shared( + getAggregationFunctionMetadata(name, *aggregateSignature)); + } + return metadata; +} + +json buildScalarMetadata( + const std::string& name, + const std::string& schema, + const std::vector& signatures) { + json j = json::array(); + json tj; + for (const auto& signature : signatures) { + if (auto functionMetadata = buildFunctionMetadata( + name, schema, protocol::FunctionKind::SCALAR, *signature)) { + protocol::to_json(tj, functionMetadata.value()); + j.push_back(tj); + } + } + return j; +} + +json buildAggregateMetadata( + const std::string& name, + const std::string& schema, + const std::vector& signatures) { + // All aggregate functions can be used as window functions. + VELOX_USER_CHECK( + getWindowFunctionSignatures(name).has_value(), + "Aggregate function {} not registered as a window function", + name); + + // The functions returned by this endpoint are stored as SqlInvokedFunction + // objects, with SqlFunctionId serving as the primary key. SqlFunctionId is + // derived from both the functionName and argumentTypes parameters. Returning + // the same function twice—once as an aggregate function and once as a window + // function introduces ambiguity, as functionKind is not a component of + // SqlFunctionId. For any aggregate function utilized as a window function, + // the function’s metadata can be obtained from the associated aggregate + // function implementation for further processing. For additional information, + // refer to the following: • + // https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionId.java + // • + // https://github.com/prestodb/presto/blob/master/presto-spi/src/main/java/com/facebook/presto/spi/function/SqlInvokedFunction.java + + const std::vector kinds = { + protocol::FunctionKind::AGGREGATE}; + json j = json::array(); + json tj; + for (const auto& kind : kinds) { + for (const auto& signature : signatures) { + if (auto functionMetadata = buildFunctionMetadata( + name, schema, kind, *signature, signature)) { + protocol::to_json(tj, functionMetadata.value()); + j.push_back(tj); + } + } + } + return j; +} + +json buildWindowMetadata( + const std::string& name, + const std::string& schema, + const std::vector& signatures) { + json j = json::array(); + json tj; + for (const auto& signature : signatures) { + if (auto functionMetadata = buildFunctionMetadata( + name, schema, protocol::FunctionKind::WINDOW, *signature)) { + protocol::to_json(tj, functionMetadata.value()); + j.push_back(tj); + } + } + return j; +} + +} // namespace + +json getFunctionsMetadata(const std::optional& catalog) { + json j; + + // Lambda to check if a function should be skipped based on catalog filter + auto skipCatalog = [&catalog](const std::string& functionCatalog) { + return catalog.has_value() && functionCatalog != catalog.value(); + }; + + // Get metadata for all registered scalar functions in velox. + const auto signatures = getFunctionSignatures(); + static const std::unordered_set kBlockList = { + "row_constructor", "in", "is_null"}; + // Exclude aggregate companion functions (extract aggregate companion + // functions are registered as vector functions). + const auto aggregateFunctions = exec::aggregateFunctions().copy(); + for (const auto& entry : signatures) { + const auto name = entry.first; + // Skip internal functions. They don't have any prefix. + if (kBlockList.count(name) != 0 || + name.find("$internal$") != std::string::npos || + getScalarMetadata(name).companionFunction) { + continue; + } + + const auto parts = util::getFunctionNameParts(name); + if (skipCatalog(parts[0])) { + continue; + } + const auto schema = parts[1]; + const auto function = parts[2]; + j[function] = buildScalarMetadata(name, schema, entry.second); + } + + // Get metadata for all registered aggregate functions in velox. + for (const auto& entry : aggregateFunctions) { + if (!aggregateFunctions.at(entry.first).metadata.companionFunction) { + const auto name = entry.first; + const auto parts = util::getFunctionNameParts(name); + if (skipCatalog(parts[0])) { + continue; + } + const auto schema = parts[1]; + const auto function = parts[2]; + j[function] = + buildAggregateMetadata(name, schema, entry.second.signatures); + } + } + + // Get metadata for all registered window functions in velox. Skip aggregates + // as they have been processed. + const auto& functions = exec::windowFunctions(); + for (const auto& entry : functions) { + if (aggregateFunctions.count(entry.first) == 0) { + const auto name = entry.first; + const auto parts = util::getFunctionNameParts(entry.first); + if (skipCatalog(parts[0])) { + continue; + } + const auto schema = parts[1]; + const auto function = parts[2]; + j[function] = buildWindowMetadata(name, schema, entry.second.signatures); + } + } + + return j; +} + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h new file mode 100644 index 0000000000000..d2a2c66d7a489 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/FunctionMetadata.h @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "presto_cpp/external/json/nlohmann/json.hpp" + +namespace facebook::presto { + +// Returns metadata for all registered functions as json. +nlohmann::json getFunctionsMetadata( + const std::optional& catalog = std::nullopt); + +} // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/functions/dynamic_registry/CMakeLists.txt b/presto-native-execution/presto_cpp/main/functions/dynamic_registry/CMakeLists.txt new file mode 100644 index 0000000000000..bb13a5e2eeb17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/dynamic_registry/CMakeLists.txt @@ -0,0 +1,17 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(${PRESTO_ENABLE_EXAMPLES}) + add_subdirectory(examples) +endif() + +add_library(presto_dynamic_function_registrar INTERFACE DynamicFunctionRegistrar.h) diff --git a/presto-native-execution/presto_cpp/main/functions/dynamic_registry/DynamicFunctionRegistrar.h b/presto-native-execution/presto_cpp/main/functions/dynamic_registry/DynamicFunctionRegistrar.h new file mode 100644 index 0000000000000..79b0a499a3553 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/functions/dynamic_registry/DynamicFunctionRegistrar.h @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/common/Configs.h" +#include "velox/functions/Macros.h" +#include "velox/functions/Registerer.h" + +namespace facebook::presto { +template