Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/_gettsim/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ def remove_group_suffix(col):


def join_numpy(
foreign_key: np.ndarray[Key], primary_key: np.ndarray[Key], target: np.ndarray[Out]
foreign_key: np.ndarray[Key],
primary_key: np.ndarray[Key],
target: np.ndarray[Out],
value_for_unresolved_foreign_key: Out,
) -> np.ndarray[Out]:
"""
Given a foreign key, find the corresponding primary key, and return the target at
Expand All @@ -284,6 +287,8 @@ def join_numpy(
The primary keys.
target : np.ndarray[Out]
The targets in the same order as the primary keys.
value_for_unresolved_foreign_key : Out
The value to return if no matching primary key is found.

Returns
-------
Expand All @@ -295,10 +300,29 @@ def join_numpy(
duplicate_primary_keys = keys[counts > 1]
raise ValueError(f"Duplicate primary keys: {duplicate_primary_keys}")

try:
sorter = np.argsort(primary_key)
idx = np.searchsorted(primary_key, foreign_key, sorter=sorter)
return target[idx]
except IndexError as e:
invalid_foreign_keys = foreign_key[~np.isin(foreign_key, primary_key)]
raise ValueError(f"Invalid foreign keys: {invalid_foreign_keys}") from e
invalid_foreign_keys = foreign_key[
(foreign_key >= 0) & (~np.isin(foreign_key, primary_key))
]
print(invalid_foreign_keys)
if len(invalid_foreign_keys) > 0:
raise ValueError(f"Invalid foreign keys: {invalid_foreign_keys}")

# For each foreign key and for each primary key, check if they match
matches_foreign_key = foreign_key[:, None] == primary_key

# For each foreign key, add a column with True at the end, to later fall back to
# the value for unresolved foreign keys
padded_matches_foreign_key = np.pad(
matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True
)

# For each foreign key, compute the index of the first matching primary key
indices = np.argmax(padded_matches_foreign_key, axis=1)

# Add the value for unresolved foreign keys at the end of the target array
padded_targets = np.pad(
target, (0, 1), "constant", constant_values=value_for_unresolved_foreign_key
)

# Return the target at the index of the first matching primary key
return padded_targets.take(indices)
27 changes: 23 additions & 4 deletions src/_gettsim_tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,61 @@


@pytest.mark.parametrize(
"foreign_key, primary_key, target, expected",
"foreign_key, primary_key, target, value_for_unresolved_foreign_key, expected",
[
(
np.array([1, 2, 3]),
np.array([1, 2, 3]),
np.array(["a", "b", "c"]),
"d",
np.array(["a", "b", "c"]),
),
(
np.array([3, 2, 1]),
np.array([1, 2, 3]),
np.array(["a", "b", "c"]),
"d",
np.array(["c", "b", "a"]),
),
(
np.array([1, 1, 1]),
np.array([1, 2, 3]),
np.array(["a", "b", "c"]),
"d",
np.array(["a", "a", "a"]),
),
(
np.array([-1]),
np.array([1]),
np.array(["a"]),
"d",
np.array(["d"]),
),
],
)
def test_join_numpy(
foreign_key: np.ndarray[int],
primary_key: np.ndarray[int],
target: np.ndarray[str],
value_for_unresolved_foreign_key: str,
expected: np.ndarray[str],
):
assert np.array_equal(join_numpy(foreign_key, primary_key, target), expected)
assert np.array_equal(
join_numpy(foreign_key, primary_key, target, value_for_unresolved_foreign_key),
expected,
)


def test_join_numpy_raises_duplicate_primary_key():
with pytest.raises(ValueError, match="Duplicate primary keys:"):
join_numpy(np.array([1, 1, 1]), np.array([1, 1, 1]), np.array(["a", "b", "c"]))
join_numpy(
np.array([1, 1, 1]),
np.array([1, 1, 1]),
np.array(["a", "b", "c"]),
"default",
)


def test_join_numpy_raises_invalid_foreign_key():
with pytest.raises(ValueError, match="Invalid foreign keys:"):
join_numpy(np.array([2]), np.array([1]), np.array(["a"]))
join_numpy(np.array([2]), np.array([1]), np.array(["a"]), "d")