From 12e48527846d993a78b159fbba3e900a4feb7b55 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 14 Sep 2022 09:28:04 -0700 Subject: [PATCH 01/11] [SPARK-40423][K8S][TESTS] Add explicit YuniKorn queue submission test coverage ### What changes were proposed in this pull request? This PR aims to add explicit Yunikorn queue submission test coverage instead of implicit assignment by admission controller. ### Why are the changes needed? - To provide a proper test coverage. - To prevent the side effect of YuniKorn admission controller which overrides all Spark's scheduler settings by default (if we do not edit the rule explicitly). This breaks Apache Spark's default scheduler K8s IT test coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually run the CI and check the YuniKorn queue UI. ``` $ build/sbt -Psparkr -Pkubernetes -Pkubernetes-integration-tests -Dspark.kubernetes.test.deployMode=docker-desktop "kubernetes-integration-tests/test" -Dtest.exclude.tags=minikube,local,decom -Dtest.default.exclude.tags= ``` Screen Shot 2022-09-14 at 2 07 38 AM Closes #37877 from dongjoon-hyun/SPARK-40423. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- docs/running-on-kubernetes.md | 5 +++-- .../spark/deploy/k8s/integrationtest/YuniKornSuite.scala | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index ad02f04bed90e..028a482d51c29 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -1824,8 +1824,7 @@ Install Apache YuniKorn: ```bash helm repo add yunikorn https://apache.github.io/yunikorn-release helm repo update -kubectl create namespace yunikorn -helm install yunikorn yunikorn/yunikorn --namespace yunikorn --version 1.1.0 +helm install yunikorn yunikorn/yunikorn --namespace yunikorn --version 1.1.0 --create-namespace --set embedAdmissionController=false ``` The above steps will install YuniKorn v1.1.0 on an existing Kubernetes cluster. @@ -1836,6 +1835,8 @@ Submit Spark jobs with the following extra options: ```bash --conf spark.kubernetes.scheduler.name=yunikorn +--conf spark.kubernetes.driver.label.queue=root.default +--conf spark.kubernetes.executor.label.queue=root.default --conf spark.kubernetes.driver.annotation.yunikorn.apache.org/app-id={{APP_ID}} --conf spark.kubernetes.executor.annotation.yunikorn.apache.org/app-id={{APP_ID}} ``` diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/YuniKornSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/YuniKornSuite.scala index 5a3c063efa14b..0dfb88b259e21 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/YuniKornSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/YuniKornSuite.scala @@ -21,8 +21,11 @@ class YuniKornSuite extends KubernetesSuite { override protected def setUpTest(): Unit = { super.setUpTest() + val namespace = sparkAppConf.get("spark.kubernetes.namespace") sparkAppConf .set("spark.kubernetes.scheduler.name", "yunikorn") + .set("spark.kubernetes.driver.label.queue", "root." + namespace) + .set("spark.kubernetes.executor.label.queue", "root." + namespace) .set("spark.kubernetes.driver.annotation.yunikorn.apache.org/app-id", "{{APP_ID}}") .set("spark.kubernetes.executor.annotation.yunikorn.apache.org/app-id", "{{APP_ID}}") } From 40590e6d911ba0615edab445a56fd98ff620afea Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 15 Sep 2022 10:42:46 +0900 Subject: [PATCH 02/11] [SPARK-40397][BUILD] Upgrade `org.scalatestplus:selenium` to 3.12.13 ### What changes were proposed in this pull request? The main change of this pr as follows: - Upgrade `org.scalatestplus:selenium` from `org.scalatestplus:selenium-3-141:3.2.10.0` to `org.scalatestplus:selenium-4-2:3.2.13.0` and upgrade selenium-java from `3.141.59` to `4.2.2`, `htmlunit-driver` from `2.62.0` to `3.62.0` - okio upgrade from `1.14.0` to `1.15.0` due to both selenium-java and kubernetes-client depends on okio 1.15.0 and maven's nearby choice has also changed from 1.14.0 to 1.15.0 ### Why are the changes needed? Use the same version as other `org.scalatestplus` series dependencies, the release notes as follows: - https://github.com/scalatest/scalatestplus-selenium/releases/tag/release-3.2.11.0-for-selenium-4.1 - https://github.com/scalatest/scalatestplus-selenium/releases/tag/release-3.2.12.0-for-selenium-4.1 - https://github.com/scalatest/scalatestplus-selenium/releases/tag/release-3.2.12.1-for-selenium-4.1 - https://github.com/scalatest/scalatestplus-selenium/releases/tag/release-3.2.13.0-for-selenium-4.2 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manual test: - ChromeUISeleniumSuite ``` build/sbt -Dguava.version=31.1-jre -Dspark.test.webdriver.chrome.driver=/path/to/chromedriver -Dtest.default.exclude.tags="" -Phive -Phive-thriftserver "core/testOnly org.apache.spark.ui.ChromeUISeleniumSuite" ``` ``` [info] ChromeUISeleniumSuite: Starting ChromeDriver 105.0.5195.52 (412c95e518836d8a7d97250d62b29c2ae6a26a85-refs/branch-heads/5195{#853}) on port 53917 Only local connections are allowed. Please see https://chromedriver.chromium.org/security-considerations for suggestions on keeping ChromeDriver safe. ChromeDriver was started successfully. [info] - SPARK-31534: text for tooltip should be escaped (4 seconds, 447 milliseconds) [info] - SPARK-31882: Link URL for Stage DAGs should not depend on paged table. (841 milliseconds) [info] - SPARK-31886: Color barrier execution mode RDD correctly (297 milliseconds) [info] - Search text for paged tables should not be saved (1 second, 676 milliseconds) [info] Run completed in 11 seconds, 819 milliseconds. [info] Total number of tests run: 4 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [success] Total time: 25 s, completed 2022-9-14 20:12:28 ``` - ChromeUIHistoryServerSuite ``` build/sbt -Dguava.version=31.1-jre -Dspark.test.webdriver.chrome.driver=/path/to/chromedriver -Dtest.default.exclude.tags="" -Phive -Phive-thriftserver "core/testOnly org.apache.spark.deploy.history.ChromeUIHistoryServerSuite" ``` ``` [info] ChromeUIHistoryServerSuite: Starting ChromeDriver 105.0.5195.52 (412c95e518836d8a7d97250d62b29c2ae6a26a85-refs/branch-heads/5195{#853}) on port 58567 Only local connections are allowed. Please see https://chromedriver.chromium.org/security-considerations for suggestions on keeping ChromeDriver safe. ChromeDriver was started successfully. [info] - ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase) (2 seconds, 416 milliseconds) [info] Run completed in 8 seconds, 936 milliseconds. [info] Total number of tests run: 1 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 1, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [success] Total time: 30 s, completed 2022-9-14 20:11:34 ``` Closes #37868 from LuciferYang/SPARK-40397. Authored-by: yangjie01 Signed-off-by: Kousuke Saruta --- dev/deps/spark-deps-hadoop-2-hive-2.3 | 2 +- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 18 +++++++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index faf86376e9488..88a77c74ec4fa 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -217,7 +217,7 @@ netty-transport-native-unix-common/4.1.80.Final//netty-transport-native-unix-com netty-transport/4.1.80.Final//netty-transport-4.1.80.Final.jar objenesis/3.2//objenesis-3.2.jar okhttp/3.12.12//okhttp-3.12.12.jar -okio/1.14.0//okio-1.14.0.jar +okio/1.15.0//okio-1.15.0.jar opencsv/2.3//opencsv-2.3.jar orc-core/1.8.0/shaded-protobuf/orc-core-1.8.0-shaded-protobuf.jar orc-mapreduce/1.8.0/shaded-protobuf/orc-mapreduce-1.8.0-shaded-protobuf.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 7a73deb019ac4..0201df37b87df 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -201,7 +201,7 @@ netty-transport-native-unix-common/4.1.80.Final//netty-transport-native-unix-com netty-transport/4.1.80.Final//netty-transport-4.1.80.Final.jar objenesis/3.2//objenesis-3.2.jar okhttp/3.12.12//okhttp-3.12.12.jar -okio/1.14.0//okio-1.14.0.jar +okio/1.15.0//okio-1.15.0.jar opencsv/2.3//opencsv-2.3.jar opentracing-api/0.33.0//opentracing-api-0.33.0.jar opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar diff --git a/pom.xml b/pom.xml index 0350ff793d2dd..7bb803228fe17 100644 --- a/pom.xml +++ b/pom.xml @@ -195,7 +195,8 @@ 4.9.3 1.1 - 3.141.59 + 4.2.2 + 3.62.0 2.62.0 1.8 1.1.0 @@ -408,7 +409,7 @@ org.scalatestplus - selenium-3-141_${scala.binary.version} + selenium-4-2_${scala.binary.version} test @@ -693,9 +694,13 @@ com.google.guava guava + + com.google.auto.service + * + io.netty - netty + * net.bytebuddy @@ -706,7 +711,7 @@ org.seleniumhq.selenium htmlunit-driver - ${htmlunit.version} + ${htmlunit-driver.version} test @@ -1160,11 +1165,10 @@ 3.2.13.0 test - org.scalatestplus - selenium-3-141_${scala.binary.version} - 3.2.10.0 + selenium-4-2_${scala.binary.version} + 3.2.13.0 test From c134c7597d19df783c085bb79b6deb01e21c769a Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Thu, 15 Sep 2022 10:47:25 +0900 Subject: [PATCH 03/11] [SPARK-40339][SPARK-40342][SPARK-40345][SPARK-40348][PS] Implement quantile in Rolling/RollingGroupby/Expanding/ExpandingGroupby ### What changes were proposed in this pull request? Implement quantile in Rolling/RollingGroupby/Expanding/ExpandingGroupby ### Why are the changes needed? Improve PS api coverage ```python >>> s = ps.Series([4, 3, 5, 2, 6]) >>> s.rolling(2).quantile(0.5) 0 NaN 1 3.0 2 3.0 3 2.0 4 2.0 dtype: float64 >>> s = ps.Series([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) >>> s.groupby(s).rolling(3).quantile(0.5).sort_index() 2 0 NaN 1 NaN 3 2 NaN 3 NaN 4 3.0 4 5 NaN 6 NaN 7 4.0 8 4.0 5 9 NaN 10 NaN dtype: float64 >>> s = ps.Series([1, 2, 3, 4]) >>> s.expanding(2).quantile(0.5) 0 NaN 1 1.0 2 2.0 3 2.0 dtype: float64 >>> s = ps.Series([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) >>> s.groupby(s).expanding(3).quantile(0.5).sort_index() 2 0 NaN 1 NaN 3 2 NaN 3 NaN 4 3.0 4 5 NaN 6 NaN 7 4.0 8 4.0 5 9 NaN 10 NaN dtype: float64 ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed Closes #37836 from Yikun/SPARK-40339. Authored-by: Yikun Jiang Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/groupby.py | 2 +- python/pyspark/pandas/missing/window.py | 4 - python/pyspark/pandas/tests/test_expanding.py | 8 + python/pyspark/pandas/tests/test_rolling.py | 8 + python/pyspark/pandas/window.py | 309 ++++++++++++++++++ 5 files changed, 326 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 2e2e5540bd4bd..05dd6d15eec6c 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -607,7 +607,7 @@ def quantile(self, q: float = 0.5, accuracy: int = 10000) -> FrameLike: ------- `quantile` in pandas-on-Spark are using distributed percentile approximation algorithm unlike pandas, the result might different with pandas, also - `interpolation` parameters are not supported yet. + `interpolation` parameter is not supported yet. See Also -------- diff --git a/python/pyspark/pandas/missing/window.py b/python/pyspark/pandas/missing/window.py index 31684e43ccf25..a6d423d08f1c7 100644 --- a/python/pyspark/pandas/missing/window.py +++ b/python/pyspark/pandas/missing/window.py @@ -82,7 +82,6 @@ class MissingPandasLikeExpanding: corr = _unsupported_function_expanding("corr") cov = _unsupported_function_expanding("cov") median = _unsupported_function_expanding("median") - quantile = _unsupported_function_expanding("quantile") validate = _unsupported_function_expanding("validate") exclusions = _unsupported_property_expanding("exclusions") @@ -101,7 +100,6 @@ class MissingPandasLikeRolling: corr = _unsupported_function_rolling("corr") cov = _unsupported_function_rolling("cov") median = _unsupported_function_rolling("median") - quantile = _unsupported_function_rolling("quantile") validate = _unsupported_function_rolling("validate") exclusions = _unsupported_property_rolling("exclusions") @@ -120,7 +118,6 @@ class MissingPandasLikeExpandingGroupby: corr = _unsupported_function_expanding("corr") cov = _unsupported_function_expanding("cov") median = _unsupported_function_expanding("median") - quantile = _unsupported_function_expanding("quantile") validate = _unsupported_function_expanding("validate") exclusions = _unsupported_property_expanding("exclusions") @@ -139,7 +136,6 @@ class MissingPandasLikeRollingGroupby: corr = _unsupported_function_rolling("corr") cov = _unsupported_function_rolling("cov") median = _unsupported_function_rolling("median") - quantile = _unsupported_function_rolling("quantile") validate = _unsupported_function_rolling("validate") exclusions = _unsupported_property_rolling("exclusions") diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index aeb0e9f297bce..77ced41eb8cb0 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -82,6 +82,9 @@ def test_expanding_max(self): def test_expanding_mean(self): self._test_expanding_func("mean") + def test_expanding_quantile(self): + self._test_expanding_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) + def test_expanding_sum(self): self._test_expanding_func("sum") @@ -212,6 +215,11 @@ def test_groupby_expanding_max(self): def test_groupby_expanding_mean(self): self._test_groupby_expanding_func("mean") + def test_groupby_expanding_quantile(self): + self._test_groupby_expanding_func( + lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") + ) + def test_groupby_expanding_sum(self): self._test_groupby_expanding_func("sum") diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index 3f92eba79ce99..be21bf16d409b 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -79,6 +79,9 @@ def test_rolling_max(self): def test_rolling_mean(self): self._test_rolling_func("mean") + def test_rolling_quantile(self): + self._test_rolling_func(lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower")) + def test_rolling_sum(self): self._test_rolling_func("sum") @@ -212,6 +215,11 @@ def test_groupby_rolling_max(self): def test_groupby_rolling_mean(self): self._test_groupby_rolling_func("mean") + def test_groupby_rolling_quantile(self): + self._test_groupby_rolling_func( + lambda x: x.quantile(0.5), lambda x: x.quantile(0.5, "lower") + ) + def test_groupby_rolling_sum(self): self._test_groupby_rolling_func("sum") diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index 2808f72fd3c12..274000cbb2cd0 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -40,6 +40,9 @@ from pyspark.pandas.spark import functions as SF from pyspark.pandas.utils import scol_for from pyspark.sql.column import Column +from pyspark.sql.types import ( + DoubleType, +) from pyspark.sql.window import WindowSpec @@ -101,6 +104,15 @@ def mean(scol: Column) -> Column: return self._apply_as_series_or_frame(mean) + def quantile(self, q: float, accuracy: int = 10000) -> FrameLike: + def quantile(scol: Column) -> Column: + return F.when( + F.row_number().over(self._unbounded_window) >= self._min_periods, + F.percentile_approx(scol.cast(DoubleType()), q, accuracy).over(self._window), + ).otherwise(SF.lit(None)) + + return self._apply_as_series_or_frame(quantile) + def std(self) -> FrameLike: def std(scol: Column) -> Column: return F.when( @@ -561,6 +573,101 @@ def mean(self) -> FrameLike: """ return super().mean() + def quantile(self, quantile: float, accuracy: int = 10000) -> FrameLike: + """ + Calculate the rolling quantile of the values. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + quantile : float + Value between 0 and 1 providing the quantile to compute. + accuracy : int, optional + Default accuracy of approximation. Larger value means better accuracy. + The relative error can be deduced by 1.0 / accuracy. + This is a panda-on-Spark specific parameter. + + Returns + ------- + Series or DataFrame + Returned object type is determined by the caller of the rolling + calculation. + + Notes + ----- + `quantile` in pandas-on-Spark are using distributed percentile approximation + algorithm unlike pandas, the result might different with pandas, also `interpolation` + parameter is not supported yet. + + the current implementation of this API uses Spark's Window without + specifying partition specification. This leads to move all data into + single partition in single machine and could cause serious + performance degradation. Avoid this method against very large dataset. + + See Also + -------- + pyspark.pandas.Series.rolling : Calling rolling with Series data. + pyspark.pandas.DataFrame.rolling : Calling rolling with DataFrames. + pyspark.pandas.Series.quantile : Aggregating quantile for Series. + pyspark.pandas.DataFrame.quantile : Aggregating quantile for DataFrame. + + Examples + -------- + >>> s = ps.Series([4, 3, 5, 2, 6]) + >>> s + 0 4 + 1 3 + 2 5 + 3 2 + 4 6 + dtype: int64 + + >>> s.rolling(2).quantile(0.5) + 0 NaN + 1 3.0 + 2 3.0 + 3 2.0 + 4 2.0 + dtype: float64 + + >>> s.rolling(3).quantile(0.5) + 0 NaN + 1 NaN + 2 4.0 + 3 3.0 + 4 5.0 + dtype: float64 + + For DataFrame, each rolling quantile is computed column-wise. + + >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) + >>> df + A B + 0 4 16 + 1 3 9 + 2 5 25 + 3 2 4 + 4 6 36 + + >>> df.rolling(2).quantile(0.5) + A B + 0 NaN NaN + 1 3.0 9.0 + 2 3.0 9.0 + 3 2.0 4.0 + 4 2.0 4.0 + + >>> df.rolling(3).quantile(0.5) + A B + 0 NaN NaN + 1 NaN NaN + 2 4.0 16.0 + 3 3.0 9.0 + 4 5.0 25.0 + """ + return super().quantile(quantile, accuracy) + def std(self) -> FrameLike: """ Calculate rolling standard deviation. @@ -1136,6 +1243,77 @@ def mean(self) -> FrameLike: """ return super().mean() + def quantile(self, quantile: float, accuracy: int = 10000) -> FrameLike: + """ + Calculate rolling quantile. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + quantile : float + Value between 0 and 1 providing the quantile to compute. + accuracy : int, optional + Default accuracy of approximation. Larger value means better accuracy. + The relative error can be deduced by 1.0 / accuracy. + This is a panda-on-Spark specific parameter. + + Returns + ------- + Series or DataFrame + Returned object type is determined by the caller of the rolling + calculation. + + Notes + ----- + `quantile` in pandas-on-Spark are using distributed percentile approximation + algorithm unlike pandas, the result might different with pandas, also `interpolation` + parameter is not supported yet. + + See Also + -------- + pyspark.pandas.Series.rolling : Calling rolling with Series data. + pyspark.pandas.DataFrame.rolling : Calling rolling with DataFrames. + pyspark.pandas.Series.quantile : Aggregating quantile for Series. + pyspark.pandas.DataFrame.quantile : Aggregating quantile for DataFrame. + + Examples + -------- + >>> s = ps.Series([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) + >>> s.groupby(s).rolling(3).quantile(0.5).sort_index() + 2 0 NaN + 1 NaN + 3 2 NaN + 3 NaN + 4 3.0 + 4 5 NaN + 6 NaN + 7 4.0 + 8 4.0 + 5 9 NaN + 10 NaN + dtype: float64 + + For DataFrame, each rolling quantile is computed column-wise. + + >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) + >>> df.groupby(df.A).rolling(2).quantile(0.5).sort_index() + B + A + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 + """ + return super().quantile(quantile, accuracy) + def std(self) -> FrameLike: """ Calculate rolling standard deviation. @@ -1483,6 +1661,66 @@ def mean(self) -> FrameLike: """ return super().mean() + def quantile(self, quantile: float, accuracy: int = 10000) -> FrameLike: + """ + Calculate the expanding quantile of the values. + + Returns + ------- + Series or DataFrame + Returned object type is determined by the caller of the expanding + calculation. + + Parameters + ---------- + quantile : float + Value between 0 and 1 providing the quantile to compute. + accuracy : int, optional + Default accuracy of approximation. Larger value means better accuracy. + The relative error can be deduced by 1.0 / accuracy. + This is a panda-on-Spark specific parameter. + + Notes + ----- + `quantile` in pandas-on-Spark are using distributed percentile approximation + algorithm unlike pandas, the result might different with pandas (the result is + similar to the interpolation set to `lower`), also `interpolation` parameter is + not supported yet. + + the current implementation of this API uses Spark's Window without + specifying partition specification. This leads to move all data into + single partition in single machine and could cause serious + performance degradation. Avoid this method against very large dataset. + + See Also + -------- + pyspark.pandas.Series.expanding : Calling expanding with Series data. + pyspark.pandas.DataFrame.expanding : Calling expanding with DataFrames. + pyspark.pandas.Series.quantile : Aggregating quantile for Series. + pyspark.pandas.DataFrame.quantile : Aggregating quantile for DataFrame. + + Examples + -------- + The below examples will show expanding quantile calculations with window sizes of + two and three, respectively. + + >>> s = ps.Series([1, 2, 3, 4]) + >>> s.expanding(2).quantile(0.5) + 0 NaN + 1 1.0 + 2 2.0 + 3 2.0 + dtype: float64 + + >>> s.expanding(3).quantile(0.5) + 0 NaN + 1 NaN + 2 2.0 + 3 2.0 + dtype: float64 + """ + return super().quantile(quantile, accuracy) + def std(self) -> FrameLike: """ Calculate expanding standard deviation. @@ -1978,6 +2216,77 @@ def mean(self) -> FrameLike: """ return super().mean() + def quantile(self, quantile: float, accuracy: int = 10000) -> FrameLike: + """ + Calculate the expanding quantile of the values. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + quantile : float + Value between 0 and 1 providing the quantile to compute. + accuracy : int, optional + Default accuracy of approximation. Larger value means better accuracy. + The relative error can be deduced by 1.0 / accuracy. + This is a panda-on-Spark specific parameter. + + Returns + ------- + Series or DataFrame + Returned object type is determined by the caller of the expanding + calculation. + + Notes + ----- + `quantile` in pandas-on-Spark are using distributed percentile approximation + algorithm unlike pandas, the result might different with pandas, also `interpolation` + parameter is not supported yet. + + See Also + -------- + pyspark.pandas.Series.expanding : Calling expanding with Series data. + pyspark.pandas.DataFrame.expanding : Calling expanding with DataFrames. + pyspark.pandas.Series.quantile : Aggregating quantile for Series. + pyspark.pandas.DataFrame.quantile : Aggregating quantile for DataFrame. + + Examples + -------- + >>> s = ps.Series([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) + >>> s.groupby(s).expanding(3).quantile(0.5).sort_index() + 2 0 NaN + 1 NaN + 3 2 NaN + 3 NaN + 4 3.0 + 4 5 NaN + 6 NaN + 7 4.0 + 8 4.0 + 5 9 NaN + 10 NaN + dtype: float64 + + For DataFrame, each expanding quantile is computed column-wise. + + >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) + >>> df.groupby(df.A).expanding(2).quantile(0.5).sort_index() + B + A + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 + """ + return super().quantile(quantile, accuracy) + def std(self) -> FrameLike: """ Calculate expanding standard deviation. From ea6857abff8e93ebd7dfcb536976278d5d8e10d7 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 15 Sep 2022 10:48:48 +0900 Subject: [PATCH 04/11] [SPARK-40426][SQL] Return a map from SparkThrowable.getMessageParameters ### What changes were proposed in this pull request? In the PR, I propose to change the `SparkThrowable` interface: 1. Return a map of parameters names to their values from `getMessageParameters()` 2. Remove `getParameterNames()` because the names can be retrieved from `getMessageParameters()`. ### Why are the changes needed? To simplifies implementation and improve code maintenance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running affected test suites: ``` $ build/sbt "core/testOnly *SparkThrowableSuite" $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite" ``` Closes #37871 from MaxGekk/getMessageParameters-map. Authored-by: Max Gekk Signed-off-by: Hyukjin Kwon --- .../java/org/apache/spark/SparkThrowable.java | 12 ++-- .../spark/memory/SparkOutOfMemoryError.java | 4 +- .../org/apache/spark/SparkException.scala | 68 +++++-------------- .../apache/spark/SparkThrowableHelper.scala | 22 +----- .../org/apache/spark/SparkFunSuite.scala | 3 +- .../apache/spark/SparkThrowableSuite.scala | 2 +- .../apache/spark/sql/AnalysisException.scala | 8 +-- 7 files changed, 34 insertions(+), 85 deletions(-) diff --git a/core/src/main/java/org/apache/spark/SparkThrowable.java b/core/src/main/java/org/apache/spark/SparkThrowable.java index 52fd64135a95c..0fe3b51d3772f 100644 --- a/core/src/main/java/org/apache/spark/SparkThrowable.java +++ b/core/src/main/java/org/apache/spark/SparkThrowable.java @@ -19,6 +19,9 @@ import org.apache.spark.annotation.Evolving; +import java.util.HashMap; +import java.util.Map; + /** * Interface mixed into Throwables thrown from Spark. * @@ -51,13 +54,8 @@ default boolean isInternalError() { return SparkThrowableHelper.isInternalError(this.getErrorClass()); } - default String[] getMessageParameters() { - return new String[]{}; - } - - // Returns a string array of all parameters that need to be passed to this error message. - default String[] getParameterNames() { - return SparkThrowableHelper.getParameterNames(this.getErrorClass(), this.getErrorSubClass()); + default Map getMessageParameters() { + return new HashMap(); } default QueryContext[] getQueryContext() { return new QueryContext[0]; } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index cfbf2e574787d..0916d9e6fe307 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -47,8 +47,8 @@ public SparkOutOfMemoryError(String errorClass, Map messageParam } @Override - public String[] getMessageParameters() { - return SparkThrowableHelper.getMessageParameters(errorClass, null, messageParameters); + public Map getMessageParameters() { + return messageParameters; } @Override diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index dba6ef9347ff0..ff59cfb3455f1 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -22,6 +22,8 @@ import java.sql.{SQLException, SQLFeatureNotSupportedException} import java.time.DateTimeException import java.util.ConcurrentModificationException +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileAlreadyExistsException class SparkException( @@ -86,11 +88,7 @@ class SparkException( errorSubClass = Some(errorSubClass), messageParameters = messageParameters) - override def getMessageParameters: Array[String] = { - errorClass.map { ec => - SparkThrowableHelper.getMessageParameters(ec, errorSubClass.orNull, messageParameters) - }.getOrElse(Array.empty) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass.orNull override def getErrorSubClass: String = errorSubClass.orNull @@ -146,9 +144,7 @@ private[spark] class SparkUpgradeException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters), cause) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull} @@ -166,9 +162,7 @@ private[spark] class SparkArithmeticException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -195,9 +189,7 @@ private[spark] class SparkUnsupportedOperationException( errorSubClass = Some(errorSubClass), messageParameters = messageParameters) - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -215,9 +207,7 @@ private[spark] class SparkClassNotFoundException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters), cause) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull} @@ -234,9 +224,7 @@ private[spark] class SparkConcurrentModificationException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters), cause) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull} @@ -254,9 +242,7 @@ private[spark] class SparkDateTimeException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -274,9 +260,7 @@ private[spark] class SparkFileAlreadyExistsException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull} @@ -292,9 +276,7 @@ private[spark] class SparkFileNotFoundException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull} @@ -312,9 +294,7 @@ private[spark] class SparkNumberFormatException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -334,9 +314,7 @@ private[spark] class SparkIllegalArgumentException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -379,9 +357,7 @@ private[spark] class SparkRuntimeException( cause = null, context = Array.empty[QueryContext]) - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -399,9 +375,7 @@ private[spark] class SparkSecurityException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -420,9 +394,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException( SparkThrowableHelper.getMessage(errorClass, errorSubClass.orNull, messageParameters, summary)) with SparkThrowable { - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -446,9 +418,7 @@ private[spark] class SparkSQLException( errorSubClass = None, messageParameters = messageParameters) - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull @@ -474,9 +444,7 @@ private[spark] class SparkSQLFeatureNotSupportedException( errorSubClass = Some(errorSubClass), messageParameters = messageParameters) - override def getMessageParameters: Array[String] = { - SparkThrowableHelper.getMessageParameters(errorClass, errorSubClass.orNull, messageParameters) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass override def getErrorSubClass: String = errorSubClass.orNull diff --git a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 81c5c6cb043b3..497fd91d77ac7 100644 --- a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -101,20 +101,6 @@ private[spark] object SparkThrowableHelper { parameterNames } - def getMessageParameters( - errorClass: String, - errorSubCLass: String, - params: Map[String, String]): Array[String] = { - getParameterNames(errorClass, errorSubCLass).map(params.getOrElse(_, "?")) - } - - def getMessageParameters( - errorClass: String, - errorSubCLass: String, - params: java.util.Map[String, String]): Array[String] = { - getParameterNames(errorClass, errorSubCLass).map(params.getOrDefault(_, "?")) - } - def getMessage( errorClass: String, errorSubClass: String, @@ -185,8 +171,6 @@ private[spark] object SparkThrowableHelper { } case MINIMAL | STANDARD => val errorClass = e.getErrorClass - assert(e.getParameterNames.size == e.getMessageParameters.size, - "Number of message parameter names and values must be the same") toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() @@ -200,10 +184,10 @@ private[spark] object SparkThrowableHelper { } val sqlState = e.getSqlState if (sqlState != null) g.writeStringField("sqlState", sqlState) - val parameterNames = e.getParameterNames - if (!parameterNames.isEmpty) { + val messageParameters = e.getMessageParameters + if (!messageParameters.isEmpty) { g.writeObjectFieldStart("messageParameters") - (parameterNames zip e.getMessageParameters).sortBy(_._1).foreach { case (name, value) => + messageParameters.asScala.toSeq.sortBy(_._1).foreach { case (name, value) => g.writeStringField(name, value) } g.writeEndObject() diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index ade20dbff83ff..10ebbe76d6c74 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -23,6 +23,7 @@ import java.nio.file.{Files, Path} import java.util.{Locale, TimeZone} import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.commons.io.FileUtils @@ -305,7 +306,7 @@ abstract class SparkFunSuite assert(exception.getErrorSubClass === errorSubClass.get) } sqlState.foreach(state => assert(exception.getSqlState === state)) - val expectedParameters = (exception.getParameterNames zip exception.getMessageParameters).toMap + val expectedParameters = exception.getMessageParameters.asScala if (matchPVals == true) { assert(expectedParameters.size === parameters.size) expectedParameters.foreach( diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index cbf273dc5c857..6eb66cd8f18fd 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -161,7 +161,7 @@ class SparkThrowableSuite extends SparkFunSuite { getMessage("UNRESOLVED_COLUMN", "WITHOUT_SUGGESTION", Map.empty[String, String]) } assert(e.getErrorClass === "INTERNAL_ERROR") - assert(e.getMessageParameters.head.contains("Undefined an error message parameter")) + assert(e.getMessageParameters().get("message").contains("Undefined an error message parameter")) // Does not fail with too many args (expects 0 args) assert(getMessage("DIVIDE_BY_ZERO", null, Map("config" -> "foo", "a" -> "bar")) == diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index 4b27dfc00c2aa..9dbdaa3d59a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -165,11 +167,7 @@ class AnalysisException protected[sql] ( message } - override def getMessageParameters: Array[String] = { - errorClass.map { ec => - SparkThrowableHelper.getMessageParameters(ec, errorSubClass.orNull, messageParameters) - }.getOrElse(Array.empty) - } + override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava override def getErrorClass: String = errorClass.orNull override def getErrorSubClass: String = errorSubClass.orNull From 1c46c87ddb1991230ddfee8f9f6205df0318e056 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 15 Sep 2022 09:52:50 +0800 Subject: [PATCH 05/11] [SPARK-40421][PS] Make `spearman` correlation in `DataFrame.corr` support missing values and `min_periods` ### What changes were proposed in this pull request? refactor `spearman` correlation in `DataFrame.corr` to: 1. support missing values; 2. add parameter min_periods; 3. enable arrow execution since no longer depend on VectorUDT; 4. support lazy evaluation; ### Why are the changes needed? to make its behavior same as Pandas ### Does this PR introduce _any_ user-facing change? yes, API change, new parameter supported ### How was this patch tested? added UT Closes #37874 from zhengruifeng/ps_df_spearman. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/frame.py | 389 ++++++++++++---------- python/pyspark/pandas/tests/test_stats.py | 66 +++- 2 files changed, 275 insertions(+), 180 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index cf14a5482660c..f6aead93f9eec 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -121,7 +121,6 @@ SPARK_INDEX_NAME_PATTERN, ) from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame -from pyspark.pandas.ml import corr from pyspark.pandas.typedef.typehints import ( as_spark_type, infer_return_type, @@ -1430,8 +1429,7 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D * spearman : Spearman rank correlation min_periods : int, optional Minimum number of observations required per pair of columns - to have a valid result. Currently only available for Pearson - correlation. + to have a valid result. .. versionadded:: 3.4.0 @@ -1462,8 +1460,6 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D There are behavior differences between pandas-on-Spark and pandas. * the `method` argument only accepts 'pearson', 'spearman' - * if the `method` is `spearman`, the data should not contain NaNs. - * if the `method` is `spearman`, `min_periods` argument is not supported. """ if method not in ["pearson", "spearman", "kendall"]: raise ValueError(f"Invalid method {method}") @@ -1471,194 +1467,251 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D raise NotImplementedError("method doesn't support kendall for now") if min_periods is not None and not isinstance(min_periods, int): raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") - if min_periods is not None and method == "spearman": - raise NotImplementedError("min_periods doesn't support spearman for now") - - if method == "pearson": - min_periods = 1 if min_periods is None else min_periods - internal = self._internal.resolved_copy - numeric_labels = [ - label - for label in internal.column_labels - if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) - ] - numeric_scols: List[Column] = [ - internal.spark_column_for(label).cast("double") for label in numeric_labels - ] - numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] - num_scols = len(numeric_scols) - sdf = internal.spark_frame - tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") - tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") - tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") - tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") - - # simple dataset - # +---+---+----+ - # | A| B| C| - # +---+---+----+ - # | 1| 2| 3.0| - # | 4| 1|null| - # +---+---+----+ - - pair_scols: List[Column] = [] - for i in range(0, num_scols): - for j in range(i, num_scols): - pair_scols.append( - F.struct( - F.lit(i).alias(tmp_index_1_col_name), - F.lit(j).alias(tmp_index_2_col_name), - numeric_scols[i].alias(tmp_value_1_col_name), - numeric_scols[j].alias(tmp_value_2_col_name), - ) + min_periods = 1 if min_periods is None else min_periods + internal = self._internal.resolved_copy + numeric_labels = [ + label + for label in internal.column_labels + if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) + ] + numeric_scols: List[Column] = [ + internal.spark_column_for(label).cast("double") for label in numeric_labels + ] + numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] + num_scols = len(numeric_scols) + + sdf = internal.spark_frame + tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") + tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") + tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") + tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") + + # simple dataset + # +---+---+----+ + # | A| B| C| + # +---+---+----+ + # | 1| 2| 3.0| + # | 4| 1|null| + # +---+---+----+ + + pair_scols: List[Column] = [] + for i in range(0, num_scols): + for j in range(i, num_scols): + pair_scols.append( + F.struct( + F.lit(i).alias(tmp_index_1_col_name), + F.lit(j).alias(tmp_index_2_col_name), + numeric_scols[i].alias(tmp_value_1_col_name), + numeric_scols[j].alias(tmp_value_2_col_name), ) + ) - # +-------------------+-------------------+-------------------+-------------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| - # +-------------------+-------------------+-------------------+-------------------+ - # | 0| 0| 1.0| 1.0| - # | 0| 1| 1.0| 2.0| - # | 0| 2| 1.0| 3.0| - # | 1| 1| 2.0| 2.0| - # | 1| 2| 2.0| 3.0| - # | 2| 2| 3.0| 3.0| - # | 0| 0| 4.0| 4.0| - # | 0| 1| 4.0| 1.0| - # | 0| 2| 4.0| null| - # | 1| 1| 1.0| 1.0| - # | 1| 2| 1.0| null| - # | 2| 2| null| null| - # +-------------------+-------------------+-------------------+-------------------+ - tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") - sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( - F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name), - ) + # +-------------------+-------------------+-------------------+-------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| + # +-------------------+-------------------+-------------------+-------------------+ + # | 0| 0| 1.0| 1.0| + # | 0| 1| 1.0| 2.0| + # | 0| 2| 1.0| 3.0| + # | 1| 1| 2.0| 2.0| + # | 1| 2| 2.0| 3.0| + # | 2| 2| 3.0| 3.0| + # | 0| 0| 4.0| 4.0| + # | 0| 1| 4.0| 1.0| + # | 0| 2| null| null| + # | 1| 1| 1.0| 1.0| + # | 1| 2| null| null| + # | 2| 2| null| null| + # +-------------------+-------------------+-------------------+-------------------+ + tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") + null_cond = F.isnull(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) | F.isnull( + F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}") + ) + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( + F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), + F.when(null_cond, F.lit(None)) + .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) + .alias(tmp_value_1_col_name), + F.when(null_cond, F.lit(None)) + .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}")) + .alias(tmp_value_2_col_name), + ) - # +-------------------+-------------------+------------------------+-----------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__| - # +-------------------+-------------------+------------------------+-----------------+ - # | 2| 2| null| 1| - # | 1| 2| null| 1| - # | 1| 1| 1.0| 2| - # | 0| 0| 1.0| 2| - # | 0| 1| -1.0| 2| - # | 0| 2| null| 1| - # +-------------------+-------------------+------------------------+-----------------+ - tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") - tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") - sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( - F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), - F.count( - F.when( - F.col(tmp_value_1_col_name).isNotNull() - & F.col(tmp_value_2_col_name).isNotNull(), - 1, - ) - ).alias(tmp_count_col_name), - ) + # convert values to avg ranks for spearman correlation + if method == "spearman": + tmp_row_number_col_name = verify_temp_column_name(sdf, "__tmp_row_number_col__") + tmp_dense_rank_col_name = verify_temp_column_name(sdf, "__tmp_dense_rank_col__") + window = Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name) - # +-------------------+-------------------+------------------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__| - # +-------------------+-------------------+------------------------+ - # | 2| 2| null| - # | 1| 2| null| - # | 2| 1| null| - # | 1| 1| 1.0| - # | 0| 0| 1.0| - # | 0| 1| -1.0| - # | 1| 0| -1.0| - # | 0| 2| null| - # | 2| 0| null| - # +-------------------+-------------------+------------------------+ + # tmp_value_1_col_name: value -> avg rank + # for example: + # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 + # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 sdf = ( sdf.withColumn( - tmp_corr_col_name, - F.when( - F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) - ).otherwise(F.lit(None)), + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), ) .withColumn( - tmp_tuple_col_name, - F.explode( - F.when( - F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), - F.lit([0]), - ).otherwise(F.lit([0, 1])) - ), + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), ) - .select( - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) - .otherwise(F.col(tmp_index_2_col_name)) - .alias(tmp_index_1_col_name), - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) - .otherwise(F.col(tmp_index_1_col_name)) - .alias(tmp_index_2_col_name), - F.col(tmp_corr_col_name), + .withColumn( + tmp_value_1_col_name, + F.when(F.isnull(F.col(tmp_value_1_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) + ) + ), ) ) - # +-------------------+--------------------+ - # |__tmp_index_1_col__| __tmp_array_col__| - # +-------------------+--------------------+ - # | 0|[{0, 1.0}, {1, -1...| - # | 1|[{0, -1.0}, {1, 1...| - # | 2|[{0, null}, {1, n...| - # +-------------------+--------------------+ - tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + # tmp_value_2_col_name: value -> avg rank sdf = ( - sdf.groupby(tmp_index_1_col_name) - .agg( - F.array_sort( - F.collect_list( - F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name)) + sdf.withColumn( + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_value_2_col_name, + F.when(F.isnull(F.col(tmp_value_2_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) ) - ).alias(tmp_array_col_name) + ), ) - .orderBy(tmp_index_1_col_name) ) - for i in range(0, num_scols): - sdf = sdf.withColumn( - tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) - ).withColumn( - numeric_col_names[i], - F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), - ) + sdf = sdf.select( + tmp_index_1_col_name, + tmp_index_2_col_name, + tmp_value_1_col_name, + tmp_value_2_col_name, + ) - index_col_names: List[str] = [] - if internal.column_labels_level > 1: - for level in range(0, internal.column_labels_level): - index_col_name = SPARK_INDEX_NAME_FORMAT(level) - indices = [label[level] for label in numeric_labels] - sdf = sdf.withColumn( - index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) - ) - index_col_names.append(index_col_name) - else: - sdf = sdf.withColumn( - SPARK_DEFAULT_INDEX_NAME, - F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), + # +-------------------+-------------------+----------------+-----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__| + # +-------------------+-------------------+----------------+-----------------+ + # | 2| 2| null| 1| + # | 1| 2| null| 1| + # | 1| 1| 1.0| 2| + # | 0| 0| 1.0| 2| + # | 0| 1| -1.0| 2| + # | 0| 2| null| 1| + # +-------------------+-------------------+----------------+-----------------+ + tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__") + tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") + + sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( + F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), + F.count( + F.when( + F.col(tmp_value_1_col_name).isNotNull() + & F.col(tmp_value_2_col_name).isNotNull(), + 1, ) - index_col_names = [SPARK_DEFAULT_INDEX_NAME] + ).alias(tmp_count_col_name), + ) - sdf = sdf.select(*index_col_names, *numeric_col_names) + # +-------------------+-------------------+----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__| + # +-------------------+-------------------+----------------+ + # | 2| 2| null| + # | 1| 2| null| + # | 2| 1| null| + # | 1| 1| 1.0| + # | 0| 0| 1.0| + # | 0| 1| -1.0| + # | 1| 0| -1.0| + # | 0| 2| null| + # | 2| 0| null| + # +-------------------+-------------------+----------------+ + sdf = ( + sdf.withColumn( + tmp_corr_col_name, + F.when( + F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) + ).otherwise(F.lit(None)), + ) + .withColumn( + tmp_tuple_col_name, + F.explode( + F.when( + F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ) + .select( + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) + .otherwise(F.col(tmp_index_2_col_name)) + .alias(tmp_index_1_col_name), + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) + .otherwise(F.col(tmp_index_1_col_name)) + .alias(tmp_index_2_col_name), + F.col(tmp_corr_col_name), + ) + ) - return DataFrame( - InternalFrame( - spark_frame=sdf, - index_spark_columns=[ - scol_for(sdf, index_col_name) for index_col_name in index_col_names - ], - column_labels=numeric_labels, - column_label_names=internal.column_label_names, + # +-------------------+--------------------+ + # |__tmp_index_1_col__| __tmp_array_col__| + # +-------------------+--------------------+ + # | 0|[{0, 1.0}, {1, -1...| + # | 1|[{0, -1.0}, {1, 1...| + # | 2|[{0, null}, {1, n...| + # +-------------------+--------------------+ + tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + sdf = ( + sdf.groupby(tmp_index_1_col_name) + .agg( + F.array_sort( + F.collect_list(F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))) + ).alias(tmp_array_col_name) + ) + .orderBy(tmp_index_1_col_name) + ) + + for i in range(0, num_scols): + sdf = sdf.withColumn( + tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) + ).withColumn( + numeric_col_names[i], + F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), + ) + + index_col_names: List[str] = [] + if internal.column_labels_level > 1: + for level in range(0, internal.column_labels_level): + index_col_name = SPARK_INDEX_NAME_FORMAT(level) + indices = [label[level] for label in numeric_labels] + sdf = sdf.withColumn( + index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) ) + index_col_names.append(index_col_name) + else: + sdf = sdf.withColumn( + SPARK_DEFAULT_INDEX_NAME, + F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), ) + index_col_names = [SPARK_DEFAULT_INDEX_NAME] - return cast(DataFrame, ps.from_pandas(corr(self, method))) + sdf = sdf.select(*index_col_names, *numeric_col_names) + + return DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, index_col_name) for index_col_name in index_col_names + ], + column_labels=numeric_labels, + column_label_names=internal.column_label_names, + ) + ) # TODO: add axis parameter and support more methods def corrwith( diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index 7e2ca96e60ff1..fbe16146ff296 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -269,26 +269,68 @@ def test_dataframe_corr(self): psdf.corr("kendall") with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): psdf.corr(min_periods="3") - with self.assertRaisesRegex(NotImplementedError, "spearman for now"): - psdf.corr(method="spearman", min_periods=3) - self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) - self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) - self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) - self.assert_eq( - (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False - ) + for method in ["pearson", "spearman"]: + self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) + self.assert_eq( + psdf.corr(method=method, min_periods=1), + pdf.corr(method=method, min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method=method, min_periods=3), + pdf.corr(method=method, min_periods=3), + check_exact=False, + ) + self.assert_eq( + (psdf + 1).corr(method=method, min_periods=2), + (pdf + 1).corr(method=method, min_periods=2), + check_exact=False, + ) # multi-index columns columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) pdf.columns = columns psdf.columns = columns - self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) - self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) - self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + for method in ["pearson", "spearman"]: + self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) + self.assert_eq( + psdf.corr(method=method, min_periods=1), + pdf.corr(method=method, min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method=method, min_periods=3), + pdf.corr(method=method, min_periods=3), + check_exact=False, + ) + self.assert_eq( + (psdf + 1).corr(method=method, min_periods=2), + (pdf + 1).corr(method=method, min_periods=2), + check_exact=False, + ) + + # test spearman with identical values + pdf = pd.DataFrame( + { + "a": [0, 1, 1, 1, 0], + "b": [2, 2, -1, 1, np.nan], + "c": [3, 3, 3, 3, 3], + "d": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(psdf.corr(method="spearman"), pdf.corr(method="spearman"), check_exact=False) self.assert_eq( - (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + psdf.corr(method="spearman", min_periods=1), + pdf.corr(method="spearman", min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method="spearman", min_periods=3), + pdf.corr(method="spearman", min_periods=3), + check_exact=False, ) def test_corr(self): From 0ea17c4d3c317667965dcf2c72b9881727a71bde Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 15 Sep 2022 12:26:46 +0900 Subject: [PATCH 06/11] [SPARK-40339][SPARK-40342][PS][DOCS][FOLLOW-UP] Add Rolling.quantile and Expanding.quantile into PySpark documentation ### What changes were proposed in this pull request? This PR adds `Rolling.quantile` and `Expanding.quantile` into documentation. ### Why are the changes needed? To show the documentation about the new features to end users. ### Does this PR introduce _any_ user-facing change? No to end users because the original PR is not released yet. ### How was this patch tested? CI in this PR should test it out. Closes #37890 from HyukjinKwon/followup-window. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/docs/source/reference/pyspark.pandas/window.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/docs/source/reference/pyspark.pandas/window.rst b/python/docs/source/reference/pyspark.pandas/window.rst index 550036537f279..b5d74f2029ac8 100644 --- a/python/docs/source/reference/pyspark.pandas/window.rst +++ b/python/docs/source/reference/pyspark.pandas/window.rst @@ -36,6 +36,7 @@ Standard moving window functions Rolling.min Rolling.max Rolling.mean + Rolling.quantile Standard expanding window functions ----------------------------------- @@ -48,6 +49,7 @@ Standard expanding window functions Expanding.min Expanding.max Expanding.mean + Expanding.quantile Exponential moving window functions ----------------------------------- From 034e48fd47f49a603c1cad507608958f5beeddc8 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 14 Sep 2022 23:06:22 -0700 Subject: [PATCH 07/11] [SPARK-40429][SQL] Only set KeyGroupedPartitioning when the referenced column is in the output ### What changes were proposed in this pull request? Only set `KeyGroupedPartitioning` when the referenced column is in the output ### Why are the changes needed? bug fixing ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test Closes #37886 from huaxingao/keyGroupedPartitioning. Authored-by: huaxingao Signed-off-by: Dongjoon Hyun --- .../v2/V2ScanPartitioningAndOrdering.scala | 14 ++++++++++++-- .../sql/connector/MetadataColumnSuite.scala | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala index 8ab0dc7072664..5c8c7cf420d65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanPartitioningAndOrdering.scala @@ -41,8 +41,18 @@ object V2ScanPartitioningAndOrdering extends Rule[LogicalPlan] with SQLConfHelpe private def partitioning(plan: LogicalPlan) = plan.transformDown { case d @ DataSourceV2ScanRelation(relation, scan: SupportsReportPartitioning, _, None, _) => val catalystPartitioning = scan.outputPartitioning() match { - case kgp: KeyGroupedPartitioning => sequenceToOption(kgp.keys().map( - V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) + case kgp: KeyGroupedPartitioning => + val partitioning = sequenceToOption( + kgp.keys().map(V2ExpressionUtils.toCatalystOpt(_, relation, relation.funCatalog))) + if (partitioning.isEmpty) { + None + } else { + if (partitioning.get.forall(p => p.references.subsetOf(d.outputSet))) { + partitioning + } else { + None + } + } case _: UnknownPartitioning => None case p => throw new IllegalArgumentException("Unsupported data source V2 partitioning " + "type: " + p.getClass.getSimpleName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala index 9b90ee43657f5..8454b9f85ecdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MetadataColumnSuite.scala @@ -216,4 +216,20 @@ class MetadataColumnSuite extends DatasourceV2SQLBase { .withColumn("right_all", struct($"right.*")) checkAnswer(dfQuery, Row(1, "a", "b", Row(1, "a"), Row(1, "b"))) } + + test("SPARK-40429: Only set KeyGroupedPartitioning when the referenced column is in the output") { + withTable(tbl) { + sql(s"CREATE TABLE $tbl (id bigint, data string) PARTITIONED BY (id)") + sql(s"INSERT INTO $tbl VALUES (1, 'a'), (2, 'b'), (3, 'c')") + checkAnswer( + spark.table(tbl).select("index", "_partition"), + Seq(Row(0, "3"), Row(0, "2"), Row(0, "1")) + ) + + checkAnswer( + spark.table(tbl).select("id", "index", "_partition"), + Seq(Row(3, 0, "3"), Row(2, 0, "2"), Row(1, 0, "1")) + ) + } + } } From 6d067d059f3d2a62035d1b5f71ea5d87e1705643 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 15 Sep 2022 10:51:56 +0300 Subject: [PATCH 08/11] [SPARK-40370][SQL] Migrate type check fails to error classes in CAST ### What changes were proposed in this pull request? In the PR, I propose to use error classes in the case of type check failure in the `CAST` expression. ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages, and improves search-ability of errors. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the modified test suites: ``` $ build/sbt "test:testOnly *CastWithAnsiOnSuite" $ build/sbt "test:testOnly *DatasetSuite" $ build/sbt "sql/testOnly org.apache.spark.sql.SQLQueryTestSuite -- -z cast.sql" ``` Closes #37869 from MaxGekk/datatype-mismatch-in-cast. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 17 ++ .../apache/spark/SparkThrowableHelper.scala | 7 +- .../spark/sql/catalyst/expressions/Cast.scala | 76 ++++---- .../catalyst/expressions/CastSuiteBase.scala | 33 ++-- .../expressions/CastWithAnsiOnSuite.scala | 23 ++- .../sql-tests/results/ansi/cast.sql.out | 176 ++++++++++++++---- .../native/stringCastAndExpressions.sql.out | 51 ++++- .../org/apache/spark/sql/DatasetSuite.scala | 14 +- 8 files changed, 300 insertions(+), 97 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 999a9b0a4aec6..22b47c979c65a 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -95,6 +95,23 @@ "message" : [ "the binary operator requires the input type , not ." ] + }, + "CAST_WITHOUT_SUGGESTION" : { + "message" : [ + "cannot cast to ." + ] + }, + "CAST_WITH_CONF_SUGGESTION" : { + "message" : [ + "cannot cast to with ANSI mode on.", + "If you have to cast to , you can set as ." + ] + }, + "CAST_WITH_FUN_SUGGESTION" : { + "message" : [ + "cannot cast to .", + "To convert values from to , you can use the functions instead." + ] } } }, diff --git a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 497fd91d77ac7..86337205d31c8 100644 --- a/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/core/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -187,9 +187,10 @@ private[spark] object SparkThrowableHelper { val messageParameters = e.getMessageParameters if (!messageParameters.isEmpty) { g.writeObjectFieldStart("messageParameters") - messageParameters.asScala.toSeq.sortBy(_._1).foreach { case (name, value) => - g.writeStringField(name, value) - } + messageParameters.asScala + .toMap // To remove duplicates + .toSeq.sortBy(_._1) + .foreach { case (name, value) => g.writeStringField(name, value) } g.writeEndObject() } val queryContext = e.getQueryContext diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d8642d22af002..78cac4143d3e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkArithmeticException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort} -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} -object Cast { +object Cast extends QueryErrorsBase { /** * As per section 6.13 "cast specification" in "Information technology — Database languages " + * "- SQL — Part 2: Foundation (SQL/Foundation)": @@ -412,47 +413,48 @@ object Cast { } } - // Show suggestion on how to complete the disallowed explicit casting with built-in type - // conversion functions. - private def suggestionOnConversionFunctions ( - from: DataType, - to: DataType, - functionNames: String): String = { - // scalastyle:off line.size.limit - s"""cannot cast ${from.catalogString} to ${to.catalogString}. - |To convert values from ${from.catalogString} to ${to.catalogString}, you can use $functionNames instead. - |""".stripMargin - // scalastyle:on line.size.limit - } - def typeCheckFailureMessage( from: DataType, to: DataType, - fallbackConf: Option[(String, String)]): String = + fallbackConf: Option[(String, String)]): DataTypeMismatch = { + def withFunSuggest(names: String*): DataTypeMismatch = { + DataTypeMismatch( + errorSubClass = "CAST_WITH_FUN_SUGGESTION", + messageParameters = Map( + "srcType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "functionNames" -> names.map(toSQLId).mkString("/"))) + } (from, to) match { case (_: NumericType, TimestampType) => - suggestionOnConversionFunctions(from, to, - "functions TIMESTAMP_SECONDS/TIMESTAMP_MILLIS/TIMESTAMP_MICROS") + withFunSuggest("TIMESTAMP_SECONDS", "TIMESTAMP_MILLIS", "TIMESTAMP_MICROS") case (TimestampType, _: NumericType) => - suggestionOnConversionFunctions(from, to, "functions UNIX_SECONDS/UNIX_MILLIS/UNIX_MICROS") + withFunSuggest("UNIX_SECONDS", "UNIX_MILLIS", "UNIX_MICROS") case (_: NumericType, DateType) => - suggestionOnConversionFunctions(from, to, "function DATE_FROM_UNIX_DATE") + withFunSuggest("DATE_FROM_UNIX_DATE") case (DateType, _: NumericType) => - suggestionOnConversionFunctions(from, to, "function UNIX_DATE") + withFunSuggest("UNIX_DATE") - // scalastyle:off line.size.limit case _ if fallbackConf.isDefined && Cast.canCast(from, to) => - s""" - | cannot cast ${from.catalogString} to ${to.catalogString} with ANSI mode on. - | If you have to cast ${from.catalogString} to ${to.catalogString}, you can set ${fallbackConf.get._1} as ${fallbackConf.get._2}. - |""".stripMargin - // scalastyle:on line.size.limit + DataTypeMismatch( + errorSubClass = "CAST_WITH_CONF_SUGGESTION", + messageParameters = Map( + "srcType" -> toSQLType(from), + "targetType" -> toSQLType(to), + "config" -> toSQLConf(fallbackConf.get._1), + "configVal" -> toSQLValue(fallbackConf.get._2, StringType))) - case _ => s"cannot cast ${from.catalogString} to ${to.catalogString}" + case _ => + DataTypeMismatch( + errorSubClass = "CAST_WITHOUT_SUGGESTION", + messageParameters = Map( + "srcType" -> toSQLType(from), + "targetType" -> toSQLType(to))) } + } def apply( child: Expression, @@ -487,8 +489,12 @@ case class Cast( child: Expression, dataType: DataType, timeZoneId: Option[String] = None, - evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends UnaryExpression - with TimeZoneAwareExpression with NullIntolerant with SupportQueryContext { + evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) + extends UnaryExpression + with TimeZoneAwareExpression + with NullIntolerant + with SupportQueryContext + with QueryErrorsBase { def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) = this(child, dataType, timeZoneId, evalMode = EvalMode.fromSQLConf(SQLConf.get)) @@ -509,7 +515,7 @@ case class Cast( evalMode == EvalMode.TRY } - private def typeCheckFailureMessage: String = evalMode match { + private def typeCheckFailureInCast: DataTypeMismatch = evalMode match { case EvalMode.ANSI => if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) { Cast.typeCheckFailureMessage(child.dataType, dataType, @@ -522,7 +528,11 @@ case class Cast( case EvalMode.TRY => Cast.typeCheckFailureMessage(child.dataType, dataType, None) case _ => - s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}" + DataTypeMismatch( + errorSubClass = "CAST_WITHOUT_SUGGESTION", + messageParameters = Map( + "srcType" -> toSQLType(child.dataType), + "targetType" -> toSQLType(dataType))) } override def checkInputDataTypes(): TypeCheckResult = { @@ -535,7 +545,7 @@ case class Cast( if (canCast) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure(typeCheckFailureMessage) + typeCheckFailureInCast } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index f627a2a835a4e..a60491b0ab8c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -27,7 +27,7 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -66,21 +66,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null) } - protected def verifyCastFailure(c: Cast, optionalExpectedMsg: Option[String] = None): Unit = { + protected def verifyCastFailure(c: Cast, expected: DataTypeMismatch): Unit = { val typeCheckResult = c.checkInputDataTypes() assert(typeCheckResult.isFailure) - assert(typeCheckResult.isInstanceOf[TypeCheckFailure]) - val message = typeCheckResult.asInstanceOf[TypeCheckFailure].message - - if (optionalExpectedMsg.isDefined) { - assert(message.contains(optionalExpectedMsg.get)) - } else { - assert("cannot cast [a-zA-Z]+ to [a-zA-Z]+".r.findFirstIn(message).isDefined) - if (evalMode == EvalMode.ANSI) { - assert(message.contains("with ANSI mode on")) - assert(message.contains(s"set ${SQLConf.ANSI_ENABLED.key} as false")) - } - } + assert(typeCheckResult.isInstanceOf[DataTypeMismatch]) + val mismatch = typeCheckResult.asInstanceOf[DataTypeMismatch] + assert(mismatch === expected) } test("null cast") { @@ -936,13 +927,19 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { test("disallow type conversions between Numeric types and Timestamp without time zone type") { import DataTypeTestUtils.numericTypes checkInvalidCastFromNumericType(TimestampNTZType) - var errorMsg = "cannot cast bigint to timestamp_ntz" - verifyCastFailure(cast(Literal(0L), TimestampNTZType), Some(errorMsg)) + verifyCastFailure( + cast(Literal(0L), TimestampNTZType), + DataTypeMismatch( + "CAST_WITHOUT_SUGGESTION", + Map("srcType" -> "\"BIGINT\"", "targetType" -> "\"TIMESTAMP_NTZ\""))) val timestampNTZLiteral = Literal.create(LocalDateTime.now(), TimestampNTZType) - errorMsg = "cannot cast timestamp_ntz to" numericTypes.foreach { numericType => - verifyCastFailure(cast(timestampNTZLiteral, numericType), Some(errorMsg)) + verifyCastFailure( + cast(timestampNTZLiteral, numericType), + DataTypeMismatch( + "CAST_WITHOUT_SUGGESTION", + Map("srcType" -> "\"TIMESTAMP_NTZ\"", "targetType" -> s""""${numericType.sql}""""))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala index 94d466617862e..3de4eb6815984 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastWithAnsiOnSuite.scala @@ -23,6 +23,7 @@ import java.time.DateTimeException import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} @@ -141,12 +142,26 @@ class CastWithAnsiOnSuite extends CastSuiteBase with QueryErrorsBase { test("ANSI mode: disallow type conversions between Numeric types and Date type") { import DataTypeTestUtils.numericTypes checkInvalidCastFromNumericType(DateType) - var errorMsg = "you can use function DATE_FROM_UNIX_DATE instead" - verifyCastFailure(cast(Literal(0L), DateType), Some(errorMsg)) + verifyCastFailure( + cast(Literal(0L), DateType), + DataTypeMismatch( + "CAST_WITH_FUN_SUGGESTION", + Map( + "srcType" -> "\"BIGINT\"", + "targetType" -> "\"DATE\"", + "functionNames" -> "`DATE_FROM_UNIX_DATE`"))) val dateLiteral = Literal(1, DateType) - errorMsg = "you can use function UNIX_DATE instead" numericTypes.foreach { numericType => - verifyCastFailure(cast(dateLiteral, numericType), Some(errorMsg)) + withClue(s"numericType = ${numericType.sql}") { + verifyCastFailure( + cast(dateLiteral, numericType), + DataTypeMismatch( + "CAST_WITH_FUN_SUGGESTION", + Map( + "srcType" -> "\"DATE\"", + "targetType" -> s""""${numericType.sql}"""", + "functionNames" -> "`UNIX_DATE`"))) + } } } diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out index deaece6e7e1ac..35d60255aba26 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/cast.sql.out @@ -611,10 +611,24 @@ SELECT HEX(CAST(CAST(123 AS byte) AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(CAST(123 AS TINYINT) AS BINARY)' due to data type mismatch: - cannot cast tinyint to binary with ANSI mode on. - If you have to cast tinyint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(CAST(123 AS TINYINT) AS BINARY)\"", + "srcType" : "\"TINYINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "CAST(CAST(123 AS byte) AS binary)" + } ] +} -- !query @@ -623,10 +637,24 @@ SELECT HEX(CAST(CAST(-123 AS byte) AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(CAST(-123 AS TINYINT) AS BINARY)' due to data type mismatch: - cannot cast tinyint to binary with ANSI mode on. - If you have to cast tinyint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(CAST(-123 AS TINYINT) AS BINARY)\"", + "srcType" : "\"TINYINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 45, + "fragment" : "CAST(CAST(-123 AS byte) AS binary)" + } ] +} -- !query @@ -635,10 +663,24 @@ SELECT HEX(CAST(123S AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(123S AS BINARY)' due to data type mismatch: - cannot cast smallint to binary with ANSI mode on. - If you have to cast smallint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(123 AS BINARY)\"", + "srcType" : "\"SMALLINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 31, + "fragment" : "CAST(123S AS binary)" + } ] +} -- !query @@ -647,10 +689,24 @@ SELECT HEX(CAST(-123S AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(-123S AS BINARY)' due to data type mismatch: - cannot cast smallint to binary with ANSI mode on. - If you have to cast smallint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(-123 AS BINARY)\"", + "srcType" : "\"SMALLINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 32, + "fragment" : "CAST(-123S AS binary)" + } ] +} -- !query @@ -659,10 +715,24 @@ SELECT HEX(CAST(123 AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(123 AS BINARY)' due to data type mismatch: - cannot cast int to binary with ANSI mode on. - If you have to cast int to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(123 AS BINARY)\"", + "srcType" : "\"INT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 30, + "fragment" : "CAST(123 AS binary)" + } ] +} -- !query @@ -671,10 +741,24 @@ SELECT HEX(CAST(-123 AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(-123 AS BINARY)' due to data type mismatch: - cannot cast int to binary with ANSI mode on. - If you have to cast int to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(-123 AS BINARY)\"", + "srcType" : "\"INT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 31, + "fragment" : "CAST(-123 AS binary)" + } ] +} -- !query @@ -683,10 +767,24 @@ SELECT HEX(CAST(123L AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(123L AS BINARY)' due to data type mismatch: - cannot cast bigint to binary with ANSI mode on. - If you have to cast bigint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(123 AS BINARY)\"", + "srcType" : "\"BIGINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 31, + "fragment" : "CAST(123L AS binary)" + } ] +} -- !query @@ -695,10 +793,24 @@ SELECT HEX(CAST(-123L AS binary)) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'CAST(-123L AS BINARY)' due to data type mismatch: - cannot cast bigint to binary with ANSI mode on. - If you have to cast bigint to binary, you can set spark.sql.ansi.enabled as false. -; line 1 pos 11 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITH_CONF_SUGGESTION", + "messageParameters" : { + "config" : "\"spark.sql.ansi.enabled\"", + "configVal" : "'false'", + "sqlExpr" : "\"CAST(-123 AS BINARY)\"", + "srcType" : "\"BIGINT\"", + "targetType" : "\"BINARY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 32, + "fragment" : "CAST(-123L AS binary)" + } ] +} -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out index bd5e33ef9f7f9..c9ff8087042a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -101,7 +101,22 @@ select cast(a as array) from t struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 't.a' due to data type mismatch: cannot cast string to array; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITHOUT_SUGGESTION", + "messageParameters" : { + "sqlExpr" : "\"a\"", + "srcType" : "\"STRING\"", + "targetType" : "\"ARRAY\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 31, + "fragment" : "cast(a as array)" + } ] +} -- !query @@ -110,7 +125,22 @@ select cast(a as struct) from t struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 't.a' due to data type mismatch: cannot cast string to struct; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITHOUT_SUGGESTION", + "messageParameters" : { + "sqlExpr" : "\"a\"", + "srcType" : "\"STRING\"", + "targetType" : "\"STRUCT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 34, + "fragment" : "cast(a as struct)" + } ] +} -- !query @@ -119,7 +149,22 @@ select cast(a as map) from t struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 't.a' due to data type mismatch: cannot cast string to map; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH", + "errorSubClass" : "CAST_WITHOUT_SUGGESTION", + "messageParameters" : { + "sqlExpr" : "\"a\"", + "srcType" : "\"STRING\"", + "targetType" : "\"MAP\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 37, + "fragment" : "cast(a as map)" + } ] +} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5be6d53f6e10c..7420ef32d4d9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -900,10 +900,16 @@ class DatasetSuite extends QueryTest test("Kryo encoder: check the schema mismatch when converting DataFrame to Dataset") { implicit val kryoEncoder = Encoders.kryo[KryoData] val df = Seq((1.0)).toDF("a") - val e = intercept[AnalysisException] { - df.as[KryoData] - }.message - assert(e.contains("cannot cast double to binary")) + checkError( + exception = intercept[AnalysisException] { + df.as[KryoData] + }, + errorClass = "DATATYPE_MISMATCH", + errorSubClass = Some("CAST_WITHOUT_SUGGESTION"), + parameters = Map( + "sqlExpr" -> "\"a\"", + "srcType" -> "\"DOUBLE\"", + "targetType" -> "\"BINARY\"")) } test("Java encoder") { From 3d14b745773f66d50c5ee5b3d7835b5f11132ec8 Mon Sep 17 00:00:00 2001 From: Yikun Jiang Date: Thu, 15 Sep 2022 17:14:21 +0800 Subject: [PATCH 09/11] [SPARK-40440][PS][DOCS] Fix wrong reference and content in PS windows related doc ### What changes were proposed in this pull request? Fix wrong reference and content in PS windows related doc: - Add `pyspark.pandas.` for window function doc - Change `pandas_on_spark.DataFrame` to `pyspark.pandas.DataFrame` to make sure link generate correctly. - Fix `Returns` and `See Also` for `Rolling.count` - Add ewm doc for `Dataframe` and `series` ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ``` cd ~/spark/python/docs make html ``` ![image](https://user-images.githubusercontent.com/1736354/190328623-4c3250af-3968-430e-adf3-890d1bda850e.png) ![image](https://user-images.githubusercontent.com/1736354/190346064-b359e217-dac0-41df-8033-8477bc94b09b.png) ![image](https://user-images.githubusercontent.com/1736354/190346138-10bfc3ad-0a37-4bdb-b4aa-cc1ee2a3b018.png) Closes #37895 from Yikun/add-doc. Authored-by: Yikun Jiang Signed-off-by: Ruifeng Zheng --- .../source/reference/pyspark.pandas/frame.rst | 1 + .../reference/pyspark.pandas/series.rst | 1 + .../reference/pyspark.pandas/window.rst | 8 +- python/pyspark/pandas/window.py | 301 +++++++++--------- 4 files changed, 160 insertions(+), 151 deletions(-) diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index ff743371320bc..9c69ca647c499 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -151,6 +151,7 @@ Computations / Descriptive Stats DataFrame.count DataFrame.cov DataFrame.describe + DataFrame.ewm DataFrame.kurt DataFrame.kurtosis DataFrame.mad diff --git a/python/docs/source/reference/pyspark.pandas/series.rst b/python/docs/source/reference/pyspark.pandas/series.rst index 1cf63c1a8ae2c..5ed6df6b2a13f 100644 --- a/python/docs/source/reference/pyspark.pandas/series.rst +++ b/python/docs/source/reference/pyspark.pandas/series.rst @@ -145,6 +145,7 @@ Computations / Descriptive Stats Series.cumsum Series.cumprod Series.describe + Series.ewm Series.filter Series.kurt Series.mad diff --git a/python/docs/source/reference/pyspark.pandas/window.rst b/python/docs/source/reference/pyspark.pandas/window.rst index b5d74f2029ac8..c840be357fa75 100644 --- a/python/docs/source/reference/pyspark.pandas/window.rst +++ b/python/docs/source/reference/pyspark.pandas/window.rst @@ -21,9 +21,11 @@ Window ====== .. currentmodule:: pyspark.pandas.window -Rolling objects are returned by ``.rolling`` calls: :func:`pandas_on_spark.DataFrame.rolling`, :func:`pandas_on_spark.Series.rolling`, etc. -Expanding objects are returned by ``.expanding`` calls: :func:`pandas_on_spark.DataFrame.expanding`, :func:`pandas_on_spark.Series.expanding`, etc. -ExponentialMoving objects are returned by ``.ewm`` calls: :func:`pandas_on_spark.DataFrame.ewm`, :func:`pandas_on_spark.Series.ewm`, etc. +Rolling objects are returned by ``.rolling`` calls: :func:`pyspark.pandas.DataFrame.rolling`, :func:`pyspark.pandas.Series.rolling`, etc. + +Expanding objects are returned by ``.expanding`` calls: :func:`pyspark.pandas.DataFrame.expanding`, :func:`pyspark.pandas.Series.expanding`, etc. + +ExponentialMoving objects are returned by ``.ewm`` calls: :func:`pyspark.pandas.DataFrame.ewm`, :func:`pyspark.pandas.Series.ewm`, etc. Standard moving window functions -------------------------------- diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index 274000cbb2cd0..7ff7a6a8eb684 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -224,10 +224,15 @@ def count(self) -> FrameLike: Returns ------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.count : Count of the full Series. - DataFrame.count : Count of the full DataFrame. + Series or DataFrame + Return type is the same as the original object with `np.float64` dtype. + + See Also + -------- + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.count : Count of the full Series. + pyspark.pandas.DataFrame.count : Count of the full DataFrame. Examples -------- @@ -279,10 +284,10 @@ def sum(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.sum : Reducing sum for Series. - DataFrame.sum : Reducing sum for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.sum : Reducing sum for Series. + pyspark.pandas.DataFrame.sum : Reducing sum for DataFrame. Examples -------- @@ -357,10 +362,10 @@ def min(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with a Series. - DataFrame.rolling : Calling object with a DataFrame. - Series.min : Similar method for Series. - DataFrame.min : Similar method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with a Series. + pyspark.pandas.DataFrame.rolling : Calling object with a DataFrame. + pyspark.pandas.Series.min : Similar method for Series. + pyspark.pandas.DataFrame.min : Similar method for DataFrame. Examples -------- @@ -434,10 +439,10 @@ def max(self) -> FrameLike: See Also -------- - Series.rolling : Series rolling. - DataFrame.rolling : DataFrame rolling. - Series.max : Similar method for Series. - DataFrame.max : Similar method for DataFrame. + pyspark.pandas.Series.rolling : Series rolling. + pyspark.pandas.DataFrame.rolling : DataFrame rolling. + pyspark.pandas.Series.max : Similar method for Series. + pyspark.pandas.DataFrame.max : Similar method for DataFrame. Examples -------- @@ -512,10 +517,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.mean : Equivalent method for Series. - DataFrame.mean : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.mean : Equivalent method for Series. + pyspark.pandas.DataFrame.mean : Equivalent method for DataFrame. Examples -------- @@ -684,10 +689,10 @@ def std(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. Examples @@ -784,10 +789,10 @@ def skew(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. Examples @@ -836,10 +841,10 @@ def kurt(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. Examples @@ -985,10 +990,10 @@ def count(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.count : Count of the full Series. - DataFrame.count : Count of the full DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.count : Count of the full Series. + pyspark.pandas.DataFrame.count : Count of the full DataFrame. Examples -------- @@ -1039,10 +1044,10 @@ def sum(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.sum : Sum of the full Series. - DataFrame.sum : Sum of the full DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.sum : Sum of the full Series. + pyspark.pandas.DataFrame.sum : Sum of the full DataFrame. Examples -------- @@ -1093,10 +1098,10 @@ def min(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.min : Min of the full Series. - DataFrame.min : Min of the full DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.min : Min of the full Series. + pyspark.pandas.DataFrame.min : Min of the full DataFrame. Examples -------- @@ -1147,10 +1152,10 @@ def max(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.max : Max of the full Series. - DataFrame.max : Max of the full DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.max : Max of the full Series. + pyspark.pandas.DataFrame.max : Max of the full DataFrame. Examples -------- @@ -1201,10 +1206,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.mean : Mean of the full Series. - DataFrame.mean : Mean of the full DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.mean : Mean of the full Series. + pyspark.pandas.DataFrame.mean : Mean of the full DataFrame. Examples -------- @@ -1325,10 +1330,10 @@ def std(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. """ return super().std() @@ -1344,10 +1349,10 @@ def var(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. """ return super().var() @@ -1363,10 +1368,10 @@ def skew(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. """ return super().skew() @@ -1382,10 +1387,10 @@ def kurt(self) -> FrameLike: See Also -------- - Series.rolling : Calling object with Series data. - DataFrame.rolling : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.rolling : Calling object with Series data. + pyspark.pandas.DataFrame.rolling : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. """ return super().kurt() @@ -1458,10 +1463,10 @@ def count(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.count : Count of the full Series. - DataFrame.count : Count of the full DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.count : Count of the full Series. + pyspark.pandas.DataFrame.count : Count of the full DataFrame. Examples -------- @@ -1499,10 +1504,10 @@ def sum(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.sum : Reducing sum for Series. - DataFrame.sum : Reducing sum for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.sum : Reducing sum for Series. + pyspark.pandas.DataFrame.sum : Reducing sum for DataFrame. Examples -------- @@ -1561,10 +1566,10 @@ def min(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with a Series. - DataFrame.expanding : Calling object with a DataFrame. - Series.min : Similar method for Series. - DataFrame.min : Similar method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with a Series. + pyspark.pandas.DataFrame.expanding : Calling object with a DataFrame. + pyspark.pandas.Series.min : Similar method for Series. + pyspark.pandas.DataFrame.min : Similar method for DataFrame. Examples -------- @@ -1597,10 +1602,10 @@ def max(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.max : Similar method for Series. - DataFrame.max : Similar method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.max : Similar method for Series. + pyspark.pandas.DataFrame.max : Similar method for DataFrame. Examples -------- @@ -1634,10 +1639,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.mean : Equivalent method for Series. - DataFrame.mean : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.mean : Equivalent method for Series. + pyspark.pandas.DataFrame.mean : Equivalent method for DataFrame. Examples -------- @@ -1737,10 +1742,10 @@ def std(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. Examples @@ -1787,10 +1792,10 @@ def var(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. Examples @@ -1837,10 +1842,10 @@ def skew(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. Examples @@ -1889,10 +1894,10 @@ def kurt(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. Examples @@ -1959,10 +1964,10 @@ def count(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.count : Count of the full Series. - DataFrame.count : Count of the full DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.count : Count of the full Series. + pyspark.pandas.DataFrame.count : Count of the full DataFrame. Examples -------- @@ -2013,10 +2018,10 @@ def sum(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.sum : Reducing sum for Series. - DataFrame.sum : Reducing sum for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.sum : Reducing sum for Series. + pyspark.pandas.DataFrame.sum : Reducing sum for DataFrame. Examples -------- @@ -2067,10 +2072,10 @@ def min(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with a Series. - DataFrame.expanding : Calling object with a DataFrame. - Series.min : Similar method for Series. - DataFrame.min : Similar method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with a Series. + pyspark.pandas.DataFrame.expanding : Calling object with a DataFrame. + pyspark.pandas.Series.min : Similar method for Series. + pyspark.pandas.DataFrame.min : Similar method for DataFrame. Examples -------- @@ -2120,10 +2125,10 @@ def max(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.max : Similar method for Series. - DataFrame.max : Similar method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.max : Similar method for Series. + pyspark.pandas.DataFrame.max : Similar method for DataFrame. Examples -------- @@ -2174,10 +2179,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.mean : Equivalent method for Series. - DataFrame.mean : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.mean : Equivalent method for Series. + pyspark.pandas.DataFrame.mean : Equivalent method for DataFrame. Examples -------- @@ -2299,10 +2304,10 @@ def std(self) -> FrameLike: See Also -------- - Series.expanding: Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding: Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. """ return super().std() @@ -2318,10 +2323,10 @@ def var(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. """ return super().var() @@ -2338,10 +2343,10 @@ def skew(self) -> FrameLike: See Also -------- - Series.expanding: Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.std : Equivalent method for Series. - DataFrame.std : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding: Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.std : Equivalent method for Series. + pyspark.pandas.DataFrame.std : Equivalent method for DataFrame. numpy.std : Equivalent method for Numpy array. """ return super().skew() @@ -2357,10 +2362,10 @@ def kurt(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.var : Equivalent method for Series. - DataFrame.var : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.var : Equivalent method for Series. + pyspark.pandas.DataFrame.var : Equivalent method for DataFrame. numpy.var : Equivalent method for Numpy array. """ return super().kurt() @@ -2512,10 +2517,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.mean : Equivalent method for Series. - DataFrame.mean : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.mean : Equivalent method for Series. + pyspark.pandas.DataFrame.mean : Equivalent method for DataFrame. Examples -------- @@ -2613,10 +2618,10 @@ def mean(self) -> FrameLike: See Also -------- - Series.expanding : Calling object with Series data. - DataFrame.expanding : Calling object with DataFrames. - Series.mean : Equivalent method for Series. - DataFrame.mean : Equivalent method for DataFrame. + pyspark.pandas.Series.expanding : Calling object with Series data. + pyspark.pandas.DataFrame.expanding : Calling object with DataFrames. + pyspark.pandas.Series.mean : Equivalent method for Series. + pyspark.pandas.DataFrame.mean : Equivalent method for DataFrame. Examples -------- From 5496d99241f1063766cf5954f754e870fbabcbe7 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 15 Sep 2022 21:49:56 +0900 Subject: [PATCH 10/11] [SPARK-40433][SS][PYTHON] Add toJVMRow in PythonSQLUtils to convert pickled PySpark Row to JVM Row ### What changes were proposed in this pull request? This PR adds toJVMRow in PythonSQLUtils to convert pickled PySpark Row to JVM Row. Co-authored with HyukjinKwon . This is a breakdown PR of https://github.com/apache/spark/pull/37863. ### Why are the changes needed? This change will be leveraged in [SPARK-40434](https://issues.apache.org/jira/browse/SPARK-40434). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? N/A. We will make sure test suites are constructed via E2E manner under [SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431). Closes #37891 from HeartSaVioR/SPARK-40433. Lead-authored-by: Jungtaek Lim Co-authored-by: Hyukjin Kwon Signed-off-by: Jungtaek Lim --- .../spark/sql/api/python/PythonSQLUtils.scala | 42 +++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 2b74bcc38501a..c495b145dc678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -22,14 +22,15 @@ import java.net.Socket import java.nio.channels.Channels import java.util.Locale -import net.razorvine.pickle.Pickler +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -37,12 +38,29 @@ import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} private[sql] object PythonSQLUtils extends Logging { - private lazy val internalRowPickler = { + private def withInternalRowPickler(f: Pickler => Array[Byte]): Array[Byte] = { EvaluatePython.registerPicklers() - new Pickler(true, false) + val pickler = new Pickler(true, false) + val ret = try { + f(pickler) + } finally { + pickler.close() + } + ret + } + + private def withInternalRowUnpickler(f: Unpickler => Any): Any = { + EvaluatePython.registerPicklers() + val unpickler = new Unpickler + val ret = try { + f(unpickler) + } finally { + unpickler.close() + } + ret } def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) @@ -94,8 +112,18 @@ private[sql] object PythonSQLUtils extends Logging { def toPyRow(row: Row): Array[Byte] = { assert(row.isInstanceOf[GenericRowWithSchema]) - internalRowPickler.dumps(EvaluatePython.toJava( - CatalystTypeConverters.convertToCatalyst(row), row.schema)) + withInternalRowPickler(_.dumps(EvaluatePython.toJava( + CatalystTypeConverters.convertToCatalyst(row), row.schema))) + } + + def toJVMRow( + arr: Array[Byte], + returnType: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]): Row = { + val fromJava = EvaluatePython.makeFromJava(returnType) + val internalRow = + fromJava(withInternalRowUnpickler(_.loads(arr))).asInstanceOf[InternalRow] + deserializer(internalRow) } def castTimestampNTZToLong(c: Column): Column = Column(CastTimestampNTZToLong(c.expr)) From 193b5b229c72d1a7a5cc19e7973cb5f02f54293f Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 15 Sep 2022 20:51:02 +0800 Subject: [PATCH 11/11] [SPARK-40387][SQL] Improve the implementation of Spark Decimal ### What changes were proposed in this pull request? This PR used to improve the implementation of Spark `Decimal`. The improvement points are as follows: 1. Use `toJavaBigDecimal` instead of `toBigDecimal.bigDecimal` 2. Extract `longVal / POW_10(_scale)` as a new method `def actualLongVal: Long` 3. Remove `BIG_DEC_ZERO` and use `decimalVal.signum` to judge whether or not equals zero. 4. Use `<` instead of `compare`. 5. Correct some code style. ### Why are the changes needed? Improve the implementation of Spark Decimal ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? N/A Closes #37830 from beliefer/SPARK-40387. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/types/Decimal.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 57e8fc060a291..cedf4440aabf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -204,7 +204,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.toBigInt } else { - BigInt(toLong) + BigInt(actualLongVal) } } @@ -212,7 +212,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.underlying().toBigInteger() } else { - java.math.BigInteger.valueOf(toLong) + java.math.BigInteger.valueOf(actualLongVal) } } @@ -226,7 +226,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def toString: String = toBigDecimal.toString() - def toPlainString: String = toBigDecimal.bigDecimal.toPlainString + def toPlainString: String = toJavaBigDecimal.toPlainString def toDebugString: String = { if (decimalVal.ne(null)) { @@ -240,9 +240,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toFloat: Float = toBigDecimal.floatValue + private def actualLongVal: Long = longVal / POW_10(_scale) + def toLong: Long = { if (decimalVal.eq(null)) { - longVal / POW_10(_scale) + actualLongVal } else { decimalVal.longValue } @@ -278,7 +280,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { - val actualLongVal = longVal / POW_10(_scale) val numericVal = f1(actualLongVal) if (actualLongVal == numericVal) { numericVal @@ -303,7 +304,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { - longVal / POW_10(_scale) + actualLongVal } else { try { // We cannot store Long.MAX_VALUE as a Double without losing precision. @@ -455,7 +456,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() - def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal.signum == 0 else longVal == 0 // We should follow DecimalPrecision promote if use longVal for add and subtract: // Operation Result Precision Result Scale @@ -466,7 +467,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.add(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) } } @@ -474,7 +475,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.subtract(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) } } @@ -504,7 +505,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision @@ -532,8 +533,6 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) private[sql] val ZERO = Decimal(0) @@ -575,9 +574,8 @@ object Decimal { } } - private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = { - bigDecimal.precision - bigDecimal.scale - } + private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = + bigDecimal.precision - bigDecimal.scale private def stringToJavaBigDecimal(str: UTF8String): JavaBigDecimal = { // According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.