Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_gettsim/arbeitslosengeld_2/kindergeldübertrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy

from ttsim import AggType, agg_by_p_id_function, join_numpy, policy_function
from ttsim import AggType, agg_by_p_id_function, join, policy_function


@agg_by_p_id_function(agg_type=AggType.SUM)
Expand Down Expand Up @@ -92,7 +92,7 @@ def kindergeld_zur_bedarfsdeckung_m(
-------

"""
return join_numpy(
return join(
kindergeld__p_id_empfänger,
p_id,
kindergeld_pro_kind_m,
Expand Down
4 changes: 2 additions & 2 deletions src/_gettsim/kindergeld/kindergeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy

from ttsim import AggType, agg_by_p_id_function, join_numpy, policy_function
from ttsim import AggType, agg_by_p_id_function, join, policy_function


@agg_by_p_id_function(agg_type=AggType.SUM)
Expand Down Expand Up @@ -203,7 +203,7 @@ def gleiche_fg_wie_empfänger(
-------

"""
fg_id_kindergeldempfänger = join_numpy(
fg_id_kindergeldempfänger = join(
p_id_empfänger,
p_id,
fg_id,
Expand Down
6 changes: 3 additions & 3 deletions src/_gettsim/unterhaltsvorschuss/unterhaltsvorschuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
AggType,
RoundingSpec,
agg_by_p_id_function,
join_numpy,
join,
policy_function,
)

Expand Down Expand Up @@ -90,7 +90,7 @@ def elternteil_alleinerziehend(
-------

"""
return join_numpy(
return join(
foreign_key=kindergeld__p_id_empfänger,
primary_key=p_id,
target=familie__alleinerziehend,
Expand Down Expand Up @@ -363,7 +363,7 @@ def elternteil_mindesteinkommen_erreicht(
Returns
-------
"""
return join_numpy(
return join(
kindergeld__p_id_empfänger,
p_id,
mindesteinkommen_erreicht,
Expand Down
4 changes: 2 additions & 2 deletions src/ttsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ttsim.rounding import RoundingSpec
from ttsim.shared import (
insert_path_and_value,
join_numpy,
join,
merge_trees,
upsert_path_and_value,
upsert_tree,
Expand Down Expand Up @@ -58,7 +58,7 @@
"get_piecewise_parameters",
"group_creation_function",
"insert_path_and_value",
"join_numpy",
"join",
"load_objects_tree_for_date",
"merge_trees",
"piecewise_polynomial",
Expand Down
45 changes: 14 additions & 31 deletions src/ttsim/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import TYPE_CHECKING, Any, TypeVar

import dags.tree as dt
import numpy
import optree

from ttsim.config import numpy_or_jax as np

if TYPE_CHECKING:
from ttsim.ttsim_objects import PolicyFunction
from ttsim.typing import (
Expand Down Expand Up @@ -399,23 +400,23 @@ def remove_group_suffix(col, groupings):
Out: TypeVar = TypeVar("Out")


def join_numpy(
foreign_key: numpy.ndarray[Key],
primary_key: numpy.ndarray[Key],
target: numpy.ndarray[Out],
def join(
foreign_key: np.ndarray,
primary_key: np.ndarray,
target: np.ndarray,
value_if_foreign_key_is_missing: Out,
) -> numpy.ndarray[Out]:
) -> np.ndarray:
"""
Given a foreign key, find the corresponding primary key, and return the target at
the same index as the primary key.
the same index as the primary key. When using Jax, does not work on String Arrays.

Parameters
----------
foreign_key : numpy.ndarray[Key]
foreign_key : np.ndarray[Key]
The foreign keys.
primary_key : numpy.ndarray[Key]
primary_key : np.ndarray[Key]
The primary keys.
target : numpy.ndarray[Out]
target : np.ndarray[Out]
The targets in the same order as the primary keys.
value_if_foreign_key_is_missing : Out
The value to return if no matching primary key is found.
Expand All @@ -424,38 +425,20 @@ def join_numpy(
-------
The joined array.
"""
if len(numpy.unique(primary_key)) != len(primary_key):
keys, counts = numpy.unique(primary_key, return_counts=True)
duplicate_primary_keys = keys[counts > 1]
msg = format_errors_and_warnings(
f"Duplicate primary keys: {duplicate_primary_keys}",
)
raise ValueError(msg)

invalid_foreign_keys = foreign_key[
(foreign_key >= 0) & (~numpy.isin(foreign_key, primary_key))
]

if len(invalid_foreign_keys) > 0:
msg = format_errors_and_warnings(
f"Invalid foreign keys: {invalid_foreign_keys}",
)
raise ValueError(msg)

# For each foreign key and for each primary key, check if they match
matches_foreign_key = foreign_key[:, None] == primary_key

# For each foreign key, add a column with True at the end, to later fall back to
# the value for unresolved foreign keys
padded_matches_foreign_key = numpy.pad(
padded_matches_foreign_key = np.pad(
matches_foreign_key, ((0, 0), (0, 1)), "constant", constant_values=True
)

# For each foreign key, compute the index of the first matching primary key
indices = numpy.argmax(padded_matches_foreign_key, axis=1)
indices = np.argmax(padded_matches_foreign_key, axis=1)

# Add the value for unresolved foreign keys at the end of the target array
padded_targets = numpy.pad(
padded_targets = np.pad(
target, (0, 1), "constant", constant_values=value_if_foreign_key_is_missing
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ttsim import (
AggType,
agg_by_p_id_function,
join_numpy,
join,
policy_function,
)

Expand Down Expand Up @@ -45,7 +45,7 @@ def in_same_household_as_recipient(
p_id_recipient: int,
) -> bool:
return (
join_numpy(
join(
foreign_key=p_id_recipient,
primary_key=p_id,
target=hh_id,
Expand Down
75 changes: 30 additions & 45 deletions tests/ttsim/test_join.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,50 @@
import numpy
import pytest

from ttsim.shared import join_numpy
from ttsim.config import numpy_or_jax as np
from ttsim.shared import join


@pytest.mark.parametrize(
"foreign_key, primary_key, target, value_if_foreign_key_is_missing, expected",
[
(
numpy.array([1, 2, 3]),
numpy.array([1, 2, 3]),
numpy.array(["a", "b", "c"]),
"d",
numpy.array(["a", "b", "c"]),
np.array([1, 2, 3]),
np.array([1, 2, 3]),
np.array([1, 2, 3]),
4,
np.array([1, 2, 3]),
),
(
numpy.array([3, 2, 1]),
numpy.array([1, 2, 3]),
numpy.array(["a", "b", "c"]),
"d",
numpy.array(["c", "b", "a"]),
np.array([3, 2, 1]),
np.array([1, 2, 3]),
np.array([1, 2, 3]),
4,
np.array([3, 2, 1]),
),
(
numpy.array([1, 1, 1]),
numpy.array([1, 2, 3]),
numpy.array(["a", "b", "c"]),
"d",
numpy.array(["a", "a", "a"]),
np.array([1, 1, 1]),
np.array([1, 2, 3]),
np.array([1, 2, 3]),
4,
np.array([1, 1, 1]),
),
(
numpy.array([-1]),
numpy.array([1]),
numpy.array(["a"]),
"d",
numpy.array(["d"]),
np.array([-1]),
np.array([1]),
np.array([1]),
4,
np.array([4]),
),
],
)
def test_join_numpy(
foreign_key: numpy.ndarray[int],
primary_key: numpy.ndarray[int],
target: numpy.ndarray[str],
value_if_foreign_key_is_missing: str,
expected: numpy.ndarray[str],
def test_join(
foreign_key: np.ndarray,
primary_key: np.ndarray,
target: np.ndarray,
value_if_foreign_key_is_missing: int,
expected: np.ndarray,
):
assert numpy.array_equal(
join_numpy(foreign_key, primary_key, target, value_if_foreign_key_is_missing),
assert np.array_equal(
join(foreign_key, primary_key, target, value_if_foreign_key_is_missing),
expected,
)


def test_join_numpy_raises_duplicate_primary_key():
with pytest.raises(ValueError, match="Duplicate primary keys:"):
join_numpy(
numpy.array([1, 1, 1]),
numpy.array([1, 1, 1]),
numpy.array(["a", "b", "c"]),
"default",
)


def test_join_numpy_raises_invalid_foreign_key():
with pytest.raises(ValueError, match="Invalid foreign keys:"):
join_numpy(numpy.array([2]), numpy.array([1]), numpy.array(["a"]), "d")
Loading