From 4b7d91b76333baa308ef159a118ee21e09c5ac78 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 8 Sep 2022 19:46:44 +0800 Subject: [PATCH 1/2] init support multi index refine tests refine tests simplify simplify simplify --- python/pyspark/pandas/frame.py | 203 +++++++++++++++++++++- python/pyspark/pandas/tests/test_stats.py | 26 +++ 2 files changed, 224 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 2a7fda2d527c0..badab5ff1af40 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1417,15 +1417,23 @@ def aggregate(self, func: Union[List[str], Dict[Name, List[str]]]) -> "DataFrame agg = aggregate - def corr(self, method: str = "pearson") -> "DataFrame": + def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame": """ Compute pairwise correlation of columns, excluding NA/null values. + .. versionadded:: 3.3.0 + Parameters ---------- method : {'pearson', 'spearman'} * pearson : standard correlation coefficient * spearman : Spearman rank correlation + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. Currently only available for Pearson + correlation. + + .. versionadded:: 3.4.0 Returns ------- @@ -1454,11 +1462,196 @@ def corr(self, method: str = "pearson") -> "DataFrame": There are behavior differences between pandas-on-Spark and pandas. * the `method` argument only accepts 'pearson', 'spearman' - * the data should not contain NaNs. pandas-on-Spark will return an error. - * pandas-on-Spark doesn't support the following argument(s). + * if the `method` is `spearman`, the data should not contain NaNs. + * if the `method` is `spearman`, `min_periods` argument is not supported. + """ + if method not in ["pearson", "spearman", "kendall"]: + raise ValueError(f"Invalid method {method}") + if method == "kendall": + raise NotImplementedError("method doesn't support kendall for now") + if min_periods is not None and not isinstance(min_periods, int): + raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") + if min_periods is not None and method == "spearman": + raise NotImplementedError("min_periods doesn't support spearman for now") + + if method == "pearson": + min_periods = 1 if min_periods is None else min_periods + internal = self._internal.resolved_copy + numeric_labels = [ + label + for label in internal.column_labels + if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) + ] + numeric_scols: List[Column] = [ + internal.spark_column_for(label).cast("double") for label in numeric_labels + ] + numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] + num_scols = len(numeric_scols) + + sdf = internal.spark_frame + tmp_index_1_col = verify_temp_column_name(sdf, "__tmp_index_1_col__") + tmp_index_2_col = verify_temp_column_name(sdf, "__tmp_index_2_col__") + tmp_value_1_col = verify_temp_column_name(sdf, "__tmp_value_1_col__") + tmp_value_2_col = verify_temp_column_name(sdf, "__tmp_value_2_col__") + + # simple dataset + # +---+---+----+ + # | A| B| C| + # +---+---+----+ + # | 1| 2| 3.0| + # | 4| 1|null| + # +---+---+----+ + + pair_scols: List[Column] = [] + for i in range(0, num_scols): + for j in range(i, num_scols): + pair_scols.append( + F.struct( + F.lit(i).alias(tmp_index_1_col), + F.lit(j).alias(tmp_index_2_col), + numeric_scols[i].alias(tmp_value_1_col), + numeric_scols[j].alias(tmp_value_2_col), + ) + ) + + # +-------------------+-------------------+-------------------+-------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| + # +-------------------+-------------------+-------------------+-------------------+ + # | 0| 0| 1.0| 1.0| + # | 0| 1| 1.0| 2.0| + # | 0| 2| 1.0| 3.0| + # | 1| 1| 2.0| 2.0| + # | 1| 2| 2.0| 3.0| + # | 2| 2| 3.0| 3.0| + # | 0| 0| 4.0| 4.0| + # | 0| 1| 4.0| 1.0| + # | 0| 2| 4.0| null| + # | 1| 1| 1.0| 1.0| + # | 1| 2| 1.0| null| + # | 2| 2| null| null| + # +-------------------+-------------------+-------------------+-------------------+ + tmp_tuple_col = verify_temp_column_name(sdf, "__tmp_tuple_col__") + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col)).select( + F.col(f"{tmp_tuple_col}.{tmp_index_1_col}").alias(tmp_index_1_col), + F.col(f"{tmp_tuple_col}.{tmp_index_2_col}").alias(tmp_index_2_col), + F.col(f"{tmp_tuple_col}.{tmp_value_1_col}").alias(tmp_value_1_col), + F.col(f"{tmp_tuple_col}.{tmp_value_2_col}").alias(tmp_value_2_col), + ) + + # +-------------------+-------------------+------------------------+-----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__| + # +-------------------+-------------------+------------------------+-----------------+ + # | 2| 2| null| 1| + # | 1| 2| null| 1| + # | 1| 1| 1.0| 2| + # | 0| 0| 1.0| 2| + # | 0| 1| -1.0| 2| + # | 0| 2| null| 1| + # +-------------------+-------------------+------------------------+-----------------+ + tmp_corr_col = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") + tmp_count_col = verify_temp_column_name(sdf, "__tmp_count_col__") + sdf = sdf.groupby(tmp_index_1_col, tmp_index_2_col).agg( + F.corr(tmp_value_1_col, tmp_value_2_col).alias(tmp_corr_col), + F.count( + F.when( + F.col(tmp_value_1_col).isNotNull() & F.col(tmp_value_2_col).isNotNull(), 1 + ) + ).alias(tmp_count_col), + ) + + # +-------------------+-------------------+------------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__| + # +-------------------+-------------------+------------------------+ + # | 2| 2| null| + # | 1| 2| null| + # | 2| 1| null| + # | 1| 1| 1.0| + # | 0| 0| 1.0| + # | 0| 1| -1.0| + # | 1| 0| -1.0| + # | 0| 2| null| + # | 2| 0| null| + # +-------------------+-------------------+------------------------+ + sdf = ( + sdf.withColumn( + tmp_corr_col, + F.when(F.col(tmp_count_col) >= min_periods, F.col(tmp_corr_col)).otherwise( + F.lit(None) + ), + ) + .withColumn( + tmp_tuple_col, + F.explode( + F.when( + F.col(tmp_index_1_col) == F.col(tmp_index_2_col), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ) + .select( + F.when(F.col(tmp_tuple_col) == 0, F.col(tmp_index_1_col)) + .otherwise(F.col(tmp_index_2_col)) + .alias(tmp_index_1_col), + F.when(F.col(tmp_tuple_col) == 0, F.col(tmp_index_2_col)) + .otherwise(F.col(tmp_index_1_col)) + .alias(tmp_index_2_col), + F.col(tmp_corr_col), + ) + ) + + # +-------------------+--------------------+ + # |__tmp_index_1_col__| __tmp_array_col__| + # +-------------------+--------------------+ + # | 0|[{0, 1.0}, {1, -1...| + # | 1|[{0, -1.0}, {1, 1...| + # | 2|[{0, null}, {1, n...| + # +-------------------+--------------------+ + tmp_array_col = verify_temp_column_name(sdf, "__tmp_array_col__") + sdf = ( + sdf.groupby(tmp_index_1_col) + .agg( + F.array_sort( + F.collect_list(F.struct(F.col(tmp_index_2_col), F.col(tmp_corr_col))) + ).alias(tmp_array_col) + ) + .orderBy(tmp_index_1_col) + ) + + for i in range(0, num_scols): + sdf = sdf.withColumn(tmp_tuple_col, F.get(F.col(tmp_array_col), i)).withColumn( + numeric_col_names[i], + F.col(f"{tmp_tuple_col}.{tmp_corr_col}"), + ) + + index_col_names: List[str] = [] + if internal.column_labels_level > 1: + for level in range(0, internal.column_labels_level): + index_col_name = SPARK_INDEX_NAME_FORMAT(level) + indices = [label[level] for label in numeric_labels] + sdf = sdf.withColumn( + index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col)) + ) + index_col_names.append(index_col_name) + else: + sdf = sdf.withColumn( + SPARK_DEFAULT_INDEX_NAME, + F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col)), + ) + index_col_names = [SPARK_DEFAULT_INDEX_NAME] + + sdf = sdf.select(*index_col_names, *numeric_col_names) + + return DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, index_col_name) for index_col_name in index_col_names + ], + column_labels=numeric_labels, + column_label_names=internal.column_label_names, + ) + ) - * `min_periods` argument is not supported - """ return cast(DataFrame, ps.from_pandas(corr(self, method))) # TODO: add axis parameter and support more methods diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index e8f5048033b47..9d51244adac64 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -257,6 +257,32 @@ def test_skew_kurt_numerical_stability(self): self.assert_eq(psdf.skew(), pdf.skew(), almost=True) self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True) + def test_dataframe_corr(self): + pdf = makeMissingDataframe(0.3, 42) + psdf = ps.from_pandas(pdf) + + with self.assertRaisesRegex(ValueError, "Invalid method"): + psdf.corr("std") + with self.assertRaisesRegex(NotImplementedError, "kendall for now"): + psdf.corr("kendall") + with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): + psdf.corr(min_periods="3") + with self.assertRaisesRegex(NotImplementedError, "spearman for now"): + psdf.corr(method="spearman", min_periods=3) + + self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) + self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) + self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) + pdf.columns = columns + psdf.columns = columns + + self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) + self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) + self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + def test_corr(self): # Disable arrow execution since corr() is using UDT internally which is not supported. with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): From 6c31ce56293584761a82ab136258a0851906ac17 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 13 Sep 2022 12:56:49 +0800 Subject: [PATCH 2/2] address comments --- python/pyspark/pandas/frame.py | 90 ++++++++++++----------- python/pyspark/pandas/tests/test_stats.py | 8 ++ 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index badab5ff1af40..b85247580c5ff 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1489,10 +1489,10 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D num_scols = len(numeric_scols) sdf = internal.spark_frame - tmp_index_1_col = verify_temp_column_name(sdf, "__tmp_index_1_col__") - tmp_index_2_col = verify_temp_column_name(sdf, "__tmp_index_2_col__") - tmp_value_1_col = verify_temp_column_name(sdf, "__tmp_value_1_col__") - tmp_value_2_col = verify_temp_column_name(sdf, "__tmp_value_2_col__") + tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") + tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") + tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") + tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") # simple dataset # +---+---+----+ @@ -1507,10 +1507,10 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D for j in range(i, num_scols): pair_scols.append( F.struct( - F.lit(i).alias(tmp_index_1_col), - F.lit(j).alias(tmp_index_2_col), - numeric_scols[i].alias(tmp_value_1_col), - numeric_scols[j].alias(tmp_value_2_col), + F.lit(i).alias(tmp_index_1_col_name), + F.lit(j).alias(tmp_index_2_col_name), + numeric_scols[i].alias(tmp_value_1_col_name), + numeric_scols[j].alias(tmp_value_2_col_name), ) ) @@ -1530,12 +1530,12 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 1| 2| 1.0| null| # | 2| 2| null| null| # +-------------------+-------------------+-------------------+-------------------+ - tmp_tuple_col = verify_temp_column_name(sdf, "__tmp_tuple_col__") - sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col)).select( - F.col(f"{tmp_tuple_col}.{tmp_index_1_col}").alias(tmp_index_1_col), - F.col(f"{tmp_tuple_col}.{tmp_index_2_col}").alias(tmp_index_2_col), - F.col(f"{tmp_tuple_col}.{tmp_value_1_col}").alias(tmp_value_1_col), - F.col(f"{tmp_tuple_col}.{tmp_value_2_col}").alias(tmp_value_2_col), + tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( + F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name), ) # +-------------------+-------------------+------------------------+-----------------+ @@ -1548,15 +1548,17 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 0| 1| -1.0| 2| # | 0| 2| null| 1| # +-------------------+-------------------+------------------------+-----------------+ - tmp_corr_col = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") - tmp_count_col = verify_temp_column_name(sdf, "__tmp_count_col__") - sdf = sdf.groupby(tmp_index_1_col, tmp_index_2_col).agg( - F.corr(tmp_value_1_col, tmp_value_2_col).alias(tmp_corr_col), + tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") + tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") + sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( + F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), F.count( F.when( - F.col(tmp_value_1_col).isNotNull() & F.col(tmp_value_2_col).isNotNull(), 1 + F.col(tmp_value_1_col_name).isNotNull() + & F.col(tmp_value_2_col_name).isNotNull(), + 1, ) - ).alias(tmp_count_col), + ).alias(tmp_count_col_name), ) # +-------------------+-------------------+------------------------+ @@ -1574,28 +1576,28 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # +-------------------+-------------------+------------------------+ sdf = ( sdf.withColumn( - tmp_corr_col, - F.when(F.col(tmp_count_col) >= min_periods, F.col(tmp_corr_col)).otherwise( - F.lit(None) - ), + tmp_corr_col_name, + F.when( + F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) + ).otherwise(F.lit(None)), ) .withColumn( - tmp_tuple_col, + tmp_tuple_col_name, F.explode( F.when( - F.col(tmp_index_1_col) == F.col(tmp_index_2_col), + F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), F.lit([0]), ).otherwise(F.lit([0, 1])) ), ) .select( - F.when(F.col(tmp_tuple_col) == 0, F.col(tmp_index_1_col)) - .otherwise(F.col(tmp_index_2_col)) - .alias(tmp_index_1_col), - F.when(F.col(tmp_tuple_col) == 0, F.col(tmp_index_2_col)) - .otherwise(F.col(tmp_index_1_col)) - .alias(tmp_index_2_col), - F.col(tmp_corr_col), + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) + .otherwise(F.col(tmp_index_2_col_name)) + .alias(tmp_index_1_col_name), + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) + .otherwise(F.col(tmp_index_1_col_name)) + .alias(tmp_index_2_col_name), + F.col(tmp_corr_col_name), ) ) @@ -1606,21 +1608,25 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 1|[{0, -1.0}, {1, 1...| # | 2|[{0, null}, {1, n...| # +-------------------+--------------------+ - tmp_array_col = verify_temp_column_name(sdf, "__tmp_array_col__") + tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") sdf = ( - sdf.groupby(tmp_index_1_col) + sdf.groupby(tmp_index_1_col_name) .agg( F.array_sort( - F.collect_list(F.struct(F.col(tmp_index_2_col), F.col(tmp_corr_col))) - ).alias(tmp_array_col) + F.collect_list( + F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name)) + ) + ).alias(tmp_array_col_name) ) - .orderBy(tmp_index_1_col) + .orderBy(tmp_index_1_col_name) ) for i in range(0, num_scols): - sdf = sdf.withColumn(tmp_tuple_col, F.get(F.col(tmp_array_col), i)).withColumn( + sdf = sdf.withColumn( + tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) + ).withColumn( numeric_col_names[i], - F.col(f"{tmp_tuple_col}.{tmp_corr_col}"), + F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), ) index_col_names: List[str] = [] @@ -1629,13 +1635,13 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D index_col_name = SPARK_INDEX_NAME_FORMAT(level) indices = [label[level] for label in numeric_labels] sdf = sdf.withColumn( - index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col)) + index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) ) index_col_names.append(index_col_name) else: sdf = sdf.withColumn( SPARK_DEFAULT_INDEX_NAME, - F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col)), + F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), ) index_col_names = [SPARK_DEFAULT_INDEX_NAME] diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index 9d51244adac64..7e2ca96e60ff1 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -258,6 +258,8 @@ def test_skew_kurt_numerical_stability(self): self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True) def test_dataframe_corr(self): + # existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr' + # when we have separate tests for df.corr and ser.corr pdf = makeMissingDataframe(0.3, 42) psdf = ps.from_pandas(pdf) @@ -273,6 +275,9 @@ def test_dataframe_corr(self): self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + self.assert_eq( + (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + ) # multi-index columns columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) @@ -282,6 +287,9 @@ def test_dataframe_corr(self): self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + self.assert_eq( + (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + ) def test_corr(self): # Disable arrow execution since corr() is using UDT internally which is not supported.