From 6bc405d26ddbacc34afcebb8ed33024d9745bc34 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 30 Apr 2024 17:25:46 +0200 Subject: [PATCH 1/5] fix: fall back to default value if foreign key is unresolved --- src/_gettsim/shared.py | 31 ++++++++++++++++++++------- src/_gettsim_tests/test_join.py | 37 ++++++++++++++++++++++++++------- 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index b058a84731..1258c70f81 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -270,7 +270,10 @@ def remove_group_suffix(col): def join_numpy( - foreign_key: np.ndarray[Key], primary_key: np.ndarray[Key], target: np.ndarray[Out] + foreign_key: np.ndarray[Key], + primary_key: np.ndarray[Key], + target: np.ndarray[Out], + default: Out, ) -> np.ndarray[Out]: """ Given a foreign key, find the corresponding primary key, and return the target at @@ -284,6 +287,8 @@ def join_numpy( The primary keys. target : np.ndarray[Out] The targets in the same order as the primary keys. + default : Out + The default value to return if no match is found. Returns ------- @@ -295,10 +300,20 @@ def join_numpy( duplicate_primary_keys = keys[counts > 1] raise ValueError(f"Duplicate primary keys: {duplicate_primary_keys}") - try: - sorter = np.argsort(primary_key) - idx = np.searchsorted(primary_key, foreign_key, sorter=sorter) - return target[idx] - except IndexError as e: - invalid_foreign_keys = foreign_key[~np.isin(foreign_key, primary_key)] - raise ValueError(f"Invalid foreign keys: {invalid_foreign_keys}") from e + # For each foreign key and for each primary key, check if they match + matches_foreign_key = foreign_key[:, None] == primary_key + + # For each foreign key, add a column with True at the end, to later fall back to + # the default + padded_matches_foreign_key = np.pad( + matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True + ) + + # For ech foreign key, compute the index of the first matching primary key + indices = np.argmax(padded_matches_foreign_key, axis=1) + + # Add the default target at the end + padded_targets = np.pad(target, (0, 1), "constant", constant_values=default) + + # Return the target at the index of the first matching primary key + return padded_targets.take(indices) diff --git a/src/_gettsim_tests/test_join.py b/src/_gettsim_tests/test_join.py index e7c876c6e5..ec91ad744d 100644 --- a/src/_gettsim_tests/test_join.py +++ b/src/_gettsim_tests/test_join.py @@ -4,42 +4,63 @@ @pytest.mark.parametrize( - "foreign_key, primary_key, target, expected", + "foreign_key, primary_key, target, default, expected", [ ( np.array([1, 2, 3]), np.array([1, 2, 3]), np.array(["a", "b", "c"]), + "d", np.array(["a", "b", "c"]), ), ( np.array([3, 2, 1]), np.array([1, 2, 3]), np.array(["a", "b", "c"]), + "d", np.array(["c", "b", "a"]), ), ( np.array([1, 1, 1]), np.array([1, 2, 3]), np.array(["a", "b", "c"]), + "d", np.array(["a", "a", "a"]), ), + ( + np.array([-1]), + np.array([1]), + np.array(["a"]), + "d", + np.array(["d"]), + ), + ( + np.array([2]), + np.array([1]), + np.array(["a"]), + "d", + np.array(["d"]), + ), ], ) def test_join_numpy( foreign_key: np.ndarray[int], primary_key: np.ndarray[int], target: np.ndarray[str], + default: str, expected: np.ndarray[str], ): - assert np.array_equal(join_numpy(foreign_key, primary_key, target), expected) + assert np.array_equal( + join_numpy(foreign_key, primary_key, target, default), + expected, + ) def test_join_numpy_raises_duplicate_primary_key(): with pytest.raises(ValueError, match="Duplicate primary keys:"): - join_numpy(np.array([1, 1, 1]), np.array([1, 1, 1]), np.array(["a", "b", "c"])) - - -def test_join_numpy_raises_invalid_foreign_key(): - with pytest.raises(ValueError, match="Invalid foreign keys:"): - join_numpy(np.array([2]), np.array([1]), np.array(["a"])) + join_numpy( + np.array([1, 1, 1]), + np.array([1, 1, 1]), + np.array(["a", "b", "c"]), + "default", + ) From ae3c2af8b0ca3fa9164e136c02e8fdba306f8f03 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 30 Apr 2024 17:39:42 +0200 Subject: [PATCH 2/5] feat: check validity of non-negative foreign keys --- src/_gettsim/shared.py | 7 +++++++ src/_gettsim_tests/test_join.py | 12 +++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index 1258c70f81..11ae0abbf2 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -300,6 +300,13 @@ def join_numpy( duplicate_primary_keys = keys[counts > 1] raise ValueError(f"Duplicate primary keys: {duplicate_primary_keys}") + invalid_foreign_keys = foreign_key[ + (foreign_key >= 0) & (~np.isin(foreign_key, primary_key)) + ] + print(invalid_foreign_keys) + if len(invalid_foreign_keys) > 0: + raise ValueError(f"Invalid foreign keys: {invalid_foreign_keys}") + # For each foreign key and for each primary key, check if they match matches_foreign_key = foreign_key[:, None] == primary_key diff --git a/src/_gettsim_tests/test_join.py b/src/_gettsim_tests/test_join.py index ec91ad744d..28f5704974 100644 --- a/src/_gettsim_tests/test_join.py +++ b/src/_gettsim_tests/test_join.py @@ -34,13 +34,6 @@ "d", np.array(["d"]), ), - ( - np.array([2]), - np.array([1]), - np.array(["a"]), - "d", - np.array(["d"]), - ), ], ) def test_join_numpy( @@ -64,3 +57,8 @@ def test_join_numpy_raises_duplicate_primary_key(): np.array(["a", "b", "c"]), "default", ) + + +def test_join_numpy_raises_invalid_foreign_key(): + with pytest.raises(ValueError, match="Invalid foreign keys:"): + join_numpy(np.array([2]), np.array([1]), np.array(["a"]), "d") From 867d6a8e8c87d7140044fc1a7ed04c7754af452a Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 30 Apr 2024 18:37:33 +0200 Subject: [PATCH 3/5] refactor: rename new parameter --- src/_gettsim/shared.py | 16 +++++++++------- src/_gettsim_tests/test_join.py | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index 11ae0abbf2..3600f4e6f8 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -273,7 +273,7 @@ def join_numpy( foreign_key: np.ndarray[Key], primary_key: np.ndarray[Key], target: np.ndarray[Out], - default: Out, + value_for_unresolved_foreign_key: Out, ) -> np.ndarray[Out]: """ Given a foreign key, find the corresponding primary key, and return the target at @@ -287,8 +287,8 @@ def join_numpy( The primary keys. target : np.ndarray[Out] The targets in the same order as the primary keys. - default : Out - The default value to return if no match is found. + value_for_unresolved_foreign_key : Out + The value to return if no matching primary key is found. Returns ------- @@ -311,16 +311,18 @@ def join_numpy( matches_foreign_key = foreign_key[:, None] == primary_key # For each foreign key, add a column with True at the end, to later fall back to - # the default + # the value for unresolved foreign keys padded_matches_foreign_key = np.pad( matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True ) - # For ech foreign key, compute the index of the first matching primary key + # For each foreign key, compute the index of the first matching primary key indices = np.argmax(padded_matches_foreign_key, axis=1) - # Add the default target at the end - padded_targets = np.pad(target, (0, 1), "constant", constant_values=default) + # Add the value for unresolved foreign keys at the end of the target array + padded_targets = np.pad( + target, (0, 1), "constant", constant_values=value_for_unresolved_foreign_key + ) # Return the target at the index of the first matching primary key return padded_targets.take(indices) diff --git a/src/_gettsim_tests/test_join.py b/src/_gettsim_tests/test_join.py index 28f5704974..cd37653bf0 100644 --- a/src/_gettsim_tests/test_join.py +++ b/src/_gettsim_tests/test_join.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize( - "foreign_key, primary_key, target, default, expected", + "foreign_key, primary_key, target, value_for_unresolved_foreign_key, expected", [ ( np.array([1, 2, 3]), @@ -40,11 +40,11 @@ def test_join_numpy( foreign_key: np.ndarray[int], primary_key: np.ndarray[int], target: np.ndarray[str], - default: str, + value_for_unresolved_foreign_key: str, expected: np.ndarray[str], ): assert np.array_equal( - join_numpy(foreign_key, primary_key, target, default), + join_numpy(foreign_key, primary_key, target, value_for_unresolved_foreign_key), expected, ) From 84ca429b0da007de5304bed72bdc3642d306eb21 Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Tue, 30 Apr 2024 19:00:35 +0200 Subject: [PATCH 4/5] Rename argument and specify it for Unterhaltsvorschuss. --- src/_gettsim/shared.py | 6 +++--- src/_gettsim/transfers/unterhaltsvors.py | 5 ++++- src/_gettsim_tests/test_join.py | 6 +++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index 3600f4e6f8..6376b03623 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -273,7 +273,7 @@ def join_numpy( foreign_key: np.ndarray[Key], primary_key: np.ndarray[Key], target: np.ndarray[Out], - value_for_unresolved_foreign_key: Out, + value_if_foreign_key_is_missing: Out, ) -> np.ndarray[Out]: """ Given a foreign key, find the corresponding primary key, and return the target at @@ -287,7 +287,7 @@ def join_numpy( The primary keys. target : np.ndarray[Out] The targets in the same order as the primary keys. - value_for_unresolved_foreign_key : Out + value_if_foreign_key_is_missing : Out The value to return if no matching primary key is found. Returns @@ -321,7 +321,7 @@ def join_numpy( # Add the value for unresolved foreign keys at the end of the target array padded_targets = np.pad( - target, (0, 1), "constant", constant_values=value_for_unresolved_foreign_key + target, (0, 1), "constant", constant_values=value_if_foreign_key_is_missing ) # Return the target at the index of the first matching primary key diff --git a/src/_gettsim/transfers/unterhaltsvors.py b/src/_gettsim/transfers/unterhaltsvors.py index 55ce5c022e..0b45883fbe 100644 --- a/src/_gettsim/transfers/unterhaltsvors.py +++ b/src/_gettsim/transfers/unterhaltsvors.py @@ -172,7 +172,10 @@ def _unterhaltsvorschuss_empf_eink_above_income_threshold( ------- """ return join_numpy( - p_id_kindergeld_empf, p_id, _unterhaltsvorschuss_eink_above_income_threshold + p_id_kindergeld_empf, + p_id, + _unterhaltsvorschuss_eink_above_income_threshold, + value_if_foreign_key_is_missing=False, ) diff --git a/src/_gettsim_tests/test_join.py b/src/_gettsim_tests/test_join.py index cd37653bf0..ebeade1130 100644 --- a/src/_gettsim_tests/test_join.py +++ b/src/_gettsim_tests/test_join.py @@ -4,7 +4,7 @@ @pytest.mark.parametrize( - "foreign_key, primary_key, target, value_for_unresolved_foreign_key, expected", + "foreign_key, primary_key, target, value_if_foreign_key_is_missing, expected", [ ( np.array([1, 2, 3]), @@ -40,11 +40,11 @@ def test_join_numpy( foreign_key: np.ndarray[int], primary_key: np.ndarray[int], target: np.ndarray[str], - value_for_unresolved_foreign_key: str, + value_if_foreign_key_is_missing: str, expected: np.ndarray[str], ): assert np.array_equal( - join_numpy(foreign_key, primary_key, target, value_for_unresolved_foreign_key), + join_numpy(foreign_key, primary_key, target, value_if_foreign_key_is_missing), expected, ) From 42c5ecac423da9d85bcb8b843a70ec3307accd2a Mon Sep 17 00:00:00 2001 From: Marvin Immesberger Date: Tue, 30 Apr 2024 19:21:12 +0200 Subject: [PATCH 5/5] Remove print statement. --- src/_gettsim/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_gettsim/shared.py b/src/_gettsim/shared.py index 6376b03623..e90a7b53fb 100644 --- a/src/_gettsim/shared.py +++ b/src/_gettsim/shared.py @@ -303,7 +303,7 @@ def join_numpy( invalid_foreign_keys = foreign_key[ (foreign_key >= 0) & (~np.isin(foreign_key, primary_key)) ] - print(invalid_foreign_keys) + if len(invalid_foreign_keys) > 0: raise ValueError(f"Invalid foreign keys: {invalid_foreign_keys}")