Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions reladiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def diff_tables(
If different values are needed per table, it's possible to omit them here, and instead set
them directly when creating each :class:`TableSegment`.

Note:
Column transformations using SQL expressions can be configured using the `transform_columns` attribute when creating
the :class:`TableSegment` instances for `table1` and `table2`. As transformations are typically specific to either
the source or target database, this parameter is not overridden directly in `diff_tables`.

Note:
It is recommended to call .close() on the returned object when done, to release thread-pool. Alternatively, you may use it as a context manager.

Expand Down
25 changes: 20 additions & 5 deletions reladiff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from sqeleton.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
from sqeleton.queries.extras import NormalizeAsString
from sqeleton.queries.ast_classes import IsDistinctFrom

from .info_tree import InfoTree

Expand Down Expand Up @@ -311,13 +312,27 @@ def _create_outer_join(self, table1, table2):
if len(cols1) != len(cols2):
raise ValueError("The provided columns are of a different count")

a = table1.make_select()
b = table2.make_select()
a = table1.make_select().alias("tbl_a")
b = table2.make_select().alias("tbl_b")

is_diff_cols = {f"is_diff_{c1}": bool_to_int(a[c1].is_distinct_from(b[c2])) for c1, c2 in safezip(cols1, cols2)}
# Create a compiler for transform_cols
compiler = Compiler(db, _is_root=False).add_table_context(a, b)

# Get transformed expressions for both tables
# Displayed output value also transformed to be similar with Hashdiffer
is_diff_cols = {}
a_cols = {}
b_cols = {}
for c1, c2 in safezip(cols1, cols2):
# Compile the transformation expression to have aliasing
expr_a = table1._get_column_transforms(c1, compiler.compile(a[c1])) or a[c1]
expr_b = table2._get_column_transforms(c2, compiler.compile(b[c2])) or b[c2]

# Normalize only needed for select #70
is_diff_cols[f"is_diff_{c1}"] = bool_to_int(IsDistinctFrom(expr_a, expr_b))
a_cols[f"{c1}_a"] = NormalizeAsString(expr_a, table1._schema[c1])
b_cols[f"{c2}_b"] = NormalizeAsString(expr_b, table2._schema[c2])

a_cols = {f"{c}_a": NormalizeAsString(a[c]) for c in cols1}
b_cols = {f"{c}_b": NormalizeAsString(b[c]) for c in cols2}
# Order columns as col1_a, col1_b, col2_a, col2_b, etc.
cols = {k: v for k, v in chain(*zip(a_cols.items(), b_cols.items()))}

Expand Down
39 changes: 31 additions & 8 deletions reladiff/table_segment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import time
from typing import List, Tuple
from typing import List, Tuple, Dict
import logging
from itertools import product

from runtype import dataclass
from dataclasses import field

from .utils import safezip, Vector
from sqeleton.utils import ArithString, split_space
from sqeleton.databases import Database, DbPath, DbKey, DbTime
from sqeleton.abcs.database_types import String_UUID
from sqeleton.schema import Schema, create_schema
from sqeleton.queries import Count, Checksum, SKIP, table, this, Expr, min_, max_, Code
from sqeleton.queries.ast_classes import BinBoolOp
from sqeleton.queries.extras import ApplyFuncAndNormalizeAsString, NormalizeAsString

logger = logging.getLogger("table_segment")
Expand Down Expand Up @@ -98,6 +100,11 @@ class TableSegment:
update_column (str, optional): Name of updated column, which signals that rows changed.
Usually updated_at or last_update. Used by `min_update` and `max_update`.
extra_columns (Tuple[str, ...], optional): Extra columns to compare
transform_columns (Dict[str, str], optional): A dictionary mapping column names to SQL transformation expressions.
These expressions are applied directly to the specified columns within the
comparison query, *before* the data is hashed or compared. Useful for
on-the-fly normalization (e.g., type casting, timezone conversions) without
requiring intermediate views or staging tables. Defaults to an empty dict.
min_key (:data:`Vector`, optional): Lowest key value, used to restrict the segment
max_key (:data:`Vector`, optional): Highest key value, used to restrict the segment
min_update (:data:`DbTime`, optional): Lowest update_column value, used to restrict the segment
Expand All @@ -116,6 +123,7 @@ class TableSegment:
key_columns: Tuple[str, ...]
update_column: str = None
extra_columns: Tuple[str, ...] = ()
transform_columns: Dict[str, str] = field(default_factory=dict)

# Restrict the segment
min_key: Vector = None
Expand Down Expand Up @@ -155,7 +163,7 @@ def _with_raw_schema(self, raw_schema: dict, refine: bool = True, allow_empty_ta
if is_empty_table and not allow_empty_table:
raise EmptyTable(f"Table {self.table_path} is empty. Use --allow-empty-tables to disable this protection.", self)

res = self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
res = self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive), transform_columns = self.transform_columns)

