diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index f73267a95fa3..8671cff054bb 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -70,7 +70,7 @@ jobs: with: fetch-depth: 0 - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -81,7 +81,7 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} @@ -89,7 +89,7 @@ jobs: benchmark-coursier-${{ github.event.inputs.jdk }} - name: Cache TPC-DS generated data id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -105,8 +105,9 @@ jobs: run: cd tpcds-kit/tools && make OS=LINUX - name: Install Java ${{ github.event.inputs.jdk }} if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ github.event.inputs.jdk }} - name: Generate TPC-DS (SF=1) table data if: steps.cache-tpcds-sf-1.outputs.cache-hit != 'true' @@ -138,7 +139,7 @@ jobs: with: fetch-depth: 0 - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -149,20 +150,21 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: benchmark-coursier-${{ github.event.inputs.jdk }}-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | benchmark-coursier-${{ github.event.inputs.jdk }} - name: Install Java ${{ github.event.inputs.jdk }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ github.event.inputs.jdk }} - name: Cache TPC-DS generated data if: contains(github.event.inputs.class, 'TPCDSQueryBenchmark') || contains(github.event.inputs.class, '*') id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/benchmark.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -186,7 +188,7 @@ jobs: echo "Preparing the benchmark results:" tar -cvf benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar `git diff --name-only` `git ls-files --others --exclude=tpcds-sf-1 --exclude-standard` - name: Upload benchmark results - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}-${{ matrix.split }} path: benchmark-results-${{ github.event.inputs.jdk }}-${{ github.event.inputs.scala }}.tar diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 8411170e2d5c..357ea2e6126b 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -209,7 +209,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -220,15 +220,16 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: ${{ matrix.java }}-${{ matrix.hadoop }}-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | ${{ matrix.java }}-${{ matrix.hadoop }}-coursier- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: Install Python 3.8 uses: actions/setup-python@v2 @@ -254,13 +255,13 @@ jobs: ./dev/run-tests --parallelism 1 --modules "$MODULES_TO_TEST" --included-tags "$INCLUDED_TAGS" --excluded-tags "$EXCLUDED_TAGS" - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}-${{ matrix.comment }}-${{ matrix.java }}-${{ matrix.hadoop }}-${{ matrix.hive }} path: "**/target/unit-tests.log" @@ -366,7 +367,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -377,15 +378,16 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: pyspark-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | pyspark-coursier- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: List Python packages (Python 3.9, PyPy3) run: | @@ -410,13 +412,13 @@ jobs: name: PySpark - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-${{ matrix.modules }}--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -455,7 +457,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -466,15 +468,16 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: sparkr-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | sparkr-coursier- - name: Install Java ${{ inputs.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ inputs.java }} - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -486,7 +489,7 @@ jobs: ./dev/run-tests --parallelism 1 --modules sparkr - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-sparkr--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" @@ -523,7 +526,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty # Cache local repositories. Note that GitHub Actions cache has a 2G limit. - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -534,14 +537,14 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: docs-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | docs-coursier- - name: Cache Maven local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: docs-maven-${{ hashFiles('**/pom.xml') }} @@ -553,7 +556,7 @@ jobs: # See also https://github.com/sphinx-doc/sphinx/issues/7551. # Jinja2 3.0.0+ causes error when building with Sphinx. # See also https://issues.apache.org/jira/browse/SPARK-35375. - python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' python3.9 -m pip install 'pandas-stubs==1.2.0.53' - name: Install dependencies for Python code generation check run: | @@ -597,8 +600,9 @@ jobs: cd docs bundle install - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Scala linter run: ./dev/lint-scala @@ -646,7 +650,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -657,15 +661,16 @@ jobs: restore-keys: | build- - name: Cache Maven local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: java${{ matrix.java }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | java${{ matrix.java }}-maven- - name: Install Java ${{ matrix.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ matrix.java }} - name: Build with Maven run: | @@ -695,7 +700,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -706,15 +711,16 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: scala-213-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | scala-213-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Build with SBT run: | @@ -743,7 +749,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -754,19 +760,20 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: tpcds-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | tpcds-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Cache TPC-DS generated data id: cache-tpcds-sf-1 - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ./tpcds-sf-1 key: tpcds-${{ hashFiles('.github/workflows/build_and_test.yml', 'sql/core/src/test/scala/org/apache/spark/sql/TPCDSSchema.scala') }} @@ -808,13 +815,13 @@ jobs: spark.sql.join.forceApplyShuffledHashJoin=true - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-tpcds--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -846,7 +853,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -857,28 +864,29 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: docker-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | docker-integration-coursier- - name: Install Java 8 - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Run tests run: | ./dev/run-tests --parallelism 1 --modules docker-integration-tests --included-tags org.apache.spark.tags.DockerTest - name: Upload test results to report if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: test-results-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/test-reports/*.xml" - name: Upload unit tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: unit-tests-log-docker-integration--8-${{ inputs.hadoop }}-hive2.3 path: "**/target/unit-tests.log" @@ -903,7 +911,7 @@ jobs: git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' merge --no-commit --progress --squash FETCH_HEAD git -c user.name='Apache Spark Test Account' -c user.email='sparktestacc@gmail.com' commit -m "Merged commit" --allow-empty - name: Cache Scala, SBT and Maven - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | build/apache-maven-* @@ -914,15 +922,16 @@ jobs: restore-keys: | build- - name: Cache Coursier local repository - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: ~/.cache/coursier key: k8s-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }} restore-keys: | k8s-integration-coursier- - name: Install Java ${{ inputs.java }} - uses: actions/setup-java@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: ${{ inputs.java }} - name: start minikube run: | @@ -948,7 +957,7 @@ jobs: build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files if: failure() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: spark-on-kubernetes-it-log path: "**/target/integration-tests.log" diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml index a8251aa5b67c..33de10b5b2e6 100644 --- a/.github/workflows/publish_snapshot.yml +++ b/.github/workflows/publish_snapshot.yml @@ -41,15 +41,16 @@ jobs: with: ref: ${{ matrix.branch }} - name: Cache Maven local repository - uses: actions/cache@c64c572235d810460d0d6876e9c705ad5002b353 # pin@v2 + uses: actions/cache@v3 with: path: ~/.m2/repository key: snapshot-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | snapshot-maven- - name: Install Java 8 - uses: actions/setup-java@d202f5dbf7256730fb690ec59f6381650114feb2 # pin@v1 + uses: actions/setup-java@v3 with: + distribution: temurin java-version: 8 - name: Publish snapshot env: diff --git a/bin/spark-class b/bin/spark-class index c1461a771228..fc343ca29fdd 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -77,7 +77,8 @@ set +o posix CMD=() DELIM=$'\n' CMD_START_FLAG="false" -while IFS= read -d "$DELIM" -r ARG; do +while IFS= read -d "$DELIM" -r _ARG; do + ARG=${_ARG//$'\r'} if [ "$CMD_START_FLAG" == "true" ]; then CMD+=("$ARG") else diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 68b271d1d05d..800ec0c02c22 100755 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -69,6 +69,8 @@ rem SPARK-28302: %RANDOM% would return the same number if we call it instantly a rem so we should make it sure to generate unique file to avoid process collision of writing into rem the same file concurrently. if exist %LAUNCHER_OUTPUT% goto :gen +rem unset SHELL to indicate non-bash environment to launcher/Main +set SHELL= "%RUNNER%" -Xmx128m -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i diff --git a/build/spark-build-info b/build/spark-build-info index eb0e3d730e23..26157e8cf8cb 100755 --- a/build/spark-build-info +++ b/build/spark-build-info @@ -24,7 +24,7 @@ RESOURCE_DIR="$1" mkdir -p "$RESOURCE_DIR" -SPARK_BUILD_INFO="${RESOURCE_DIR}"/spark-version-info.properties +SPARK_BUILD_INFO="${RESOURCE_DIR%/}"/spark-version-info.properties echo_build_properties() { echo version=$1 diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 390a8b156dc4..b376515bf1af 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -19,8 +19,10 @@ syntax = 'proto3'; package spark.connect; +import "google/protobuf/any.proto"; import "spark/connect/commands.proto"; import "spark/connect/relations.proto"; +import "spark/connect/types.proto"; option java_multiple_files = true; option java_package = "org.apache.spark.connect.proto"; @@ -51,6 +53,12 @@ message Request { message UserContext { string user_id = 1; string user_name = 2; + + // To extend the existing user context message that is used to identify incoming requests, + // Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + // messages into this message. Extensions are stored as a `repeated` type to be able to + // handle multiple active extensions. + repeated google.protobuf.Any extensions = 999; } } @@ -109,11 +117,10 @@ message Response { // reason about the performance. message AnalyzeResponse { string client_id = 1; - repeated string column_names = 2; - repeated string column_types = 3; + DataType schema = 2; // The extended explain string as produced by Spark. - string explain_string = 4; + string explain_string = 3; } // Main interface for the SparkConnect service. diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 7dbde775ee88..94010487ee5b 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -109,6 +109,12 @@ message Join { Relation right = 2; Expression join_condition = 3; JoinType join_type = 4; + // Optional. using_columns provides a list of columns that should present on both sides of + // the join inputs that this Join will join on. For example A JOIN B USING col_name is + // equivalent to A JOIN B on A.col_name = B.col_name. + // + // This field does not co-exist with join_condition. + repeated string using_columns = 5; enum JoinType { JOIN_TYPE_UNSPECIFIED = 0; diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 81c5328c9b29..76d159cfd159 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -26,4 +26,12 @@ private[spark] object Connect { .intConf .createWithDefault(15002) + val CONNECT_GRPC_INTERCEPTOR_CLASSES = + ConfigBuilder("spark.connect.grpc.interceptor.classes") + .doc( + "Comma separated list of class names that must " + + "implement the io.grpc.ServerInterceptor interface.") + .version("3.4.0") + .stringConf + .createOptional } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 4630c86049c5..6ae6dfa1577c 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -236,15 +236,45 @@ package object dsl { .build() def join( + otherPlan: proto.Relation, + joinType: JoinType, + condition: Option[proto.Expression]): proto.Relation = { + join(otherPlan, joinType, Seq(), condition) + } + + def join(otherPlan: proto.Relation, condition: Option[proto.Expression]): proto.Relation = { + join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), condition) + } + + def join(otherPlan: proto.Relation): proto.Relation = { + join(otherPlan, JoinType.JOIN_TYPE_INNER, Seq(), None) + } + + def join(otherPlan: proto.Relation, joinType: JoinType): proto.Relation = { + join(otherPlan, joinType, Seq(), None) + } + + def join( + otherPlan: proto.Relation, + joinType: JoinType, + usingColumns: Seq[String]): proto.Relation = { + join(otherPlan, joinType, usingColumns, None) + } + + private def join( otherPlan: proto.Relation, joinType: JoinType = JoinType.JOIN_TYPE_INNER, - condition: Option[proto.Expression] = None): proto.Relation = { + usingColumns: Seq[String], + condition: Option[proto.Expression]): proto.Relation = { val relation = proto.Relation.newBuilder() val join = proto.Join.newBuilder() join .setLeft(logicalPlan) .setRight(otherPlan) .setJoinType(joinType) + if (usingColumns.nonEmpty) { + join.addAllUsingColumns(usingColumns.asJava) + } if (condition.isDefined) { join.setJoinCondition(condition.get) } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index da3adce43ba9..0ee90b5e8fbb 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._ import org.apache.spark.connect.proto import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType} /** * This object offers methods to convert to/from connect proto to catalyst types. @@ -50,11 +50,28 @@ object DataTypeProtoConverter { proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build() case StringType => proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build() + case LongType => + proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build() + case struct: StructType => + toConnectProtoStructType(struct) case _ => throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.") } } + def toConnectProtoStructType(schema: StructType): proto.DataType = { + val struct = proto.DataType.Struct.newBuilder() + for (structField <- schema.fields) { + struct.addFields( + proto.DataType.StructField + .newBuilder() + .setName(structField.name) + .setType(toConnectProtoType(structField.dataType)) + .setNullable(structField.nullable)) + } + proto.DataType.newBuilder().setStruct(struct).build() + } + def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = { mode match { case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 880618cc3338..9e3899f4a1a0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sample, SubqueryAlias} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.QueryExecution @@ -292,14 +292,23 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { private def transformJoin(rel: proto.Join): LogicalPlan = { assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") + if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) { + throw InvalidPlanInput( + s"Using columns or join conditions cannot be set at the same time in Join") + } val joinCondition = if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None - + val catalystJointype = transformJoinType( + if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER) + val joinType = if (rel.getUsingColumnsCount > 0) { + UsingJoin(catalystJointype, rel.getUsingColumnsList.asScala.toSeq) + } else { + catalystJointype + } logical.Join( left = transformRelation(rel.getLeft), right = transformRelation(rel.getRight), - joinType = transformJoinType( - if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER), + joinType = joinType, condition = joinCondition, hint = logical.JoinHint.NONE) } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala new file mode 100644 index 000000000000..cddd4b976638 --- /dev/null +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterceptorRegistry.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.lang.reflect.InvocationTargetException + +import io.grpc.ServerInterceptor +import io.grpc.netty.NettyServerBuilder + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.util.Utils + +/** + * This object provides a global list of configured interceptors for GRPC. The interceptors are + * added to the GRPC server in order of their position in the list. Once the statically compiled + * interceptors are added, dynamically configured interceptors are added. + */ +object SparkConnectInterceptorRegistry { + + // Contains the list of configured interceptors. + private lazy val interceptorChain: Seq[InterceptorBuilder] = Seq( + // Adding a new interceptor at compile time works like the eaxmple below with the dummy + // interceptor: + // interceptor[DummyInterceptor](classOf[DummyInterceptor]) + ) + + /** + * Given a NettyServerBuilder instance, will chain all interceptors to it in reverse order. + * @param sb + */ + def chainInterceptors(sb: NettyServerBuilder): Unit = { + interceptorChain.foreach(i => sb.intercept(i())) + createConfiguredInterceptors().foreach(sb.intercept(_)) + } + + // Type used to identify the closure responsible to instantiate a ServerInterceptor. + type InterceptorBuilder = () => ServerInterceptor + + /** + * Exposed for testing only. + */ + def createConfiguredInterceptors(): Seq[ServerInterceptor] = { + // Check all values from the Spark conf. + val classes = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES) + if (classes.nonEmpty) { + classes.get + .split(",") + .map(_.trim) + .filter(_.nonEmpty) + .map(Utils.classForName[ServerInterceptor](_)) + .map(createInstance(_)) + } else { + Seq.empty + } + } + + /** + * Creates a new instance of T using the default constructor. + * @param cls + * @tparam T + * @return + */ + private def createInstance[T <: ServerInterceptor](cls: Class[T]): ServerInterceptor = { + val ctorOpt = cls.getConstructors.find(_.getParameterCount == 0) + if (ctorOpt.isEmpty) { + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_CTOR_MISSING", + messageParameters = Map("cls" -> cls.getName), + cause = null) + } + try { + ctorOpt.get.newInstance().asInstanceOf[T] + } catch { + case e: InvocationTargetException => + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", + messageParameters = Map("msg" -> e.getTargetException.getMessage), + cause = e) + case e: Exception => + throw new SparkException( + errorClass = "CONNECT.INTERCEPTOR_RUNTIME_ERROR", + messageParameters = Map("msg" -> e.getMessage), + cause = e) + } + } + + /** + * Creates a callable expression that instantiates the configured GPRC interceptor + * implementation. + */ + private def interceptor[T <: ServerInterceptor](cls: Class[T]): InterceptorBuilder = + () => createInstance(cls) +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 7c494e39a69a..5841017e5bb7 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ - import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder import io.grpc.{Server, Status} @@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse, Request, Response, Spark import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT -import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.ExtendedMode /** @@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean) request: Request, responseObserver: StreamObserver[AnalyzeResponse]): Unit = { try { + if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) { + responseObserver.onError( + new UnsupportedOperationException( + s"${request.getPlan.getOpTypeCase} not supported for analysis.")) + } val session = SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session - - val logicalPlan = request.getPlan.getOpTypeCase match { - case proto.Plan.OpTypeCase.ROOT => - new SparkConnectPlanner(request.getPlan.getRoot, session).transform() - case _ => - responseObserver.onError( - new UnsupportedOperationException( - s"${request.getPlan.getOpTypeCase} not supported for analysis.")) - return - } - val ds = Dataset.ofRows(session, logicalPlan) - val explainString = ds.queryExecution.explainString(ExtendedMode) - - val resp = proto.AnalyzeResponse - .newBuilder() - .setExplainString(explainString) - .setClientId(request.getClientId) - - resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava) - resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava) - responseObserver.onNext(resp.build()) + val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session) + response.setClientId(request.getClientId) + responseObserver.onNext(response.build()) responseObserver.onCompleted() } catch { case e: Throwable => @@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean) Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) } } + + def handleAnalyzePlanRequest( + relation: proto.Relation, + session: SparkSession): proto.AnalyzeResponse.Builder = { + val logicalPlan = new SparkConnectPlanner(relation, session).transform() + + val ds = Dataset.ofRows(session, logicalPlan) + val explainString = ds.queryExecution.explainString(ExtendedMode) + + val response = proto.AnalyzeResponse + .newBuilder() + .setExplainString(explainString) + response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema)) + } } /** @@ -192,6 +191,9 @@ object SparkConnectService { .forPort(port) .addService(new SparkConnectService(debugMode)) + // Add all registered interceptors to the server builder. + SparkConnectInterceptorRegistry.chainInterceptors(sb) + // If debug mode is configured, load the ProtoReflection service so that tools like // grpcurl can introspect the API for debugging. if (debugMode) { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala new file mode 100644 index 000000000000..4132cca91086 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.messages + +import org.apache.spark.SparkFunSuite +import org.apache.spark.connect.proto + +class ConnectProtoMessagesSuite extends SparkFunSuite { + test("UserContext can deal with extensions") { + // Create the builder. + val builder = proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin") + + // Create the extension value. + val lit = proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(32).build()) + // Pack the extension into Any. + val aval = com.google.protobuf.Any.pack(lit.build()) + // Add Any to the repeated field list. + builder.addExtensions(aval) + // Create serialized value. + val serialized = builder.build().toByteArray + + // Now, read the serialized value. + val result = proto.Request.UserContext.parseFrom(serialized) + assert(result.getUserId.equals("1")) + assert(result.getUserName.equals("Martin")) + assert(result.getExtensionsCount == 1) + + val ext = result.getExtensions(0) + assert(ext.is(classOf[proto.Expression])) + val extLit = ext.unpack(classOf[proto.Expression]) + assert(extLit.hasLiteral) + assert(extLit.getLiteral.hasI32) + assert(extLit.getLiteral.getI32 == 32) + } +} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 980e899c26ed..6fc47e07c598 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -220,6 +220,20 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(res.nodeName == "Join") assert(res != null) + val e = intercept[InvalidPlanInput] { + val simpleJoin = proto.Relation.newBuilder + .setJoin( + proto.Join.newBuilder + .setLeft(readRel) + .setRight(readRel) + .addUsingColumns("test_col") + .setJoinCondition(joinCondition)) + .build() + transform(simpleJoin) + } + assert( + e.getMessage.contains( + "Using columns or join conditions cannot be set at the same time in Join")) } test("Simple Projection") { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index d8bb1684cb84..0325b6573bd3 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -20,7 +20,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation /** @@ -32,11 +32,13 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto(Seq($"id".int, $"name".string)) - lazy val connectTestRelation2 = createLocalRelationProto(Seq($"key".int, $"value".int)) + lazy val connectTestRelation2 = createLocalRelationProto( + Seq($"key".int, $"value".int, $"name".string)) lazy val sparkTestRelation: LocalRelation = LocalRelation($"id".int, $"name".string) - lazy val sparkTestRelation2: LocalRelation = LocalRelation($"key".int, $"value".int) + lazy val sparkTestRelation2: LocalRelation = + LocalRelation($"key".int, $"value".int, $"name".string) test("Basic select") { val connectPlan = { @@ -117,6 +119,14 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { val sparkPlan3 = sparkTestRelation.join(sparkTestRelation2, y) comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) } + + val connectPlan4 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform( + connectTestRelation.join(connectTestRelation2, JoinType.JOIN_TYPE_INNER, Seq("name"))) + } + val sparkPlan4 = sparkTestRelation.join(sparkTestRelation2, UsingJoin(Inner, Seq("name"))) + comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) } test("Test sample") { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala new file mode 100644 index 000000000000..4be8d1705b9e --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.planner + +import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Testing Connect Service implementation. + */ +class SparkConnectServiceSuite extends SharedSparkSession { + + test("Test schema in analyze response") { + withTable("test") { + spark.sql(""" + | CREATE TABLE test (col1 INT, col2 STRING) + | USING parquet + |""".stripMargin) + + val instance = new SparkConnectService(false) + val relation = proto.Relation + .newBuilder() + .setRead( + proto.Read + .newBuilder() + .setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build()) + .build()) + .build() + + val response = instance.handleAnalyzePlanRequest(relation, spark) + + assert(response.getSchema.hasStruct) + val schema = response.getSchema.getStruct + assert(schema.getFieldsCount == 2) + assert( + schema.getFields(0).getName == "col1" + && schema.getFields(0).getType.getKindCase == proto.DataType.KindCase.I32) + assert( + schema.getFields(1).getName == "col2" + && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING) + } + } +} diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala new file mode 100644 index 000000000000..bac02ec7af69 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/service/InterceptorRegistrySuite.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.{Metadata, ServerCall, ServerCallHandler, ServerInterceptor} +import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener +import io.grpc.netty.NettyServerBuilder + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Used for testing only, does not do anything. + */ +class DummyInterceptor extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +/** + * Used for testing only. + */ +class TestingInterceptorNoTrivialCtor(id: Int) extends ServerInterceptor { + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +/** + * Used for testing only. + */ +class TestingInterceptorInstantiationError extends ServerInterceptor { + throw new ArrayIndexOutOfBoundsException("Bad Error") + + override def interceptCall[ReqT, RespT]( + call: ServerCall[ReqT, RespT], + headers: Metadata, + next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] = { + val listener = next.startCall(call, headers) + new SimpleForwardingServerCallListener[ReqT](listener) { + override def onMessage(message: ReqT): Unit = { + delegate().onMessage(message) + } + } + } +} + +class InterceptorRegistrySuite extends SharedSparkSession { + + override def beforeEach(): Unit = { + if (SparkEnv.get.conf.contains(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES)) { + SparkEnv.get.conf.remove(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES) + } + } + + def withSparkConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SparkEnv.get.conf + pairs.foreach { kv => conf.set(kv._1, kv._2) } + try f + finally { + pairs.foreach { kv => conf.remove(kv._1) } + } + } + + test("Check that the empty registry works") { + val sb = NettyServerBuilder.forPort(9999) + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + + test("Test server builder and configured interceptor") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.DummyInterceptor") { + val sb = NettyServerBuilder.forPort(9999) + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + } + + test("Test server build throws when using bad configured interceptor") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor") { + val sb = NettyServerBuilder.forPort(9999) + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.chainInterceptors(sb) + } + } + } + + test("Exception handling for interceptor classes") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorNoTrivialCtor") { + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + "org.apache.spark.sql.connect.service.TestingInterceptorInstantiationError") { + assertThrows[SparkException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + } + + test("No configured interceptors returns empty list") { + // Not set. + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.isEmpty) + // Set to empty string + withSparkConf(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> "") { + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.isEmpty) + } + } + + test("Configured classes can have multiple entries") { + withSparkConf( + Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> + (" org.apache.spark.sql.connect.service.DummyInterceptor," + + " org.apache.spark.sql.connect.service.DummyInterceptor ")) { + assert(SparkConnectInterceptorRegistry.createConfiguredInterceptors.size == 2) + } + } + + test("Configured class not found is properly thrown") { + withSparkConf(Connect.CONNECT_GRPC_INTERCEPTOR_CLASSES.key -> "this.class.does.not.exist") { + assertThrows[ClassNotFoundException] { + SparkConnectInterceptorRegistry.createConfiguredInterceptors + } + } + } + +} diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 0f9b665718ca..6f5b3b5a1347 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -76,6 +76,23 @@ "Another instance of this query was just started by a concurrent session." ] }, + "CONNECT" : { + "message" : [ + "Generic Spark Connect error." + ], + "subClass" : { + "INTERCEPTOR_CTOR_MISSING" : { + "message" : [ + "Cannot instantiate GRPC interceptor because is missing a default constructor without arguments." + ] + }, + "INTERCEPTOR_RUNTIME_ERROR" : { + "message" : [ + "Error instantiating GRPC interceptor: " + ] + } + } + }, "CONVERSION_INVALID_INPUT" : { "message" : [ "The value () cannot be converted to because it is malformed. Correct the value as per the syntax, or change its format. Use to tolerate malformed input and return NULL instead." @@ -143,6 +160,11 @@ "Offset expression must be a literal." ] }, + "HASH_MAP_TYPE" : { + "message" : [ + "Input to the function cannot contain elements of the \"MAP\" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on \"MAP\" elements. To restore previous behavior set \"spark.sql.legacy.allowHashOnMapType\" to \"true\"." + ] + }, "INVALID_JSON_MAP_KEY_TYPE" : { "message" : [ "Input schema can only contain STRING as a key type for a MAP." @@ -4291,4 +4313,4 @@ "Not enough memory to build and broadcast the table to all worker nodes. As a workaround, you can either disable broadcast by setting to -1 or increase the spark driver memory by setting to a higher value" ] } -} +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index d45dc937910d..99b4e894bf0a 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -162,9 +162,9 @@ private[spark] object ThreadUtils { /** * Wrapper over newSingleThreadExecutor. */ - def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadExecutor(threadFactory) + Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor] } /** diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index 1d1061aaadbd..2422b003d6d1 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -113,7 +113,7 @@ ivy/2.5.0//ivy-2.5.0.jar jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.13.4//jackson-core-2.13.4.jar -jackson-databind/2.13.4.1//jackson-databind-2.13.4.1.jar +jackson-databind/2.13.4.2//jackson-databind-2.13.4.2.jar jackson-dataformat-cbor/2.13.4//jackson-dataformat-cbor-2.13.4.jar jackson-dataformat-yaml/2.13.4//jackson-dataformat-yaml-2.13.4.jar jackson-datatype-jsr310/2.13.4//jackson-datatype-jsr310-2.13.4.jar @@ -160,30 +160,30 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.3//jul-to-slf4j-2.0.3.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.1.1//kubernetes-client-api-6.1.1.jar -kubernetes-client/6.1.1//kubernetes-client-6.1.1.jar -kubernetes-httpclient-okhttp/6.1.1//kubernetes-httpclient-okhttp-6.1.1.jar -kubernetes-model-admissionregistration/6.1.1//kubernetes-model-admissionregistration-6.1.1.jar -kubernetes-model-apiextensions/6.1.1//kubernetes-model-apiextensions-6.1.1.jar -kubernetes-model-apps/6.1.1//kubernetes-model-apps-6.1.1.jar -kubernetes-model-autoscaling/6.1.1//kubernetes-model-autoscaling-6.1.1.jar -kubernetes-model-batch/6.1.1//kubernetes-model-batch-6.1.1.jar -kubernetes-model-certificates/6.1.1//kubernetes-model-certificates-6.1.1.jar -kubernetes-model-common/6.1.1//kubernetes-model-common-6.1.1.jar -kubernetes-model-coordination/6.1.1//kubernetes-model-coordination-6.1.1.jar -kubernetes-model-core/6.1.1//kubernetes-model-core-6.1.1.jar -kubernetes-model-discovery/6.1.1//kubernetes-model-discovery-6.1.1.jar -kubernetes-model-events/6.1.1//kubernetes-model-events-6.1.1.jar -kubernetes-model-extensions/6.1.1//kubernetes-model-extensions-6.1.1.jar -kubernetes-model-flowcontrol/6.1.1//kubernetes-model-flowcontrol-6.1.1.jar -kubernetes-model-gatewayapi/6.1.1//kubernetes-model-gatewayapi-6.1.1.jar -kubernetes-model-metrics/6.1.1//kubernetes-model-metrics-6.1.1.jar -kubernetes-model-networking/6.1.1//kubernetes-model-networking-6.1.1.jar -kubernetes-model-node/6.1.1//kubernetes-model-node-6.1.1.jar -kubernetes-model-policy/6.1.1//kubernetes-model-policy-6.1.1.jar -kubernetes-model-rbac/6.1.1//kubernetes-model-rbac-6.1.1.jar -kubernetes-model-scheduling/6.1.1//kubernetes-model-scheduling-6.1.1.jar -kubernetes-model-storageclass/6.1.1//kubernetes-model-storageclass-6.1.1.jar +kubernetes-client-api/6.2.0//kubernetes-client-api-6.2.0.jar +kubernetes-client/6.2.0//kubernetes-client-6.2.0.jar +kubernetes-httpclient-okhttp/6.2.0//kubernetes-httpclient-okhttp-6.2.0.jar +kubernetes-model-admissionregistration/6.2.0//kubernetes-model-admissionregistration-6.2.0.jar +kubernetes-model-apiextensions/6.2.0//kubernetes-model-apiextensions-6.2.0.jar +kubernetes-model-apps/6.2.0//kubernetes-model-apps-6.2.0.jar +kubernetes-model-autoscaling/6.2.0//kubernetes-model-autoscaling-6.2.0.jar +kubernetes-model-batch/6.2.0//kubernetes-model-batch-6.2.0.jar +kubernetes-model-certificates/6.2.0//kubernetes-model-certificates-6.2.0.jar +kubernetes-model-common/6.2.0//kubernetes-model-common-6.2.0.jar +kubernetes-model-coordination/6.2.0//kubernetes-model-coordination-6.2.0.jar +kubernetes-model-core/6.2.0//kubernetes-model-core-6.2.0.jar +kubernetes-model-discovery/6.2.0//kubernetes-model-discovery-6.2.0.jar +kubernetes-model-events/6.2.0//kubernetes-model-events-6.2.0.jar +kubernetes-model-extensions/6.2.0//kubernetes-model-extensions-6.2.0.jar +kubernetes-model-flowcontrol/6.2.0//kubernetes-model-flowcontrol-6.2.0.jar +kubernetes-model-gatewayapi/6.2.0//kubernetes-model-gatewayapi-6.2.0.jar +kubernetes-model-metrics/6.2.0//kubernetes-model-metrics-6.2.0.jar +kubernetes-model-networking/6.2.0//kubernetes-model-networking-6.2.0.jar +kubernetes-model-node/6.2.0//kubernetes-model-node-6.2.0.jar +kubernetes-model-policy/6.2.0//kubernetes-model-policy-6.2.0.jar +kubernetes-model-rbac/6.2.0//kubernetes-model-rbac-6.2.0.jar +kubernetes-model-scheduling/6.2.0//kubernetes-model-scheduling-6.2.0.jar +kubernetes-model-storageclass/6.2.0//kubernetes-model-storageclass-6.2.0.jar lapack/3.0.2//lapack-3.0.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar @@ -240,7 +240,7 @@ pickle/1.2//pickle-1.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar -rocksdbjni/7.6.0//rocksdbjni-7.6.0.jar +rocksdbjni/7.7.3//rocksdbjni-7.7.3.jar scala-collection-compat_2.12/2.7.0//scala-collection-compat_2.12-2.7.0.jar scala-compiler/2.12.17//scala-compiler-2.12.17.jar scala-library/2.12.17//scala-library-2.12.17.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 39a0e6170586..ecaf4293f247 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -101,7 +101,7 @@ ivy/2.5.0//ivy-2.5.0.jar jackson-annotations/2.13.4//jackson-annotations-2.13.4.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar jackson-core/2.13.4//jackson-core-2.13.4.jar -jackson-databind/2.13.4.1//jackson-databind-2.13.4.1.jar +jackson-databind/2.13.4.2//jackson-databind-2.13.4.2.jar jackson-dataformat-cbor/2.13.4//jackson-dataformat-cbor-2.13.4.jar jackson-dataformat-yaml/2.13.4//jackson-dataformat-yaml-2.13.4.jar jackson-datatype-jsr310/2.13.4//jackson-datatype-jsr310-2.13.4.jar @@ -144,30 +144,30 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.3//jul-to-slf4j-2.0.3.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.1.1//kubernetes-client-api-6.1.1.jar -kubernetes-client/6.1.1//kubernetes-client-6.1.1.jar -kubernetes-httpclient-okhttp/6.1.1//kubernetes-httpclient-okhttp-6.1.1.jar -kubernetes-model-admissionregistration/6.1.1//kubernetes-model-admissionregistration-6.1.1.jar -kubernetes-model-apiextensions/6.1.1//kubernetes-model-apiextensions-6.1.1.jar -kubernetes-model-apps/6.1.1//kubernetes-model-apps-6.1.1.jar -kubernetes-model-autoscaling/6.1.1//kubernetes-model-autoscaling-6.1.1.jar -kubernetes-model-batch/6.1.1//kubernetes-model-batch-6.1.1.jar -kubernetes-model-certificates/6.1.1//kubernetes-model-certificates-6.1.1.jar -kubernetes-model-common/6.1.1//kubernetes-model-common-6.1.1.jar -kubernetes-model-coordination/6.1.1//kubernetes-model-coordination-6.1.1.jar -kubernetes-model-core/6.1.1//kubernetes-model-core-6.1.1.jar -kubernetes-model-discovery/6.1.1//kubernetes-model-discovery-6.1.1.jar -kubernetes-model-events/6.1.1//kubernetes-model-events-6.1.1.jar -kubernetes-model-extensions/6.1.1//kubernetes-model-extensions-6.1.1.jar -kubernetes-model-flowcontrol/6.1.1//kubernetes-model-flowcontrol-6.1.1.jar -kubernetes-model-gatewayapi/6.1.1//kubernetes-model-gatewayapi-6.1.1.jar -kubernetes-model-metrics/6.1.1//kubernetes-model-metrics-6.1.1.jar -kubernetes-model-networking/6.1.1//kubernetes-model-networking-6.1.1.jar -kubernetes-model-node/6.1.1//kubernetes-model-node-6.1.1.jar -kubernetes-model-policy/6.1.1//kubernetes-model-policy-6.1.1.jar -kubernetes-model-rbac/6.1.1//kubernetes-model-rbac-6.1.1.jar -kubernetes-model-scheduling/6.1.1//kubernetes-model-scheduling-6.1.1.jar -kubernetes-model-storageclass/6.1.1//kubernetes-model-storageclass-6.1.1.jar +kubernetes-client-api/6.2.0//kubernetes-client-api-6.2.0.jar +kubernetes-client/6.2.0//kubernetes-client-6.2.0.jar +kubernetes-httpclient-okhttp/6.2.0//kubernetes-httpclient-okhttp-6.2.0.jar +kubernetes-model-admissionregistration/6.2.0//kubernetes-model-admissionregistration-6.2.0.jar +kubernetes-model-apiextensions/6.2.0//kubernetes-model-apiextensions-6.2.0.jar +kubernetes-model-apps/6.2.0//kubernetes-model-apps-6.2.0.jar +kubernetes-model-autoscaling/6.2.0//kubernetes-model-autoscaling-6.2.0.jar +kubernetes-model-batch/6.2.0//kubernetes-model-batch-6.2.0.jar +kubernetes-model-certificates/6.2.0//kubernetes-model-certificates-6.2.0.jar +kubernetes-model-common/6.2.0//kubernetes-model-common-6.2.0.jar +kubernetes-model-coordination/6.2.0//kubernetes-model-coordination-6.2.0.jar +kubernetes-model-core/6.2.0//kubernetes-model-core-6.2.0.jar +kubernetes-model-discovery/6.2.0//kubernetes-model-discovery-6.2.0.jar +kubernetes-model-events/6.2.0//kubernetes-model-events-6.2.0.jar +kubernetes-model-extensions/6.2.0//kubernetes-model-extensions-6.2.0.jar +kubernetes-model-flowcontrol/6.2.0//kubernetes-model-flowcontrol-6.2.0.jar +kubernetes-model-gatewayapi/6.2.0//kubernetes-model-gatewayapi-6.2.0.jar +kubernetes-model-metrics/6.2.0//kubernetes-model-metrics-6.2.0.jar +kubernetes-model-networking/6.2.0//kubernetes-model-networking-6.2.0.jar +kubernetes-model-node/6.2.0//kubernetes-model-node-6.2.0.jar +kubernetes-model-policy/6.2.0//kubernetes-model-policy-6.2.0.jar +kubernetes-model-rbac/6.2.0//kubernetes-model-rbac-6.2.0.jar +kubernetes-model-scheduling/6.2.0//kubernetes-model-scheduling-6.2.0.jar +kubernetes-model-storageclass/6.2.0//kubernetes-model-storageclass-6.2.0.jar lapack/3.0.2//lapack-3.0.2.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar @@ -227,7 +227,7 @@ pickle/1.2//pickle-1.2.jar protobuf-java/2.5.0//protobuf-java-2.5.0.jar py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar -rocksdbjni/7.6.0//rocksdbjni-7.6.0.jar +rocksdbjni/7.7.3//rocksdbjni-7.7.3.jar scala-collection-compat_2.12/2.7.0//scala-collection-compat_2.12-2.7.0.jar scala-compiler/2.12.17//scala-compiler-2.12.17.jar scala-library/2.12.17//scala-library-2.12.17.jar diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2a427139148a..a439b4cbbed0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -484,6 +484,7 @@ def __hash__(self): "pyspark.sql.tests.pandas.test_pandas_udf_typehints", "pyspark.sql.tests.pandas.test_pandas_udf_typehints_with_future_annotations", "pyspark.sql.tests.pandas.test_pandas_udf_window", + "pyspark.sql.tests.test_pandas_sqlmetrics", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 18cc579e4f9e..aaad2a328091 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -34,6 +34,7 @@ license: | - Valid hexadecimal strings should include only allowed symbols (0-9A-Fa-f). - Valid values for `fmt` are case-insensitive `hex`, `base64`, `utf-8`, `utf8`. - Since Spark 3.4, Spark throws only `PartitionsAlreadyExistException` when it creates partitions but some of them exist already. In Spark 3.3 or earlier, Spark can throw either `PartitionsAlreadyExistException` or `PartitionAlreadyExistsException`. + - Since Spark 3.4, Spark will do validation for partition spec in ALTER PARTITION to follow the behavior of `spark.sql.storeAssignmentPolicy` which may cause an exception if type conversion fails, e.g. `ALTER TABLE .. ADD PARTITION(p='a')` if column `p` is int type. To restore the legacy behavior, set `spark.sql.legacy.skipPartitionSpecTypeValidation` to `true`. ## Upgrading from Spark SQL 3.2 to 3.3 diff --git a/docs/web-ui.md b/docs/web-ui.md index d3356ec5a43f..e228d7fe2a98 100644 --- a/docs/web-ui.md +++ b/docs/web-ui.md @@ -406,6 +406,8 @@ Here is the list of SQL metrics: time to build hash map the time spent on building hash map ShuffledHashJoin task commit time the time spent on committing the output of a task after the writes succeed any write operation on a file-based table job commit time the time spent on committing the output of a job after the writes succeed any write operation on a file-based table + data sent to Python workers the number of bytes of serialized data sent to the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas + data returned from Python workers the number of bytes of serialized data received back from the Python workers ArrowEvalPython, AggregateInPandas, BatchEvalPython, FlatMapGroupsInPandas, FlatMapsCoGroupsInPandas, FlatMapsCoGroupsInPandasWithState, MapInPandas, PythonMapInArrow, WindowsInPandas ## Structured Streaming Tab diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index e1054c7060f1..6501fc1764c2 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -87,7 +87,9 @@ public static void main(String[] argsArray) throws Exception { cmd = buildCommand(builder, env, printLaunchCommand); } - if (isWindows()) { + // test for shell environments, to enable non-Windows treatment of command line prep + boolean shellflag = !isEmpty(System.getenv("SHELL")); + if (isWindows() && !shellflag) { System.out.println(prepareWindowsCommand(cmd, env)); } else { // A sequence of NULL character and newline separates command-strings and others. @@ -96,7 +98,7 @@ public static void main(String[] argsArray) throws Exception { // In bash, use NULL as the arg separator since it cannot be used in an argument. List bashCmd = prepareBashCommand(cmd, env); for (String c : bashCmd) { - System.out.print(c); + System.out.print(c.replaceFirst("\r$","")); System.out.print('\0'); } } diff --git a/pom.xml b/pom.xml index d933c1c6f6dd..707aed043796 100644 --- a/pom.xml +++ b/pom.xml @@ -176,7 +176,7 @@ true 1.9.13 2.13.4 - 2.13.4.1 + 2.13.4.2 1.1.8.4 3.0.2 1.15 @@ -219,7 +219,7 @@ 9.0.0 org.fusesource.leveldbjni - 6.1.1 + 6.2.0 ${java.home} @@ -682,7 +682,7 @@ org.rocksdb rocksdbjni - 7.6.0 + 7.7.3 ${leveldbjni.group} @@ -990,18 +990,10 @@ jackson-datatype-jsr310 ${fasterxml.jackson.version} - com.fasterxml.jackson.module jackson-module-scala_${scala.binary.version} ${fasterxml.jackson.version} - - - com.google.guava - guava - - com.fasterxml.jackson.module diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index cc103e4ab00a..33883a2efaa5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -599,11 +599,18 @@ object SparkParallelTestGrouping { object Core { import scala.sys.process.Process + def buildenv = Process(Seq("uname")).!!.trim.replaceFirst("[^A-Za-z0-9].*", "").toLowerCase + def bashpath = Process(Seq("where", "bash")).!!.split("[\r\n]+").head.replace('\\', '/') lazy val settings = Seq( (Compile / resourceGenerators) += Def.task { val buildScript = baseDirectory.value + "/../build/spark-build-info" val targetDir = baseDirectory.value + "/target/extra-resources/" - val command = Seq("bash", buildScript, targetDir, version.value) + // support Windows build under cygwin/mingw64, etc + val bash = buildenv match { + case "cygwin" | "msys2" | "mingw64" | "clang64" => bashpath + case _ => "bash" + } + val command = Seq(bash, buildScript, targetDir, version.value) Process(command).!! val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties" Seq(propsFile) diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index c3afae10ddb6..ec3df0716392 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -262,10 +262,12 @@ Management StorageLevel.DISK_ONLY_3 StorageLevel.MEMORY_AND_DISK StorageLevel.MEMORY_AND_DISK_2 + StorageLevel.MEMORY_AND_DISK_DESER StorageLevel.MEMORY_ONLY StorageLevel.MEMORY_ONLY_2 StorageLevel.OFF_HEAP TaskContext.attemptNumber + TaskContext.cpus TaskContext.get TaskContext.getLocalProperty TaskContext.partitionId @@ -277,6 +279,7 @@ Management BarrierTaskContext.allGather BarrierTaskContext.attemptNumber BarrierTaskContext.barrier + BarrierTaskContext.cpus BarrierTaskContext.get BarrierTaskContext.getLocalProperty BarrierTaskContext.getTaskInfos diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 5a64845598ea..37ddbaf1673d 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -142,6 +142,7 @@ Datetime Functions window session_window timestamp_seconds + window_time Collection Functions diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 166c18ba4e9d..b5466b467d84 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -6044,6 +6044,19 @@ def test_mode(self): with self.assertRaises(ValueError): psdf.mode(axis=2) + def f(index, iterator): + return ["3", "3", "3", "3", "4"] if index == 3 else ["0", "1", "2", "3", "4"] + + rdd = self.spark.sparkContext.parallelize( + [ + 1, + ], + 4, + ).mapPartitionsWithIndex(f) + df = self.spark.createDataFrame(rdd, schema="string") + psdf = df.pandas_api() + self.assert_eq(psdf.mode(), psdf._to_pandas().mode()) + def test_abs(self): pdf = pd.DataFrame({"a": [-2, -1, 0, 1]}) psdf = ps.from_pandas(pdf) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 0ae075521c63..f4b6d2ec302d 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -33,6 +33,7 @@ from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.plan import SQL +from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType from typing import Optional, Any, Union @@ -91,14 +92,13 @@ def metrics(self) -> typing.List[MetricValue]: class AnalyzeResult: - def __init__(self, cols: typing.List[str], types: typing.List[str], explain: str): - self.cols = cols - self.types = types + def __init__(self, schema: pb2.DataType, explain: str): + self.schema = schema self.explain_string = explain @classmethod def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": - return AnalyzeResult(pb.column_names, pb.column_types, pb.explain_string) + return AnalyzeResult(pb.schema, pb.explain_string) class RemoteSparkSession(object): @@ -151,7 +151,44 @@ def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: req.plan.CopyFrom(plan) return self._execute_and_fetch(req) - def analyze(self, plan: pb2.Plan) -> AnalyzeResult: + def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: + if schema.HasField("struct"): + structFields = [] + for proto_field in schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + elif schema.HasField("i64"): + return LongType() + elif schema.HasField("string"): + return StringType() + else: + raise Exception("Only support long, string, struct conversion") + + def schema(self, plan: pb2.Plan) -> StructType: + proto_schema = self._analyze(plan).schema + # Server side should populate the struct field which is the schema. + assert proto_schema.HasField("struct") + structFields = [] + for proto_field in proto_schema.struct.fields: + structFields.append( + StructField( + proto_field.name, + self._proto_schema_to_pyspark_schema(proto_field.type), + proto_field.nullable, + ) + ) + return StructType(structFields) + + def explain_string(self, plan: pb2.Plan) -> str: + return self._analyze(plan).explain_string + + def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: req = pb2.Request() req.user_context.user_id = self._user_id req.plan.CopyFrom(plan) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index eabcf433ae9b..bf9ed83615b6 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -34,6 +34,7 @@ Expression, LiteralExpression, ) +from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString @@ -96,7 +97,7 @@ class DataFrame(object): of the DataFrame with the changes applied. """ - def __init__(self, data: Optional[List[Any]] = None, schema: Optional[List[str]] = None): + def __init__(self, data: Optional[List[Any]] = None, schema: Optional[StructType] = None): """Creates a new data frame""" self._schema = schema self._plan: Optional[plan.LogicalPlan] = None @@ -157,11 +158,44 @@ def coalesce(self, num_partitions: int) -> "DataFrame": def describe(self, cols: List[ColumnRef]) -> Any: ... + def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": + """Return a new :class:`DataFrame` with duplicate rows removed, + optionally only deduplicating based on certain columns. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + subset : List of column names, optional + List of columns to use for duplicate comparison (default All columns). + + Returns + ------- + :class:`DataFrame` + DataFrame without duplicated rows. + """ + if subset is None: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) + else: + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, column_names=subset), session=self._session + ) + def distinct(self) -> "DataFrame": - """Returns all distinct rows.""" - all_cols = self.columns - gf = self.groupBy(*all_cols) - return gf.agg() + """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`DataFrame` + DataFrame with distinct rows. + """ + return DataFrame.withPlan( + plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session + ) def drop(self, *cols: "ColumnOrString") -> "DataFrame": all_cols = self.columns @@ -282,11 +316,32 @@ def toPandas(self) -> Optional["pandas.DataFrame"]: query = self._plan.to_proto(self._session) return self._session._to_pandas(query) + def schema(self) -> StructType: + """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`StructType` + """ + if self._schema is None: + if self._plan is not None: + query = self._plan.to_proto(self._session) + if self._session is None: + raise Exception("Cannot analyze without RemoteSparkSession.") + self._schema = self._session.schema(query) + return self._schema + else: + raise Exception("Empty plan.") + else: + return self._schema + def explain(self) -> str: if self._plan is not None: query = self._plan.to_proto(self._session) if self._session is None: raise Exception("Cannot analyze without RemoteSparkSession.") - return self._session.analyze(query).explain_string + return self._session.explain_string(query) else: return "" diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 297b15994d3b..d6b6f9e3b67d 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -327,6 +327,45 @@ def _repr_html_(self) -> str: """ +class Deduplicate(LogicalPlan): + def __init__( + self, + child: Optional["LogicalPlan"], + all_columns_as_keys: bool = False, + column_names: Optional[List[str]] = None, + ) -> None: + super().__init__(child) + self.all_columns_as_keys = all_columns_as_keys + self.column_names = column_names + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys + if self.column_names is not None: + plan.deduplicate.column_names.extend(self.column_names) + return plan + + def print(self, indent: int = 0) -> str: + c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else "" + return ( + f"{' ' * indent}\n{c_buf}" + ) + + def _repr_html_(self) -> str: + return f""" +
    +
  • + Deduplicate
    + all_columns_as_keys: {self.all_columns_as_keys}
    + column_names: {self.column_names}
    + {self._child_repr_()} +
  • +
