Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions reladiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def diff_tables(
If different values are needed per table, it's possible to omit them here, and instead set
them directly when creating each :class:`TableSegment`.

Note:
Column transformations using SQL expressions can be configured using the `transform_columns` attribute when creating
the :class:`TableSegment` instances for `table1` and `table2`. As transformations are typically specific to either
the source or target database, this parameter is not overridden directly in `diff_tables`.

Note:
It is recommended to call .close() on the returned object when done, to release thread-pool. Alternatively, you may use it as a context manager.

Expand Down
24 changes: 21 additions & 3 deletions reladiff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
)
from sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
from sqeleton.queries.extras import NormalizeAsString
from sqeleton.queries.ast_classes import LazyOps, ExprNode
from typing import Optional, Any

from .info_tree import InfoTree

Expand All @@ -43,6 +45,12 @@

TABLE_WRITE_LIMIT = 1000

@dataclass
class OverrideNormalizeAsString(NormalizeAsString, LazyOps, ExprNode):
expr: ExprNode
expr_type: Optional[Any] = None # Match type hint of NormalizeAsString
type = str


def merge_dicts(dicts):
i = iter(dicts)
Expand Down Expand Up @@ -314,10 +322,20 @@ def _create_outer_join(self, table1, table2):
a = table1.make_select()
b = table2.make_select()

is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)}
# Get transformed expressions for both tables
# Displayed output value also transformed to be similar with Hashdiffer
is_diff_cols = {}
a_cols = {}
b_cols = {}
for c1, c2 in safezip(cols1, cols2):
ovrrd_c1 = table1.transform_columns.get(c1)
ovrrd_c2 = table2.transform_columns.get(c2)
expr_a = OverrideNormalizeAsString(Code(ovrrd_c1) if ovrrd_c1 else a[c1], table1._schema[c1])
expr_b = OverrideNormalizeAsString(Code(ovrrd_c2) if ovrrd_c2 else b[c2], table2._schema[c2])
is_diff_cols[f"is_diff_{c1}"] = bool_to_int(expr_a.is_distinct_from(expr_b))
a_cols[f"{c1}_a"] = expr_a
b_cols[f"{c2}_b"] = expr_b

a_cols = {f"{c}_a": NormalizeAsString(a[c]) for c in cols1}
b_cols = {f"{c}_b": NormalizeAsString(b[c]) for c in cols2}
# Order columns as col1_a, col1_b, col2_a, col2_b, etc.
cols = {k: v for k, v in chain(*zip(a_cols.items(), b_cols.items()))}

Expand Down
24 changes: 19 additions & 5 deletions reladiff/table_segment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import time
from typing import List, Tuple
from typing import List, Tuple, Dict
import logging
from itertools import product

from runtype import dataclass
from dataclasses import field

from .utils import safezip, Vector
from sqeleton.utils import ArithString, split_space
Expand Down Expand Up @@ -98,6 +99,11 @@ class TableSegment:
update_column (str, optional): Name of updated column, which signals that rows changed.
Usually updated_at or last_update. Used by `min_update` and `max_update`.
extra_columns (Tuple[str, ...], optional): Extra columns to compare
transform_columns (Dict[str, str], optional): A dictionary mapping column names to SQL transformation expressions.
These expressions are applied directly to the specified columns within the
comparison query, *before* the data is hashed or compared. Useful for
on-the-fly normalization (e.g., type casting, timezone conversions) without
requiring intermediate views or staging tables. Defaults to an empty dict.
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
Expand All @@ -116,6 +122,7 @@ class TableSegment:
key_columns: Tuple[str, ...]
update_column: str = None
extra_columns: Tuple[str, ...] = ()
transform_columns: Dict[str, str] = field(default_factory=dict)

# Restrict the segment
min_key: Vector = None
Expand Down Expand Up @@ -155,7 +162,7 @@ def _with_raw_schema(self, raw_schema: dict, refine: bool = True, allow_empty_ta
if is_empty_table and not allow_empty_table:
raise EmptyTable(f"Table {self.table_path} is empty. Use --allow-empty-tables to disable this protection.", self)

res = self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
res = self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive), transform_columns = self.transform_columns)

return EmptyTableSegment(res) if is_empty_table else res

Expand All @@ -167,7 +174,7 @@ def with_schema(self, refine: bool = True, allow_empty_table: bool = False) -> "
return self._with_raw_schema(
self.database.query_table_schema(self.table_path), refine=refine, allow_empty_table=allow_empty_table
)

def _cast_col_value(self, col, value):
"""Cast the value to the right type, based on the type of the column

Expand Down Expand Up @@ -250,7 +257,14 @@ def relevant_columns(self) -> List[str]:

@property
def _relevant_columns_repr(self) -> List[Expr]:
return [NormalizeAsString(this[c]) for c in self.relevant_columns]
expressions = []
for c in self.relevant_columns:
if c in self.transform_columns:
transform_expr = self.transform_columns[c]
expressions.append(NormalizeAsString(Code(transform_expr.format(column=this[c])), self._schema[c]))
else:
expressions.append(NormalizeAsString(this[c], self._schema[c]))
return expressions

def count(self) -> int:
"""Count how many rows are in the segment, in one pass."""
Expand Down Expand Up @@ -330,7 +344,7 @@ def count_and_checksum(self) -> Tuple[int, int]:
return (0, None)

def __getattr__(self, attr):
assert attr in ("database", "key_columns", "key_types", "relevant_columns", "_schema")
assert attr in ("database", "key_columns", "key_types", "relevant_columns", "_schema", "transform_columns")
return getattr(self._table_segment, attr)

@property
Expand Down