return EmptyTableSegment(res) if is_empty_table else res

Expand All @@ -167,7 +175,7 @@ def with_schema(self, refine: bool = True, allow_empty_table: bool = False) -> "
return self._with_raw_schema(
self.database.query_table_schema(self.table_path), refine=refine, allow_empty_table=allow_empty_table
)

def _cast_col_value(self, col, value):
"""Cast the value to the right type, based on the type of the column

Expand All @@ -179,15 +187,27 @@ def _cast_col_value(self, col, value):
return str(value)
return value

def _get_column_transforms(self, col_name: str, aliased_col=None) -> Expr:
"""Get the Column Expression from the Transform Rules, if the column is present
For hashdiff - aliased_col will be None
For joindiff - aliased_col will be the aliased column name
"""
transform_expr = self.transform_columns.get(col_name)

if aliased_col:
return Code(transform_expr.replace(col_name, aliased_col)) if transform_expr else None

return Code(transform_expr) if transform_expr else this[col_name]

def _make_key_range(self):
if self.min_key is not None:
for mn, k in safezip(self.min_key, self.key_columns):
mn = self._cast_col_value(k, mn)
yield mn <= this[k]
yield BinBoolOp(">=", [self._get_column_transforms(k), mn])
if self.max_key is not None:
for k, mx in safezip(self.key_columns, self.max_key):
mx = self._cast_col_value(k, mx)
yield this[k] < mx
yield BinBoolOp("<", [self._get_column_transforms(k), mx])

def _make_update_range(self):
if self.min_update is not None:
Expand Down Expand Up @@ -250,7 +270,10 @@ def relevant_columns(self) -> List[str]:

@property
def _relevant_columns_repr(self) -> List[Expr]:
return [NormalizeAsString(this[c]) for c in self.relevant_columns]
expressions = []
for c in self.relevant_columns:
expressions.append(NormalizeAsString(self._get_column_transforms(c), self._schema[c]))
return expressions

def count(self) -> int:
"""Count how many rows are in the segment, in one pass."""
Expand All @@ -277,7 +300,7 @@ def query_key_range(self) -> Tuple[tuple, tuple]:
"""Query database for minimum and maximum key. This is used for setting the initial bounds."""
# Normalizes the result (needed for UUIDs) after the min/max computation
select = self.make_select().select(
ApplyFuncAndNormalizeAsString(this[k], f) for k in self.key_columns for f in (min_, max_)
ApplyFuncAndNormalizeAsString(self._get_column_transforms(k), f) for k in self.key_columns for f in (min_, max_)
)
result = tuple(self.database.query(select, tuple))

Expand Down Expand Up @@ -330,7 +353,7 @@ def count_and_checksum(self) -> Tuple[int, int]:
return (0, None)

def __getattr__(self, attr):
assert attr in ("database", "key_columns", "key_types", "relevant_columns", "_schema")
assert attr in ("database", "key_columns", "key_types", "relevant_columns", "_schema", "transform_columns", "_get_column_transforms")
return getattr(self._table_segment, attr)

@property
Expand Down