From 1013ca8a82075b7619bc91af5c70dd8b5370d8a8 Mon Sep 17 00:00:00 2001
From: John Fieldman <1295319+yozzo@users.noreply.github.com>
Date: Tue, 29 Oct 2024 13:40:44 +0000
Subject: [PATCH 1/3] Fix comparison report when one column is all NAs (#343)
* Fix compariosn report when one column is all NAs
With situations when one of the column in the comparison was
all NA's then this would break the reporting. For some reason
when the matching (boolean) of the actual and expected columns happened when there was a categorical value compared to a NA, the result was NA, rather than False, as it would happen for the other elements in the cols compared.
This has now been addressed at the intersect rows level which doesn't seem to break the reporting anymore.
* Fix column_equal to work with StringArrays with pd.NA values not returning booleans
* Add test for fn column_equal to work with StringArrays with pd.NA
Add test for fn column_equal to work with StringArrays containing pd.NA values not returning booleans when compared with other df's with rows of StringArrays
* Fix linter error
Printing out the report would've been useful for this test, but looks like it makes the linter fail the build. This has now been fixed.
* Fix column_equal to work with StringArrays with pd.NA values
Fix column_equal to work with StringArrays with pd.NA values not returning booleans, and update formatting to match the linter expectation
* Add test for fn column_equal to work with StringArrays with pd.NA
Add test for fn column_equal to work with StringArrays containing pd.NA values not returning booleans when compared with other df's with rows of StringArrays, and format test to match the linter.
---
datacompy/core.py | 4 +++-
tests/test_core.py | 32 ++++++++++++++++++++++++++++++++
2 files changed, 35 insertions(+), 1 deletion(-)
diff --git a/datacompy/core.py b/datacompy/core.py
index 889fb901..7d0da3d8 100644
--- a/datacompy/core.py
+++ b/datacompy/core.py
@@ -799,6 +799,7 @@ def columns_equal(
A series of Boolean values. True == the values match, False == the
values don't match.
"""
+ default_value = "DATACOMPY_NULL"
compare: pd.Series[bool]
# short circuit if comparing mixed type columns. We don't want to support this moving forward.
@@ -842,7 +843,8 @@ def columns_equal(
compare = compare_string_and_date_columns(col_1, col_2)
else:
compare = pd.Series(
- (col_1 == col_2) | (col_1.isnull() & col_2.isnull())
+ (col_1.fillna(default_value) == col_2.fillna(default_value))
+ | (col_1.isnull() & col_2.isnull())
)
except Exception:
# Blanket exception should just return all False
diff --git a/tests/test_core.py b/tests/test_core.py
index 482a12f4..b0c3647f 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -958,6 +958,38 @@ def test_sample_mismatch():
assert (output.name_df1 != output.name_df2).all()
+def test_sample_mismatch_with_nans():
+ """Checks that comparison of StringArrays with pd.NA values returns booleans
+
+ When comparing pd.NA with a string the result is pd.NA, this breaks the compare
+ report with the following error:
+ "E ValueError: a must be greater than 0 unless no samples are taken"
+
+ Dataframes with StringArray type rows come when using pd.Dataframes created from
+ parquet files using the pyarrow engine.
+ """
+ df1 = pd.DataFrame(
+ {
+ "acct_id": [10000001221, 10000001222, 10000001223],
+ "name": pd.array([pd.NA, pd.NA, pd.NA], dtype="string"),
+ }
+ )
+ df1.set_index("acct_id", inplace=True)
+
+ df2 = pd.DataFrame(
+ {
+ "acct_id": [10000001221, 10000001222, 10000001223],
+ "name": pd.array([pd.NA, "Tobias Funke", pd.NA], dtype="string"),
+ }
+ )
+
+ df2.set_index("acct_id", inplace=True)
+
+ report = datacompy.Compare(df1=df1, df2=df2, on_index=True).report()
+
+ assert "Tobias Funke" in report
+
+
def test_all_mismatch_not_ignore_matching_cols_no_cols_matching():
data1 = """acct_id,dollar_amt,name,float_fld,date_fld
10000001234,123.45,George Maharis,14530.1555,2017-01-01
From 1ea649ae9ca134c68225193f7fb64c0da251a9e2 Mon Sep 17 00:00:00 2001
From: rhaffar <141745338+rhaffar@users.noreply.github.com>
Date: Tue, 29 Oct 2024 17:01:22 -0400
Subject: [PATCH 2/3] Adding SnowflakeCompare (Snowflake/Snowpark compare)
(#333)
* adding snowflake/snowpark compare
* undo change
* add partial case sensitive support
* update doc
* pr comments
* remainder PR comments
* update readme
* mocking abs, trim
* remove local testing, fix text
* ignore snowflake tests
* catch snowpark import in test config
* python 3.12 actions without snowflake
* clean
* fix
* catch snowflake imports, fix snowflake type annotations
* conditional install
* conditional install
* conditional install
* fix conditional
---
.github/workflows/test-package.yml | 18 +-
README.md | 2 +
datacompy/__init__.py | 2 +
datacompy/snowflake.py | 1202 +++++++++++++++++++++++++
docs/source/index.rst | 3 +-
docs/source/snowflake_usage.rst | 268 ++++++
pyproject.toml | 5 +-
tests/conftest.py | 24 +
tests/test_snowflake.py | 1326 ++++++++++++++++++++++++++++
9 files changed, 2844 insertions(+), 6 deletions(-)
create mode 100644 datacompy/snowflake.py
create mode 100644 docs/source/snowflake_usage.rst
create mode 100644 tests/conftest.py
create mode 100644 tests/test_snowflake.py
diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml
index 26b18def..dc50777d 100644
--- a/.github/workflows/test-package.yml
+++ b/.github/workflows/test-package.yml
@@ -64,17 +64,27 @@ jobs:
java-version: '8'
distribution: 'adopt'
- - name: Install Spark and datacompy
+ - name: Install Spark, Pandas, and Numpy
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-spark pypandoc
python -m pip install pyspark[connect]==${{ matrix.spark-version }}
python -m pip install pandas==${{ matrix.pandas-version }}
python -m pip install numpy==${{ matrix.numpy-version }}
+
+ - name: Install Datacompy without Snowflake/Snowpark if Python 3.12
+ if: ${{ matrix.python-version == '3.12' }}
+ run: |
+ python -m pip install .[dev_no_snowflake]
+
+ - name: Install Datacompy with all dev dependencies if Python 3.9, 3.10, or 3.11
+ if: ${{ matrix.python-version != '3.12' }}
+ run: |
python -m pip install .[dev]
+
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
test-bare-install:
@@ -101,7 +111,7 @@ jobs:
python -m pip install .[tests]
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
test-fugue-install-no-spark:
@@ -127,4 +137,4 @@ jobs:
python -m pip install .[tests,duckdb,polars,dask,ray]
- name: Test with pytest
run: |
- python -m pytest tests/
+ python -m pytest tests/ --ignore=tests/test_snowflake.py
diff --git a/README.md b/README.md
index b60457ef..cf6070df 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,7 @@ pip install datacompy[spark]
pip install datacompy[dask]
pip install datacompy[duckdb]
pip install datacompy[ray]
+pip install datacompy[snowflake]
```
@@ -95,6 +96,7 @@ with the Pandas on Spark implementation. Spark plans to support Pandas 2 in [Spa
- Pandas: ([See documentation](https://capitalone.github.io/datacompy/pandas_usage.html))
- Spark: ([See documentation](https://capitalone.github.io/datacompy/spark_usage.html))
- Polars: ([See documentation](https://capitalone.github.io/datacompy/polars_usage.html))
+- Snowflake/Snowpark: ([See documentation](https://capitalone.github.io/datacompy/snowflake_usage.html))
- Fugue is a Python library that provides a unified interface for data processing on Pandas, DuckDB, Polars, Arrow,
Spark, Dask, Ray, and many other backends. DataComPy integrates with Fugue to provide a simple way to compare data
across these backends. Please note that Fugue will use the Pandas (Native) logic at its lowest level
diff --git a/datacompy/__init__.py b/datacompy/__init__.py
index a6d331e8..74154839 100644
--- a/datacompy/__init__.py
+++ b/datacompy/__init__.py
@@ -43,12 +43,14 @@
unq_columns,
)
from datacompy.polars import PolarsCompare
+from datacompy.snowflake import SnowflakeCompare
from datacompy.spark.sql import SparkSQLCompare
__all__ = [
"BaseCompare",
"Compare",
"PolarsCompare",
+ "SnowflakeCompare",
"SparkSQLCompare",
"all_columns_match",
"all_rows_overlap",
diff --git a/datacompy/snowflake.py b/datacompy/snowflake.py
new file mode 100644
index 00000000..19f63978
--- /dev/null
+++ b/datacompy/snowflake.py
@@ -0,0 +1,1202 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Compare two Snowpark SQL DataFrames and Snowflake tables.
+
+Originally this package was meant to provide similar functionality to
+PROC COMPARE in SAS - i.e. human-readable reporting on the difference between
+two dataframes.
+"""
+
+import logging
+import os
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Union, cast
+
+import pandas as pd
+from ordered_set import OrderedSet
+
+try:
+ import snowflake.snowpark as sp
+ from snowflake.snowpark import Window
+ from snowflake.snowpark.exceptions import SnowparkSQLException
+ from snowflake.snowpark.functions import (
+ abs,
+ col,
+ concat,
+ contains,
+ is_null,
+ lit,
+ monotonically_increasing_id,
+ row_number,
+ trim,
+ when,
+ )
+except ImportError:
+ pass # for non-snowflake users
+from datacompy.base import BaseCompare
+from datacompy.spark.sql import decimal_comparator
+
+LOG = logging.getLogger(__name__)
+
+
+NUMERIC_SNOWPARK_TYPES = [
+ "tinyint",
+ "smallint",
+ "int",
+ "bigint",
+ "float",
+ "double",
+ decimal_comparator(),
+]
+
+
+class SnowflakeCompare(BaseCompare):
+ """Comparison class to be used to compare whether two Snowpark dataframes are equal.
+
+ df1 and df2 can refer to either a Snowpark dataframe or the name of a valid Snowflake table.
+ The data structures which df1 and df2 represent must contain all of the join_columns,
+ with unique column names. Differences between values are compared to
+ abs_tol + rel_tol * abs(df2['value']).
+
+ Parameters
+ ----------
+ session: snowflake.snowpark.session
+ Session with the required connection session info for user and targeted tables
+ df1 : Union[str, sp.Dataframe]
+ First table to check, provided either as the table's name or as a Snowpark DF.
+ df2 : Union[str, sp.Dataframe]
+ Second table to check, provided either as the table's name or as a Snowpark DF.
+ join_columns : list or str, optional
+ Column(s) to join dataframes on. If a string is passed in, that one
+ column will be used.
+ abs_tol : float, optional
+ Absolute tolerance between two values.
+ rel_tol : float, optional
+ Relative tolerance between two values.
+ df1_name : str, optional
+ A string name for the first dataframe. If used alongside a snowflake table,
+ overrides the default convention of naming the dataframe after the table.
+ df2_name : str, optional
+ A string name for the second dataframe.
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns (including any join
+ columns).
+
+ Attributes
+ ----------
+ df1_unq_rows : sp.DataFrame
+ All records that are only in df1 (based on a join on join_columns)
+ df2_unq_rows : sp.DataFrame
+ All records that are only in df2 (based on a join on join_columns)
+ """
+
+ def __init__(
+ self,
+ session: "sp.Session",
+ df1: Union[str, "sp.DataFrame"],
+ df2: Union[str, "sp.DataFrame"],
+ join_columns: Optional[Union[List[str], str]],
+ abs_tol: float = 0,
+ rel_tol: float = 0,
+ df1_name: Optional[str] = None,
+ df2_name: Optional[str] = None,
+ ignore_spaces: bool = False,
+ ) -> None:
+ if join_columns is None:
+ errmsg = "join_columns cannot be None"
+ raise ValueError(errmsg)
+ elif not join_columns:
+ errmsg = "join_columns is empty"
+ raise ValueError(errmsg)
+ elif isinstance(join_columns, (str, int, float)):
+ self.join_columns = [str(join_columns).replace('"', "").upper()]
+ else:
+ self.join_columns = [
+ str(col).replace('"', "").upper()
+ for col in cast(List[str], join_columns)
+ ]
+
+ self._any_dupes: bool = False
+ self.session = session
+ self.df1 = (df1, df1_name)
+ self.df2 = (df2, df2_name)
+ self.abs_tol = abs_tol
+ self.rel_tol = rel_tol
+ self.ignore_spaces = ignore_spaces
+ self.df1_unq_rows: sp.DataFrame
+ self.df2_unq_rows: sp.DataFrame
+ self.intersect_rows: sp.DataFrame
+ self.column_stats: List[Dict[str, Any]] = []
+ self._compare(ignore_spaces=ignore_spaces)
+
+ @property
+ def df1(self) -> "sp.DataFrame":
+ """Get the first dataframe."""
+ return self._df1
+
+ @df1.setter
+ def df1(self, df1: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None:
+ """Check that df1 is either a Snowpark DF or the name of a valid Snowflake table."""
+ (df, df_name) = df1
+ if isinstance(df, str):
+ table_name = [table_comp.upper() for table_comp in df.split(".")]
+ if len(table_name) != 3:
+ errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema."
+ raise ValueError(errmsg)
+ self.df1_name = df_name.upper() if df_name else table_name[2]
+ self._df1 = self.session.table(df)
+ else:
+ self._df1 = df
+ self.df1_name = df_name.upper() if df_name else "DF1"
+ self._validate_dataframe(self.df1_name, "df1")
+
+ @property
+ def df2(self) -> "sp.DataFrame":
+ """Get the second dataframe."""
+ return self._df2
+
+ @df2.setter
+ def df2(self, df2: tuple[Union[str, "sp.DataFrame"], Optional[str]]) -> None:
+ """Check that df2 is either a Snowpark DF or the name of a valid Snowflake table."""
+ (df, df_name) = df2
+ if isinstance(df, str):
+ table_name = [table_comp.upper() for table_comp in df.split(".")]
+ if len(table_name) != 3:
+ errmsg = f"{df} is not a valid table name. Be sure to include the target db and schema."
+ raise ValueError(errmsg)
+ self.df2_name = df_name.upper() if df_name else table_name[2]
+ self._df2 = self.session.table(df)
+ else:
+ self._df2 = df
+ self.df2_name = df_name.upper() if df_name else "DF2"
+ self._validate_dataframe(self.df2_name, "df2")
+
+ def _validate_dataframe(self, df_name: str, index: str) -> None:
+ """Validate the provided Snowpark dataframe.
+
+ The dataframe can either be a standalone Snowpark dataframe or a representative
+ of a Snowflake table - in the latter case we check that the table it represents
+ is a valid table by forcing a collection.
+
+ Parameters
+ ----------
+ df_name : str
+ Name of the Snowflake table / Snowpark dataframe
+ index : str
+ The "index" of the dataframe - df1 or df2.
+ """
+ df = getattr(self, index)
+ if not isinstance(df, "sp.DataFrame"):
+ raise TypeError(f"{df_name} must be a valid sp.Dataframe")
+
+ # force all columns to be non-case-sensitive
+ if index == "df1":
+ col_map = dict(
+ zip(
+ self._df1.columns,
+ [str(c).replace('"', "").upper() for c in self._df1.columns],
+ )
+ )
+ self._df1 = self._df1.rename(col_map)
+ if index == "df2":
+ col_map = dict(
+ zip(
+ self._df2.columns,
+ [str(c).replace('"', "").upper() for c in self._df2.columns],
+ )
+ )
+ self._df2 = self._df2.rename(dict(col_map))
+
+ df = getattr(self, index) # refresh
+ if not set(self.join_columns).issubset(set(df.columns)):
+ raise ValueError(f"{df_name} must have all columns from join_columns")
+ if len(set(df.columns)) < len(df.columns):
+ raise ValueError(f"{df_name} must have unique column names")
+
+ if df.drop_duplicates(self.join_columns).count() < df.count():
+ self._any_dupes = True
+
+ def _compare(self, ignore_spaces: bool) -> None:
+ """Actually run the comparison.
+
+ This method will log out information about what is different between
+ the two dataframes.
+ """
+ LOG.info(f"Number of columns in common: {len(self.intersect_columns())}")
+ LOG.debug("Checking column overlap")
+ for column in self.df1_unq_columns():
+ LOG.info(f"Column in df1 and not in df2: {column}")
+ LOG.info(
+ f"Number of columns in df1 and not in df2: {len(self.df1_unq_columns())}"
+ )
+ for column in self.df2_unq_columns():
+ LOG.info(f"Column in df2 and not in df1: {column}")
+ LOG.info(
+ f"Number of columns in df2 and not in df1: {len(self.df2_unq_columns())}"
+ )
+ LOG.debug("Merging dataframes")
+ self._dataframe_merge(ignore_spaces)
+ self._intersect_compare(ignore_spaces)
+ if self.matches():
+ LOG.info("df1 matches df2")
+ else:
+ LOG.info("df1 does not match df2")
+
+ def df1_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df1."""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df1.columns) - OrderedSet(self.df2.columns)
+ )
+
+ def df2_unq_columns(self) -> OrderedSet[str]:
+ """Get columns that are unique to df2."""
+ return cast(
+ OrderedSet[str], OrderedSet(self.df2.columns) - OrderedSet(self.df1.columns)
+ )
+
+ def intersect_columns(self) -> OrderedSet[str]:
+ """Get columns that are shared between the two dataframes."""
+ return OrderedSet(self.df1.columns) & OrderedSet(self.df2.columns)
+
+ def _dataframe_merge(self, ignore_spaces: bool) -> None:
+ """Merge df1 to df2 on the join columns.
+
+ Gets df1 - df2, df2 - df1, and df1 & df2
+ joining on the ``join_columns``.
+ """
+ LOG.debug("Outer joining")
+
+ df1 = self.df1
+ df2 = self.df2
+ temp_join_columns = deepcopy(self.join_columns)
+
+ if self._any_dupes:
+ LOG.debug("Duplicate rows found, deduping by order of remaining fields")
+ # setting internal index
+ LOG.info("Adding internal index to dataframes")
+ df1 = df1.withColumn("__index", monotonically_increasing_id())
+ df2 = df2.withColumn("__index", monotonically_increasing_id())
+
+ # Create order column for uniqueness of match
+ order_column = temp_column_name(df1, df2)
+ df1 = df1.join(
+ _generate_id_within_group(df1, temp_join_columns, order_column),
+ on="__index",
+ how="inner",
+ ).drop("__index")
+ df2 = df2.join(
+ _generate_id_within_group(df2, temp_join_columns, order_column),
+ on="__index",
+ how="inner",
+ ).drop("__index")
+ temp_join_columns.append(order_column)
+
+ # drop index
+ LOG.info("Dropping internal index")
+ df1 = df1.drop("__index")
+ df2 = df2.drop("__index")
+
+ if ignore_spaces:
+ for column in self.join_columns:
+ if "string" in next(
+ dtype for name, dtype in df1.dtypes if name == column
+ ):
+ df1 = df1.withColumn(column, trim(col(column)))
+ if "string" in next(
+ dtype for name, dtype in df2.dtypes if name == column
+ ):
+ df2 = df2.withColumn(column, trim(col(column)))
+
+ df1 = df1.withColumn("merge", lit(True))
+ df2 = df2.withColumn("merge", lit(True))
+
+ for c in df1.columns:
+ df1 = df1.withColumnRenamed(c, c + "_" + self.df1_name)
+ for c in df2.columns:
+ df2 = df2.withColumnRenamed(c, c + "_" + self.df2_name)
+
+ # NULL SAFE Outer join, not possible with Snowpark Dataframe join
+ df1.createOrReplaceTempView("df1")
+ df2.createOrReplaceTempView("df2")
+ on = " and ".join(
+ [
+ f"EQUAL_NULL(df1.{c}_{self.df1_name}, df2.{c}_{self.df2_name})"
+ for c in temp_join_columns
+ ]
+ )
+ outer_join = self.session.sql(
+ """
+ SELECT * FROM
+ df1 FULL OUTER JOIN df2
+ ON
+ """
+ + on
+ )
+ # Create join indicator
+ outer_join = outer_join.withColumn(
+ "_merge",
+ when(
+ outer_join[f"MERGE_{self.df1_name}"]
+ & outer_join[f"MERGE_{self.df2_name}"],
+ lit("BOTH"),
+ )
+ .when(
+ outer_join[f"MERGE_{self.df1_name}"]
+ & outer_join[f"MERGE_{self.df2_name}"].is_null(),
+ lit("LEFT_ONLY"),
+ )
+ .when(
+ outer_join[f"MERGE_{self.df1_name}"].is_null()
+ & outer_join[f"MERGE_{self.df2_name}"],
+ lit("RIGHT_ONLY"),
+ ),
+ )
+
+ df1 = df1.drop(f"MERGE_{self.df1_name}")
+ df2 = df2.drop(f"MERGE_{self.df2_name}")
+
+ # Clean up temp columns for duplicate row matching
+ if self._any_dupes:
+ outer_join = outer_join.select_expr(
+ f"* EXCLUDE ({order_column}_{self.df1_name}, {order_column}_{self.df2_name})"
+ )
+ df1 = df1.drop(f"{order_column}_{self.df1_name}")
+ df2 = df2.drop(f"{order_column}_{self.df2_name}")
+
+ # Capitalization required - clean up
+ df1_cols = get_merged_columns(df1, outer_join, self.df1_name)
+ df2_cols = get_merged_columns(df2, outer_join, self.df2_name)
+
+ LOG.debug("Selecting df1 unique rows")
+ self.df1_unq_rows = outer_join[outer_join["_merge"] == "LEFT_ONLY"][df1_cols]
+
+ LOG.debug("Selecting df2 unique rows")
+ self.df2_unq_rows = outer_join[outer_join["_merge"] == "RIGHT_ONLY"][df2_cols]
+ LOG.info(f"Number of rows in df1 and not in df2: {self.df1_unq_rows.count()}")
+ LOG.info(f"Number of rows in df2 and not in df1: {self.df2_unq_rows.count()}")
+
+ LOG.debug("Selecting intersecting rows")
+ self.intersect_rows = outer_join[outer_join["_merge"] == "BOTH"]
+ LOG.info(
+ f"Number of rows in df1 and df2 (not necessarily equal): {self.intersect_rows.count()}"
+ )
+ self.intersect_rows = self.intersect_rows.cache_result()
+
+ def _intersect_compare(self, ignore_spaces: bool) -> None:
+ """Run the comparison on the intersect dataframe.
+
+ This loops through all columns that are shared between df1 and df2, and
+ creates a column column_match which is True for matches, False
+ otherwise.
+ """
+ LOG.debug("Comparing intersection")
+ max_diff: float
+ null_diff: int
+ row_cnt = self.intersect_rows.count()
+ for column in self.intersect_columns():
+ if column in self.join_columns:
+ match_cnt = row_cnt
+ col_match = ""
+ max_diff = 0
+ null_diff = 0
+ else:
+ col_1 = column + "_" + self.df1_name
+ col_2 = column + "_" + self.df2_name
+ col_match = column + "_MATCH"
+ self.intersect_rows = columns_equal(
+ self.intersect_rows,
+ col_1,
+ col_2,
+ col_match,
+ self.rel_tol,
+ self.abs_tol,
+ ignore_spaces,
+ )
+ match_cnt = (
+ self.intersect_rows.select(col_match)
+ .where(col(col_match) == True) # noqa: E712
+ .count()
+ )
+ max_diff = calculate_max_diff(
+ self.intersect_rows,
+ col_1,
+ col_2,
+ )
+ null_diff = calculate_null_diff(self.intersect_rows, col_1, col_2)
+
+ if row_cnt > 0:
+ match_rate = float(match_cnt) / row_cnt
+ else:
+ match_rate = 0
+ LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")
+
+ col1_dtype, _ = _get_column_dtypes(self.df1, column, column)
+ col2_dtype, _ = _get_column_dtypes(self.df2, column, column)
+
+ self.column_stats.append(
+ {
+ "column": column,
+ "match_column": col_match,
+ "match_cnt": match_cnt,
+ "unequal_cnt": row_cnt - match_cnt,
+ "dtype1": str(col1_dtype),
+ "dtype2": str(col2_dtype),
+ "all_match": all(
+ (
+ col1_dtype == col2_dtype,
+ row_cnt == match_cnt,
+ )
+ ),
+ "max_diff": max_diff,
+ "null_diff": null_diff,
+ }
+ )
+
+ def all_columns_match(self) -> bool:
+ """Whether the columns all match in the dataframes.
+
+ Returns
+ -------
+ bool
+ True if all columns in df1 are in df2 and vice versa
+ """
+ return self.df1_unq_columns() == self.df2_unq_columns() == set()
+
+ def all_rows_overlap(self) -> bool:
+ """Whether the rows are all present in both dataframes.
+
+ Returns
+ -------
+ bool
+ True if all rows in df1 are in df2 and vice versa (based on
+ existence for join option)
+ """
+ return self.df1_unq_rows.count() == self.df2_unq_rows.count() == 0
+
+ def count_matching_rows(self) -> int:
+ """Count the number of rows match (on overlapping fields).
+
+ Returns
+ -------
+ int
+ Number of matching rows
+ """
+ conditions = []
+ match_columns = []
+ for column in self.intersect_columns():
+ if column not in self.join_columns:
+ match_columns.append(column + "_MATCH")
+ conditions.append(f"{column}_MATCH = True")
+ if len(conditions) > 0:
+ match_columns_count = self.intersect_rows.filter(
+ " and ".join(conditions)
+ ).count()
+ else:
+ match_columns_count = 0
+ return match_columns_count
+
+ def intersect_rows_match(self) -> bool:
+ """Check whether the intersect rows all match."""
+ actual_length = self.intersect_rows.count()
+ return self.count_matching_rows() == actual_length
+
+ def matches(self, ignore_extra_columns: bool = False) -> bool:
+ """Return True or False if the dataframes match.
+
+ Parameters
+ ----------
+ ignore_extra_columns : bool
+ Ignores any columns in one dataframe and not in the other.
+
+ Returns
+ -------
+ bool
+ True or False if the dataframes match.
+ """
+ return not (
+ (not ignore_extra_columns and not self.all_columns_match())
+ or not self.all_rows_overlap()
+ or not self.intersect_rows_match()
+ )
+
+ def subset(self) -> bool:
+ """Return True if dataframe 2 is a subset of dataframe 1.
+
+ Dataframe 2 is considered a subset if all of its columns are in
+ dataframe 1, and all of its rows match rows in dataframe 1 for the
+ shared columns.
+
+ Returns
+ -------
+ bool
+ True if dataframe 2 is a subset of dataframe 1.
+ """
+ return not (
+ self.df2_unq_columns() != set()
+ or self.df2_unq_rows.count() != 0
+ or not self.intersect_rows_match()
+ )
+
+ def sample_mismatch(
+ self, column: str, sample_count: int = 10, for_display: bool = False
+ ) -> "sp.DataFrame":
+ """Return sample mismatches.
+
+ Gets a sub-dataframe which contains the identifying
+ columns, and df1 and df2 versions of the column.
+
+ Parameters
+ ----------
+ column : str
+ The raw column name (i.e. without ``_df1`` appended)
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+ for_display : bool, optional
+ Whether this is just going to be used for display (overwrite the
+ column names)
+
+ Returns
+ -------
+ sp.DataFrame
+ A sample of the intersection dataframe, containing only the
+ "pertinent" columns, for rows that don't match on the provided
+ column.
+ """
+ row_cnt = self.intersect_rows.count()
+ col_match = self.intersect_rows.select(column + "_MATCH")
+ match_cnt = col_match.where(
+ col(column + "_MATCH") == True # noqa: E712
+ ).count()
+ sample_count = min(sample_count, row_cnt - match_cnt)
+ sample = (
+ self.intersect_rows.where(col(column + "_MATCH") == False) # noqa: E712
+ .drop(column + "_MATCH")
+ .limit(sample_count)
+ )
+
+ for c in self.join_columns:
+ sample = sample.withColumnRenamed(c + "_" + self.df1_name, c)
+
+ return_cols = [
+ *self.join_columns,
+ column + "_" + self.df1_name,
+ column + "_" + self.df2_name,
+ ]
+ to_return = sample.select(return_cols)
+
+ if for_display:
+ return to_return.toDF(
+ *[
+ *self.join_columns,
+ column + " (" + self.df1_name + ")",
+ column + " (" + self.df2_name + ")",
+ ]
+ )
+ return to_return
+
+ def all_mismatch(self, ignore_matching_cols: bool = False) -> "sp.DataFrame":
+ """Get all rows with any columns that have a mismatch.
+
+ Returns all df1 and df2 versions of the columns and join
+ columns.
+
+ Parameters
+ ----------
+ ignore_matching_cols : bool, optional
+ Whether showing the matching columns in the output or not. The default is False.
+
+ Returns
+ -------
+ sp.DataFrame
+ All rows of the intersection dataframe, containing any columns, that don't match.
+ """
+ match_list = []
+ return_list = []
+ for c in self.intersect_rows.columns:
+ if c.endswith("_MATCH"):
+ orig_col_name = c[:-6]
+
+ col_comparison = columns_equal(
+ self.intersect_rows,
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ c,
+ self.rel_tol,
+ self.abs_tol,
+ self.ignore_spaces,
+ )
+
+ if not ignore_matching_cols or (
+ ignore_matching_cols
+ and col_comparison.select(c)
+ .where(col(c) == False) # noqa: E712
+ .count()
+ > 0
+ ):
+ LOG.debug(f"Adding column {orig_col_name} to the result.")
+ match_list.append(c)
+ return_list.extend(
+ [
+ orig_col_name + "_" + self.df1_name,
+ orig_col_name + "_" + self.df2_name,
+ ]
+ )
+ elif ignore_matching_cols:
+ LOG.debug(
+ f"Column {orig_col_name} is equal in df1 and df2. It will not be added to the result."
+ )
+
+ mm_rows = self.intersect_rows.withColumn(
+ "match_array", concat(*match_list)
+ ).where(contains(col("match_array"), lit("false")))
+
+ for c in self.join_columns:
+ mm_rows = mm_rows.withColumnRenamed(c + "_" + self.df1_name, c)
+
+ return mm_rows.select(self.join_columns + return_list)
+
+ def report(
+ self,
+ sample_count: int = 10,
+ column_count: int = 10,
+ html_file: Optional[str] = None,
+ ) -> str:
+ """Return a string representation of a report.
+
+ The representation can
+ then be printed or saved to a file.
+
+ Parameters
+ ----------
+ sample_count : int, optional
+ The number of sample records to return. Defaults to 10.
+
+ column_count : int, optional
+ The number of columns to display in the sample records output. Defaults to 10.
+
+ html_file : str, optional
+ HTML file name to save report output to. If ``None`` the file creation will be skipped.
+
+ Returns
+ -------
+ str
+ The report, formatted kinda nicely.
+ """
+ # Header
+ report = render("header.txt")
+ df_header = pd.DataFrame(
+ {
+ "DataFrame": [self.df1_name, self.df2_name],
+ "Columns": [len(self.df1.columns), len(self.df2.columns)],
+ "Rows": [self.df1.count(), self.df2.count()],
+ }
+ )
+ report += df_header[["DataFrame", "Columns", "Rows"]].to_string()
+ report += "\n\n"
+
+ # Column Summary
+ report += render(
+ "column_summary.txt",
+ len(self.intersect_columns()),
+ len(self.df1_unq_columns()),
+ len(self.df2_unq_columns()),
+ self.df1_name,
+ self.df2_name,
+ )
+
+ # Row Summary
+ match_on = ", ".join(self.join_columns)
+ report += render(
+ "row_summary.txt",
+ match_on,
+ self.abs_tol,
+ self.rel_tol,
+ self.intersect_rows.count(),
+ self.df1_unq_rows.count(),
+ self.df2_unq_rows.count(),
+ self.intersect_rows.count() - self.count_matching_rows(),
+ self.count_matching_rows(),
+ self.df1_name,
+ self.df2_name,
+ "Yes" if self._any_dupes else "No",
+ )
+
+ # Column Matching
+ report += render(
+ "column_comparison.txt",
+ len([col for col in self.column_stats if col["unequal_cnt"] > 0]),
+ len([col for col in self.column_stats if col["unequal_cnt"] == 0]),
+ sum(col["unequal_cnt"] for col in self.column_stats),
+ )
+
+ match_stats = []
+ match_sample = []
+ any_mismatch = False
+ for column in self.column_stats:
+ if not column["all_match"]:
+ any_mismatch = True
+ match_stats.append(
+ {
+ "Column": column["column"],
+ f"{self.df1_name} dtype": column["dtype1"],
+ f"{self.df2_name} dtype": column["dtype2"],
+ "# Unequal": column["unequal_cnt"],
+ "Max Diff": column["max_diff"],
+ "# Null Diff": column["null_diff"],
+ }
+ )
+ if column["unequal_cnt"] > 0:
+ match_sample.append(
+ self.sample_mismatch(
+ column["column"], sample_count, for_display=True
+ )
+ )
+
+ if any_mismatch:
+ report += "Columns with Unequal Values or Types\n"
+ report += "------------------------------------\n"
+ report += "\n"
+ df_match_stats = pd.DataFrame(match_stats)
+ df_match_stats.sort_values("Column", inplace=True)
+ # Have to specify again for sorting
+ report += df_match_stats[
+ [
+ "Column",
+ f"{self.df1_name} dtype",
+ f"{self.df2_name} dtype",
+ "# Unequal",
+ "Max Diff",
+ "# Null Diff",
+ ]
+ ].to_string()
+ report += "\n\n"
+
+ if sample_count > 0:
+ report += "Sample Rows with Unequal Values\n"
+ report += "-------------------------------\n"
+ report += "\n"
+ for sample in match_sample:
+ report += sample.toPandas().to_string()
+ report += "\n\n"
+
+ if min(sample_count, self.df1_unq_rows.count()) > 0:
+ report += (
+ f"Sample Rows Only in {self.df1_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df1_name)}\n"
+ )
+ report += "\n"
+ columns = self.df1_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df1_unq_rows.count())
+ report += (
+ self.df1_unq_rows.limit(unq_count)
+ .select(columns)
+ .toPandas()
+ .to_string()
+ )
+ report += "\n\n"
+
+ if min(sample_count, self.df2_unq_rows.count()) > 0:
+ report += (
+ f"Sample Rows Only in {self.df2_name} (First {column_count} Columns)\n"
+ )
+ report += (
+ f"---------------------------------------{'-' * len(self.df2_name)}\n"
+ )
+ report += "\n"
+ columns = self.df2_unq_rows.columns[:column_count]
+ unq_count = min(sample_count, self.df2_unq_rows.count())
+ report += (
+ self.df2_unq_rows.limit(unq_count)
+ .select(columns)
+ .toPandas()
+ .to_string()
+ )
+ report += "\n\n"
+
+ if html_file:
+ html_report = report.replace("\n", "
").replace(" ", " ")
+ html_report = f"
{html_report}
"
+ with open(html_file, "w") as f:
+ f.write(html_report)
+
+ return report
+
+
+def render(filename: str, *fields: Union[int, float, str]) -> str:
+ """Render out an individual template.
+
+ This basically just reads in a
+ template file, and applies ``.format()`` on the fields.
+
+ Parameters
+ ----------
+ filename : str
+ The file that contains the template. Will automagically prepend the
+ templates directory before opening
+ fields : list
+ Fields to be rendered out in the template
+
+ Returns
+ -------
+ str
+ The fully rendered out file.
+ """
+ this_dir = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(this_dir, "templates", filename)) as file_open:
+ return file_open.read().format(*fields)
+
+
+def columns_equal(
+ dataframe: "sp.DataFrame",
+ col_1: str,
+ col_2: str,
+ col_match: str,
+ rel_tol: float = 0,
+ abs_tol: float = 0,
+ ignore_spaces: bool = False,
+) -> "sp.DataFrame":
+ """Compare two columns from a dataframe.
+
+ Returns a True/False series with the same index as column 1.
+
+ - Two nulls (np.nan) will evaluate to True.
+ - A null and a non-null value will evaluate to False.
+ - Numeric values will use the relative and absolute tolerances.
+ - Decimal values (decimal.Decimal) will attempt to be converted to floats
+ before comparing
+ - Non-numeric values (i.e. where np.isclose can't be used) will just
+ trigger True on two nulls or exact matches.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+ col_match : str
+ The matching column denoting if the compare was a match or not
+ rel_tol : float, optional
+ Relative tolerance
+ abs_tol : float, optional
+ Absolute tolerance
+ ignore_spaces : bool, optional
+ Flag to strip whitespace (including newlines) from string columns
+
+ Returns
+ -------
+ sp.DataFrame
+ A column of boolean values are added. True == the values match, False == the
+ values don't match.
+ """
+ base_dtype, compare_dtype = _get_column_dtypes(dataframe, col_1, col_2)
+ if _is_comparable(base_dtype, compare_dtype):
+ if (base_dtype in NUMERIC_SNOWPARK_TYPES) and (
+ compare_dtype in NUMERIC_SNOWPARK_TYPES
+ ): # numeric tolerance comparison
+ dataframe = dataframe.withColumn(
+ col_match,
+ when(
+ (col(col_1).eqNullSafe(col(col_2)))
+ | (
+ abs(col(col_1) - col(col_2))
+ <= lit(abs_tol) + (lit(rel_tol) * abs(col(col_2)))
+ ),
+ # corner case of col1 != NaN and col2 == Nan returns True incorrectly
+ when(
+ (is_null(col(col_1)) == False) # noqa: E712
+ & (is_null(col(col_2)) == True), # noqa: E712
+ lit(False),
+ ).otherwise(lit(True)),
+ ).otherwise(lit(False)),
+ )
+ else: # non-numeric comparison
+ if ignore_spaces:
+ when_clause = trim(col(col_1)).eqNullSafe(trim(col(col_2)))
+ else:
+ when_clause = col(col_1).eqNullSafe(col(col_2))
+
+ dataframe = dataframe.withColumn(
+ col_match,
+ when(when_clause, lit(True)).otherwise(lit(False)),
+ )
+ else:
+ LOG.debug(
+ f"Skipping {col_1}({base_dtype}) and {col_2}({compare_dtype}), columns are not comparable"
+ )
+ dataframe = dataframe.withColumn(col_match, lit(False))
+ return dataframe
+
+
+def get_merged_columns(
+ original_df: "sp.DataFrame", merged_df: "sp.DataFrame", suffix: str
+) -> List[str]:
+ """Get the columns from an original dataframe, in the new merged dataframe.
+
+ Parameters
+ ----------
+ original_df : sp.DataFrame
+ The original, pre-merge dataframe
+ merged_df : sp.DataFrame
+ Post-merge with another dataframe, with suffixes added in.
+ suffix : str
+ What suffix was used to distinguish when the original dataframe was
+ overlapping with the other merged dataframe.
+
+ Returns
+ -------
+ List[str]
+ Column list of the original dataframe pre suffix
+ """
+ columns = []
+ for column in original_df.columns:
+ if column in merged_df.columns:
+ columns.append(column)
+ elif f"{column}_{suffix}" in merged_df.columns:
+ columns.append(f"{column}_{suffix}")
+ else:
+ raise ValueError("Column not found: %s", column)
+ return columns
+
+
+def calculate_max_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> float:
+ """Get a maximum difference between two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ float
+ max diff
+ """
+ # Attempting to coalesce maximum diff for non-numeric results in error, if error return 0 max diff.
+ try:
+ diff = dataframe.select(
+ (col(col_1).astype("float") - col(col_2).astype("float")).alias("diff")
+ )
+ abs_diff = diff.select(abs(col("diff")).alias("abs_diff"))
+ max_diff: float = (
+ abs_diff.where(is_null(col("abs_diff")) == False) # noqa: E712
+ .agg({"abs_diff": "max"})
+ .collect()[0][0]
+ )
+ except SnowparkSQLException:
+ return None
+
+ if pd.isna(max_diff) or pd.isnull(max_diff) or max_diff is None:
+ return 0
+ else:
+ return max_diff
+
+
+def calculate_null_diff(dataframe: "sp.DataFrame", col_1: str, col_2: str) -> int:
+ """Get the null differences between two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ int
+ null diff
+ """
+ nulls_df = dataframe.withColumn(
+ "col_1_null",
+ when(col(col_1).isNull() == True, lit(True)).otherwise( # noqa: E712
+ lit(False)
+ ),
+ )
+ nulls_df = nulls_df.withColumn(
+ "col_2_null",
+ when(col(col_2).isNull() == True, lit(True)).otherwise( # noqa: E712
+ lit(False)
+ ),
+ ).select(["col_1_null", "col_2_null"])
+
+ # (not a and b) or (a and not b)
+ null_diff = nulls_df.where(
+ ((col("col_1_null") == False) & (col("col_2_null") == True)) # noqa: E712
+ | ((col("col_1_null") == True) & (col("col_2_null") == False)) # noqa: E712
+ ).count()
+
+ if pd.isna(null_diff) or pd.isnull(null_diff) or null_diff is None:
+ return 0
+ else:
+ return null_diff
+
+
+def _generate_id_within_group(
+ dataframe: "sp.DataFrame", join_columns: List[str], order_column_name: str
+) -> "sp.DataFrame":
+ """Generate an ID column that can be used to deduplicate identical rows.
+
+ The series generated
+ is the order within a unique group, and it handles nulls. Requires a ``__index`` column.
+
+ Parameters
+ ----------
+ dataframe : sp.DataFrame
+ The dataframe to operate on
+ join_columns : list
+ List of strings which are the join columns
+ order_column_name: str
+ The name of the ``row_number`` column name
+
+ Returns
+ -------
+ sp.DataFrame
+ Original dataframe with the ID column that's unique in each group
+ """
+ default_value = "DATACOMPY_NULL"
+ null_check = False
+ default_check = False
+ for c in join_columns:
+ if dataframe.where(col(c).isNull()).limit(1).collect():
+ null_check = True
+ break
+ for c in [
+ column for column, type in dataframe[join_columns].dtypes if "string" in type
+ ]:
+ if dataframe.where(col(c).isin(default_value)).limit(1).collect():
+ default_check = True
+ break
+
+ if null_check:
+ if default_check:
+ raise ValueError(f"{default_value} was found in your join columns")
+
+ return (
+ dataframe.select(
+ *(col(c).cast("string").alias(c) for c in join_columns + ["__index"]) # noqa: RUF005
+ )
+ .fillna(default_value)
+ .withColumn(
+ order_column_name,
+ row_number().over(Window.orderBy("__index").partitionBy(join_columns))
+ - 1,
+ )
+ .select(["__index", order_column_name])
+ )
+ else:
+ return (
+ dataframe.select(join_columns + ["__index"]) # noqa: RUF005
+ .withColumn(
+ order_column_name,
+ row_number().over(Window.orderBy("__index").partitionBy(join_columns))
+ - 1,
+ )
+ .select(["__index", order_column_name])
+ )
+
+
+def _get_column_dtypes(
+ dataframe: "sp.DataFrame", col_1: "str", col_2: "str"
+) -> tuple[str, str]:
+ """Get the dtypes of two columns.
+
+ Parameters
+ ----------
+ dataframe: sp.DataFrame
+ DataFrame to do comparison on
+ col_1 : str
+ The first column to look at
+ col_2 : str
+ The second column
+
+ Returns
+ -------
+ Tuple(str, str)
+ Tuple of base and compare datatype
+ """
+ base_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_1)
+ compare_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_2)
+ return base_dtype, compare_dtype
+
+
+def _is_comparable(type1: str, type2: str) -> bool:
+ """Check if two SnowPark data types can be safely compared.
+
+ Two data types are considered comparable if any of the following apply:
+ 1. Both data types are the same
+ 2. Both data types are numeric
+
+ Parameters
+ ----------
+ type1 : str
+ A string representation of a Snowpark data type
+ type2 : str
+ A string representation of a Snowpark data type
+
+ Returns
+ -------
+ bool
+ True if both data types are comparable
+ """
+ return (
+ type1 == type2
+ or (type1 in NUMERIC_SNOWPARK_TYPES and type2 in NUMERIC_SNOWPARK_TYPES)
+ or ("string" in type1 and type2 == "date")
+ or (type1 == "date" and "string" in type2)
+ or ("string" in type1 and type2 == "timestamp")
+ or (type1 == "timestamp" and "string" in type2)
+ )
+
+
+def temp_column_name(*dataframes) -> str:
+ """Get a temp column name that isn't included in columns of any dataframes.
+
+ Parameters
+ ----------
+ dataframes : list of DataFrames
+ The DataFrames to create a temporary column name for
+
+ Returns
+ -------
+ str
+ String column name that looks like '_temp_x' for some integer x
+ """
+ i = 0
+ columns = []
+ for dataframe in dataframes:
+ columns = columns + list(dataframe.columns)
+ columns = set(columns)
+
+ while True:
+ temp_column = f"_TEMP_{i}"
+ unique = True
+
+ if temp_column in columns:
+ i += 1
+ unique = False
+ if unique:
+ return temp_column
diff --git a/docs/source/index.rst b/docs/source/index.rst
index e6f77d96..b1cb4c4a 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -11,6 +11,7 @@ Contents
Installation
Pandas Usage
Spark Usage
+ Snowflake Usage
Polars Usage
Fugue Usage
Benchmarks
@@ -28,4 +29,4 @@ Indices and tables
* :ref:`genindex`
* :ref:`modindex`
-* :ref:`search`
\ No newline at end of file
+* :ref:`search`
diff --git a/docs/source/snowflake_usage.rst b/docs/source/snowflake_usage.rst
new file mode 100644
index 00000000..3c2687e3
--- /dev/null
+++ b/docs/source/snowflake_usage.rst
@@ -0,0 +1,268 @@
+Snowpark/Snowflake Usage
+========================
+
+For ``SnowflakeCompare``
+
+- ``on_index`` is not supported.
+- Joining is done using ``EQUAL_NULL`` which is the equality test that is safe for null values.
+- Compares ``snowflake.snowpark.DataFrame``, which can be provided as either raw Snowflake dataframes
+or the as the names of full names of valid snowflake tables, which we will process into Snowpark dataframes.
+
+
+SnowflakeCompare Object Setup
+---------------------------------------------------
+There are two ways to specify input dataframes for ``SnowflakeCompare``
+
+Provide Snowpark dataframes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: python
+
+ from snowflake.snowpark import Session
+ from snowflake.snowpark import Row
+ import datetime
+ import datacompy.snowflake as sp
+
+ connection_parameters = {
+ ...
+ }
+ session = Session.builder.configs(connection_parameters).create()
+
+ data1 = [
+ Row(acct_id=10000001234, dollar_amt=123.45, name='George Maharis', float_fld=14530.1555,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=1.0,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001236, dollar_amt=1345.0, name='George Bluth', float_fld=None,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001237, dollar_amt=123456.0, name='Bob Loblaw', float_fld=345.12,
+ date_fld=datetime.date(2017, 1, 1)),
+ Row(acct_id=10000001239, dollar_amt=1.05, name='Lucille Bluth', float_fld=None,
+ date_fld=datetime.date(2017, 1, 1)),
+ ]
+
+ data2 = [
+ Row(acct_id=10000001234, dollar_amt=123.4, name='George Michael Bluth', float_fld=14530.155),
+ Row(acct_id=10000001235, dollar_amt=0.45, name='Michael Bluth', float_fld=None),
+ Row(acct_id=None, dollar_amt=1345.0, name='George Bluth', float_fld=1.0),
+ Row(acct_id=10000001237, dollar_amt=123456.0, name='Robert Loblaw', float_fld=345.12),
+ Row(acct_id=10000001238, dollar_amt=1.05, name='Loose Seal Bluth', float_fld=111.0),
+ ]
+
+ df_1 = session.createDataFrame(data1)
+ df_2 = session.createDataFrame(data2)
+
+ compare = sp.SnowflakeCompare(
+ session,
+ df_1,
+ df_2,
+ join_columns=['acct_id'],
+ rel_tol=1e-03,
+ abs_tol=1e-04,
+ )
+ compare.matches(ignore_extra_columns=False)
+
+ # This method prints out a human-readable report summarizing and sampling differences
+ print(compare.report())
+
+
+Provide the full name (``{db}.{schema}.{table_name}``) of valid Snowflake tables
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Given the dataframes from the prior examples...
+
+.. code-block:: python
+ df_1.write.mode("overwrite").save_as_table("toy_table_1")
+ df_2.write.mode("overwrite").save_as_table("toy_table_2")
+
+ compare = sp.SnowflakeCompare(
+ session,
+ f"{db}.{schema}.toy_table_1",
+ f"{db}.{schema}.toy_table_2",
+ join_columns=['acct_id'],
+ rel_tol=1e-03,
+ abs_tol=1e-04,
+ )
+ compare.matches(ignore_extra_columns=False)
+
+ # This method prints out a human-readable report summarizing and sampling differences
+ print(compare.report())
+
+Reports
+-------
+
+A report is generated by calling ``report()``, which returns a string.
+Here is a sample report generated by ``datacompy`` for the two tables above,
+joined on ``acct_id`` (Note: the names for your dataframes are extracted from
+the name of the provided Snowflake table. If you chose to directly use Snowpark
+dataframes, then the names will default to ``DF1`` and ``DF2``.)::
+
+ DataComPy Comparison
+ --------------------
+
+ DataFrame Summary
+ -----------------
+
+ DataFrame Columns Rows
+ 0 DF1 5 5
+ 1 DF2 4 5
+
+ Column Summary
+ --------------
+
+ Number of columns in common: 4
+ Number of columns in DF1 but not in DF2: 1
+ Number of columns in DF2 but not in DF1: 0
+
+ Row Summary
+ -----------
+
+ Matched on: ACCT_ID
+ Any duplicates on match values: No
+ Absolute Tolerance: 0
+ Relative Tolerance: 0
+ Number of rows in common: 4
+ Number of rows in DF1 but not in DF2: 1
+ Number of rows in DF2 but not in DF1: 1
+
+ Number of rows with some compared columns unequal: 4
+ Number of rows with all compared columns equal: 0
+
+ Column Comparison
+ -----------------
+
+ Number of columns compared with some values unequal: 3
+ Number of columns compared with all values equal: 1
+ Total number of values which compare unequal: 6
+
+ Columns with Unequal Values or Types
+ ------------------------------------
+
+ Column DF1 dtype DF2 dtype # Unequal Max Diff # Null Diff
+ 0 DOLLAR_AMT double double 1 0.0500 0
+ 2 FLOAT_FLD double double 3 0.0005 2
+ 1 NAME string(16777216) string(16777216) 2 NaN 0
+
+ Sample Rows with Unequal Values
+ -------------------------------
+
+ ACCT_ID DOLLAR_AMT (DF1) DOLLAR_AMT (DF2)
+ 0 10000001234 123.45 123.4
+
+ ACCT_ID NAME (DF1) NAME (DF2)
+ 0 10000001234 George Maharis George Michael Bluth
+ 1 10000001237 Bob Loblaw Robert Loblaw
+
+ ACCT_ID FLOAT_FLD (DF1) FLOAT_FLD (DF2)
+ 0 10000001234 14530.1555 14530.155
+ 1 10000001235 1.0000 NaN
+ 2 10000001236 NaN 1.000
+
+ Sample Rows Only in DF1 (First 10 Columns)
+ ------------------------------------------
+
+ ACCT_ID_DF1 DOLLAR_AMT_DF1 NAME_DF1 FLOAT_FLD_DF1 DATE_FLD_DF1
+ 0 10000001239 1.05 Lucille Bluth NaN 2017-01-01
+
+ Sample Rows Only in DF2 (First 10 Columns)
+ ------------------------------------------
+
+ ACCT_ID_DF2 DOLLAR_AMT_DF2 NAME_DF2 FLOAT_FLD_DF2
+ 0 10000001238 1.05 Loose Seal Bluth 111.0
+
+
+Convenience Methods
+-------------------
+
+There are a few convenience methods and attributes available after the comparison has been run:
+
+.. code-block:: python
+
+ compare.intersect_rows[['name_df1', 'name_df2', 'name_match']].show()
+ # --------------------------------------------------------
+ # |"NAME_DF1" |"NAME_DF2" |"NAME_MATCH" |
+ # --------------------------------------------------------
+ # |George Maharis |George Michael Bluth |False |
+ # |Michael Bluth |Michael Bluth |True |
+ # |George Bluth |George Bluth |True |
+ # |Bob Loblaw |Robert Loblaw |False |
+ # --------------------------------------------------------
+
+ compare.df1_unq_rows.show()
+ # ---------------------------------------------------------------------------------------
+ # |"ACCT_ID_DF1" |"DOLLAR_AMT_DF1" |"NAME_DF1" |"FLOAT_FLD_DF1" |"DATE_FLD_DF1" |
+ # ---------------------------------------------------------------------------------------
+ # |10000001239 |1.05 |Lucille Bluth |NULL |2017-01-01 |
+ # ---------------------------------------------------------------------------------------
+
+ compare.df2_unq_rows.show()
+ # -------------------------------------------------------------------------
+ # |"ACCT_ID_DF2" |"DOLLAR_AMT_DF2" |"NAME_DF2" |"FLOAT_FLD_DF2" |
+ # -------------------------------------------------------------------------
+ # |10000001238 |1.05 |Loose Seal Bluth |111.0 |
+ # -------------------------------------------------------------------------
+
+ print(compare.intersect_columns())
+ # OrderedSet(['acct_id', 'dollar_amt', 'name', 'float_fld'])
+
+ print(compare.df1_unq_columns())
+ # OrderedSet(['date_fld'])
+
+ print(compare.df2_unq_columns())
+ # OrderedSet()
+
+Duplicate rows
+--------------
+
+Datacompy will try to handle rows that are duplicate in the join columns. It does this behind the
+scenes by generating a unique ID within each unique group of the join columns. For example, if you
+have two dataframes you're trying to join on acct_id:
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+2 George Bluth
+=========== ================
+
+=========== ================
+acct_id name
+=========== ================
+1 George Maharis
+1 Michael Bluth
+1 Tony Wonder
+2 George Bluth
+=========== ================
+
+Datacompy will generate a unique temporary ID for joining:
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+2 George Bluth 0
+=========== ================ ========
+
+=========== ================ ========
+acct_id name temp_id
+=========== ================ ========
+1 George Maharis 0
+1 Michael Bluth 1
+1 Tony Wonder 2
+2 George Bluth 0
+=========== ================ ========
+
+And then merge the two dataframes on a combination of the join_columns you specified and the temporary
+ID, before dropping the temp_id again. So the first two rows in the first dataframe will match the
+first two rows in the second dataframe, and the third row in the second dataframe will be recognized
+as uniquely in the second.
+
+Additional considerations
+-------------------------
+- It is strongly recommended against joining on float columns (or any column with floating point precision).
+Columns joining tables are compared on the basis of an exact comparison, therefore if the values comparing
+your float columns are not exact, you will likely get unexpected results.
+- Case-sensitive columns are only partially supported. We essentially treat case-sensitive
+columns as if they are case-insensitive. Therefore you may use case-sensitive columns as long as
+you don't have several columns with the same name differentiated only be case sensitivity.
diff --git a/pyproject.toml b/pyproject.toml
index ed9c0f9b..9c86c82f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -56,16 +56,19 @@ python-tag = "py3"
[project.optional-dependencies]
duckdb = ["fugue[duckdb]"]
spark = ["pyspark[connect]>=3.1.1; python_version < \"3.11\"", "pyspark[connect]>=3.4; python_version >= \"3.11\""]
+snowflake = ["snowflake-connector-python", "snowflake-snowpark-python"]
dask = ["fugue[dask]"]
ray = ["fugue[ray]"]
docs = ["sphinx", "furo", "myst-parser"]
tests = ["pytest", "pytest-cov"]
tests-spark = ["pytest", "pytest-cov", "pytest-spark"]
+tests-snowflake = ["snowflake-snowpark-python[localtest]"]
qa = ["pre-commit", "ruff==0.5.7", "mypy", "pandas-stubs"]
build = ["build", "twine", "wheel"]
edgetest = ["edgetest", "edgetest-conda"]
-dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"]
+dev_no_snowflake = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[qa]", "datacompy[build]"]
+dev = ["datacompy[duckdb]", "datacompy[spark]", "datacompy[snowflake]", "datacompy[docs]", "datacompy[tests]", "datacompy[tests-spark]", "datacompy[tests-snowflake]", "datacompy[qa]", "datacompy[build]"]
# Linters, formatters and type checkers
[tool.ruff]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..54532b4b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,24 @@
+"""Testing configuration file, currently used for generating a Snowpark local session for testing."""
+
+import os
+
+import pytest
+
+try:
+ from snowflake.snowpark.session import Session
+except ModuleNotFoundError:
+ pass
+
+CONNECTION_PARAMETERS = {
+ "account": os.environ.get("SF_ACCOUNT"),
+ "user": os.environ.get("SF_UID"),
+ "password": os.environ.get("SF_PWD"),
+ "warehouse": os.environ.get("SF_WAREHOUSE"),
+ "database": os.environ.get("SF_DATABASE"),
+ "schema": os.environ.get("SF_SCHEMA"),
+}
+
+
+@pytest.fixture(scope="module")
+def snowpark_session() -> "Session":
+ return Session.builder.configs(CONNECTION_PARAMETERS).create()
diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py
new file mode 100644
index 00000000..548f6043
--- /dev/null
+++ b/tests/test_snowflake.py
@@ -0,0 +1,1326 @@
+#
+# Copyright 2024 Capital One Services, LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "LICENSE");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Testing out the datacompy functionality
+"""
+
+import io
+import logging
+import sys
+from datetime import datetime
+from decimal import Decimal
+from io import StringIO
+from unittest import mock
+
+import numpy as np
+import pandas as pd
+import pytest
+from pytest import raises
+
+pytest.importorskip("pyspark")
+
+
+from datacompy.snowflake import (
+ SnowflakeCompare,
+ _generate_id_within_group,
+ calculate_max_diff,
+ columns_equal,
+ temp_column_name,
+)
+from pandas.testing import assert_series_equal
+from snowflake.snowpark.exceptions import SnowparkSQLException
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+
+pd.DataFrame.iteritems = pd.DataFrame.items # Pandas 2+ compatability
+np.bool = np.bool_ # Numpy 1.24.3+ comptability
+
+
+def test_numeric_columns_equal_abs(snowpark_session):
+ data = """A|B|EXPECTED
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_numeric_columns_equal_rel(snowpark_session):
+ data = """A|B|EXPECTED
+1|1|True
+2|2.1|True
+3|4|False
+4|NULL|False
+NULL|4|False
+NULL|NULL|True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_string_columns_equal(snowpark_session):
+ data = """A|B|EXPECTED
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |False
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |False
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_string_columns_equal_with_ignore_spaces(snowpark_session):
+ data = """A|B|EXPECTED
+Hi|Hi|True
+Yo|Yo|True
+Hey|Hey |True
+résumé|resume|False
+résumé|résumé|True
+💩|💩|True
+💩|🤔|False
+ | |True
+ | |True
+datacompy|DataComPy|False
+something||False
+|something|False
+||True"""
+ df = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep="|"))
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_date_columns_equal(snowpark_session):
+ data = """A|B|EXPECTED
+2017-01-01|2017-01-01|True
+2017-01-02|2017-01-02|True
+2017-10-01|2017-10-10|False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ pdf = pd.read_csv(io.StringIO(data), sep="|")
+ df = snowpark_session.createDataFrame(pdf)
+ # First compare just the strings
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+ # Then compare converted to datetime objects
+ pdf["A"] = pd.to_datetime(pdf["A"])
+ pdf["B"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", rel_tol=0.2).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+ # and reverse
+ actual_out_rev = columns_equal(df, "B", "A", "ACTUAL", rel_tol=0.2).toPandas()[
+ "ACTUAL"
+ ]
+ assert_series_equal(expect_out, actual_out_rev, check_names=False)
+
+
+def test_date_columns_equal_with_ignore_spaces(snowpark_session):
+ data = """A|B|EXPECTED
+2017-01-01|2017-01-01 |True
+2017-01-02 |2017-01-02|True
+2017-10-01 |2017-10-10 |False
+2017-01-01||False
+|2017-01-01|False
+||True"""
+ pdf = pd.read_csv(io.StringIO(data), sep="|")
+ df = snowpark_session.createDataFrame(pdf)
+ # First compare just the strings
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+ # Then compare converted to datetime objects
+ try: # pandas 2
+ pdf["A"] = pd.to_datetime(pdf["A"], format="mixed")
+ pdf["B"] = pd.to_datetime(pdf["B"], format="mixed")
+ except ValueError: # pandas 1.5
+ pdf["A"] = pd.to_datetime(pdf["A"])
+ pdf["B"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(
+ df, "A", "B", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+ # and reverse
+ actual_out_rev = columns_equal(
+ df, "B", "A", "ACTUAL", rel_tol=0.2, ignore_spaces=True
+ ).toPandas()["ACTUAL"]
+ assert_series_equal(expect_out, actual_out_rev, check_names=False)
+
+
+def test_date_columns_unequal(snowpark_session):
+ """I want datetime fields to match with dates stored as strings"""
+ data = [{"A": "2017-01-01", "B": "2017-01-02"}, {"A": "2017-01-01"}]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ pdf["B_DT"] = pd.to_datetime(pdf["B"])
+ df = snowpark_session.createDataFrame(pdf)
+ assert columns_equal(df, "A", "A_DT", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "B", "B_DT", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "A_DT", "A", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "B_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert not columns_equal(df, "B_DT", "A", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "A", "B_DT", "ACTUAL").toPandas()["ACTUAL"].any()
+ assert not columns_equal(df, "B", "A_DT", "ACTUAL").toPandas()["ACTUAL"].any()
+
+
+def test_bad_date_columns(snowpark_session):
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ data = [
+ {"A": "2017-01-01", "B": "2017-01-01"},
+ {"A": "2017-01-01", "B": "217-01-01"},
+ ]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ df = snowpark_session.createDataFrame(pdf)
+ assert not columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].all()
+ assert columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"].any()
+
+
+def test_rounded_date_columns(snowpark_session):
+ """If strings can't be coerced into dates then it should be false for the
+ whole column.
+ """
+ data = [
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00.000000", "EXP": True},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00.123456", "EXP": False},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:01.000000", "EXP": False},
+ {"A": "2017-01-01", "B": "2017-01-01 00:00:00", "EXP": True},
+ ]
+ pdf = pd.DataFrame(data)
+ pdf["A_DT"] = pd.to_datetime(pdf["A"])
+ df = snowpark_session.createDataFrame(pdf)
+ actual = columns_equal(df, "A_DT", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expected = df.select("EXP").toPandas()["EXP"]
+ assert_series_equal(actual, expected, check_names=False)
+
+
+def test_decimal_float_columns_equal(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": 1, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True},
+ {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": False},
+ {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": 1, "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_float_columns_equal_rel(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": 1, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.3, "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": 1.000003, "EXPECTED": True},
+ {"A": Decimal("1.000000004"), "B": 1.000000003, "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": 1.2, "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": 1, "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[
+ "ACTUAL"
+ ]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_columns_equal(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True},
+ {
+ "A": Decimal("1.000000004"),
+ "B": Decimal("1.000000003"),
+ "EXPECTED": False,
+ },
+ {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": Decimal("1"), "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_decimal_columns_equal_rel(snowpark_session):
+ data = [
+ {"A": Decimal("1"), "B": Decimal("1"), "EXPECTED": True},
+ {"A": Decimal("1.3"), "B": Decimal("1.3"), "EXPECTED": True},
+ {"A": Decimal("1.000003"), "B": Decimal("1.000003"), "EXPECTED": True},
+ {
+ "A": Decimal("1.000000004"),
+ "B": Decimal("1.000000003"),
+ "EXPECTED": True,
+ },
+ {"A": Decimal("1.3"), "B": Decimal("1.2"), "EXPECTED": False},
+ {"A": np.nan, "B": np.nan, "EXPECTED": True},
+ {"A": np.nan, "B": Decimal("1"), "EXPECTED": False},
+ {"A": Decimal("1"), "B": np.nan, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL", abs_tol=0.001).toPandas()[
+ "ACTUAL"
+ ]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_infinity_and_beyond(snowpark_session):
+ # https://spark.apache.org/docs/latest/sql-ref-datatypes.html#positivenegative-infinity-semantics
+ # Positive/negative infinity multiplied by 0 returns NaN.
+ # Positive infinity sorts lower than NaN and higher than any other values.
+ # Negative infinity sorts lower than any other values.
+ data = [
+ {"A": np.inf, "B": np.inf, "EXPECTED": True},
+ {"A": -np.inf, "B": -np.inf, "EXPECTED": True},
+ {"A": -np.inf, "B": np.inf, "EXPECTED": True},
+ {"A": np.inf, "B": -np.inf, "EXPECTED": True},
+ {"A": 1, "B": 1, "EXPECTED": True},
+ {"A": 1, "B": 0, "EXPECTED": False},
+ ]
+ pdf = pd.DataFrame(data)
+ df = snowpark_session.createDataFrame(pdf)
+ actual_out = columns_equal(df, "A", "B", "ACTUAL").toPandas()["ACTUAL"]
+ expect_out = df.select("EXPECTED").toPandas()["EXPECTED"]
+ assert_series_equal(expect_out, actual_out, check_names=False)
+
+
+def test_compare_table_setter_bad(snowpark_session):
+ # Invalid table name
+ with raises(ValueError, match="invalid_table_name_1 is not a valid table name."):
+ SnowflakeCompare(
+ snowpark_session, "invalid_table_name_1", "invalid_table_name_2", ["A"]
+ )
+ # Valid table name but table does not exist
+ with raises(SnowparkSQLException):
+ SnowflakeCompare(
+ snowpark_session, "non.existant.table_1", "non.existant.table_2", ["A"]
+ )
+
+
+def test_compare_table_setter_good(snowpark_session):
+ data = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df = pd.read_csv(StringIO(data), sep=",")
+ database = snowpark_session.get_current_database().replace('"', "")
+ schema = snowpark_session.get_current_schema().replace('"', "")
+ full_table_name = f"{database}.{schema}"
+ toy_table_name_1 = "DC_TOY_TABLE_1"
+ toy_table_name_2 = "DC_TOY_TABLE_2"
+ full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}"
+ full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}"
+
+ snowpark_session.write_pandas(
+ df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True
+ )
+ snowpark_session.write_pandas(
+ df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True
+ )
+
+ compare = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID"],
+ )
+ assert compare.df1.toPandas().equals(df)
+ assert compare.join_columns == ["ACCT_ID"]
+
+
+def test_compare_df_setter_bad(snowpark_session):
+ pdf = pd.DataFrame([{"A": 1, "C": 2}, {"A": 2, "C": 2}])
+ df = snowpark_session.createDataFrame(pdf)
+ with raises(TypeError, match="DF1 must be a valid sp.Dataframe"):
+ SnowflakeCompare(snowpark_session, 3, 2, ["A"])
+ with raises(ValueError, match="DF1 must have all columns from join_columns"):
+ SnowflakeCompare(snowpark_session, df, df.select("*"), ["B"])
+ pdf = pd.DataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 3}])
+ df_dupe = snowpark_session.createDataFrame(pdf)
+ pd.testing.assert_frame_equal(
+ SnowflakeCompare(
+ snowpark_session, df_dupe, df_dupe.select("*"), ["A", "B"]
+ ).df1.toPandas(),
+ pdf,
+ check_dtype=False,
+ )
+
+
+def test_compare_df_setter_good(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+ assert compare.join_columns == ["A"]
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A", "B"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+ assert compare.join_columns == ["A", "B"]
+
+
+def test_compare_df_setter_different_cases(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1.toPandas().equals(df1.toPandas())
+
+
+def test_columns_overlap(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 3}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1_unq_columns() == set()
+ assert compare.df2_unq_columns() == set()
+ assert compare.intersect_columns() == {"A", "B"}
+
+
+def test_columns_no_overlap(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "D": "OH"}, {"A": 2, "B": 3, "D": "YA"}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert compare.df1_unq_columns() == {"C"}
+ assert compare.df2_unq_columns() == {"D"}
+ assert compare.intersect_columns() == {"A", "B"}
+
+
+def test_columns_maintain_order_through_set_operations(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "JOIN": ["A", "B"],
+ "F": [0, 0],
+ "G": [1, 2],
+ "B": [2, 2],
+ "H": [3, 3],
+ "A": [4, 4],
+ "C": [-2, -3],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "JOIN": ["A", "B"],
+ "E": [0, 1],
+ "H": [1, 2],
+ "B": [2, 3],
+ "A": [-1, -1],
+ "G": [4, 4],
+ "D": [-3, -2],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ compare = SnowflakeCompare(snowpark_session, df1, df2, ["JOIN"])
+ assert list(compare.df1_unq_columns()) == ["F", "C"]
+ assert list(compare.df2_unq_columns()) == ["E", "D"]
+ assert list(compare.intersect_columns()) == ["JOIN", "G", "B", "H", "A"]
+
+
+def test_10k_rows(snowpark_session):
+ rng = np.random.default_rng()
+ pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"])
+ pdf.reset_index(inplace=True)
+ pdf.columns = ["A", "B", "C"]
+ pdf2 = pdf.copy()
+ pdf2["B"] = pdf2["B"] + 0.1
+ df1 = snowpark_session.createDataFrame(pdf)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ compare_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"], abs_tol=0.2)
+ assert compare_tol.matches()
+ assert compare_tol.df1_unq_rows.count() == 0
+ assert compare_tol.df2_unq_rows.count() == 0
+ assert compare_tol.intersect_columns() == {"A", "B", "C"}
+ assert compare_tol.all_columns_match()
+ assert compare_tol.all_rows_overlap()
+ assert compare_tol.intersect_rows_match()
+
+ compare_no_tol = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not compare_no_tol.matches()
+ assert compare_no_tol.df1_unq_rows.count() == 0
+ assert compare_no_tol.df2_unq_rows.count() == 0
+ assert compare_no_tol.intersect_columns() == {"A", "B", "C"}
+ assert compare_no_tol.all_columns_match()
+ assert compare_no_tol.all_rows_overlap()
+ assert not compare_no_tol.intersect_rows_match()
+
+
+def test_subset(snowpark_session, caplog):
+ caplog.set_level(logging.DEBUG)
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "C": "HI"}])
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert comp.subset()
+
+
+def test_not_subset(snowpark_session, caplog):
+ caplog.set_level(logging.INFO)
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "YO"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "C": "HI"}, {"A": 2, "B": 2, "C": "GREAT"}]
+ )
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not comp.subset()
+ assert "C: 1 / 2 (50.00%) match" in caplog.text
+
+
+def test_large_subset(snowpark_session):
+ rng = np.random.default_rng()
+ pdf = pd.DataFrame(rng.integers(0, 100, size=(10000, 2)), columns=["B", "C"])
+ pdf.reset_index(inplace=True)
+ pdf.columns = ["A", "B", "C"]
+ pdf2 = pdf[["A", "B"]].head(50).copy()
+ df1 = snowpark_session.createDataFrame(pdf)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, ["A"])
+ assert not comp.matches()
+ assert comp.subset()
+
+
+def test_string_joiner(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}])
+ df2 = snowpark_session.createDataFrame([{"AB": 1, "BC": 2}, {"AB": 2, "BC": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "AB")
+ assert compare.matches()
+
+
+def test_decimal_with_joins(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": Decimal("1"), "B": 2}, {"A": Decimal("2"), "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_decimal_with_nulls(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": Decimal("2")}, {"A": 2, "B": Decimal("2")}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2}, {"A": 2, "B": 2}, {"A": 3, "B": 2}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_strings_with_joins(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_temp_column_name(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}])
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_0"
+
+
+def test_temp_column_name_one_has(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": 2}, {"A": "BYE", "B": 2}, {"A": "back fo mo", "B": 3}]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_1"
+
+
+def test_temp_column_name_both_have_temp_1(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_0": "HI", "B": 2},
+ {"_TEMP_0": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_1"
+
+
+def test_temp_column_name_both_have_temp_2(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_0": "HI", "B": 2}, {"_TEMP_0": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_0": "HI", "B": 2},
+ {"_TEMP_1": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_2"
+
+
+def test_temp_column_name_one_already(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"_TEMP_1": "HI", "B": 2}, {"_TEMP_1": "BYE", "B": 2}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [
+ {"_TEMP_1": "HI", "B": 2},
+ {"_TEMP_1": "BYE", "B": 2},
+ {"A": "back fo mo", "B": 3},
+ ]
+ )
+ actual = temp_column_name(df1, df2)
+ assert actual == "_TEMP_0"
+
+
+# Duplicate testing!
+
+
+def test_simple_dupes_one_field(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_two_fields(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2, "C": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A", "B"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_two_vals_1(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert compare.matches()
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_two_vals_2(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 2, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert not compare.matches()
+ assert compare.df1_unq_rows.count() == 1
+ assert compare.df2_unq_rows.count() == 1
+ assert compare.intersect_rows.count() == 1
+ # Just render the report to make sure it renders.
+ compare.report()
+
+
+def test_simple_dupes_one_field_three_to_two_vals(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2}, {"A": 1, "B": 0}, {"A": 1, "B": 0}]
+ )
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 0}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+ assert not compare.matches()
+ assert compare.df1_unq_rows.count() == 1
+ assert compare.df2_unq_rows.count() == 0
+ assert compare.intersect_rows.count() == 2
+ # Just render the report to make sure it renders.
+ compare.report()
+ assert "(First 1 Columns)" in compare.report(column_count=1)
+ assert "(First 2 Columns)" in compare.report(column_count=2)
+
+
+def test_dupes_from_real_data(snowpark_session):
+ data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM
+100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0
+200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1,
+100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A,
+100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0
+200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,,
+100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0
+200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0
+200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,,
+100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0
+200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F,
+100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0
+200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F,
+100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0
+200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A,
+100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0
+200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,"""
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data), sep=","))
+ df2 = df1.select("*")
+ compare_acct = SnowflakeCompare(
+ snowpark_session, df1, df2, join_columns=["ACCT_ID"]
+ )
+ assert compare_acct.matches()
+ compare_acct.report()
+
+ compare_unq = SnowflakeCompare(
+ snowpark_session,
+ df1,
+ df2,
+ join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"],
+ )
+ assert compare_unq.matches()
+ compare_unq.report()
+
+
+def test_table_compare_from_real_data(snowpark_session):
+ data = """ACCT_ID,ACCT_SFX_NUM,TRXN_POST_DT,TRXN_POST_SEQ_NUM,TRXN_AMT,TRXN_DT,DEBIT_CR_CD,CASH_ADV_TRXN_COMN_CNTRY_CD,MRCH_CATG_CD,MRCH_PSTL_CD,VISA_MAIL_PHN_CD,VISA_RQSTD_PMT_SVC_CD,MC_PMT_FACILITATOR_IDN_NUM
+100,0,2017-06-17,1537019,30.64,2017-06-15,D,CAN,5812,M2N5P5,,,0.0
+200,0,2017-06-24,1022477,485.32,2017-06-22,D,USA,4511,7114,7.0,1,
+100,0,2017-06-17,1537039,2.73,2017-06-16,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-29,1049223,22.41,2017-06-28,D,USA,4789,21211,,A,
+100,0,2017-06-17,1537029,34.05,2017-06-16,D,CAN,5812,M4E 2C7,,,0.0
+200,0,2017-06-29,1049213,9.12,2017-06-28,D,CAN,5814,0,,,
+100,0,2017-06-19,1646426,165.21,2017-06-17,D,CAN,5411,M4M 3H9,,,0.0
+200,0,2017-06-30,1233082,28.54,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646436,17.87,2017-06-18,D,CAN,5812,M4J 1M9,,,0.0
+200,0,2017-06-30,1233092,24.39,2017-06-29,D,USA,4121,94105,7.0,G,
+100,0,2017-06-19,1646446,5.27,2017-06-17,D,CAN,5200,M4M 3G6,,,0.0
+200,0,2017-06-30,1233102,61.8,2017-06-30,D,CAN,4121,0,,,
+100,0,2017-06-20,1607573,41.99,2017-06-19,D,CAN,5661,M4C1M9,,,0.0
+200,0,2017-07-01,1009403,2.31,2017-06-29,D,USA,5814,22102,,F,
+100,0,2017-06-20,1607553,86.88,2017-06-19,D,CAN,4812,H2R3A8,,,0.0
+200,0,2017-07-01,1009423,5.5,2017-06-29,D,USA,5812,2903,,F,
+100,0,2017-06-20,1607563,25.17,2017-06-19,D,CAN,5641,M4C 1M9,,,0.0
+200,0,2017-07-01,1009433,214.12,2017-06-29,D,USA,3640,20170,,A,
+100,0,2017-06-20,1607593,1.67,2017-06-19,D,CAN,5814,M2N 6L7,,,0.0
+200,0,2017-07-01,1009393,2.01,2017-06-29,D,USA,5814,22102,,F,"""
+ df = pd.read_csv(StringIO(data), sep=",")
+ database = snowpark_session.get_current_database().replace('"', "")
+ schema = snowpark_session.get_current_schema().replace('"', "")
+ full_table_name = f"{database}.{schema}"
+ toy_table_name_1 = "DC_TOY_TABLE_1"
+ toy_table_name_2 = "DC_TOY_TABLE_2"
+ full_toy_table_name_1 = f"{full_table_name}.{toy_table_name_1}"
+ full_toy_table_name_2 = f"{full_table_name}.{toy_table_name_2}"
+
+ snowpark_session.write_pandas(
+ df, toy_table_name_1, table_type="temp", auto_create_table=True, overwrite=True
+ )
+ snowpark_session.write_pandas(
+ df, toy_table_name_2, table_type="temp", auto_create_table=True, overwrite=True
+ )
+
+ compare_acct = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID"],
+ )
+ assert compare_acct.matches()
+ compare_acct.report()
+
+ compare_unq = SnowflakeCompare(
+ snowpark_session,
+ full_toy_table_name_1,
+ full_toy_table_name_2,
+ join_columns=["ACCT_ID", "ACCT_SFX_NUM", "TRXN_POST_DT", "TRXN_POST_SEQ_NUM"],
+ )
+ assert compare_unq.matches()
+ compare_unq.report()
+
+
+def test_strings_with_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": " A"}, {"A": "BYE", "B": "A"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A "}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_decimal_with_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert not compare.intersect_rows_match()
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_ignore_spaces(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": " A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A "}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_insensitive_lowercase_cols(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"a": 1, "B": "A"}, {"a": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "a")
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_joins_with_sensitive_lowercase_cols(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{'"a"': 1, "B": "A"}, {'"a"': 2, "B": "A"}])
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, '"a"')
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+
+
+def test_strings_with_ignore_spaces_and_join_columns(snowpark_session):
+ df1 = snowpark_session.createDataFrame(
+ [{"A": "HI", "B": "A"}, {"A": "BYE", "B": "A"}]
+ )
+ df2 = snowpark_session.createDataFrame(
+ [{"A": " HI ", "B": "A"}, {"A": " BYE ", "B": "A"}]
+ )
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert not compare.matches()
+ assert compare.all_columns_match()
+ assert not compare.all_rows_overlap()
+ assert compare.count_matching_rows() == 0
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+def test_integers_with_ignore_spaces_and_join_columns(snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": "A"}, {"A": 2, "B": "A"}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=False)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "A", ignore_spaces=True)
+ assert compare.matches()
+ assert compare.all_columns_match()
+ assert compare.all_rows_overlap()
+ assert compare.intersect_rows_match()
+ assert compare.count_matching_rows() == 2
+
+
+def test_sample_mismatch(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.sample_mismatch(column="NAME", sample_count=1).toPandas()
+ assert output.shape[0] == 1
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+ output = compare.sample_mismatch(column="NAME", sample_count=2).toPandas()
+ assert output.shape[0] == 2
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+ output = compare.sample_mismatch(column="NAME", sample_count=3).toPandas()
+ assert output.shape[0] == 2
+ assert (output.NAME_DF1 != output.NAME_DF2).all()
+
+
+def test_all_mismatch_not_ignore_matching_cols_no_cols_matching(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+def test_all_mismatch_not_ignore_matching_cols_some_cols_matching(snowpark_session):
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 0
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 4
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 0
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 4
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+def test_all_mismatch_ignore_matching_cols_some_cols_matching_diff_rows(
+ snowpark_session,
+):
+ # Case where there are rows on either dataset which don't match up.
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ 10000001241,1111.05,Lucille Bluth,
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch(ignore_matching_cols=True).toPandas()
+
+ assert output.shape[0] == 4
+ assert output.shape[1] == 5
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+ assert not ("NAME_DF1" in output and "NAME_DF2" in output)
+ assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output)
+
+
+def test_all_mismatch_ignore_matching_cols_some_cols_matching(snowpark_session):
+ # Columns dollar_amt and name are matching
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Bob Loblaw,345.12,
+ 10000001238,1.05,Lucille Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch(ignore_matching_cols=True).toPandas()
+
+ assert output.shape[0] == 4
+ assert output.shape[1] == 5
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+ assert not ("NAME_DF1" in output and "NAME_DF2" in output)
+ assert not ("DOLLAR_AMT_DF1" in output and "DOLLAR_AMT_DF1" in output)
+
+
+def test_all_mismatch_ignore_matching_cols_no_cols_matching(snowpark_session):
+ data1 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.45,George Maharis,14530.1555,2017-01-01
+ 10000001235,0.45,Michael Bluth,1,2017-01-01
+ 10000001236,1345,George Bluth,,2017-01-01
+ 10000001237,123456,Bob Loblaw,345.12,2017-01-01
+ 10000001239,1.05,Lucille Bluth,,2017-01-01
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+
+ data2 = """ACCT_ID,DOLLAR_AMT,NAME,FLOAT_FLD,DATE_FLD
+ 10000001234,123.4,George Michael Bluth,14530.155,
+ 10000001235,0.45,Michael Bluth,,
+ 10000001236,1345,George Bluth,1,
+ 10000001237,123456,Robert Loblaw,345.12,
+ 10000001238,1.05,Loose Seal Bluth,111,
+ 10000001240,123.45,George Maharis,14530.1555,2017-01-02
+ """
+ df1 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data1), sep=","))
+ df2 = snowpark_session.createDataFrame(pd.read_csv(StringIO(data2), sep=","))
+ compare = SnowflakeCompare(snowpark_session, df1, df2, "ACCT_ID")
+
+ output = compare.all_mismatch().toPandas()
+ assert output.shape[0] == 4
+ assert output.shape[1] == 9
+
+ assert (output.NAME_DF1 != output.NAME_DF2).values.sum() == 2
+ assert (~(output.NAME_DF1 != output.NAME_DF2)).values.sum() == 2
+
+ assert (output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2).values.sum() == 1
+ assert (~(output.DOLLAR_AMT_DF1 != output.DOLLAR_AMT_DF2)).values.sum() == 3
+
+ assert (output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2).values.sum() == 3
+ assert (~(output.FLOAT_FLD_DF1 != output.FLOAT_FLD_DF2)).values.sum() == 1
+
+ assert (output.DATE_FLD_DF1 != output.DATE_FLD_DF2).values.sum() == 4
+ assert (~(output.DATE_FLD_DF1 != output.DATE_FLD_DF2)).values.sum() == 0
+
+
+@pytest.mark.parametrize(
+ "column, expected",
+ [
+ ("BASE", 0),
+ ("FLOATS", 0.2),
+ ("DECIMALS", 0.1),
+ ("NULL_FLOATS", 0.1),
+ ("STRINGS", 0.1),
+ ("INFINITY", np.inf),
+ ],
+)
+def test_calculate_max_diff(snowpark_session, column, expected):
+ pdf = pd.DataFrame(
+ {
+ "BASE": [1, 1, 1, 1, 1],
+ "FLOATS": [1.1, 1.1, 1.1, 1.2, 0.9],
+ "DECIMALS": [
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ Decimal("1.1"),
+ ],
+ "NULL_FLOATS": [np.nan, 1.1, 1, 1, 1],
+ "STRINGS": ["1", "1", "1", "1.1", "1"],
+ "INFINITY": [1, 1, 1, 1, np.inf],
+ }
+ )
+ MAX_DIFF_DF = snowpark_session.createDataFrame(pdf)
+ assert np.isclose(
+ calculate_max_diff(MAX_DIFF_DF, "BASE", column),
+ expected,
+ )
+
+
+def test_dupes_with_nulls_strings(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 3, 4, 5],
+ "FLD_2": ["A", np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 3, 4, 5],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"])
+ assert comp.subset()
+
+
+def test_dupes_with_nulls_ints(snowpark_session):
+ pdf1 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 2, 3, 3, 4, 5, 5],
+ "FLD_2": [1, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 2, 3, 3, 4, 5, 5],
+ }
+ )
+ pdf2 = pd.DataFrame(
+ {
+ "FLD_1": [1, 2, 3, 4, 5],
+ "FLD_2": [1, np.nan, np.nan, np.nan, np.nan],
+ "FLD_3": [1, 2, 3, 4, 5],
+ }
+ )
+ df1 = snowpark_session.createDataFrame(pdf1)
+ df2 = snowpark_session.createDataFrame(pdf2)
+ comp = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["FLD_1", "FLD_2"])
+ assert comp.subset()
+
+
+def test_generate_id_within_group(snowpark_session):
+ matrix = [
+ (
+ pd.DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}),
+ pd.Series([0, 0, 0]),
+ ),
+ (
+ pd.DataFrame(
+ {
+ "A": ["A", "A", "DATACOMPY_NULL"],
+ "B": [1, 1, 2],
+ "__INDEX": [1, 2, 3],
+ }
+ ),
+ pd.Series([0, 1, 0]),
+ ),
+ (
+ pd.DataFrame({"A": [-999, 2, 3], "B": [1, 2, 3], "__INDEX": [1, 2, 3]}),
+ pd.Series([0, 0, 0]),
+ ),
+ (
+ pd.DataFrame(
+ {"A": [1, np.nan, np.nan], "B": [1, 2, 2], "__INDEX": [1, 2, 3]}
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ (
+ pd.DataFrame(
+ {"A": ["1", np.nan, np.nan], "B": ["1", "2", "2"], "__INDEX": [1, 2, 3]}
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ (
+ pd.DataFrame(
+ {
+ "A": [datetime(2018, 1, 1), np.nan, np.nan],
+ "B": ["1", "2", "2"],
+ "__INDEX": [1, 2, 3],
+ }
+ ),
+ pd.Series([0, 0, 1]),
+ ),
+ ]
+ for i in matrix:
+ dataframe = i[0]
+ expected = i[1]
+ actual = (
+ _generate_id_within_group(
+ snowpark_session.createDataFrame(dataframe), ["A", "B"], "_TEMP_0"
+ )
+ .orderBy("__INDEX")
+ .select("_TEMP_0")
+ .toPandas()
+ )
+ assert (actual["_TEMP_0"] == expected).all()
+
+
+def test_generate_id_within_group_single_join(snowpark_session):
+ dataframe = snowpark_session.createDataFrame(
+ [{"A": 1, "B": 2, "__INDEX": 1}, {"A": 1, "B": 2, "__INDEX": 2}]
+ )
+ expected = pd.Series([0, 1])
+ actual = (
+ _generate_id_within_group(dataframe, ["A"], "_TEMP_0")
+ .orderBy("__INDEX")
+ .select("_TEMP_0")
+ ).toPandas()
+ assert (actual["_TEMP_0"] == expected).all()
+
+
+@mock.patch("datacompy.snowflake.render")
+def test_save_html(mock_render, snowpark_session):
+ df1 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ df2 = snowpark_session.createDataFrame([{"A": 1, "B": 2}, {"A": 1, "B": 2}])
+ compare = SnowflakeCompare(snowpark_session, df1, df2, join_columns=["A"])
+
+ m = mock.mock_open()
+ with mock.patch("datacompy.snowflake.open", m, create=True):
+ # assert without HTML call
+ compare.report()
+ assert mock_render.call_count == 4
+ m.assert_not_called()
+
+ mock_render.reset_mock()
+ m = mock.mock_open()
+ with mock.patch("datacompy.snowflake.open", m, create=True):
+ # assert with HTML call
+ compare.report(html_file="test.html")
+ assert mock_render.call_count == 4
+ m.assert_called_with("test.html", "w")
From 223b61cd9670fd7dc5bc48f187195cf962352603 Mon Sep 17 00:00:00 2001
From: rhaffar <141745338+rhaffar@users.noreply.github.com>
Date: Wed, 30 Oct 2024 09:28:46 -0400
Subject: [PATCH 3/3] Snowflake compare annotation fix, add docs (#344)
* changes
* doc fix
* version bump
---
datacompy/__init__.py | 2 +-
datacompy/snowflake.py | 2 +-
docs/source/developer_instructions.rst | 20 +++++++++++++++++++-
3 files changed, 21 insertions(+), 3 deletions(-)
diff --git a/datacompy/__init__.py b/datacompy/__init__.py
index 74154839..8ca604ac 100644
--- a/datacompy/__init__.py
+++ b/datacompy/__init__.py
@@ -18,7 +18,7 @@
Then extended to carry that functionality over to Spark Dataframes.
"""
-__version__ = "0.14.1"
+__version__ = "0.14.2"
import platform
from warnings import warn
diff --git a/datacompy/snowflake.py b/datacompy/snowflake.py
index 19f63978..4441e5a0 100644
--- a/datacompy/snowflake.py
+++ b/datacompy/snowflake.py
@@ -200,7 +200,7 @@ def _validate_dataframe(self, df_name: str, index: str) -> None:
The "index" of the dataframe - df1 or df2.
"""
df = getattr(self, index)
- if not isinstance(df, "sp.DataFrame"):
+ if not isinstance(df, sp.DataFrame):
raise TypeError(f"{df_name} must be a valid sp.Dataframe")
# force all columns to be non-case-sensitive
diff --git a/docs/source/developer_instructions.rst b/docs/source/developer_instructions.rst
index 9f6becbe..29c8cdc2 100644
--- a/docs/source/developer_instructions.rst
+++ b/docs/source/developer_instructions.rst
@@ -43,6 +43,24 @@ Run ``python -m pytest`` to run all unittests defined in the subfolder
`pytest-runner `_.
+Snowflake testing
+-----------------
+Testing the Snowflake compare requires the use of a Snowflake cluster, as Snowflake does not support local running.
+This means that Snowflake tests do not get run in CICD, and changes to the Snowflake Compare must be validated by
+the process of running these tests locally.
+
+Note that you must have the following environment variables set in order to instantiate a Snowflake Connection (for testing purposes):
+
+- "SF_ACCOUNT": with your SF account
+- "SF_UID": with your SF username
+- "SF_PWD": with your SF password
+- "SF_WAREHOUSE": with your desired SF warehouse
+- "SF_DATABASE": with a valid database with which you have access
+- "SF_SCHEMA": with a valid schema belonging to the provided database
+
+Once these are set, you are free to run the suite of Snowflake tests.
+
+
Management of Requirements
--------------------------
@@ -130,4 +148,4 @@ Finally upload to PyPi::
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
# real pypi
- twine upload dist/*
\ No newline at end of file
+ twine upload dist/*