diff --git a/dataframely/functional.py b/dataframely/functional.py index 475bb67..89931ee 100644 --- a/dataframely/functional.py +++ b/dataframely/functional.py @@ -23,22 +23,43 @@ # --------------------------------- RELATIONSHIP 1:1 --------------------------------- # +LEN_COLUMN = "__len__" + def filter_relationship_one_to_one( lhs: LazyFrame[S] | pl.LazyFrame, rhs: LazyFrame[T] | pl.LazyFrame, /, on: str | list[str], + *, + keep_only_unique: bool = False, ) -> pl.LazyFrame: """Express a 1:1 mapping between data frames for a collection filter. Args: lhs: The first data frame in the 1:1 mapping. rhs: The second data frame in the 1:1 mapping. - on: The columns to join the data frames on. If not provided, the join columns - are inferred from the mutual primary keys of the provided data frames. + on: The columns to join the data frames on. + keep_only_unique: If `True`, drop non-unique rows from both data frames. + This is useful when the join columns do not already uniquely identify rows. """ - return lhs.join(rhs, on=on) + if keep_only_unique: + return ( + lhs.group_by(on) + .len(LEN_COLUMN) + .filter(pl.col(LEN_COLUMN) == 1) + .drop(LEN_COLUMN) + .join( + rhs.group_by(on) + .len(LEN_COLUMN) + .filter(pl.col(LEN_COLUMN) == 1) + .drop(LEN_COLUMN), + on=on, + how="inner", + ) + ) + else: + return lhs.join(rhs, on=on) # ------------------------------- RELATIONSHIP 1:{1,N} ------------------------------- # @@ -55,8 +76,7 @@ def filter_relationship_one_to_at_least_one( Args: lhs: The data frame with exactly one occurrence for a set of key columns. rhs: The data frame with at least one occurrence for a set of key columns. - on: The columns to join the data frames on. If not provided, the join columns - are inferred from the joint primary keys of the provided data frames. + on: The columns to join the data frames on. """ return lhs.join(rhs.unique(on), on=on) diff --git a/tests/functional/test_relationships.py b/tests/functional/test_relationships.py index b650e11..b72e250 100644 --- a/tests/functional/test_relationships.py +++ b/tests/functional/test_relationships.py @@ -29,7 +29,7 @@ class EmployeeSchema(dy.Schema): @pytest.fixture() def departments() -> dy.LazyFrame[DepartmentSchema]: - return DepartmentSchema.cast(pl.LazyFrame({"department_id": [1, 2]})) + return DepartmentSchema.cast(pl.LazyFrame({"department_id": [1, 2, 3]})) @pytest.fixture() @@ -44,9 +44,9 @@ def employees() -> dy.LazyFrame[EmployeeSchema]: return EmployeeSchema.cast( pl.LazyFrame( { - "department_id": [2, 2, 2], - "employee_number": [101, 102, 103], - "name": ["Huey", "Dewey", "Louie"], + "department_id": [2, 2, 2, 3], + "employee_number": [101, 102, 103, 104], + "name": ["Huey", "Dewey", "Louie", "Daisy"], } ) ) @@ -67,6 +67,19 @@ def test_one_to_one( assert actual.select("department_id").collect().to_series().to_list() == [1] +def test_one_to_one_keep_only_unique( + departments: dy.LazyFrame[DepartmentSchema], + employees: dy.LazyFrame[EmployeeSchema], +) -> None: + actual = dy.filter_relationship_one_to_one( + departments, + employees, + on="department_id", + keep_only_unique=True, + ) + assert actual.select("department_id").collect().to_series().to_list() == [3] + + def test_one_to_at_least_one( departments: dy.LazyFrame[DepartmentSchema], employees: dy.LazyFrame[EmployeeSchema], @@ -74,4 +87,4 @@ def test_one_to_at_least_one( actual = dy.filter_relationship_one_to_at_least_one( departments, employees, on="department_id" ) - assert actual.select("department_id").collect().to_series().to_list() == [2] + assert actual.select("department_id").collect().to_series().to_list() == [2, 3]