+ """ + + class Sort(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], *columns: Union[SortOrder, ColumnRef, str] diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 8de6565bae86..eb9ecc9157f2 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -28,12 +28,14 @@ _sym_db = _symbol_database.Default() +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 +from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x43\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,28 +46,28 @@ DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" - _PLAN._serialized_start = 104 - _PLAN._serialized_end = 220 - _REQUEST._serialized_start = 223 - _REQUEST._serialized_end = 442 - _REQUEST_USERCONTEXT._serialized_start = 375 - _REQUEST_USERCONTEXT._serialized_end = 442 - _RESPONSE._serialized_start = 445 - _RESPONSE._serialized_end = 1413 - _RESPONSE_ARROWBATCH._serialized_start = 674 - _RESPONSE_ARROWBATCH._serialized_end = 849 - _RESPONSE_JSONBATCH._serialized_start = 851 - _RESPONSE_JSONBATCH._serialized_end = 911 - _RESPONSE_METRICS._serialized_start = 914 - _RESPONSE_METRICS._serialized_end = 1398 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 998 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1308 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1196 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1308 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1310 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1398 - _ANALYZERESPONSE._serialized_start = 1416 - _ANALYZERESPONSE._serialized_end = 1571 - _SPARKCONNECTSERVICE._serialized_start = 1574 - _SPARKCONNECTSERVICE._serialized_end = 1736 + _PLAN._serialized_start = 158 + _PLAN._serialized_end = 274 + _REQUEST._serialized_start = 277 + _REQUEST._serialized_end = 551 + _REQUEST_USERCONTEXT._serialized_start = 429 + _REQUEST_USERCONTEXT._serialized_end = 551 + _RESPONSE._serialized_start = 554 + _RESPONSE._serialized_end = 1522 + _RESPONSE_ARROWBATCH._serialized_start = 783 + _RESPONSE_ARROWBATCH._serialized_end = 958 + _RESPONSE_JSONBATCH._serialized_start = 960 + _RESPONSE_JSONBATCH._serialized_end = 1020 + _RESPONSE_METRICS._serialized_start = 1023 + _RESPONSE_METRICS._serialized_end = 1507 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1305 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507 + _ANALYZERESPONSE._serialized_start = 1525 + _ANALYZERESPONSE._serialized_end = 1659 + _SPARKCONNECTSERVICE._serialized_start = 1662 + _SPARKCONNECTSERVICE._serialized_end = 1824 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index bcac7d11a808..5ffd7701b440 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -35,11 +35,13 @@ limitations under the License. """ import builtins import collections.abc +import google.protobuf.any_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message import pyspark.sql.connect.proto.commands_pb2 import pyspark.sql.connect.proto.relations_pb2 +import pyspark.sql.connect.proto.types_pb2 import sys if sys.version_info >= (3, 8): @@ -102,17 +104,32 @@ class Request(google.protobuf.message.Message): USER_ID_FIELD_NUMBER: builtins.int USER_NAME_FIELD_NUMBER: builtins.int + EXTENSIONS_FIELD_NUMBER: builtins.int user_id: builtins.str user_name: builtins.str + @property + def extensions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + google.protobuf.any_pb2.Any + ]: + """To extend the existing user context message that is used to identify incoming requests, + Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + messages into this message. Extensions are stored as a `repeated` type to be able to + handle multiple active extensions. + """ def __init__( self, *, user_id: builtins.str = ..., user_name: builtins.str = ..., + extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., ) -> None: ... def ClearField( self, - field_name: typing_extensions.Literal["user_id", b"user_id", "user_name", b"user_name"], + field_name: typing_extensions.Literal[ + "extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name" + ], ) -> None: ... CLIENT_ID_FIELD_NUMBER: builtins.int @@ -385,39 +402,27 @@ class AnalyzeResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor CLIENT_ID_FIELD_NUMBER: builtins.int - COLUMN_NAMES_FIELD_NUMBER: builtins.int - COLUMN_TYPES_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int EXPLAIN_STRING_FIELD_NUMBER: builtins.int client_id: builtins.str @property - def column_names( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... - @property - def column_types( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... explain_string: builtins.str """The extended explain string as produced by Spark.""" def __init__( self, *, client_id: builtins.str = ..., - column_names: collections.abc.Iterable[builtins.str] | None = ..., - column_types: collections.abc.Iterable[builtins.str] | None = ..., + schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., explain_string: builtins.str = ..., ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["schema", b"schema"] + ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "client_id", - b"client_id", - "column_names", - b"column_names", - "column_types", - b"column_types", - "explain_string", - b"explain_string", + "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema" ], ) -> None: ... diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index d9a596fba8c6..2a38a014926e 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xf0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12.\n\x04seed\x18\x05 \x01(\x0b\x32\x1a.spark.connect.Sample.SeedR\x04seed\x1a\x1a\n\x04Seed\x12\x12\n\x04seed\x18\x01 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xf0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12.\n\x04seed\x18\x05 \x01(\x0b\x32\x1a.spark.connect.Sample.SeedR\x04seed\x1a\x1a\n\x04Seed\x12\x12\n\x04seed\x18\x01 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -64,35 +64,35 @@ _FILTER._serialized_start = 1512 _FILTER._serialized_end = 1624 _JOIN._serialized_start = 1627 - _JOIN._serialized_end = 2040 - _JOIN_JOINTYPE._serialized_start = 1853 - _JOIN_JOINTYPE._serialized_end = 2040 - _UNION._serialized_start = 2043 - _UNION._serialized_end = 2248 - _UNION_UNIONTYPE._serialized_start = 2164 - _UNION_UNIONTYPE._serialized_end = 2248 - _LIMIT._serialized_start = 2250 - _LIMIT._serialized_end = 2326 - _OFFSET._serialized_start = 2328 - _OFFSET._serialized_end = 2407 - _AGGREGATE._serialized_start = 2410 - _AGGREGATE._serialized_end = 2735 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2639 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2735 - _SORT._serialized_start = 2738 - _SORT._serialized_end = 3240 - _SORT_SORTFIELD._serialized_start = 2858 - _SORT_SORTFIELD._serialized_end = 3046 - _SORT_SORTDIRECTION._serialized_start = 3048 - _SORT_SORTDIRECTION._serialized_end = 3156 - _SORT_SORTNULLS._serialized_start = 3158 - _SORT_SORTNULLS._serialized_end = 3240 - _DEDUPLICATE._serialized_start = 3243 - _DEDUPLICATE._serialized_end = 3385 - _LOCALRELATION._serialized_start = 3387 - _LOCALRELATION._serialized_end = 3480 - _SAMPLE._serialized_start = 3483 - _SAMPLE._serialized_end = 3723 - _SAMPLE_SEED._serialized_start = 3697 - _SAMPLE_SEED._serialized_end = 3723 + _JOIN._serialized_end = 2077 + _JOIN_JOINTYPE._serialized_start = 1890 + _JOIN_JOINTYPE._serialized_end = 2077 + _UNION._serialized_start = 2080 + _UNION._serialized_end = 2285 + _UNION_UNIONTYPE._serialized_start = 2201 + _UNION_UNIONTYPE._serialized_end = 2285 + _LIMIT._serialized_start = 2287 + _LIMIT._serialized_end = 2363 + _OFFSET._serialized_start = 2365 + _OFFSET._serialized_end = 2444 + _AGGREGATE._serialized_start = 2447 + _AGGREGATE._serialized_end = 2772 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2676 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2772 + _SORT._serialized_start = 2775 + _SORT._serialized_end = 3277 + _SORT_SORTFIELD._serialized_start = 2895 + _SORT_SORTFIELD._serialized_end = 3083 + _SORT_SORTDIRECTION._serialized_start = 3085 + _SORT_SORTDIRECTION._serialized_end = 3193 + _SORT_SORTNULLS._serialized_start = 3195 + _SORT_SORTNULLS._serialized_end = 3277 + _DEDUPLICATE._serialized_start = 3280 + _DEDUPLICATE._serialized_end = 3422 + _LOCALRELATION._serialized_start = 3424 + _LOCALRELATION._serialized_end = 3517 + _SAMPLE._serialized_start = 3520 + _SAMPLE._serialized_end = 3760 + _SAMPLE_SEED._serialized_start = 3734 + _SAMPLE_SEED._serialized_end = 3760 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index df179df1480e..d3186c4e3df0 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -467,6 +467,7 @@ class Join(google.protobuf.message.Message): RIGHT_FIELD_NUMBER: builtins.int JOIN_CONDITION_FIELD_NUMBER: builtins.int JOIN_TYPE_FIELD_NUMBER: builtins.int + USING_COLUMNS_FIELD_NUMBER: builtins.int @property def left(self) -> global___Relation: ... @property @@ -474,6 +475,16 @@ class Join(google.protobuf.message.Message): @property def join_condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression: ... join_type: global___Join.JoinType.ValueType + @property + def using_columns( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Optional. using_columns provides a list of columns that should present on both sides of + the join inputs that this Join will join on. For example A JOIN B USING col_name is + equivalent to A JOIN B on A.col_name = B.col_name. + + This field does not co-exist with join_condition. + """ def __init__( self, *, @@ -481,6 +492,7 @@ class Join(google.protobuf.message.Message): right: global___Relation | None = ..., join_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ..., join_type: global___Join.JoinType.ValueType = ..., + using_columns: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, @@ -499,6 +511,8 @@ class Join(google.protobuf.message.Message): b"left", "right", b"right", + "using_columns", + b"using_columns", ], ) -> None: ... diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f01379afd6ef..ad1bc488e876 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -4884,6 +4884,52 @@ def check_string_field(field, fieldName): # type: ignore[no-untyped-def] return _invoke_function("window", time_col, windowDuration) +def window_time( + windowColumn: "ColumnOrName", +) -> Column: + """Computes the event time from a window column. The column window values are produced + by window aggregating operators and are of type `STRUCT` + where start is inclusive and end is exclusive. The event time of records produced by window + aggregating operators can be computed as ``window_time(window)`` and are + ``window.end - lit(1).alias("microsecond")`` (as microsecond is the minimal supported event + time precision). The window column must be one produced by a window aggregating operator. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + windowColumn : :class:`~pyspark.sql.Column` + The window column of a window aggregate records. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import datetime + >>> df = spark.createDataFrame( + ... [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], + ... ).toDF("date", "val") + + Group the data into 5 second time windows and aggregate as sum. + + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + + Extract the window event time using the window_time function. + + >>> w.select( + ... w.window.end.cast("string").alias("end"), + ... window_time(w.window).cast("string").alias("window_time"), + ... "sum" + ... ).collect() + [Row(end='2016-03-11 09:00:10', window_time='2016-03-11 09:00:09.999999', sum=1)] + """ + window_col = _to_java_column(windowColumn) + return _invoke_function("window_time", window_col) + + def session_window(timeColumn: "ColumnOrName", gapDuration: Union[Column, str]) -> Column: """ Generates session window given a timestamp specifying column. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f6988a1d1200..459b05cc37aa 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -22,6 +22,7 @@ import pandas from pyspark.sql import SparkSession, Row +from pyspark.sql.types import StructType, StructField, LongType, StringType from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit @@ -97,6 +98,15 @@ def test_simple_explain_string(self): result = df.explain() self.assertGreater(len(result), 0) + def test_schema(self): + schema = self.connect.read.table(self.tbl_name).schema() + self.assertEqual( + StructType( + [StructField("id", LongType(), True), StructField("name", StringType(), True)] + ), + schema, + ) + def test_simple_binary_expressions(self): """Test complex expression""" df = self.connect.read.table(self.tbl_name) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 3b609db7a028..450f5c70faba 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -72,6 +72,25 @@ def test_sample(self): self.assertEqual(plan.root.sample.with_replacement, True) self.assertEqual(plan.root.sample.seed.seed, -1) + def test_deduplicate(self): + df = self.connect.readTable(table_name=self.tbl_name) + + distinct_plan = df.distinct()._plan.to_proto(self.connect) + self.assertEqual(distinct_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(distinct_plan.root.deduplicate.column_names), 0) + + deduplicate_on_all_columns_plan = df.dropDuplicates()._plan.to_proto(self.connect) + self.assertEqual(deduplicate_on_all_columns_plan.root.deduplicate.all_columns_as_keys, True) + self.assertEqual(len(deduplicate_on_all_columns_plan.root.deduplicate.column_names), 0) + + deduplicate_on_subset_columns_plan = df.dropDuplicates(["name", "height"])._plan.to_proto( + self.connect + ) + self.assertEqual( + deduplicate_on_subset_columns_plan.root.deduplicate.all_columns_as_keys, False + ) + self.assertEqual(len(deduplicate_on_subset_columns_plan.root.deduplicate.column_names), 2) + def test_relation_alias(self): df = self.connect.readTable(table_name=self.tbl_name) plan = df.alias("table_alias")._plan.to_proto(self.connect) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 32cc77e11155..55ef012b6d02 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -894,6 +894,22 @@ def test_window_functions_cumulative_sum(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[: len(r)]) + def test_window_time(self): + df = self.spark.createDataFrame( + [(datetime.datetime(2016, 3, 11, 9, 0, 7), 1)], ["date", "val"] + ) + from pyspark.sql import functions as F + + w = df.groupBy(F.window("date", "5 seconds")).agg(F.sum("val").alias("sum")) + r = w.select( + w.window.end.cast("string").alias("end"), + F.window_time(w.window).cast("string").alias("window_time"), + "sum", + ).collect() + self.assertEqual( + r[0], Row(end="2016-03-11 09:00:10", window_time="2016-03-11 09:00:09.999999", sum=1) + ) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/python/pyspark/sql/tests/test_pandas_sqlmetrics.py b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py new file mode 100644 index 000000000000..d182bafd8b54 --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_sqlmetrics.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from typing import cast + +from pyspark.sql.functions import pandas_udf +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), +) +class PandasSQLMetrics(ReusedSQLTestCase): + def test_pandas_sql_metrics_basic(self): + # SPARK-34265: Instrument Python UDFs using SQL metrics + + python_sql_metrics = [ + "data sent to Python workers", + "data returned from Python workers", + "number of output rows", + ] + + @pandas_udf("long") + def test_pandas(col1): + return col1 * col1 + + self.spark.range(10).select(test_pandas("id")).collect() + + statusStore = self.spark._jsparkSession.sharedState().statusStore() + lastExecId = statusStore.executionsList().last().executionId() + executionMetrics = statusStore.execution(lastExecId).get().metrics().mkString() + + for metric in python_sql_metrics: + self.assertIn(metric, executionMetrics) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_sqlmetrics import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 780cfdfba8e9..0e05b523ccc9 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -42,7 +42,11 @@ fi SPARK_CLASSPATH="$SPARK_CLASSPATH:${SPARK_HOME}/jars/*" env | grep SPARK_JAVA_OPT_ | sort -t_ -k4 -n | sed 's/[^=]*=\(.*\)/\1/g' > /tmp/java_opts.txt -readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt +if [ "$(command -v readarray)" ]; then + readarray -t SPARK_EXECUTOR_JAVA_OPTS < /tmp/java_opts.txt +else + SPARK_EXECUTOR_JAVA_OPTS=("${(@f)$(< /tmp/java_opts.txt)}") +fi if [ -n "$SPARK_EXTRA_CLASSPATH" ]; then SPARK_CLASSPATH="$SPARK_CLASSPATH:$SPARK_EXTRA_CLASSPATH" diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 9a1862d32dc1..102dd4b76d23 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.{Collection, Collections, Date} +import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ @@ -40,6 +41,19 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi private var driver: SchedulerDriver = _ private var scheduler: MesosClusterScheduler = _ + private val submissionTime = new AtomicLong(System.currentTimeMillis()) + + // Queued drivers in MesosClusterScheduler are ordered based on MesosDriverDescription + // The default ordering checks for priority, followed by submission time. For two driver + // submissions with same priority and if made in quick succession (such that submission + // time is same due to millisecond granularity), this results in dropping the + // second MesosDriverDescription from the queuedDrivers - as driverOrdering + // returns 0 when comparing the descriptions. Ensure two seperate submissions + // have differnt dates + private def getDate: Date = { + new Date(submissionTime.incrementAndGet()) + } + private def setScheduler(sparkConfVars: Map[String, String] = null): Unit = { val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") @@ -68,7 +82,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map[String, String](), submissionId, - new Date()) + getDate) } test("can queue drivers") { @@ -108,7 +122,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test"), (config.DRIVER_MEMORY_OVERHEAD.key, "0")), "s1", - new Date())) + getDate)) assert(response.success) val offer = Offer.newBuilder() .addResources( @@ -213,7 +227,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map("spark.mesos.executor.home" -> "test", "spark.app.name" -> "test"), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem*2, cpu) @@ -240,7 +254,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi Map("spark.mesos.executor.home" -> "test", "spark.app.name" -> "test"), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem*2, cpu) @@ -270,7 +284,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_MEMORY_OVERHEAD.key -> "0" ), "s1", - new Date())) + getDate)) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) @@ -296,7 +310,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.NETWORK_LABELS.key -> "key1:val1,key2:val2", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -327,7 +341,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -352,7 +366,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -378,7 +392,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi "spark.app.name" -> "test", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -413,7 +427,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_CONSTRAINTS.key -> driverConstraints, config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) } @@ -452,7 +466,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_LABELS.key -> "key:value", config.DRIVER_MEMORY_OVERHEAD.key -> "0"), "s1", - new Date())) + getDate)) assert(response.success) @@ -474,7 +488,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test")), "s1", new Date())) + Map((config.EXECUTOR_HOME.key, "test"), ("spark.app.name", "test")), "s1", getDate)) assert(response.success) val agentId = SlaveID.newBuilder().setValue("s1").build() val offer = Offer.newBuilder() @@ -533,7 +547,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), "sub1", new Date())) + Map(("spark.mesos.executor.home", "test"), ("spark.app.name", "test")), "sub1", getDate)) assert(response.success) // Offer a resource to launch the submitted driver @@ -651,7 +665,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.EXECUTOR_URI.key -> "s3a://bucket/spark-version.tgz", "another.conf" -> "\\value"), "s1", - new Date()) + getDate) val expectedCmd = "cd spark-version*; " + "bin/spark-submit --name \"app name\" --master mesos://mesos://localhost:5050 " + @@ -691,7 +705,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "dummy"), "s1", - new Date()) + getDate) assertThrows[NoSuchElementException] { scheduler.getDriverPriority(desc) @@ -702,7 +716,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map[String, String](), "s2", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 0.0f) @@ -711,7 +725,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "default"), "s3", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 0.0f) @@ -720,7 +734,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s4", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 1.0f) @@ -729,7 +743,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi command, Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s5", - new Date()) + getDate) assert(scheduler.getDriverPriority(desc) == 2.0f) } @@ -746,22 +760,22 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response0 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s0", new Date())) + Map("spark.mesos.dispatcher.queue" -> "ROUTINE"), "s0", getDate)) assert(response0.success) val response1 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map[String, String](), "s1", new Date())) + Map[String, String](), "s1", getDate)) assert(response1.success) val response2 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "EXCEPTIONAL"), "s2", new Date())) + Map("spark.mesos.dispatcher.queue" -> "EXCEPTIONAL"), "s2", getDate)) assert(response2.success) val response3 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s3", new Date())) + Map("spark.mesos.dispatcher.queue" -> "URGENT"), "s3", getDate)) assert(response3.success) val state = scheduler.getSchedulerState() @@ -782,12 +796,12 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi val response0 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map("spark.mesos.dispatcher.queue" -> "LOWER"), "s0", new Date())) + Map("spark.mesos.dispatcher.queue" -> "LOWER"), "s0", getDate)) assert(response0.success) val response1 = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, - Map[String, String](), "s1", new Date())) + Map[String, String](), "s1", getDate)) assert(response1.success) val state = scheduler.getSchedulerState() @@ -812,7 +826,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi config.DRIVER_MEMORY_OVERHEAD.key -> "0") ++ addlSparkConfVars, "s1", - new Date()) + getDate) val response = scheduler.submitDriver(driverDesc) assert(response.success) val offer = Utils.createOffer("o1", "s1", mem, cpu) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b185b38797bb..fc12b6522b41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -56,7 +56,6 @@ import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} -import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils import org.apache.spark.util.collection.{Utils => CUtils} @@ -313,6 +312,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveAggregateFunctions :: TimeWindowing :: SessionWindowing :: + ResolveWindowTime :: ResolveDefaultColumns(v1SessionCatalog) :: ResolveInlineTables :: ResolveLambdaVariables :: @@ -3965,242 +3965,6 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { } } -/** - * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to - * figure out how many windows a time column can map to, we over-estimate the number of windows and - * filter out the rows where the time column is not inside the time window. - */ -object TimeWindowing extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.dsl.expressions._ - - private final val WINDOW_COL_NAME = "window" - private final val WINDOW_START = "start" - private final val WINDOW_END = "end" - - /** - * Generates the logical plan for generating window ranges on a timestamp column. Without - * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many - * window ranges a timestamp will map to given all possible combinations of a window duration, - * slide duration and start time (offset). Therefore, we express and over-estimate the number of - * windows there may be, and filter the valid windows. We use last Project operator to group - * the window columns into a struct so they can be accessed as `window.start` and `window.end`. - * - * The windows are calculated as below: - * maxNumOverlapping <- ceil(windowDuration / slideDuration) - * for (i <- 0 until maxNumOverlapping) - * lastStart <- timestamp - (timestamp - startTime + slideDuration) % slideDuration - * windowStart <- lastStart - i * slideDuration - * windowEnd <- windowStart + windowDuration - * return windowStart, windowEnd - * - * This behaves as follows for the given parameters for the time: 12:05. The valid windows are - * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the - * Filter operator. - * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m - * 11:55 - 12:07 + 11:52 - 12:04 x - * 12:00 - 12:12 + 11:57 - 12:09 + - * 12:05 - 12:17 + 12:02 - 12:14 + - * - * @param plan The logical plan - * @return the logical plan that will generate the time windows using the Expand operator, with - * the Filter operator for correctness and Project for usability. - */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsPattern(TIME_WINDOW), ruleId) { - case p: LogicalPlan if p.children.size == 1 => - val child = p.children.head - val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - - val numWindowExpr = p.expressions.flatMap(_.collect { - case s: SessionWindow => s - case t: TimeWindow => t - }).toSet.size - - // Only support a single window expression for now - if (numWindowExpr == 1 && windowExpressions.nonEmpty && - windowExpressions.head.timeColumn.resolved && - windowExpressions.head.checkInputDataTypes().isSuccess) { - - val window = windowExpressions.head - - val metadata = window.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - def getWindow(i: Int, dataType: DataType): Expression = { - val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) - val lastStart = timestamp - (timestamp - window.startTime - + window.slideDuration) % window.slideDuration - val windowStart = lastStart - i * window.slideDuration - val windowEnd = windowStart + window.windowDuration - - // We make sure value fields are nullable since the dataType of TimeWindow defines them - // as nullable. - CreateNamedStruct( - Literal(WINDOW_START) :: - PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: - Literal(WINDOW_END) :: - PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: - Nil) - } - - val windowAttr = AttributeReference( - WINDOW_COL_NAME, window.dataType, metadata = metadata)() - - if (window.windowDuration == window.slideDuration) { - val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( - exprId = windowAttr.exprId, explicitMetadata = Some(metadata)) - - val replacedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - // For backwards compatibility we add a filter to filter out nulls - val filterExpr = IsNotNull(window.timeColumn) - - replacedPlan.withNewChildren( - Project(windowStruct +: child.output, - Filter(filterExpr, child)) :: Nil) - } else { - val overlappingWindows = - math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = - Seq.tabulate(overlappingWindows)(i => - getWindow(i, window.timeColumn.dataType)) - - val projections = windows.map(_ +: child.output) - - // When the condition windowDuration % slideDuration = 0 is fulfilled, - // the estimation of the number of windows becomes exact one, - // which means all produced windows are valid. - val filterExpr = - if (window.windowDuration % window.slideDuration == 0) { - IsNotNull(window.timeColumn) - } else { - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) - } - - val substitutedPlan = Filter(filterExpr, - Expand(projections, windowAttr +: child.output, child)) - - val renamedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } - - renamedPlan.withNewChildren(substitutedPlan :: Nil) - } - } else if (numWindowExpr > 1) { - throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) - } else { - p // Return unchanged. Analyzer will throw exception later - } - } -} - -/** Maps a time column to a session window. */ -object SessionWindowing extends Rule[LogicalPlan] { - import org.apache.spark.sql.catalyst.dsl.expressions._ - - private final val SESSION_COL_NAME = "session_window" - private final val SESSION_START = "start" - private final val SESSION_END = "end" - - /** - * Generates the logical plan for generating session window on a timestamp column. - * Each session window is initially defined as [timestamp, timestamp + gap). - * - * This also adds a marker to the session column so that downstream can easily find the column - * on session window. - */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { - case p: LogicalPlan if p.children.size == 1 => - val child = p.children.head - val sessionExpressions = - p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet - - val numWindowExpr = p.expressions.flatMap(_.collect { - case s: SessionWindow => s - case t: TimeWindow => t - }).toSet.size - - // Only support a single session expression for now - if (numWindowExpr == 1 && sessionExpressions.nonEmpty && - sessionExpressions.head.timeColumn.resolved && - sessionExpressions.head.checkInputDataTypes().isSuccess) { - - val session = sessionExpressions.head - - val metadata = session.timeColumn match { - case a: Attribute => a.metadata - case _ => Metadata.empty - } - - val newMetadata = new MetadataBuilder() - .withMetadata(metadata) - .putBoolean(SessionWindow.marker, true) - .build() - - val sessionAttr = AttributeReference( - SESSION_COL_NAME, session.dataType, metadata = newMetadata)() - - val sessionStart = - PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) - val gapDuration = session.gapDuration match { - case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => - Cast(expr, CalendarIntervalType) - case other => - throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) - } - val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, - session.timeColumn.dataType, LongType) - - // We make sure value fields are nullable since the dataType of SessionWindow defines them - // as nullable. - val literalSessionStruct = CreateNamedStruct( - Literal(SESSION_START) :: - PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) - .castNullable() :: - Literal(SESSION_END) :: - PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) - .castNullable() :: - Nil) - - val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( - exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) - - val replacedPlan = p transformExpressions { - case s: SessionWindow => sessionAttr - } - - val filterByTimeRange = session.gapDuration match { - case Literal(interval: CalendarInterval, CalendarIntervalType) => - interval == null || interval.months + interval.days + interval.microseconds <= 0 - case _ => true - } - - // As same as tumbling window, we add a filter to filter out nulls. - // And we also filter out events with negative or zero or invalid gap duration. - val filterExpr = if (filterByTimeRange) { - IsNotNull(session.timeColumn) && - (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) - } else { - IsNotNull(session.timeColumn) - } - - replacedPlan.withNewChildren( - Filter(filterExpr, - Project(sessionStruct +: child.output, child)) :: Nil) - } else if (numWindowExpr > 1) { - throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) - } else { - p // Return unchanged. Analyzer will throw exception later - } - } -} - /** * Resolve expressions if they contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4346f51b613a..cad036a34e97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1066,7 +1066,12 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { // 1 | 2 | 4 // and the plan after rewrite will give the original query incorrect results. def failOnUnsupportedCorrelatedPredicate(predicates: Seq[Expression], p: LogicalPlan): Unit = { - if (predicates.nonEmpty) { + // Correlated non-equality predicates are only supported with the decorrelate + // inner query framework. Currently we only use this new framework for scalar + // and lateral subqueries. + val allowNonEqualityPredicates = + SQLConf.get.decorrelateInnerQueryEnabled && (isScalar || isLateral) + if (!allowNonEqualityPredicates && predicates.nonEmpty) { // Report a non-supported case as an exception p.failAnalysis( errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ef8ce3f48d5a..f5e494e90967 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -639,6 +639,7 @@ object FunctionRegistry { expression[Year]("year"), expression[TimeWindow]("window"), expression[SessionWindow]("session_window"), + expression[WindowTime]("window_time"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), // We keep the 2 expression builders below to have different function docs. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala new file mode 100644 index 000000000000..fd5da3ff13d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Cast, CreateNamedStruct, Expression, GetStructField, IsNotNull, Literal, PreciseTimestampConversion, SessionWindow, Subtract, TimeWindow, WindowTime} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.TIME_WINDOW +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{CalendarIntervalType, DataType, LongType, Metadata, MetadataBuilder, StructType} +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to + * figure out how many windows a time column can map to, we over-estimate the number of windows and + * filter out the rows where the time column is not inside the time window. + */ +object TimeWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val WINDOW_COL_NAME = "window" + private final val WINDOW_START = "start" + private final val WINDOW_END = "end" + + /** + * Generates the logical plan for generating window ranges on a timestamp column. Without + * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many + * window ranges a timestamp will map to given all possible combinations of a window duration, + * slide duration and start time (offset). Therefore, we express and over-estimate the number of + * windows there may be, and filter the valid windows. We use last Project operator to group + * the window columns into a struct so they can be accessed as `window.start` and `window.end`. + * + * The windows are calculated as below: + * maxNumOverlapping <- ceil(windowDuration / slideDuration) + * for (i <- 0 until maxNumOverlapping) + * lastStart <- timestamp - (timestamp - startTime + slideDuration) % slideDuration + * windowStart <- lastStart - i * slideDuration + * windowEnd <- windowStart + windowDuration + * return windowStart, windowEnd + * + * This behaves as follows for the given parameters for the time: 12:05. The valid windows are + * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the + * Filter operator. + * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m + * 11:55 - 12:07 + 11:52 - 12:04 x + * 12:00 - 12:12 + 11:57 - 12:09 + + * 12:05 - 12:17 + 12:02 - 12:14 + + * + * @param plan The logical plan + * @return the logical plan that will generate the time windows using the Expand operator, with + * the Filter operator for correctness and Project for usability. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(TIME_WINDOW), ruleId) { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowExpressions = + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single window expression for now + if (numWindowExpr == 1 && windowExpressions.nonEmpty && + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { + + val window = windowExpressions.head + + if (StructType.acceptsType(window.timeColumn.dataType)) { + return p.transformExpressions { + case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn)) + } + } + + val metadata = window.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(TimeWindow.marker, true) + .build() + + def getWindow(i: Int, dataType: DataType): Expression = { + val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType) + val lastStart = timestamp - (timestamp - window.startTime + + window.slideDuration) % window.slideDuration + val windowStart = lastStart - i * window.slideDuration + val windowEnd = windowStart + window.windowDuration + + // We make sure value fields are nullable since the dataType of TimeWindow defines them + // as nullable. + CreateNamedStruct( + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() :: + Nil) + } + + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = newMetadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)( + exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) + + replacedPlan.withNewChildren( + Project(windowStruct +: child.output, + Filter(filterExpr, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => + getWindow(i, window.timeColumn.dataType)) + + val projections = windows.map(_ +: child.output) + + // When the condition windowDuration % slideDuration = 0 is fulfilled, + // the estimation of the number of windows becomes exact one, + // which means all produced windows are valid. + val filterExpr = + if (window.windowDuration % window.slideDuration == 0) { + IsNotNull(window.timeColumn) + } else { + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + } + + val substitutedPlan = Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + if (StructType.acceptsType(session.timeColumn.dataType)) { + return p transformExpressions { + case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn)) + } + } + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = + PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType) + val gapDuration = session.gapDuration match { + case expr if Cast.canCast(expr.dataType, CalendarIntervalType) => + Cast(expr, CalendarIntervalType) + case other => + throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType) + } + val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration, + session.timeColumn.dataType, LongType) + + // We make sure value fields are nullable since the dataType of SessionWindow defines them + // as nullable. + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType) + .castNullable() :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType) + .castNullable() :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + val filterByTimeRange = session.gapDuration match { + case Literal(interval: CalendarInterval, CalendarIntervalType) => + interval == null || interval.months + interval.days + interval.microseconds <= 0 + case _ => true + } + + // As same as tumbling window, we add a filter to filter out nulls. + // And we also filter out events with negative or zero or invalid gap duration. + val filterExpr = if (filterByTimeRange) { + IsNotNull(session.timeColumn) && + (sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START)) + } else { + IsNotNull(session.timeColumn) + } + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + +/** + * Resolves the window_time expression which extracts the correct window time from the + * window column generated as the output of the window aggregating operators. The + * window column is of type struct { start: TimestampType, end: TimestampType }. + * The correct representative event time of a window is ``window.end - 1``. + * */ +object ResolveWindowTime extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowTimeExpressions = + p.expressions.flatMap(_.collect { case w: WindowTime => w }).toSet + + if (windowTimeExpressions.size == 1 && + windowTimeExpressions.head.windowColumn.resolved && + windowTimeExpressions.head.checkInputDataTypes().isSuccess) { + + val windowTime = windowTimeExpressions.head + + val metadata = windowTime.windowColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + if (!metadata.contains(TimeWindow.marker) && + !metadata.contains(SessionWindow.marker)) { + // FIXME: error framework? + throw new AnalysisException( + "The input is not a correct window column: $windowTime", plan = Some(p)) + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .remove(TimeWindow.marker) + .remove(SessionWindow.marker) + .build() + + val attr = AttributeReference( + "window_time", windowTime.dataType, metadata = newMetadata)() + + // NOTE: "window.end" is "exclusive" upper bound of window, so if we use this value as + // it is, it is going to be bound to the different window even if we apply the same window + // spec. Decrease 1 microsecond from window.end to let the window_time be bound to the + // correct window range. + val subtractExpr = + PreciseTimestampConversion( + Subtract(PreciseTimestampConversion( + GetStructField(windowTime.windowColumn, 1), + windowTime.dataType, LongType), Literal(1L)), + LongType, + windowTime.dataType) + + val newColumn = Alias(subtractExpr, "window_time")( + exprId = attr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case w: WindowTime => attr + } + + replacedPlan.withNewChildren(Project(newColumn +: child.output, child) :: Nil) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 1b1d5514b3f2..fa52e6cd8517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -22,8 +22,8 @@ import java.lang.reflect.{Method, Modifier} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -56,7 +56,9 @@ import org.apache.spark.util.Utils since = "2.0.0", group = "misc_funcs") case class CallMethodViaReflection(children: Seq[Expression]) - extends Nondeterministic with CodegenFallback { + extends Nondeterministic + with CodegenFallback + with QueryErrorsBase { override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("reflect") @@ -65,7 +67,7 @@ case class CallMethodViaReflection(children: Seq[Expression]) DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "expectedNum" -> "> 1", "actualNum" -> children.length.toString)) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index d7deca2f7b76..53c79d1fd54b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -136,6 +136,8 @@ case class TimeWindow( } object TimeWindow { + val marker = "spark.timeWindow" + /** * Parses the interval string for a valid time duration. CalendarInterval expects interval * strings to start with the string `interval`. For usability, we prepend `interval` to the string diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala new file mode 100644 index 000000000000..1bb934cb2023 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowTime.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ + +// scalastyle:off line.size.limit line.contains.tab +@ExpressionDescription( + usage = """ + _FUNC_(window_column) - Extract the time value from time/session window column which can be used for event time value of window. + The extracted time is (window.end - 1) which reflects the fact that the the aggregating + windows have exclusive upper bound - [start, end) + See 'Window Operations on Event Time' in Structured Streaming guide doc for detailed explanation and examples. + """, + arguments = """ + Arguments: + * window_column - The column representing time/session window. + """, + examples = """ + Examples: + > SELECT a, window.start as start, window.end as end, _FUNC_(window), cnt FROM (SELECT a, window, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, window.start); + A1 2021-01-01 00:00:00 2021-01-01 00:05:00 2021-01-01 00:04:59.999999 2 + A1 2021-01-01 00:05:00 2021-01-01 00:10:00 2021-01-01 00:09:59.999999 1 + A2 2021-01-01 00:00:00 2021-01-01 00:05:00 2021-01-01 00:04:59.999999 1 + """, + group = "datetime_funcs", + since = "3.4.0") +// scalastyle:on line.size.limit line.contains.tab +case class WindowTime(windowColumn: Expression) + extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + override def child: Expression = windowColumn + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) + + override def dataType: DataType = child.dataType.asInstanceOf[StructType].head.dataType + + override def prettyName: String = "window_time" + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + override protected def withNewChildInternal(newChild: Expression): WindowTime = + copy(windowColumn = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b802678ec046..902f53309de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -41,7 +41,7 @@ case class Max(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex override def dataType: DataType = child.dataType override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") + TypeUtils.checkForOrderingExpr(child.dataType, prettyName) private lazy val max = AttributeReference("max", child.dataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala index 664bc32ccc46..096a42686a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/MaxByAndMinBy.scala @@ -47,7 +47,7 @@ abstract class MaxMinBy extends DeclarativeAggregate with BinaryLike[Expression] override def dataType: DataType = valueExpr.dataType override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(orderingExpr.dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(orderingExpr.dataType, prettyName) // The attributes used to keep extremum (max or min) and associated aggregated values. private lazy val extremumOrdering = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 9c5c7bbda4dc..7a9588808dbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -41,7 +41,7 @@ case class Min(child: Expression) extends DeclarativeAggregate with UnaryLike[Ex override def dataType: DataType = child.dataType override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") + TypeUtils.checkForOrderingExpr(child.dataType, prettyName) private lazy val min = AttributeReference("min", child.dataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 380289ba5fee..cd6e1a5a18e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -59,10 +59,10 @@ case class Mode( override def update( buffer: OpenHashMap[AnyRef, Long], input: InternalRow): OpenHashMap[AnyRef, Long] = { - val key = child.eval(input).asInstanceOf[AnyRef] + val key = child.eval(input) if (key != null) { - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) } buffer } @@ -121,10 +121,12 @@ case class PandasMode( override def update( buffer: OpenHashMap[AnyRef, Long], input: InternalRow): OpenHashMap[AnyRef, Long] = { - val key = child.eval(input).asInstanceOf[AnyRef] + val key = child.eval(input) - if (key != null || !ignoreNA) { - buffer.changeValue(key, 1L, _ + 1L) + if (key != null) { + buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) + } else if (!ignoreNA) { + buffer.changeValue(null, 1L, _ + 1L) } buffer } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d82108aa3c9f..3e8ec94c33ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -1203,7 +1203,7 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "expectedNum" -> "> 1", "actualNum" -> children.length.toString)) } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { @@ -1215,7 +1215,7 @@ case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression ) ) } else { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(dataType, prettyName) } } @@ -1294,7 +1294,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "expectedNum" -> "> 1", "actualNum" -> children.length.toString)) } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) { @@ -1306,7 +1306,7 @@ case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpress ) ) } else { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(dataType, prettyName) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index efaadac6ed1c..256139aca014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder -import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} @@ -34,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, Tree import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet @@ -47,8 +46,10 @@ import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String} * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit * casting. */ -trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression - with ImplicitCastInputTypes { +trait BinaryArrayExpressionWithImplicitCast + extends BinaryExpression + with ImplicitCastInputTypes + with QueryErrorsBase { @transient protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType @@ -72,7 +73,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression DataTypeMismatch( errorSubClass = "BINARY_ARRAY_DIFF_TYPES", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "arrayType" -> toSQLType(ArrayType), "leftType" -> toSQLType(left.dataType), "rightType" -> toSQLType(right.dataType) @@ -219,7 +220,10 @@ case class MapKeys(child: Expression) group = "map_funcs", since = "3.3.0") case class MapContainsKey(left: Expression, right: Expression) - extends RuntimeReplaceable with BinaryLike[Expression] with ImplicitCastInputTypes { + extends RuntimeReplaceable + with BinaryLike[Expression] + with ImplicitCastInputTypes + with QueryErrorsBase { override lazy val replacement: Expression = ArrayContains(MapKeys(left), right) @@ -240,14 +244,14 @@ case class MapContainsKey(left: Expression, right: Expression) case (_, NullType) => DataTypeMismatch( errorSubClass = "NULL_TYPE", - Map("functionName" -> prettyName)) + Map("functionName" -> toSQLId(prettyName))) case (MapType(kt, _, _), dt) if kt.sameType(dt) => - TypeUtils.checkForOrderingExpr(kt, s"function $prettyName") + TypeUtils.checkForOrderingExpr(kt, prettyName) case _ => DataTypeMismatch( errorSubClass = "MAP_CONTAINS_KEY_DIFF_TYPES", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "dataType" -> toSQLType(MapType), "leftType" -> toSQLType(left.dataType), "rightType" -> toSQLType(right.dataType) @@ -676,20 +680,21 @@ case class MapEntries(child: Expression) """, group = "map_funcs", since = "2.4.0") -case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression { +case class MapConcat(children: Seq[Expression]) + extends ComplexTypeMergingExpression + with QueryErrorsBase { override def checkInputDataTypes(): TypeCheckResult = { - val funcName = s"function $prettyName" if (children.exists(!_.dataType.isInstanceOf[MapType])) { DataTypeMismatch( errorSubClass = "MAP_CONCAT_DIFF_TYPES", messageParameters = Map( - "functionName" -> funcName, + "functionName" -> toSQLId(prettyName), "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]") ) ) } else { - val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName) + val sameTypeCheck = TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), prettyName) if (sameTypeCheck.isFailure) { sameTypeCheck } else { @@ -802,7 +807,10 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """, group = "map_funcs", since = "2.4.0") -case class MapFromEntries(child: Expression) extends UnaryExpression with NullIntolerant { +case class MapFromEntries(child: Expression) + extends UnaryExpression + with NullIntolerant + with QueryErrorsBase { @transient private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { @@ -827,7 +835,7 @@ case class MapFromEntries(child: Expression) extends UnaryExpression with NullIn DataTypeMismatch( errorSubClass = "MAP_FROM_ENTRIES_WRONG_TYPE", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "childExpr" -> toSQLExpr(child), "childType" -> toSQLType(child.dataType) ) @@ -1290,7 +1298,7 @@ case class ArrayContains(left: Expression, right: Expression) case (_, NullType) => TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") @@ -1373,7 +1381,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(elementType, prettyName) case failure => failure } @@ -1901,7 +1909,7 @@ case class ArrayMin(child: Expression) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(dataType, prettyName) } else { typeCheckResult } @@ -1974,7 +1982,7 @@ case class ArrayMax(child: Expression) override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(dataType, prettyName) } else { typeCheckResult } @@ -2063,7 +2071,7 @@ case class ArrayPosition(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") @@ -2419,7 +2427,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " + childTypes.map(_.catalogString).mkString("[", ", ", "]")) } - TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") + TypeUtils.checkForSameTypeInputExpr(childTypes, prettyName) } } @@ -3473,7 +3481,7 @@ case class ArrayRemove(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { case (ArrayType(e1, _), e2) if e1.sameType(e2) => - TypeUtils.checkForOrderingExpr(e2, s"function $prettyName") + TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " + s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " + s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") @@ -3673,7 +3681,7 @@ case class ArrayDistinct(child: Expression) super.checkInputDataTypes() match { case f if f.isFailure => f case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(elementType, prettyName) } } @@ -3828,8 +3836,7 @@ trait ArrayBinaryLike override def checkInputDataTypes(): TypeCheckResult = { val typeCheckResult = super.checkInputDataTypes() if (typeCheckResult.isSuccess) { - TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, - s"function $prettyName") + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, prettyName) } else { typeCheckResult } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c6ae14e5e3c9..27d4f506ac86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -68,7 +68,7 @@ case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolea override def stringArgs: Iterator[Any] = super.stringArgs.take(1) override def checkInputDataTypes(): TypeCheckResult = { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), prettyName) } private val defaultElementType: DataType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 274de47ee752..d0ef5365bc94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -444,7 +444,7 @@ case class GetMapValue(child: Expression, key: Expression) super.checkInputDataTypes() match { case f if f.isFailure => f case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(keyType, prettyName) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7ac486f05af1..4f8ed1953f40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,8 @@ import org.apache.commons.codec.digest.MessageDigestAlgorithms import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} @@ -268,15 +270,17 @@ abstract class HashExpression[E] extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.length < 1) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName requires at least one argument") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_PARAMS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "> 0", + "actualNum" -> children.length.toString)) } else if (children.exists(child => hasMapType(child.dataType)) && !SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE)) { - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName cannot contain elements of MapType. In Spark, same maps " + - "may have different hashcode, thus hash expressions are prohibited on MapType elements." + - s" To restore previous behavior set ${SQLConf.LEGACY_ALLOW_HASH_ON_MAPTYPE.key} " + - "to true.") + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(prettyName))) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5b8b4b3f621e..98513fb5dddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -1023,7 +1023,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) super.checkArgumentDataTypes() match { case TypeCheckResult.TypeCheckSuccess => if (leftKeyType.sameType(rightKeyType)) { - TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(leftKeyType, prettyName) } else { TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 959edbd1c5ae..3529644aeeac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -355,7 +355,9 @@ case class GetJsonObject(json: Expression, path: Expression) since = "1.6.0") // scalastyle:on line.size.limit line.contains.tab case class JsonTuple(children: Seq[Expression]) - extends Generator with CodegenFallback { + extends Generator + with CodegenFallback + with QueryErrorsBase { import SharedFactory._ @@ -396,7 +398,7 @@ case class JsonTuple(children: Seq[Expression]) DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> prettyName, + "functionName" -> toSQLId(prettyName), "expectedNum" -> "> 1", "actualNum" -> children.length.toString)) } else if (children.forall(child => StringType.acceptsType(child.dataType))) { @@ -404,7 +406,7 @@ case class JsonTuple(children: Seq[Expression]) } else { DataTypeMismatch( errorSubClass = "NON_STRING_TYPE", - messageParameters = Map("funcName" -> prettyName)) + messageParameters = Map("funcName" -> toSQLId(prettyName))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 5643598b4bd5..f69ece52d858 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -22,7 +22,8 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils} @@ -1481,7 +1482,12 @@ abstract class RoundBase(child: Expression, scale: Expression, if (scale.foldable) { TypeCheckSuccess } else { - TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "scala", + "inputType" -> toSQLType(scale.dataType), + "inputExpr" -> toSQLExpr(scale))) } case f => f } @@ -1788,7 +1794,7 @@ case class WidthBucket( TypeCheckSuccess case _ => val types = Seq(value.dataType, minValue.dataType, maxValue.dataType) - TypeUtils.checkForSameTypeInputExpr(types, s"function $prettyName") + TypeUtils.checkForSameTypeInputExpr(types, prettyName) } case f => f } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 8d171c2c6631..1e6cc356173e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -60,7 +60,7 @@ case class Coalesce(children: Seq[Expression]) TypeCheckResult.TypeCheckFailure( s"input to function $prettyName requires at least one argument") } else { - TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), prettyName) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 21f65cb3402e..899ece6f5297 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -404,7 +404,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) |Right side: |[${query.childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin) } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(value.dataType, prettyName) } } @@ -453,7 +453,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + s"${value.dataType.catalogString} != ${mismatchOpt.get.dataType.catalogString}") } else { - TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") + TypeUtils.checkForOrderingExpr(value.dataType, prettyName) } } @@ -934,7 +934,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName) + TypeUtils.checkForOrderingExpr(left.dataType, symbol) case failure => failure } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 6927c4cfa3c9..8ae4bb9c29c0 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -278,7 +278,7 @@ case class Elt( DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> "elt", + "functionName" -> toSQLId(prettyName), "expectedNum" -> "> 1", "actualNum" -> children.length.toString ) @@ -305,7 +305,7 @@ case class Elt( ) ) } - TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") + TypeUtils.checkForSameTypeInputExpr(inputTypes, prettyName) } } @@ -782,7 +782,7 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { TypeUtils.checkForSameTypeInputExpr( - input.dataType :: replace.dataType :: Nil, s"function $prettyName") + input.dataType :: replace.dataType :: Nil, prettyName) } else { inputTypeCheck } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2664fd638062..afbf73027277 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -989,7 +989,10 @@ object ColumnPruning extends Rule[LogicalPlan] { object CollapseProject extends Rule[LogicalPlan] with AliasHelper { def apply(plan: LogicalPlan): LogicalPlan = { - val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + apply(plan, conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)) + } + + def apply(plan: LogicalPlan, alwaysInline: Boolean): LogicalPlan = { plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) { case p1 @ Project(_, p2: Project) if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9a1d20ed9b21..6665d885554f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -730,7 +730,9 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { object OneRowSubquery { def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = { - CollapseProject(EliminateSubqueryAliases(plan)) match { + // SPARK-40800: always inline expressions to support a broader range of correlated + // subqueries and avoid expensive domain joins. + CollapseProject(EliminateSubqueryAliases(plan), alwaysInline = true) match { case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList)) case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 0bb5d29c5c47..de1460eb2ea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -35,7 +35,7 @@ object TypeUtils extends QueryErrorsBase { DataTypeMismatch( errorSubClass = "INVALID_ORDERING_TYPE", Map( - "functionName" -> caller, + "functionName" -> toSQLId(caller), "dataType" -> toSQLType(dt) ) ) @@ -49,7 +49,7 @@ object TypeUtils extends QueryErrorsBase { DataTypeMismatch( errorSubClass = "DATA_DIFF_TYPES", messageParameters = Map( - "functionName" -> caller, + "functionName" -> toSQLId(caller), "dataType" -> types.map(toSQLType).mkString("(", " or ", ")") ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 72eb420de374..0a60c6b0265a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,6 +1982,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ASYNC_LOG_PURGE = + buildConf("spark.sql.streaming.asyncLogPurge.enabled") + .internal() + .doc("When true, purging the offset log and " + + "commit log of old entries will be done asynchronously.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + @@ -2995,6 +3004,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION = + buildConf("spark.sql.legacy.skipTypeValidationOnAlterPartition") + .internal() + .doc("When true, skip validation for partition spec in ALTER PARTITION. E.g., " + + "`ALTER TABLE .. ADD PARTITION(p='a')` would work even the partition type is int. " + + s"When false, the behavior follows ${STORE_ASSIGNMENT_POLICY.key}") + .version("3.4.0") + .booleanConf + .createWithDefault(false) + val SORT_BEFORE_REPARTITION = buildConf("spark.sql.execution.sortBeforeRepartition") .internal() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala index 1f5e225324ef..87f140cb3c4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/PartitioningUtils.scala @@ -20,14 +20,47 @@ package org.apache.spark.sql.util import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CharType, StructType, VarcharType} +import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy +import org.apache.spark.sql.types.{CharType, DataType, StringType, StructField, StructType, VarcharType} import org.apache.spark.unsafe.types.UTF8String private[sql] object PartitioningUtils { + + def castPartitionSpec(value: String, dt: DataType, conf: SQLConf): Expression = { + conf.storeAssignmentPolicy match { + // SPARK-30844: try our best to follow StoreAssignmentPolicy for static partition + // values but not completely follow because we can't do static type checking due to + // the reason that the parser has erased the type info of static partition values + // and converted them to string. + case StoreAssignmentPolicy.ANSI | StoreAssignmentPolicy.STRICT => + val cast = Cast(Literal(value), dt, Option(conf.sessionLocalTimeZone), + ansiEnabled = true) + cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) + cast + case _ => + Cast(Literal(value), dt, Option(conf.sessionLocalTimeZone), + ansiEnabled = false) + } + } + + private def normalizePartitionStringValue(value: String, field: StructField): String = { + val casted = Cast( + castPartitionSpec(value, field.dataType, SQLConf.get), + StringType, + Option(SQLConf.get.sessionLocalTimeZone) + ).eval() + if (casted != null) { + casted.asInstanceOf[UTF8String].toString + } else { + null + } + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a @@ -61,6 +94,14 @@ private[sql] object PartitioningUtils { case other => other } v.asInstanceOf[T] + case _ if !SQLConf.get.getConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION) && + value != null && value != DEFAULT_PARTITION_NAME => + val v = value match { + case Some(str: String) => Some(normalizePartitionStringValue(str, normalizedFiled)) + case str: String => normalizePartitionStringValue(str, normalizedFiled) + case other => other + } + v.asInstanceOf[T] case _ => value } normalizedFiled.name -> normalizedVal diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8b71bb05550a..ecd5b9e22fb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -725,7 +725,7 @@ class AnalysisErrorSuite extends AnalysisTest { inputPlan = plan2, expectedErrorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", expectedMessageParameters = Map( - "functionName" -> "EqualTo", + "functionName" -> "`=`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"(b = d)\"" ), @@ -917,7 +917,7 @@ class AnalysisErrorSuite extends AnalysisTest { (And($"a" === $"c", Cast($"d", IntegerType) === $"c"), "CAST(d#x AS INT) = outer(c#x)")) conditions.foreach { case (cond, msg) => val plan = Project( - ScalarSubquery( + Exists( Aggregate(Nil, count(Literal(1)).as("cnt") :: Nil, Filter(cond, t1)) ).as("sub") :: Nil, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index f2697d4ca3b0..ec2cd79dee18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -223,8 +223,7 @@ trait AnalysisTest extends PlanTest { } } - protected def parseException(parser: String => Any)( - sqlText: String): ParseException = { + protected def parseException(parser: String => Any)(sqlText: String): ParseException = { intercept[ParseException](parser(sqlText)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index b41f627bac94..e3829311e2dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -297,7 +298,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = EqualTo($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField = mapField)\"", - "functionName" -> "EqualTo", + "functionName" -> "`=`", "dataType" -> "\"MAP\"" ) ) @@ -305,7 +306,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = EqualTo($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField = mapField)\"", - "functionName" -> "EqualTo", + "functionName" -> "`=`", "dataType" -> "\"MAP\"" ) ) @@ -313,7 +314,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = EqualNullSafe($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField <=> mapField)\"", - "functionName" -> "EqualNullSafe", + "functionName" -> "`<=>`", "dataType" -> "\"MAP\"" ) ) @@ -321,7 +322,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = LessThan($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField < mapField)\"", - "functionName" -> "LessThan", + "functionName" -> "`<`", "dataType" -> "\"MAP\"" ) ) @@ -329,7 +330,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = LessThanOrEqual($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField <= mapField)\"", - "functionName" -> "LessThanOrEqual", + "functionName" -> "`<=`", "dataType" -> "\"MAP\"" ) ) @@ -337,7 +338,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = GreaterThan($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField > mapField)\"", - "functionName" -> "GreaterThan", + "functionName" -> "`>`", "dataType" -> "\"MAP\"" ) ) @@ -345,7 +346,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = GreaterThanOrEqual($"mapField", $"mapField"), messageParameters = Map( "sqlExpr" -> "\"(mapField >= mapField)\"", - "functionName" -> "GreaterThanOrEqual", + "functionName" -> "`>=`", "dataType" -> "\"MAP\"" ) ) @@ -384,7 +385,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = Min($"mapField"), messageParameters = Map( "sqlExpr" -> "\"min(mapField)\"", - "functionName" -> "function min", + "functionName" -> "`min`", "dataType" -> "\"MAP\"" ) ) @@ -392,7 +393,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = Max($"mapField"), messageParameters = Map( "sqlExpr" -> "\"max(mapField)\"", - "functionName" -> "function max", + "functionName" -> "`max`", "dataType" -> "\"MAP\"" ) ) @@ -426,7 +427,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = CreateArray(Seq($"intField", $"booleanField")), messageParameters = Map( "sqlExpr" -> "\"array(intField, booleanField)\"", - "functionName" -> "function array", + "functionName" -> "`array`", "dataType" -> "(\"INT\" or \"BOOLEAN\")" ) ) @@ -434,14 +435,37 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = Coalesce(Seq($"intField", $"booleanField")), messageParameters = Map( "sqlExpr" -> "\"coalesce(intField, booleanField)\"", - "functionName" -> "function coalesce", + "functionName" -> "`coalesce`", "dataType" -> "(\"INT\" or \"BOOLEAN\")" ) ) assertError(Coalesce(Nil), "function coalesce requires at least one argument") - assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") - assertError(new XxHash64(Nil), "function xxhash64 requires at least one argument") + + val murmur3Hash = new Murmur3Hash(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(murmur3Hash) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> toSQLId(murmur3Hash.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + val xxHash64 = new XxHash64(Nil) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(xxHash64) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> toSQLId(xxHash64.prettyName), + "expectedNum" -> "> 0", + "actualNum" -> "0")) + assertError(Explode($"intField"), "input to function explode should be array or map type") assertError(PosExplode($"intField"), @@ -478,8 +502,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(Round(Literal(null), Literal(null))) assertSuccess(Round($"intField", Literal(1))) - assertError(Round($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Round($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"round(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) + checkError( exception = intercept[AnalysisException] { assertSuccess(Round($"intField", $"booleanField")) @@ -516,9 +549,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assertSuccess(BRound(Literal(null), Literal(null))) assertSuccess(BRound($"intField", Literal(1))) - - assertError(BRound($"intField", $"intField"), - "Only foldable Expression is allowed") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(BRound($"intField", $"intField")) + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"bround(intField, intField)\"", + "inputName" -> "scala", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"intField\"")) checkError( exception = intercept[AnalysisException] { assertSuccess(BRound($"intField", $"booleanField")) @@ -561,7 +601,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = expr1, messageParameters = Map( "sqlExpr" -> toSQLExpr(expr1), - "functionName" -> expr1.prettyName, + "functionName" -> toSQLId(expr1.prettyName), "expectedNum" -> "> 1", "actualNum" -> "1") ) @@ -581,7 +621,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer expr = expr3, messageParameters = Map( "sqlExpr" -> toSQLExpr(expr3), - "functionName" -> s"function ${expr3.prettyName}", + "functionName" -> s"`${expr3.prettyName}`", "dataType" -> "\"MAP\"" ) ) @@ -602,4 +642,15 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer assert(Literal.create(Map(42L -> null), MapType(LongType, NullType)).sql == "MAP(42L, NULL)") } + + test("hash expressions are prohibited on MapType elements") { + val argument = Literal.create(Map(42L -> true), MapType(LongType, BooleanType)) + val murmur3Hash = new Murmur3Hash(Seq(argument)) + assert(murmur3Hash.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "HASH_MAP_TYPE", + messageParameters = Map("functionName" -> toSQLId(murmur3Hash.prettyName)) + ) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 5e5d0f7445e3..73cc9aca5682 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -242,7 +242,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case TypeCheckResult.DataTypeMismatch(errorSubClass, messageParameters) => assert(errorSubClass == "INVALID_ORDERING_TYPE") assert(messageParameters === Map( - "functionName" -> "function in", "dataType" -> "\"MAP\"")) + "functionName" -> "`in`", "dataType" -> "\"MAP\"")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index fce94bf02a0b..94ae774070c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1594,7 +1594,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { DataTypeMismatch( errorSubClass = "WRONG_NUM_PARAMS", messageParameters = Map( - "functionName" -> "elt", + "functionName" -> "`elt`", "expectedNum" -> "> 1", "actualNum" -> "1" ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index 1d9965548a20..f3d3f8690064 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -142,11 +142,12 @@ class ParserUtilsSuite extends SparkFunSuite { test("operationNotAllowed") { val errorMessage = "parse.fail.operation.not.allowed.error.message" - val e = intercept[ParseException] { - operationNotAllowed(errorMessage, showFuncContext) - }.getMessage - assert(e.contains("Operation not allowed")) - assert(e.contains(errorMessage)) + checkError( + exception = intercept[ParseException] { + operationNotAllowed(errorMessage, showFuncContext) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> errorMessage)) } test("checkDuplicateKeys") { @@ -154,10 +155,12 @@ class ParserUtilsSuite extends SparkFunSuite { checkDuplicateKeys[String](properties, createDbContext) val properties2 = Seq(("a", "a"), ("b", "b"), ("a", "c")) - val e = intercept[ParseException] { - checkDuplicateKeys(properties2, createDbContext) - }.getMessage - assert(e.contains("Found duplicate keys")) + checkError( + exception = intercept[ParseException] { + checkDuplicateKeys(properties2, createDbContext) + }, + errorClass = "DUPLICATE_KEY", + parameters = Map("keyColumn" -> "`a`")) } test("source") { @@ -201,10 +204,12 @@ class ParserUtilsSuite extends SparkFunSuite { val message = "ParserRuleContext should not be empty." validate(f1(showFuncContext), message, showFuncContext) - val e = intercept[ParseException] { - validate(f1(emptyContext), message, emptyContext) - }.getMessage - assert(e.contains(message)) + checkError( + exception = intercept[ParseException] { + validate(f1(emptyContext), message, emptyContext) + }, + errorClass = "_LEGACY_ERROR_TEMP_0064", + parameters = Map("msg" -> message)) } test("withOrigin") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index c2b240b3c496..62557ead1d2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -290,8 +290,17 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) // Illegal names. - Seq("", "d.q.g", "t:", "${some.var.x}", "tab:1").foreach { identifier => - intercept[ParseException](parseTableIdentifier(identifier)) + Seq( + "" -> ("PARSE_EMPTY_STATEMENT", Map.empty[String, String]), + "d.q.g" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "'.'", "hint" -> "")), + "t:" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> ": extra input ':'")), + "${some.var.x}" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "'$'", "hint" -> "")), + "tab:1" -> ("PARSE_SYNTAX_ERROR", Map("error" -> "':'", "hint" -> "")) + ).foreach { case (identifier, (errorClass, parameters)) => + checkError( + exception = intercept[ParseException](parseTableIdentifier(identifier)), + errorClass = errorClass, + parameters = parameters) } } @@ -307,10 +316,10 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true", SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") { reservedKeywordsInAnsiMode.foreach { keyword => - val errMsg = intercept[ParseException] { - parseTableIdentifier(keyword) - }.getMessage - assert(errMsg.contains("Syntax error at or near")) + checkError( + exception = intercept[ParseException](parseTableIdentifier(keyword)), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> s"'$keyword'", "hint" -> "")) assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) } @@ -363,7 +372,10 @@ class TableIdentifierParserSuite extends SQLKeywordUtils { val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) assert(complexName === parseTableIdentifier("```d``b``1`.```weird``table``name`")) assert(complexName === parseTableIdentifier(complexName.quotedString)) - intercept[ParseException](parseTableIdentifier(complexName.unquotedString)) + checkError( + exception = intercept[ParseException](parseTableIdentifier(complexName.unquotedString)), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'b'", "hint" -> "")) // Table identifier contains continuous backticks should be treated correctly. val complexName2 = TableIdentifier("x``y", Some("d``b")) assert(complexName2 === parseTableIdentifier(complexName2.quotedString)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 5519f016e48d..a7e2054dfaf8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.parser -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, SparkThrowable} import org.apache.spark.sql.types._ class TableSchemaParserSuite extends SparkFunSuite { @@ -30,9 +30,6 @@ class TableSchemaParserSuite extends SparkFunSuite { } } - def assertError(sql: String): Unit = - intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) - checkTableSchema("a int", new StructType().add("a", "int")) checkTableSchema("A int", new StructType().add("A", "int")) checkTableSchema("a INT", new StructType().add("a", "int")) @@ -73,11 +70,31 @@ class TableSchemaParserSuite extends SparkFunSuite { // Negative cases test("Negative cases") { - assertError("") - assertError("a") - assertError("a INT b long") - assertError("a INT,, b long") - assertError("a INT, b long,,") - assertError("a INT, b long, c int,") + def parseException(sql: String): SparkThrowable = + intercept[ParseException](CatalystSqlParser.parseTableSchema(sql)) + + checkError( + exception = parseException(""), + errorClass = "PARSE_EMPTY_STATEMENT") + checkError( + exception = parseException("a"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "end of input", "hint" -> "")) + checkError( + exception = parseException("a INT b long"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'b'", "hint" -> "")) + checkError( + exception = parseException("a INT,, b long"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "','", "hint" -> ": extra input ','")) + checkError( + exception = parseException("a INT, b long,,"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "','", "hint" -> "")) + checkError( + exception = parseException("a INT, b long, c int,"), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "end of input", "hint" -> "")) } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7203fc591081..cfcf7455ad03 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -193,6 +193,11 @@ mockito-core test
+ + org.mockito + mockito-inline + test + org.seleniumhq.selenium selenium-java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0216503fba0f..8b985e82963e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -48,9 +48,9 @@ import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation -import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -106,22 +106,8 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - conf.storeAssignmentPolicy match { - // SPARK-30844: try our best to follow StoreAssignmentPolicy for static partition - // values but not completely follow because we can't do static type checking due to - // the reason that the parser has erased the type info of static partition values - // and converted them to string. - case StoreAssignmentPolicy.ANSI | StoreAssignmentPolicy.STRICT => - val cast = Cast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = true) - cast.setTagValue(Cast.BY_TABLE_INSERTION, ()) - Some(Alias(cast, field.name)()) - case _ => - val castExpression = - Cast(Literal(partValue), field.dataType, Option(conf.sessionLocalTimeZone), - ansiEnabled = false) - Some(Alias(castExpression, field.name)()) - } + Some(Alias(CatalystPartitioningUtils.castPartitionSpec( + partValue, field.dataType, conf), field.name)()) } else { throw QueryCompilationErrors.multiplePartitionColumnValuesSpecifiedError( field, potentialSpecs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2f85149ee8e1..6a8b197742d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -46,7 +46,7 @@ case class AggregateInPandasExec( udfExpressions: Seq[PythonUDF], resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with PythonSQLMetrics { override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -163,7 +163,8 @@ case class AggregateInPandasExec( argOffsets, aggInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(projectedRowIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index bd8c72029dcb..f3531668c8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA import org.apache.spark.sql.execution.streaming.GroupStateImpl @@ -58,7 +59,8 @@ class ApplyInPandasWithStatePythonRunner( stateEncoder: ExpressionEncoder[Row], keySchema: StructType, outputSchema: StructType, - stateValueSchema: StructType) + stateValueSchema: StructType, + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { @@ -116,6 +118,7 @@ class ApplyInPandasWithStatePythonRunner( val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) while (inputIterator.hasNext) { + val startData = dataOut.size() val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") w.startNewGroup(keyRow, groupState) @@ -126,6 +129,8 @@ class ApplyInPandasWithStatePythonRunner( } w.finalizeGroup() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } w.finalizeData() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 096712cf9352..b11dd4947af6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -61,7 +61,7 @@ private[spark] class BatchIterator[T](iter: Iterator[T], batchSize: Int) */ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan, evalType: Int) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone @@ -85,7 +85,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] argOffsets, schema, sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 8467feb91d14..dbafc444281e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -32,7 +33,8 @@ class ArrowPythonRunner( argOffsets: Array[Array[Int]], protected override val schema: StructType, protected override val timeZoneId: String, - protected override val workerConf: Map[String, String]) + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowInput with BasicPythonArrowOutput { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 10f7966b93d1..ca7ca2e2f80a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{StructField, StructType} * A physical plan that evaluates a [[PythonUDF]] */ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) - extends EvalPythonExec { + extends EvalPythonExec with PythonSQLMetrics { protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -77,7 +77,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets) + val outputIterator = + new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -94,6 +95,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => + pythonMetrics("pythonNumRowsReceived") += 1 if (udfs.length == 1) { // fast path for single UDF mutableRow(0) = fromJava(result) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 2661896ececc..1df9f37188a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -45,7 +46,8 @@ class CoGroupedArrowPythonRunner( leftSchema: StructType, rightSchema: StructType, timeZoneId: String, - conf: Map[String, String]) + conf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) with BasicPythonArrowOutput { @@ -77,10 +79,14 @@ class CoGroupedArrowPythonRunner( // For each we first send the number of dataframes in each group then send // first df, then send second df. End of data is marked by sending 0. while (inputIterator.hasNext) { + val startData = dataOut.size() dataOut.writeInt(2) val (nextLeft, nextRight) = inputIterator.next() writeGroup(nextLeft, leftSchema, dataOut, "left") writeGroup(nextRight, rightSchema, dataOut, "right") + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } dataOut.writeInt(0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index b39787b12a48..629df51e18ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -54,7 +54,7 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends SparkPlan with BinaryExecNode { + extends SparkPlan with BinaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -77,7 +77,6 @@ case class FlatMapCoGroupsInPandasExec( } override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) @@ -97,7 +96,8 @@ case class FlatMapCoGroupsInPandasExec( StructType.fromAttributes(leftDedup), StructType.fromAttributes(rightDedup), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f0e815e966e7..271ccdb6b271 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -50,7 +50,7 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan with UnaryExecNode { + extends SparkPlan with UnaryExecNode with PythonSQLMetrics { private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) @@ -89,7 +89,8 @@ case class FlatMapGroupsInPandasExec( Array(argOffsets), StructType.fromAttributes(dedupAttributes), sessionLocalTimeZone, - pythonRunnerConf) + pythonRunnerConf, + pythonMetrics) executePython(data, output, runner) }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 09123344c2e2..3b096f07241f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -62,7 +62,8 @@ case class FlatMapGroupsInPandasWithStateExec( timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], eventTimeWatermark: Option[Long], - child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + child: SparkPlan) + extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null @@ -166,7 +167,8 @@ case class FlatMapGroupsInPandasWithStateExec( stateEncoder.asInstanceOf[ExpressionEncoder[Row]], groupingAttributes.toStructType, outAttributes.toStructType, - stateType) + stateType, + pythonMetrics) val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index d25c13835407..450891c69483 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow` */ -trait MapInBatchExec extends UnaryExecNode { +trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { protected val func: Expression protected val pythonEvalType: Int @@ -75,7 +75,8 @@ trait MapInBatchExec extends UnaryExecNode { argOffsets, StructType(StructField("struct", outputTypes) :: Nil), sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index bf66791183ec..5a0541d11cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -26,6 +26,7 @@ import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -41,6 +42,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val timeZoneId: String + protected def pythonMetrics: Map[String, SQLMetric] + protected def writeIteratorToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, @@ -115,6 +118,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { + val startData = dataOut.size() val nextBatch = inputIterator.next() while (nextBatch.hasNext) { @@ -124,6 +128,8 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In arrowWriter.finish() writer.writeBatch() arrowWriter.reset() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 339f114539c2..c12c690f776a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -37,6 +38,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column */ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + protected def pythonMetrics: Map[String, SQLMetric] + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT @@ -82,10 +85,15 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ } try { if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() batchLoaded = reader.loadNextBatch() if (batchLoaded) { val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart deserializeColumnarBatch(batch, schema) } else { reader.close(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala new file mode 100644 index 000000000000..a748c1bc1008 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonSQLMetrics.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.metric.SQLMetrics + +private[sql] trait PythonSQLMetrics { self: SparkPlan => + + val pythonMetrics = Map( + "pythonDataSent" -> SQLMetrics.createSizeMetric(sparkContext, + "data sent to Python workers"), + "pythonDataReceived" -> SQLMetrics.createSizeMetric(sparkContext, + "data returned from Python workers"), + "pythonNumRowsReceived" -> SQLMetrics.createMetric(sparkContext, + "number of output rows") + ) + + override lazy val metrics = pythonMetrics +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index d1109d251c28..09e06b55df3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ import org.apache.spark.api.python._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf /** @@ -31,7 +32,8 @@ import org.apache.spark.sql.internal.SQLConf class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, - argOffsets: Array[Array[Int]]) + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric]) extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, evalType, argOffsets) { @@ -50,8 +52,13 @@ class PythonUDFRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val startData = dataOut.size() + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData } } } @@ -77,6 +84,7 @@ class PythonUDFRunner( case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) + pythonMetrics("pythonDataReceived") += length obj case 0 => Array.emptyByteArray case SpecialLengths.TIMING_DATA => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index ccb1ed92525d..dcaffed89cca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -84,7 +84,7 @@ case class WindowInPandasExec( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) - extends WindowExecBase { + extends WindowExecBase with PythonSQLMetrics { /** * Helper functions and data structures for window bounds @@ -375,7 +375,8 @@ case class WindowInPandasExec( argOffsets, pythonInputSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(pythonInput, context.partitionId(), context) + pythonRunnerConf, + pythonMetrics).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index bcd226f95f82..50092571e856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -17,55 +17,22 @@ package org.apache.spark.sql.execution.stat -import scala.collection.mutable.{Map => MutableMap} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.{functions, Column, DataFrame} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object FrequentItems extends Logging { - /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ - private class FreqItemCounter(size: Int) extends Serializable { - val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long] - /** - * Add a new example to the counts if it exists, otherwise deduct the count - * from existing items. - */ - def add(key: Any, count: Long): this.type = { - if (baseMap.contains(key)) { - baseMap(key) += count - } else { - if (baseMap.size < size) { - baseMap += key -> count - } else { - val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min - val remainder = count - minCount - if (remainder >= 0) { - baseMap += key -> count // something will get kicked out, so we can add this - baseMap.retain((k, v) => v > minCount) - baseMap.transform((k, v) => v - minCount) - } else { - baseMap.transform((k, v) => v - count) - } - } - } - this - } - - /** - * Merge two maps of counts. - * @param other The map containing the counts for that partition - */ - def merge(other: FreqItemCounter): this.type = { - other.baseMap.foreach { case (k, v) => - add(k, v) - } - this - } - } - /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in @@ -85,42 +52,142 @@ object FrequentItems extends Logging { cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.") - val numCols = cols.length + // number of max items to keep counts for val sizeOfMap = (1 / support).toInt - val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) - - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)( - seqOp = (counts, row) => { - var i = 0 - while (i < numCols) { - val thisMap = counts(i) - val key = row.get(i) - thisMap.add(key, 1L) - i += 1 - } - counts - }, - combOp = (baseCounts, counts) => { - var i = 0 - while (i < numCols) { - baseCounts(i).merge(counts(i)) - i += 1 + + val frequentItemCols = cols.map { col => + val aggExpr = new CollectFrequentItems(functions.col(col).expr, sizeOfMap) + Column(aggExpr.toAggregateExpression(isDistinct = false)).as(s"${col}_freqItems") + } + + df.select(frequentItemCols: _*) + } +} + +case class CollectFrequentItems( + child: Expression, + size: Int, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[mutable.Map[Any, Long]] + with UnaryLike[Expression] { + require(size > 0) + + def this(child: Expression, size: Int) = this(child, size, 0, 0) + + // Returns empty array for empty inputs + override def nullable: Boolean = false + + override def dataType: DataType = ArrayType(child.dataType, containsNull = child.nullable) + + override def prettyName: String = "collect_frequent_items" + + override def createAggregationBuffer(): mutable.Map[Any, Long] = + mutable.Map.empty[Any, Long] + + private def add(map: mutable.Map[Any, Long], key: Any, count: Long): mutable.Map[Any, Long] = { + if (map.contains(key)) { + map(key) += count + } else { + if (map.size < size) { + map += key -> count + } else { + val minCount = if (map.values.isEmpty) 0 else map.values.min + val remainder = count - minCount + if (remainder >= 0) { + map += key -> count // something will get kicked out, so we can add this + map.retain((k, v) => v > minCount) + map.transform((k, v) => v - minCount) + } else { + map.transform((k, v) => v - count) } - baseCounts } - ) - val justItems = freqItems.map(m => m.baseMap.keys.toArray) - val resultRow = Row(justItems : _*) + } + map + } + + override def update( + buffer: mutable.Map[Any, Long], + input: InternalRow): mutable.Map[Any, Long] = { + val key = child.eval(input) + if (key != null) { + this.add(buffer, InternalRow.copyValue(key), 1L) + } else { + this.add(buffer, key, 1L) + } + } + + override def merge( + buffer: mutable.Map[Any, Long], + input: mutable.Map[Any, Long]): mutable.Map[Any, Long] = { + val otherIter = input.iterator + while (otherIter.hasNext) { + val (key, count) = otherIter.next + add(buffer, key, count) + } + buffer + } - val outputCols = cols.map { name => - val originalField = df.resolve(name) + override def eval(buffer: mutable.Map[Any, Long]): Any = + new GenericArrayData(buffer.keys.toArray) - // append frequent Items to the column name for easy debugging - StructField(name + "_freqItems", ArrayType(originalField.dataType, originalField.nullable)) - }.toArray + private lazy val projection = + UnsafeProjection.create(Array[DataType](child.dataType, LongType)) - val schema = StructType(outputCols).toAttributes - Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow))) + override def serialize(map: mutable.Map[Any, Long]): Array[Byte] = { + val buffer = new Array[Byte](4 << 10) // 4K + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(bos) + Utils.tryWithSafeFinally { + // Write pairs in counts map to byte buffer. + map.foreach { case (key, count) => + val row = InternalRow.apply(key, count) + val unsafeRow = projection.apply(row) + out.writeInt(unsafeRow.getSizeInBytes) + unsafeRow.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + + bos.toByteArray + } { + out.close() + bos.close() + } } + + override def deserialize(bytes: Array[Byte]): mutable.Map[Any, Long] = { + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(bis) + Utils.tryWithSafeFinally { + val map = mutable.Map.empty[Any, Long] + // Read unsafeRow size and content in bytes. + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(2) + row.pointTo(bs, sizeOfNextRow) + // Insert the pairs into counts map. + val key = row.get(0, child.dataType) + val count = row.get(1, LongType).asInstanceOf[Long] + map.update(key, count) + sizeOfNextRow = ins.readInt() + } + + map + } { + ins.close() + bis.close() + } + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c9d3b9999083..80e8f6d73420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,16 +21,12 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode} import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.Utils object StatFunctions extends Logging { @@ -188,54 +184,23 @@ object StatFunctions extends Logging { /** Generate a table of frequencies for the elements of two columns. */ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { - val tableName = s"${col1}_$col2" - val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) - if (counts.length == 1e6.toInt) { - logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + - "the pairs. Please try reducing the amount of distinct items in your columns.") - } - def cleanElement(element: Any): String = { - if (element == null) "null" else element.toString - } - // get the distinct sorted values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = - Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted) - val columnSize = distinctCol2.size - require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + - s"exceed 1e4. Currently $columnSize") - val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => - val countsRow = new GenericInternalRow(columnSize + 1) - rows.foreach { (row: Row) => - // row.get(0) is column 1 - // row.get(1) is column 2 - // row.get(2) is the frequency - val columnIndex = distinctCol2(cleanElement(row.get(1))) - countsRow.setLong(columnIndex + 1, row.getLong(2)) - } - // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) - countsRow - }.toSeq - // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept - // special keywords and `.`, wrap the column names in ``. - def cleanColumnName(name: String): String = { - name.replace("`", "") - } - // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in - // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => - StructField(cleanColumnName(r._1.toString), LongType) - } - val schema = StructType(StructField(tableName, StringType) +: headerNames) - - Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + df.groupBy( + when(isnull(col(col1)), "null") + .otherwise(col(col1).cast("string")) + .as(s"${col1}_$col2") + ).pivot( + when(isnull(col(col2)), "null") + .otherwise(regexp_replace(col(col2).cast("string"), "`", "")) + ).count().na.fill(0L) } /** Calculate selected summary statistics for a dataset */ def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = { - - val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") - val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics + val selectedStatistics = if (statistics.nonEmpty) { + statistics.toArray + } else { + Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") + } val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => try { @@ -247,71 +212,66 @@ object StatFunctions extends Logging { } require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) { - Cast(e, DoubleType, evalMode = EvalMode.TRY) - } else { - e - } - var percentileIndex = 0 - val statisticFns = selectedStatistics.map { stats => - if (stats.endsWith("%")) { - val index = percentileIndex - percentileIndex += 1 - (child: Expression) => - GetArrayItem( - new ApproximatePercentile(castAsDoubleIfNecessary(child), - Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false))) - .toAggregateExpression(), - Literal(index)) - } else { - stats.toLowerCase(Locale.ROOT) match { - case "count" => (child: Expression) => Count(child).toAggregateExpression() - case "count_distinct" => (child: Expression) => - Count(child).toAggregateExpression(isDistinct = true) - case "approx_count_distinct" => (child: Expression) => - HyperLogLogPlusPlus(child).toAggregateExpression() - case "mean" => (child: Expression) => - Average(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "stddev" => (child: Expression) => - StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression() - case "min" => (child: Expression) => Min(child).toAggregateExpression() - case "max" => (child: Expression) => Max(child).toAggregateExpression() - case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + var mapColumns = Seq.empty[Column] + var columnNames = Seq.empty[String] + + ds.schema.fields.foreach { field => + if (field.dataType.isInstanceOf[NumericType] || field.dataType.isInstanceOf[StringType]) { + val column = col(field.name) + var casted = column + if (field.dataType.isInstanceOf[StringType]) { + casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY)) } - } - } - val selectedCols = ds.logicalPlan.output - .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) + val percentilesCol = if (percentiles.nonEmpty) { + percentile_approx(casted, lit(percentiles), + lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } else null - val aggExprs = statisticFns.flatMap { func => - selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) - } + var aggColumns = Seq.empty[Column] + var percentileIndex = 0 + selectedStatistics.foreach { stats => + aggColumns :+= lit(stats) - // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. - lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head + stats.toLowerCase(Locale.ROOT) match { + case "count" => aggColumns :+= count(column) - // We will have one row for each selected statistic in the result. - val result = Array.fill[InternalRow](selectedStatistics.length) { - // each row has the statistic name, and statistic values of each selected column. - new GenericInternalRow(selectedCols.length + 1) - } + case "count_distinct" => aggColumns :+= count_distinct(column) + + case "approx_count_distinct" => aggColumns :+= approx_count_distinct(column) + + case "mean" => aggColumns :+= avg(casted) + + case "stddev" => aggColumns :+= stddev(casted) + + case "min" => aggColumns :+= min(column) + + case "max" => aggColumns :+= max(column) - var rowIndex = 0 - while (rowIndex < result.length) { - val statsName = selectedStatistics(rowIndex) - result(rowIndex).update(0, UTF8String.fromString(statsName)) - for (colIndex <- selectedCols.indices) { - val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) - result(rowIndex).update(colIndex + 1, statsValue) + case percentile if percentile.endsWith("%") => + aggColumns :+= get(percentilesCol, lit(percentileIndex)) + percentileIndex += 1 + + case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats) + } + } + + // map { "count" -> "1024", "min" -> "1.0", ... } + mapColumns :+= map(aggColumns.map(_.cast(StringType)): _*).as(field.name) + columnNames :+= field.name } - rowIndex += 1 } - // All columns are string type - val output = AttributeReference("summary", StringType)() +: - selectedCols.map(c => AttributeReference(c.name, StringType)()) - - Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + if (mapColumns.isEmpty) { + ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply)) + .withColumnRenamed("_1", "summary") + } else { + val valueColumns = columnNames.map { columnName => + new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + } + ds.select(mapColumns: _*) + .withColumn("summary", explode(lit(selectedStatistics))) + .select(Array(col("summary")) ++ valueColumns: _*) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala new file mode 100644 index 000000000000..b3729dbc7b45 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils + +/** + * Used to enable the capability to allow log purges to be done asynchronously + */ +trait AsyncLogPurge extends Logging { + + protected var currentBatchId: Long + + protected val minLogEntriesToMaintain: Int + + + protected[sql] val errorNotifier: ErrorNotifier + + protected val sparkSession: SparkSession + + private val asyncPurgeExecutorService + = ThreadUtils.newDaemonSingleThreadExecutor("async-log-purge") + + private val purgeRunning = new AtomicBoolean(false) + + protected def purge(threshold: Long): Unit + + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) + + protected def purgeAsync(): Unit = { + if (purgeRunning.compareAndSet(false, true)) { + // save local copy because currentBatchId may get updated. There are not really + // any concurrency issues here in regards to calculating the purge threshold + // but for the sake of defensive coding lets make a copy + val currentBatchIdCopy: Long = currentBatchId + asyncPurgeExecutorService.execute(() => { + try { + purge(currentBatchIdCopy - minLogEntriesToMaintain) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + purgeRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + + protected def asyncLogPurgeShutdown(): Unit = { + ThreadUtils.shutdown(asyncPurgeExecutorService) + } + + // used for testing + private[sql] def arePendingAsyncPurge: Boolean = { + purgeRunning.get() || + asyncPurgeExecutorService.getQueue.size() > 0 || + asyncPurgeExecutorService.getActiveCount > 0 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala new file mode 100644 index 000000000000..0f25d0667a0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.internal.Logging + +/** + * Class to notify of any errors that might have occurred out of band + */ +class ErrorNotifier extends Logging { + + private val error = new AtomicReference[Throwable] + + /** To indicate any errors that have occurred */ + def markError(th: Throwable): Unit = { + logError("A fatal error has occurred.", th) + error.set(th) + } + + /** Get any errors that have occurred */ + def getError(): Option[Throwable] = { + Option(error.get()) + } + + /** Throw errors that have occurred */ + def throwErrorIfExists(): Unit = { + getError().foreach({th => throw th}) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 153bc82f8928..5f8fb93827b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -46,7 +46,9 @@ class MicroBatchExecution( plan: WriteToStream) extends StreamExecution( sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink, trigger, - triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) { + triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge { + + protected[sql] val errorNotifier = new ErrorNotifier() @volatile protected var sources: Seq[SparkDataStream] = Seq.empty @@ -210,6 +212,14 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + override def cleanup(): Unit = { + super.cleanup() + + // shutdown and cleanup required for async log purge mechanism + asyncLogPurgeShutdown() + logInfo(s"Async log purge executor pool for query ${prettyIdString} has been shutdown") + } + /** Begins recording statistics about query progress for a given trigger. */ override protected def startTrigger(): Unit = { super.startTrigger() @@ -226,6 +236,10 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { + + // check if there are any previous errors and bubble up any existing async operations + errorNotifier.throwErrorIfExists + var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -536,7 +550,11 @@ class MicroBatchExecution( // It is now safe to discard the metadata beyond the minimum number to retain. // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minLogEntriesToMaintain < currentBatchId) { - purge(currentBatchId - minLogEntriesToMaintain) + if (useAsyncPurge) { + purgeAsync() + } else { + purge(currentBatchId - minLogEntriesToMaintain) + } } } noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index eeaa37aa7ffb..5afd744f5e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -347,6 +347,7 @@ abstract class StreamExecution( try { stopSources() + cleanup() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -410,6 +411,12 @@ abstract class StreamExecution( } } + + /** + * Any clean up that needs to happen when the query is stopped or exits + */ + protected def cleanup(): Unit = {} + /** * Interrupts the query execution thread and awaits its termination until until it exceeds the * timeout. The timeout can be set on "spark.sql.streaming.stopTimeout". diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 2b8fc6515618..b540f9f00939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonSQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} import org.apache.spark.sql.types._ @@ -93,7 +94,7 @@ trait StateStoreReader extends StatefulOperator { } /** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { self: SparkPlan => +trait StateStoreWriter extends StatefulOperator with PythonSQLMetrics { self: SparkPlan => override lazy val metrics = statefulOperatorCustomMetrics ++ Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), @@ -109,7 +110,7 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan => "numShufflePartitions" -> SQLMetrics.createMetric(sparkContext, "number of shuffle partitions"), "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext, "number of state store instances") - ) ++ stateStoreCustomMetrics + ) ++ stateStoreCustomMetrics ++ pythonMetrics /** * Get the progress made by this stateful operator after execution. This should be called in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 620e1c607217..f38f24920faf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3777,6 +3777,23 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Extracts the event time from the window column. + * + * The window column is of StructType { start: Timestamp, end: Timestamp } where start is + * inclusive and end is exclusive. Since event time can support microsecond precision, + * window_time(window) = window.end - 1 microsecond. + * + * @param windowColumn The window column (typically produced by window aggregation) of type + * StructType { start: Timestamp, end: Timestamp } + * + * @group datetime_funcs + * @since 3.4.0 + */ + def window_time(windowColumn: Column): Column = withExpr { + WindowTime(windowColumn.expr) + } + /** * Generates session window given a timestamp specifying column. * diff --git a/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker deleted file mode 100644 index eb074c6ae3fc..000000000000 --- a/sql/core/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -mock-maker-inline diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 4ce4f1225ce6..6f111b777a6d 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -345,6 +345,7 @@ | org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT weekday('2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT weekofyear('2008-02-20') | struct | | org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket | SELECT width_bucket(5.3, 0.2, 10.6, 5) | struct | +| org.apache.spark.sql.catalyst.expressions.WindowTime | window_time | SELECT a, window.start as start, window.end as end, window_time(window), cnt FROM (SELECT a, window, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, window.start) | struct | | org.apache.spark.sql.catalyst.expressions.XxHash64 | xxhash64 | SELECT xxhash64('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.Year | year | SELECT year('2016-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.ZipWith | zip_with | SELECT zip_with(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)) | struct>> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql index fc5776c46afd..dc1a35072728 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/join-lateral.sql @@ -44,6 +44,9 @@ SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c1 = t2.c1); -- lateral join with correlated non-equality predicates SELECT * FROM t1, LATERAL (SELECT c2 FROM t2 WHERE t1.c2 < t2.c2); +-- SPARK-36114: lateral join with aggregation and correlated non-equality predicates +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2); + -- lateral join can reference preceding FROM clause items SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2); -- expect error: cannot resolve `t2.c1` diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql index b999d1723c91..6d673f149cc9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql @@ -190,3 +190,48 @@ SELECT c1, ( -- Multi-value subquery error SELECT (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b; + +-- SPARK-36114: Support correlated non-equality predicates +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)); +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)); + +-- Neumann example Q2 +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)); +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)); + +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)); + +-- Correlated non-equality predicates +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1; + +-- Correlated non-equality predicates with the COUNT bug. +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1; + +-- Correlated equality predicates that are not supported after SPARK-35080 +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c); + +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 2078d3d8eb68..18ba4fb0ab7d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -3609,7 +3609,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", - "functionName" : "function array", + "functionName" : "`array`", "sqlExpr" : "\"array(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" }, "queryContext" : [ { @@ -3648,7 +3648,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", - "functionName" : "function coalesce", + "functionName" : "`coalesce`", "sqlExpr" : "\"coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out index a550dbbec882..a9b577dd4c37 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/map.sql.out @@ -73,7 +73,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", "messageParameters" : { "dataType" : "\"MAP\"", - "functionName" : "map_contains_key", + "functionName" : "`map_contains_key`", "leftType" : "\"MAP\"", "rightType" : "\"INT\"", "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" @@ -98,7 +98,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", "messageParameters" : { "dataType" : "\"MAP\"", - "functionName" : "map_contains_key", + "functionName" : "`map_contains_key`", "leftType" : "\"MAP\"", "rightType" : "\"STRING\"", "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 6eb5fb4ce844..bdb9ba81ff31 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -3422,7 +3422,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", - "functionName" : "function array", + "functionName" : "`array`", "sqlExpr" : "\"array(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" }, "queryContext" : [ { @@ -3461,7 +3461,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"INTERVAL MONTH\" or \"INTERVAL DAY\")", - "functionName" : "function coalesce", + "functionName" : "`coalesce`", "sqlExpr" : "\"coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)\"" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out index be07ba7bd9a1..34c0543dfdda 100644 --- a/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/join-lateral.sql.out @@ -272,6 +272,15 @@ struct 1 2 3 +-- !query +SELECT * FROM t1, LATERAL (SELECT max(c2) AS m FROM t2 WHERE t1.c2 < t2.c2) +-- !query schema +struct +-- !query output +0 1 3 +1 2 3 + + -- !query SELECT * FROM t1 JOIN t2 JOIN LATERAL (SELECT t1.c2 + t2.c2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/map.sql.out b/sql/core/src/test/resources/sql-tests/results/map.sql.out index a550dbbec882..a9b577dd4c37 100644 --- a/sql/core/src/test/resources/sql-tests/results/map.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/map.sql.out @@ -73,7 +73,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", "messageParameters" : { "dataType" : "\"MAP\"", - "functionName" : "map_contains_key", + "functionName" : "`map_contains_key`", "leftType" : "\"MAP\"", "rightType" : "\"INT\"", "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" @@ -98,7 +98,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.MAP_CONTAINS_KEY_DIFF_TYPES", "messageParameters" : { "dataType" : "\"MAP\"", - "functionName" : "map_contains_key", + "functionName" : "`map_contains_key`", "leftType" : "\"MAP\"", "rightType" : "\"STRING\"", "sqlExpr" : "\"map_contains_key(map(1, a, 2, b), 1)\"" diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out index d1e56786207e..38ab365ef694 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out @@ -433,3 +433,110 @@ org.apache.spark.SparkException "fragment" : "(SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t)" } ] } + + +-- !query +CREATE OR REPLACE TEMP VIEW t1(c1, c2) AS (VALUES (0, 1), (1, 2)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW t2(c1, c2) AS (VALUES (0, 2), (0, 3)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW students(id, name, major, year) AS (VALUES + (0, 'A', 'CS', 2022), + (1, 'B', 'CS', 2022), + (2, 'C', 'Math', 2022)) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE OR REPLACE TEMP VIEW exams(sid, course, curriculum, grade, date) AS (VALUES + (0, 'C1', 'CS', 4, 2020), + (0, 'C2', 'CS', 3, 2021), + (1, 'C1', 'CS', 2, 2020), + (1, 'C2', 'CS', 1, 2021)) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT students.name, exams.course +FROM students, exams +WHERE students.id = exams.sid + AND (students.major = 'CS' OR students.major = 'Games Eng') + AND exams.grade >= ( + SELECT avg(exams.grade) + 1 + FROM exams + WHERE students.id = exams.sid + OR (exams.curriculum = students.major AND students.year > exams.date)) +-- !query schema +struct +-- !query output +A C1 + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +2 +NULL + + +-- !query +SELECT (SELECT min(c2) FROM t2 WHERE t1.c1 >= t2.c1 AND t1.c2 < t2.c2) FROM t1 +-- !query schema +struct +-- !query output +2 +3 + + +-- !query +SELECT (SELECT count(*) FROM t2 WHERE t1.c1 > t2.c1) FROM t1 +-- !query schema +struct +-- !query output +0 +2 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES ('ab'), ('abc'), ('bc')) t2(c) + WHERE t1.c = substring(t2.c, 1, 1) +) FROM (VALUES ('a'), ('b')) t1(c) +-- !query schema +struct +-- !query output +a 2 +b 1 + + +-- !query +SELECT c, ( + SELECT count(*) + FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b) + WHERE a + b = c +) FROM (VALUES (6)) t2(c) +-- !query schema +struct +-- !query output +6 4 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index 0e2b0cf2789e..726356b7896d 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -95,7 +95,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"MAP\" or \"MAP, ARRAY>\")", - "functionName" : "function map_concat", + "functionName" : "`map_concat`", "sqlExpr" : "\"map_concat(tinyint_map1, array_map1)\"" }, "queryContext" : [ { @@ -120,7 +120,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"MAP\" or \"MAP\")", - "functionName" : "function map_concat", + "functionName" : "`map_concat`", "sqlExpr" : "\"map_concat(boolean_map1, int_map2)\"" }, "queryContext" : [ { @@ -145,7 +145,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"MAP\" or \"MAP, STRUCT>\")", - "functionName" : "function map_concat", + "functionName" : "`map_concat`", "sqlExpr" : "\"map_concat(int_map1, struct_map2)\"" }, "queryContext" : [ { @@ -170,7 +170,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"MAP, STRUCT>\" or \"MAP, ARRAY>\")", - "functionName" : "function map_concat", + "functionName" : "`map_concat`", "sqlExpr" : "\"map_concat(struct_map1, array_map2)\"" }, "queryContext" : [ { @@ -195,7 +195,7 @@ org.apache.spark.sql.AnalysisException "errorClass" : "DATATYPE_MISMATCH.DATA_DIFF_TYPES", "messageParameters" : { "dataType" : "(\"MAP\" or \"MAP, ARRAY>\")", - "functionName" : "function map_concat", + "functionName" : "`map_concat`", "sqlExpr" : "\"map_concat(int_map1, array_map2)\"" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out index f532b0d41e34..14ecf98c7a83 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-except.sql.out @@ -97,19 +97,6 @@ WHERE udf(t1.v) >= (SELECT min(udf(t2.v)) FROM t2 WHERE t2.k = t1.k) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - "messageParameters" : { - "treeNode" : "(cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\nFilter (cast(udf(cast(k#x as string)) as string) = cast(udf(cast(outer(k#x) as string)) as string))\n+- SubqueryAlias t2\n +- View (`t2`, [k#x,v#x])\n +- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]\n +- Project [k#x, v#x]\n +- SubqueryAlias t2\n +- LocalRelation [k#x, v#x]\n" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 39, - "stopIndex" : 141, - "fragment" : "SELECT udf(max(udf(t2.v)))\n FROM t2\n WHERE udf(t2.k) = udf(t1.k)" - } ] -} +two diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 54911d2a6fb6..ff8dd596ebe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -918,7 +918,7 @@ class DataFrameAggregateSuite extends QueryTest errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", sqlState = None, parameters = Map( - "functionName" -> "function max_by", + "functionName" -> "`max_by`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"max_by(x, y)\"" ), @@ -988,7 +988,7 @@ class DataFrameAggregateSuite extends QueryTest errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", sqlState = None, parameters = Map( - "functionName" -> "function min_by", + "functionName" -> "`min_by`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"min_by(x, y)\"" ), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7dea7799b666..85877c97ed59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1012,7 +1012,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "function map_concat"), + "functionName" -> "`map_concat`"), context = ExpectedContext( fragment = "map_concat(map1, map2)", start = 0, @@ -1028,7 +1028,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "function map_concat") + "functionName" -> "`map_concat`") ) checkError( @@ -1040,7 +1040,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "function map_concat"), + "functionName" -> "`map_concat`"), context = ExpectedContext( fragment = "map_concat(map1, 12)", start = 0, @@ -1056,7 +1056,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "function map_concat") + "functionName" -> "`map_concat`") ) } @@ -3651,7 +3651,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_zip_with(mmi, mmi, lambdafunction(x, x, y, z))\"", "dataType" -> "\"MAP\"", - "functionName" -> "function map_zip_with"), + "functionName" -> "`map_zip_with`"), context = ExpectedContext( fragment = "map_zip_with(mmi, mmi, (x, y, z) -> x)", start = 0, @@ -4219,16 +4219,68 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val funcsMustHaveAtLeastOneArg = ("coalesce", (df: DataFrame) => df.select(coalesce())) :: - ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: - ("hash", (df: DataFrame) => df.select(hash())) :: - ("hash", (df: DataFrame) => df.selectExpr("hash()")) :: - ("xxhash64", (df: DataFrame) => df.select(xxhash64())) :: - ("xxhash64", (df: DataFrame) => df.selectExpr("xxhash64()")) :: Nil + ("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) :: Nil funcsMustHaveAtLeastOneArg.foreach { case (name, func) => val errMsg = intercept[AnalysisException] { func(df) }.getMessage assert(errMsg.contains(s"input to function $name requires at least one argument")) } + checkError( + exception = intercept[AnalysisException] { + df.select(hash()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("hash()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"hash()\"", + "functionName" -> "`hash`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "hash()", + start = 0, + stop = 5)) + + checkError( + exception = intercept[AnalysisException] { + df.select(xxhash64()) + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0")) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("xxhash64()") + }, + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_PARAMS", + sqlState = None, + parameters = Map( + "sqlExpr" -> "\"xxhash64()\"", + "functionName" -> "`xxhash64`", + "expectedNum" -> "> 0", + "actualNum" -> "0"), + context = ExpectedContext( + fragment = "xxhash64()", + start = 0, + stop = 9)) + checkError( exception = intercept[AnalysisException] { df.select(greatest()) @@ -4237,7 +4289,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"greatest()\"", - "functionName" -> "greatest", + "functionName" -> "`greatest`", "expectedNum" -> "> 1", "actualNum" -> "0") ) @@ -4250,7 +4302,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"greatest()\"", - "functionName" -> "greatest", + "functionName" -> "`greatest`", "expectedNum" -> "> 1", "actualNum" -> "0"), context = ExpectedContext( @@ -4267,7 +4319,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"least()\"", - "functionName" -> "least", + "functionName" -> "`least`", "expectedNum" -> "> 1", "actualNum" -> "0") ) @@ -4280,7 +4332,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { sqlState = None, parameters = Map( "sqlExpr" -> "\"least()\"", - "functionName" -> "least", + "functionName" -> "`least`", "expectedNum" -> "> 1", "actualNum" -> "0"), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index bd39453f5120..f775eb9ecfc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -575,4 +575,66 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { validateWindowColumnInSchema(schema2, "window") } } + + test("window_time function on raw window column") { + val df = Seq( + ("2016-03-27 19:38:18"), ("2016-03-27 19:39:25") + ).toDF("time") + + checkAnswer( + df.select(window($"time", "10 seconds").as("window")) + .select( + $"window.end".cast("string"), + window_time($"window").cast("string") + ), + Seq( + Row("2016-03-27 19:38:20", "2016-03-27 19:38:19.999999"), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:29.999999") + ) + ) + } + + test("2 window_time functions on raw window column") { + val df = Seq( + ("2016-03-27 19:38:18"), ("2016-03-27 19:39:25") + ).toDF("time") + + val e = intercept[AnalysisException] { + df + .withColumn("time2", expr("time - INTERVAL 5 minutes")) + .select( + window($"time", "10 seconds").as("window1"), + window($"time2", "10 seconds").as("window2") + ) + .select( + $"window1.end".cast("string"), + window_time($"window1").cast("string"), + $"window2.end".cast("string"), + window_time($"window2").cast("string") + ) + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product of rows, " + + "therefore they are currently not supported")) + } + + test("window_time function on agg output") { + val df = Seq( + ("2016-03-27 19:38:19", 1), ("2016-03-27 19:39:25", 2) + ).toDF("time", "value") + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select( + $"window.start".cast("string"), + $"window.end".cast("string"), + window_time($"window").cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:38:10", "2016-03-27 19:38:20", "2016-03-27 19:38:19.999999", 1), + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", "2016-03-27 19:39:29.999999", 1) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index a22abd505ca0..e9daa825dd46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -418,4 +418,18 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { assert(err.contains("cannot be passed in untyped `select` API. " + "Use the typed `Dataset.select` API instead.")) } + + test("SPARK-40906: Mode should copy keys before inserting into Map") { + val df = spark.sparkContext.parallelize(Seq.empty[Int], 4) + .mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { + Iterator("3", "3", "3", "3", "4") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + + val agg = df.select(mode(col("a"))).as[String] + checkDataset(agg, "3") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala index e1b7f7f57b65..45ae3e549775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -54,10 +54,10 @@ class MiscFunctionsSuite extends QueryTest with SharedSparkSession { SQLConf.ENFORCE_RESERVED_KEYWORDS.key -> "true") { Seq("user", "current_user").foreach { func => checkAnswer(sql(s"select $func"), Row(user)) - } - Seq("user()", "current_user()").foreach { func => - val e = intercept[ParseException](sql(s"select $func")) - assert(e.getMessage.contains(func)) + checkError( + exception = intercept[ParseException](sql(s"select $func()")), + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> s"'$func'", "hint" -> "")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index ecb4bfd0ec41..4b5863563677 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -66,6 +66,11 @@ class SubquerySuite extends QueryTest t.createOrReplaceTempView("t") } + private def checkNumJoins(plan: LogicalPlan, numJoins: Int): Unit = { + val joins = plan.collect { case j: Join => j } + assert(joins.size == numJoins) + } + test("SPARK-18854 numberedTreeString for subquery") { val df = sql("select * from range(10) where id not in " + "(select id from range(2) union all select id from range(2))") @@ -562,17 +567,10 @@ class SubquerySuite extends QueryTest } test("non-equal correlated scalar subquery") { - val exception = intercept[AnalysisException] { - sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "select sum(b) from l l2 where l2.a < l1.a", start = 11, stop = 51)) + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a < l1.a) sum_b from l l1"), + Seq(Row(1, null), Row(1, null), Row(2, 4), Row(2, 4), Row(3, 6), Row(null, null), + Row(null, null), Row(6, 9))) } test("disjunctive correlated scalar subquery") { @@ -2105,25 +2103,17 @@ class SubquerySuite extends QueryTest } } - test("SPARK-38155: disallow distinct aggregate in lateral subqueries") { + test("SPARK-36114: distinct aggregate in lateral subqueries") { withTempView("t1", "t2") { Seq((0, 1)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((1, 2), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") - val exception = intercept[AnalysisException] { - sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)") - } - checkErrorMatchPVals( - exception, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1", start = 31, stop = 73)) + checkAnswer( + sql("SELECT * FROM t1 JOIN LATERAL (SELECT DISTINCT c2 FROM t2 WHERE c1 > t1.c1)"), + Row(0, 1, 2) :: Nil) } } - test("SPARK-38180: allow safe cast expressions in correlated equality conditions") { + test("SPARK-38180, SPARK-36114: allow safe cast expressions in correlated equality conditions") { withTempView("t1", "t2") { Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") Seq((0, 2), (0, 3)).toDF("c1", "c2").createOrReplaceTempView("t2") @@ -2139,19 +2129,14 @@ class SubquerySuite extends QueryTest |FROM (SELECT CAST(c1 AS STRING) a FROM t1) |""".stripMargin), Row(5) :: Row(null) :: Nil) - val exception1 = intercept[AnalysisException] { - sql( - """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) - |FROM (SELECT CAST(c1 AS SHORT) a FROM t1)""".stripMargin) - } - checkErrorMatchPVals( - exception1, - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "CORRELATED_COLUMN_IS_NOT_ALLOWED_IN_PREDICATE", - parameters = Map("treeNode" -> "(?s).*"), - sqlState = None, - context = ExpectedContext( - fragment = "SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a", start = 8, stop = 57)) + // SPARK-36114: we now allow non-safe cast expressions in correlated predicates. + val df = sql( + """SELECT (SELECT SUM(c2) FROM t2 WHERE CAST(c1 AS SHORT) = a) + |FROM (SELECT CAST(c1 AS SHORT) a FROM t1) + |""".stripMargin) + checkAnswer(df, Row(5) :: Row(null) :: Nil) + // The optimized plan should have one left outer join and one domain (inner) join. + checkNumJoins(df.queryExecution.optimizedPlan, 2) } } @@ -2469,10 +2454,41 @@ class SubquerySuite extends QueryTest Row(2)) // Cannot use non-orderable data type in one row subquery that cannot be collapsed. - val error = intercept[AnalysisException] { - sql("select (select concat(a, a) from (select upper(x['a']) as a)) from v1").collect() - } - assert(error.getMessage.contains("Correlated column reference 'v1.x' cannot be map type")) + val error = intercept[AnalysisException] { + sql( + """ + |select ( + | select concat(a, a) from + | (select upper(x['a'] + rand()) as a) + |) from v1 + |""".stripMargin).collect() + } + assert(error.getMessage.contains("Correlated column reference 'v1.x' cannot be map type")) + } + } + + test("SPARK-40800: always inline expressions in OptimizeOneRowRelationSubquery") { + withTempView("t1") { + sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a") + // Scalar subquery. + checkAnswer(sql( + """ + |SELECT ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank) + |) FROM t1 + |""".stripMargin), + Row("a")) + // Lateral subquery. + checkAnswer( + sql(""" + |SELECT sorted[0] FROM t1 + |JOIN LATERAL ( + | SELECT array_sort(a, (i, j) -> rank[i] - rank[j]) AS sorted + | FROM (SELECT MAP('a', 1, 'b', 2) rank) + |) + |""".stripMargin), + Row("a")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala index e113499ec685..f414de1b87c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableAddPartitionSuiteBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command import java.time.{Duration, Period} +import org.apache.spark.SparkNumberFormatException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -40,6 +41,7 @@ import org.apache.spark.sql.internal.SQLConf */ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils { override val command = "ALTER TABLE .. ADD PARTITION" + def defaultPartitionName: String test("one partition") { withNamespaceAndTable("ns", "tbl") { t => @@ -213,4 +215,46 @@ trait AlterTableAddPartitionSuiteBase extends QueryTest with DDLCommandTestUtils Row(Period.ofYears(1), Duration.ofDays(-1), "bbb"))) } } + + test("SPARK-40798: Alter partition should verify partition value") { + def shouldThrowException(policy: SQLConf.StoreAssignmentPolicy.Value): Boolean = policy match { + case SQLConf.StoreAssignmentPolicy.ANSI | SQLConf.StoreAssignmentPolicy.STRICT => + true + case SQLConf.StoreAssignmentPolicy.LEGACY => + false + } + + SQLConf.StoreAssignmentPolicy.values.foreach { policy => + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> policy.toString) { + if (shouldThrowException(policy)) { + checkError( + exception = intercept[SparkNumberFormatException] { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "ansiConfig" -> "\"spark.sql.ansi.enabled\"", + "expression" -> "'aaa'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"INT\""), + context = ExpectedContext( + fragment = s"ALTER TABLE $t ADD PARTITION (p='aaa')", + start = 0, + stop = 35 + t.length)) + } else { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + + sql(s"ALTER TABLE $t ADD PARTITION (p=null)") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala index 11df5ede8bbf..d41fd6b00f8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableAddPartitionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME import org.apache.spark.sql.execution.command import org.apache.spark.sql.internal.SQLConf @@ -33,6 +34,8 @@ import org.apache.spark.sql.internal.SQLConf * `org.apache.spark.sql.hive.execution.command.AlterTableAddPartitionSuite` */ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuiteBase { + override def defaultPartitionName: String = DEFAULT_PARTITION_NAME + test("empty string as partition value") { withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (col1 INT, p1 STRING) $defaultUsing PARTITIONED BY (p1)") @@ -157,6 +160,18 @@ trait AlterTableAddPartitionSuiteBase extends command.AlterTableAddPartitionSuit checkPartitions(t, Map("id" -> "1"), Map("id" -> "2")) } } + + test("SPARK-40798: Alter partition should verify partition value - legacy") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION.key -> "true") { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> "aaa")) + sql(s"ALTER TABLE $t DROP PARTITION (p='aaa')") + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala index 835be8573fdc..a9ab11e483fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableAddPartitionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.analysis.PartitionsAlreadyExistException import org.apache.spark.sql.execution.command +import org.apache.spark.sql.internal.SQLConf /** * The class contains tests for the `ALTER TABLE .. ADD PARTITION` command @@ -28,6 +29,8 @@ import org.apache.spark.sql.execution.command class AlterTableAddPartitionSuite extends command.AlterTableAddPartitionSuiteBase with CommandSuiteBase { + override def defaultPartitionName: String = "null" + test("SPARK-33650: add partition into a table which doesn't support partition management") { withNamespaceAndTable("ns", "tbl", s"non_part_$catalog") { t => sql(s"CREATE TABLE $t (id bigint, data string) $defaultUsing") @@ -121,4 +124,16 @@ class AlterTableAddPartitionSuite checkPartitions(t, Map("id" -> "1"), Map("id" -> "2")) } } + + test("SPARK-40798: Alter partition should verify partition value - legacy") { + withNamespaceAndTable("ns", "tbl") { t => + sql(s"CREATE TABLE $t (c int) $defaultUsing PARTITIONED BY (p int)") + + withSQLConf(SQLConf.SKIP_TYPE_VALIDATION_ON_ALTER_PARTITION.key -> "true") { + sql(s"ALTER TABLE $t ADD PARTITION (p='aaa')") + checkPartitions(t, Map("p" -> defaultPartitionName)) + sql(s"ALTER TABLE $t DROP PARTITION (p=null)") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 70784c20a8eb..7850b2d79b04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -84,4 +84,23 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { checkAnswer(actual, expected) } + + test("SPARK-34265: Instrument Python UDF execution using SQL Metrics") { + + val pythonSQLMetrics = List( + "data sent to Python workers", + "data returned from Python workers", + "number of output rows") + + val df = base.groupBy(pythonTestUDF(base("a") + 1)) + .agg(pythonTestUDF(pythonTestUDF(base("a") + 1))) + df.count() + + val statusStore = spark.sharedState.statusStore + val lastExecId = statusStore.executionsList.last.executionId + val executionMetrics = statusStore.execution(lastExecId).get.metrics.mkString + for (metric <- pythonSQLMetrics) { + assert(executionMetrics.contains(metric)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index 749ca9d06eaf..0ddd48420ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -21,17 +21,20 @@ import java.io.File import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.should._ +import org.scalatest.time.{Seconds, Span} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest, Trigger} import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.util.Utils -class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Matchers { import testImplicits._ @@ -39,6 +42,84 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { sqlContext.streams.active.foreach(_.stop()) } + def getListOfFiles(dir: String): List[File] = { + val d = new File(dir) + if (d.exists && d.isDirectory) { + d.listFiles.filter(_.isFile).toList + } else { + List[File]() + } + } + + test("async log purging") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + }, + AddData(inputData, 2), + CheckNewAnswer(2), + AddData(inputData, 3), + CheckNewAnswer(3), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + }, + StopStream + ) + } + } + } + + test("error notifier test") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + val e = intercept[StreamingQueryException] { + + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + q.asInstanceOf[MicroBatchExecution].errorNotifier.markError(new Exception("test")) + }, + AddData(inputData, 2), + CheckNewAnswer(2)) + } + e.getCause.getMessage should include("test") + } + } + } + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { val inputData = MemoryStream[Int] val df = inputData.toDF() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index 3de6d375149b..7e9052fb530e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -706,40 +706,35 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter test("insert overwrite to dir with mixed syntax") { withTempView("test_insert_table") { spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") - - val e = intercept[ParseException] { - sql( - s""" - |INSERT OVERWRITE DIRECTORY 'file://tmp' + checkError( + exception = intercept[ParseException] { sql( + s"""INSERT OVERWRITE DIRECTORY 'file://tmp' |USING json |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |SELECT * FROM test_insert_table - """.stripMargin) - }.getMessage - - assert(e.contains("Syntax error at or near 'ROW'")) + |SELECT * FROM test_insert_table""".stripMargin) + }, + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'ROW'", "hint" -> "")) } } test("insert overwrite to dir with multi inserts") { withTempView("test_insert_table") { spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table") - - val e = intercept[ParseException] { - sql( - s""" - |INSERT OVERWRITE DIRECTORY 'file://tmp2' - |USING json - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |SELECT * FROM test_insert_table - |INSERT OVERWRITE DIRECTORY 'file://tmp2' - |USING json - |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' - |SELECT * FROM test_insert_table - """.stripMargin) - }.getMessage - - assert(e.contains("Syntax error at or near 'ROW'")) + checkError( + exception = intercept[ParseException] { + sql( + s"""INSERT OVERWRITE DIRECTORY 'file://tmp2' + |USING json + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table + |INSERT OVERWRITE DIRECTORY 'file://tmp2' + |USING json + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |SELECT * FROM test_insert_table""".stripMargin) + }, + errorClass = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'ROW'", "hint" -> "")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 99aef0e47de9..653906366b3a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -608,7 +608,7 @@ class HiveDDLSuite } test("SPARK-19129: drop partition with a empty string will drop the whole table") { - val df = spark.createDataFrame(Seq((0, "a"), (1, "b"))).toDF("partCol1", "name") + val df = spark.createDataFrame(Seq(("0", "a"), ("1", "b"))).toDF("partCol1", "name") df.write.mode("overwrite").partitionBy("partCol1").saveAsTable("partitionedTable") assertAnalysisError( "alter table partitionedTable drop partition(partCol1='')", @@ -2678,27 +2678,30 @@ class HiveDDLSuite } test("Hive CTAS can't create partitioned table by specifying schema") { - val err1 = intercept[ParseException] { - spark.sql( - s""" - |CREATE TABLE t (a int) - |PARTITIONED BY (b string) - |STORED AS parquet - |AS SELECT 1 as a, "a" as b - """.stripMargin) - }.getMessage - assert(err1.contains("Schema may not be specified in a Create Table As Select")) - - val err2 = intercept[ParseException] { - spark.sql( - s""" - |CREATE TABLE t - |PARTITIONED BY (b string) - |STORED AS parquet - |AS SELECT 1 as a, "a" as b - """.stripMargin) - }.getMessage - assert(err2.contains("Partition column types may not be specified in Create Table As Select")) + val sql1 = + s"""CREATE TABLE t (a int) + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b""".stripMargin + checkError( + exception = intercept[ParseException](sql(sql1)), + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map( + "message" -> "Schema may not be specified in a Create Table As Select (CTAS) statement"), + context = ExpectedContext(sql1, 0, 92)) + + val sql2 = + s"""CREATE TABLE t + |PARTITIONED BY (b string) + |STORED AS parquet + |AS SELECT 1 as a, "a" as b""".stripMargin + checkError( + exception = intercept[ParseException](sql(sql2)), + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map( + "message" -> + "Partition column types may not be specified in Create Table As Select (CTAS)"), + context = ExpectedContext(sql2, 0, 84)) } test("Hive CTAS with dynamic partition") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e80c41401227..01c8d6ffe1be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -21,7 +21,6 @@ import java.io.File import java.net.URI import java.nio.file.Files import java.sql.Timestamp -import java.util.Locale import scala.util.Try @@ -73,9 +72,17 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } } - private def assertUnsupportedFeature(body: => Unit): Unit = { - val e = intercept[ParseException] { body } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + private def assertUnsupportedFeature( + body: => Unit, + message: String, + expectedContext: ExpectedContext): Unit = { + checkError( + exception = intercept[ParseException] { + body + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> message), + context = expectedContext) } // Testing the Broadcast based join for cartesian join (cross join) @@ -155,13 +162,25 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """.stripMargin) test("multiple generators in projection") { - intercept[AnalysisException] { - sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() - } - - intercept[AnalysisException] { - sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() - } + checkError( + exception = intercept[AnalysisException] { + sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() + }, + errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", + parameters = Map( + "clause" -> "SELECT", + "num" -> "2", + "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() + }, + errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", + parameters = Map( + "clause" -> "SELECT", + "num" -> "2", + "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) } createQueryTest("! operator", @@ -686,9 +705,12 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. ignore("non-boolean conditions in a CaseWhen are illegal") { - intercept[Exception] { - sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() - } + checkError( + exception = intercept[AnalysisException] { + sql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + }, + errorClass = null, + parameters = Map.empty) } createQueryTest("case sensitivity when query Hive table", @@ -807,14 +829,15 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd } test("ADD JAR command") { - val testJar = TestHive.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath sql("CREATE TABLE alter1(a INT, b INT)") - intercept[Exception] { - sql( - """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' - |WITH serdeproperties('s1'='9') - """.stripMargin) - } + checkError( + exception = intercept[AnalysisException] { + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9')""".stripMargin) + }, + errorClass = null, + parameters = Map.empty) sql("DROP TABLE alter1") } @@ -1229,22 +1252,30 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd sql("SET hive.exec.dynamic.partition.mode=strict") // Should throw when using strict dynamic partition mode without any static partition - intercept[AnalysisException] { - sql( - """INSERT INTO TABLE dp_test PARTITION(dp) - |SELECT key, value, key % 5 FROM src - """.stripMargin) - } + checkError( + exception = intercept[AnalysisException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src""".stripMargin) + }, + errorClass = "_LEGACY_ERROR_TEMP_1168", + parameters = Map( + "tableName" -> "`spark_catalog`.`default`.`dp_test`", + "targetColumns" -> "4", + "insertedColumns" -> "3", + "staticPartCols" -> "0")) sql("SET hive.exec.dynamic.partition.mode=nonstrict") // Should throw when a static partition appears after a dynamic partition - intercept[AnalysisException] { - sql( - """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) - |SELECT key, value, key % 5 FROM src - """.stripMargin) - } + checkError( + exception = intercept[AnalysisException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src""".stripMargin) + }, + errorClass = null, + parameters = Map.empty) } test("SPARK-3414 regression: should store analyzed logical plan when creating a temporary view") { @@ -1338,15 +1369,30 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd s2.sql("create table test_b(key INT, value STRING)") sql("select * from test_a") - intercept[AnalysisException] { - sql("select * from test_b") - } + checkError( + exception = intercept[AnalysisException] { + sql("select * from test_b") + }, + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`test_b`"), + context = ExpectedContext( + fragment = "test_b", + start = 14, + stop = 19)) + sql("select * from b.test_b") s2.sql("select * from test_b") - intercept[AnalysisException] { - s2.sql("select * from test_a") - } + checkError( + exception = intercept[AnalysisException] { + s2.sql("select * from test_a") + }, + errorClass = "TABLE_OR_VIEW_NOT_FOUND", + parameters = Map("relationName" -> "`test_a`"), + context = ExpectedContext( + fragment = "test_a", + start = 14, + stop = 19)) s2.sql("select * from a.test_a") } finally { sql("DROP TABLE IF EXISTS test_a") @@ -1362,28 +1408,48 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd sql("USE hive_test_db") assert("hive_test_db" == sql("select current_database()").first().getString(0)) - intercept[AnalysisException] { - sql("USE not_existing_db") - } + checkError( + exception = intercept[AnalysisException] { + sql("USE not_existing_db") + }, + errorClass = "SCHEMA_NOT_FOUND", + parameters = Map("schemaName" -> "`not_existing_db`")) sql(s"USE $currentDatabase") assert(currentDatabase == sql("select current_database()").first().getString(0)) } test("lookup hive UDF in another thread") { - val e = intercept[AnalysisException] { - range(1).selectExpr("not_a_udf()") - } - assert(e.getMessage.contains("Undefined function")) - assert(e.getMessage.contains("not_a_udf")) + checkErrorMatchPVals( + exception = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + }, + errorClass = "_LEGACY_ERROR_TEMP_1242", + sqlState = None, + parameters = Map( + "rawName" -> "not_a_udf", + "fullName" -> "spark_catalog.[a-z]+.not_a_udf"), + context = ExpectedContext( + fragment = "not_a_udf()", + start = 0, + stop = 10)) + var success = false val t = new Thread("test") { override def run(): Unit = { - val e = intercept[AnalysisException] { - range(1).selectExpr("not_a_udf()") - } - assert(e.getMessage.contains("Undefined function")) - assert(e.getMessage.contains("not_a_udf")) + checkErrorMatchPVals( + exception = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + }, + errorClass = "_LEGACY_ERROR_TEMP_1242", + sqlState = None, + parameters = Map( + "rawName" -> "not_a_udf", + "fullName" -> "spark_catalog.[a-z]+.not_a_udf"), + context = ExpectedContext( + fragment = "not_a_udf()", + start = 0, + stop = 10)) success = true } } @@ -1399,50 +1465,129 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd // since they modify /clear stuff. test("role management commands are not supported") { - assertUnsupportedFeature { sql("CREATE ROLE my_role") } - assertUnsupportedFeature { sql("DROP ROLE my_role") } - assertUnsupportedFeature { sql("SHOW CURRENT ROLES") } - assertUnsupportedFeature { sql("SHOW ROLES") } - assertUnsupportedFeature { sql("SHOW GRANT") } - assertUnsupportedFeature { sql("SHOW ROLE GRANT USER my_principal") } - assertUnsupportedFeature { sql("SHOW PRINCIPALS my_role") } - assertUnsupportedFeature { sql("SET ROLE my_role") } - assertUnsupportedFeature { sql("GRANT my_role TO USER my_user") } - assertUnsupportedFeature { sql("GRANT ALL ON my_table TO USER my_user") } - assertUnsupportedFeature { sql("REVOKE my_role FROM USER my_user") } - assertUnsupportedFeature { sql("REVOKE ALL ON my_table FROM USER my_user") } + assertUnsupportedFeature( + sql("CREATE ROLE my_role"), + "CREATE ROLE", + ExpectedContext(fragment = "CREATE ROLE my_role", start = 0, stop = 18)) + assertUnsupportedFeature( + sql("DROP ROLE my_role"), + "DROP ROLE", + ExpectedContext(fragment = "DROP ROLE my_role", start = 0, stop = 16)) + assertUnsupportedFeature( + sql("SHOW CURRENT ROLES"), + "SHOW CURRENT ROLES", + ExpectedContext(fragment = "SHOW CURRENT ROLES", start = 0, stop = 17)) + assertUnsupportedFeature( + sql("SHOW ROLES"), + "SHOW ROLES", + ExpectedContext(fragment = "SHOW ROLES", start = 0, stop = 9)) + assertUnsupportedFeature( + sql("SHOW GRANT"), + "SHOW GRANT", + ExpectedContext(fragment = "SHOW GRANT", start = 0, stop = 9)) + assertUnsupportedFeature( + sql("SHOW ROLE GRANT USER my_principal"), + "SHOW ROLE GRANT", + ExpectedContext(fragment = "SHOW ROLE GRANT USER my_principal", start = 0, stop = 32)) + assertUnsupportedFeature( + sql("SHOW PRINCIPALS my_role"), + "SHOW PRINCIPALS", + ExpectedContext(fragment = "SHOW PRINCIPALS my_role", start = 0, stop = 22)) + assertUnsupportedFeature( + sql("SET ROLE my_role"), + "SET ROLE", + ExpectedContext(fragment = "SET ROLE my_role", start = 0, stop = 15)) + assertUnsupportedFeature( + sql("GRANT my_role TO USER my_user"), + "GRANT", + ExpectedContext(fragment = "GRANT my_role TO USER my_user", start = 0, stop = 28)) + assertUnsupportedFeature( + sql("GRANT ALL ON my_table TO USER my_user"), + "GRANT", + ExpectedContext(fragment = "GRANT ALL ON my_table TO USER my_user", start = 0, stop = 36)) + assertUnsupportedFeature( + sql("REVOKE my_role FROM USER my_user"), + "REVOKE", + ExpectedContext(fragment = "REVOKE my_role FROM USER my_user", start = 0, stop = 31)) + assertUnsupportedFeature( + sql("REVOKE ALL ON my_table FROM USER my_user"), + "REVOKE", + ExpectedContext(fragment = "REVOKE ALL ON my_table FROM USER my_user", start = 0, stop = 39)) } test("import/export commands are not supported") { - assertUnsupportedFeature { sql("IMPORT TABLE my_table FROM 'my_path'") } - assertUnsupportedFeature { sql("EXPORT TABLE my_table TO 'my_path'") } + assertUnsupportedFeature( + sql("IMPORT TABLE my_table FROM 'my_path'"), + "IMPORT TABLE", + ExpectedContext(fragment = "IMPORT TABLE my_table FROM 'my_path'", start = 0, stop = 35)) + assertUnsupportedFeature( + sql("EXPORT TABLE my_table TO 'my_path'"), + "EXPORT TABLE", + ExpectedContext(fragment = "EXPORT TABLE my_table TO 'my_path'", start = 0, stop = 33)) } test("some show commands are not supported") { - assertUnsupportedFeature { sql("SHOW COMPACTIONS") } - assertUnsupportedFeature { sql("SHOW TRANSACTIONS") } - assertUnsupportedFeature { sql("SHOW INDEXES ON my_table") } - assertUnsupportedFeature { sql("SHOW LOCKS my_table") } + assertUnsupportedFeature( + sql("SHOW COMPACTIONS"), + "SHOW COMPACTIONS", + ExpectedContext(fragment = "SHOW COMPACTIONS", start = 0, stop = 15)) + assertUnsupportedFeature( + sql("SHOW TRANSACTIONS"), + "SHOW TRANSACTIONS", + ExpectedContext(fragment = "SHOW TRANSACTIONS", start = 0, stop = 16)) + assertUnsupportedFeature( + sql("SHOW INDEXES ON my_table"), + "SHOW INDEXES", + ExpectedContext(fragment = "SHOW INDEXES ON my_table", start = 0, stop = 23)) + assertUnsupportedFeature( + sql("SHOW LOCKS my_table"), + "SHOW LOCKS", + ExpectedContext(fragment = "SHOW LOCKS my_table", start = 0, stop = 18)) } test("lock/unlock table and database commands are not supported") { - assertUnsupportedFeature { sql("LOCK TABLE my_table SHARED") } - assertUnsupportedFeature { sql("UNLOCK TABLE my_table") } - assertUnsupportedFeature { sql("LOCK DATABASE my_db SHARED") } - assertUnsupportedFeature { sql("UNLOCK DATABASE my_db") } + assertUnsupportedFeature( + sql("LOCK TABLE my_table SHARED"), + "LOCK TABLE", + ExpectedContext(fragment = "LOCK TABLE my_table SHARED", start = 0, stop = 25)) + assertUnsupportedFeature( + sql("UNLOCK TABLE my_table"), + "UNLOCK TABLE", + ExpectedContext(fragment = "UNLOCK TABLE my_table", start = 0, stop = 20)) + assertUnsupportedFeature( + sql("LOCK DATABASE my_db SHARED"), + "LOCK DATABASE", + ExpectedContext(fragment = "LOCK DATABASE my_db SHARED", start = 0, stop = 25)) + assertUnsupportedFeature( + sql("UNLOCK DATABASE my_db"), + "UNLOCK DATABASE", + ExpectedContext(fragment = "UNLOCK DATABASE my_db", start = 0, stop = 20)) } test("alter index command is not supported") { - assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table REBUILD")} - assertUnsupportedFeature { - sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} + val sql1 = "ALTER INDEX my_index ON my_table REBUILD" + assertUnsupportedFeature( + sql(sql1), + "ALTER INDEX", + ExpectedContext(fragment = sql1, start = 0, stop = 39)) + val sql2 = "ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")" + assertUnsupportedFeature( + sql(sql2), + "ALTER INDEX", + ExpectedContext(fragment = sql2, start = 0, stop = 70)) } test("create/drop macro commands are not supported") { - assertUnsupportedFeature { - sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") - } - assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } + val sql1 = "CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))" + assertUnsupportedFeature( + sql(sql1), + "CREATE TEMPORARY MACRO", + ExpectedContext(fragment = sql1, start = 0, stop = 62)) + val sql2 = "DROP TEMPORARY MACRO SIGMOID" + assertUnsupportedFeature( + sql(sql2), + "DROP TEMPORARY MACRO", + ExpectedContext(fragment = sql2, start = 0, stop = 27)) } test("dynamic partitioning is allowed when hive.exec.dynamic.partition.mode is nonstrict") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7eeff8116490..a8745b2946bb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1744,17 +1744,21 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi test("SPARK-14981: DESC not supported for sorting columns") { withTable("t") { - val cause = intercept[ParseException] { - sql( - """CREATE TABLE t USING PARQUET - |OPTIONS (PATH '/path/to/file') - |CLUSTERED BY (a) SORTED BY (b DESC) INTO 2 BUCKETS - |AS SELECT 1 AS a, 2 AS b - """.stripMargin - ) - } - - assert(cause.getMessage.contains("Column ordering must be ASC, was 'DESC'")) + checkError( + exception = intercept[ParseException] { + sql( + """CREATE TABLE t USING PARQUET + |OPTIONS (PATH '/path/to/file') + |CLUSTERED BY (a) SORTED BY (b DESC) INTO 2 BUCKETS + |AS SELECT 1 AS a, 2 AS b + """.stripMargin) + }, + errorClass = "_LEGACY_ERROR_TEMP_0035", + parameters = Map("message" -> "Column ordering must be ASC, was 'DESC'"), + context = ExpectedContext( + fragment = "CLUSTERED BY (a) SORTED BY (b DESC) INTO 2 BUCKETS", + start = 60, + stop = 109)) } }