Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e37116f
Make id functions jittable
mj023 May 12, 2025
18bd688
Fix fg_id error
mj023 May 12, 2025
bf47bb7
Make METTSIM group creation jittable
mj023 May 13, 2025
f0608e9
Fix Bugs + Adjust Test to new IDs
mj023 May 13, 2025
a23ba8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2025
ebee806
Make order of p_ids irrelevant
mj023 May 14, 2025
8c9cdf5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2025
4fbb227
Fix wrong test cases
mj023 May 14, 2025
70f3ff8
Remove exception test
mj023 May 14, 2025
d0d103a
Make p_id irrelevant in Mettsim
mj023 May 14, 2025
ca4b11e
Stop using self created p_ids
mj023 May 14, 2025
894dfe5
Remove limit on p_id value
mj023 May 14, 2025
b37cb98
Fix remaining tests
mj023 May 14, 2025
813b050
Merge branch 'collect-unify-parsing-of-params' into groupings_jax
hmgaudecker May 16, 2025
4f54a5d
Make Ids consecutive numbers
mj023 May 19, 2025
a31b845
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2025
4b572f0
Update Tests
mj023 May 19, 2025
c30c885
Merge branch 'groupings_jax' of https://github.com/iza-institute-of-l…
mj023 May 19, 2025
50cedb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2025
b90c150
Remove ID precomputation in tests
mj023 May 20, 2025
a4262ac
Actually Jit the tests
mj023 May 21, 2025
c562cc2
Merge branch 'move-gettsim-params-files' into groupings_jax
hmgaudecker May 22, 2025
02b7262
Revert fake change.
hmgaudecker May 22, 2025
df6421d
Revert Tests, Add comments
mj023 May 22, 2025
4a3637b
Revert missed test
mj023 May 22, 2025
2614701
Make functions private
mj023 May 22, 2025
81a6979
Re-order.
hmgaudecker May 22, 2025
c7565be
Skip tests that fail because of pre-defined group ids, created #924 f…
hmgaudecker May 22, 2025
b93f706
Small renamings; | easier to read than + for Booleans.
hmgaudecker May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 90 additions & 190 deletions src/_gettsim/ids.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Input columns."""

from collections import Counter

from ttsim import group_creation_function, policy_input
from ttsim.config import numpy_or_jax as np

Expand All @@ -22,27 +20,20 @@ def ehe_id(
familie__p_id_ehepartner: np.ndarray,
) -> np.ndarray:
"""Couples that are either married or in a civil union."""
p_id_to_ehe_id: dict[int, int] = {}
next_ehe_id = 0
result: list[int] = []

for index, current_p_id in enumerate(map(int, p_id)):
current_p_id_ehepartner = int(familie__p_id_ehepartner[index])

if current_p_id_ehepartner >= 0 and current_p_id_ehepartner in p_id_to_ehe_id:
result.append(p_id_to_ehe_id[current_p_id_ehepartner])
continue
n = np.max(p_id) + 1
familie__p_id_ehepartner = np.where(
familie__p_id_ehepartner < 0, p_id, familie__p_id_ehepartner
)
result = (
np.maximum(p_id, familie__p_id_ehepartner)
+ np.minimum(p_id, familie__p_id_ehepartner) * n
)

# New married couple
result.append(next_ehe_id)
p_id_to_ehe_id[current_p_id] = next_ehe_id
next_ehe_id += 1

return np.array(result)
return result


@group_creation_function()
def fg_id( # noqa: PLR0912
def fg_id(
arbeitslosengeld_2__p_id_einstandspartner: np.ndarray,
p_id: np.ndarray,
hh_id: np.ndarray,
Expand All @@ -55,103 +46,53 @@ def fg_id( # noqa: PLR0912
Maximum of two generations, the relevant base unit for Bürgergeld / Arbeitslosengeld
2, before excluding children who have enough income fend for themselves.
"""
# Build indexes
p_id_to_index: dict[int, int] = {}
p_id_to_p_ids_children: dict[int, list[int]] = {}

for index, current_p_id in enumerate(map(int, p_id)):
# Fast access from p_id to index
p_id_to_index[current_p_id] = index

# Fast access from p_id to p_ids of children
current_familie__p_id_elternteil_1 = int(familie__p_id_elternteil_1[index])
current_familie__p_id_elternteil_2 = int(familie__p_id_elternteil_2[index])

if current_familie__p_id_elternteil_1 >= 0:
if current_familie__p_id_elternteil_1 not in p_id_to_p_ids_children:
p_id_to_p_ids_children[current_familie__p_id_elternteil_1] = []
p_id_to_p_ids_children[current_familie__p_id_elternteil_1].append(
current_p_id
)

if current_familie__p_id_elternteil_2 >= 0:
if current_familie__p_id_elternteil_2 not in p_id_to_p_ids_children:
p_id_to_p_ids_children[current_familie__p_id_elternteil_2] = []
p_id_to_p_ids_children[current_familie__p_id_elternteil_2].append(
current_p_id
)
n = np.max(p_id) + 1

p_id_to_fg_id = {}
next_fg_id = 0

for index, current_p_id in enumerate(map(int, p_id)):
# Already assigned a fg_id to this p_id via einstandspartner /
# parent
if current_p_id in p_id_to_fg_id:
continue

p_id_to_fg_id[current_p_id] = next_fg_id

current_hh_id = int(hh_id[index])
current_p_id_einstandspartner = int(
arbeitslosengeld_2__p_id_einstandspartner[index]
familie__p_id_elternteil_1_loc = familie__p_id_elternteil_1
familie__p_id_elternteil_2_loc = familie__p_id_elternteil_2
for i in range(p_id.shape[0]):
familie__p_id_elternteil_1_loc = np.where(
familie__p_id_elternteil_1 == p_id[i], i, familie__p_id_elternteil_1_loc
)
familie__p_id_elternteil_2_loc = np.where(
familie__p_id_elternteil_2 == p_id[i], i, familie__p_id_elternteil_2_loc
)
current_p_id_children = p_id_to_p_ids_children.get(current_p_id, [])

# Assign fg to children
for current_p_id_child in current_p_id_children:
child_index = p_id_to_index[current_p_id_child]
child_hh_id = int(hh_id[child_index])
child_alter = int(alter[child_index])
child_p_id_children = p_id_to_p_ids_children.get(current_p_id_child, [])

if (
child_hh_id == current_hh_id
# TODO (@MImmesberger): Check correct conditions for grown up children
# https://github.com/iza-institute-of-labor-economics/gettsim/pull/509
# TODO(@MImmesberger): Remove hard-coded number
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/668
and child_alter < 25
and len(child_p_id_children) == 0
):
p_id_to_fg_id[current_p_id_child] = next_fg_id

# Assign fg to einstandspartner
if current_p_id_einstandspartner >= 0:
p_id_to_fg_id[current_p_id_einstandspartner] = next_fg_id
current_p_id_einstandspartner_children = p_id_to_p_ids_children.get(
current_p_id_einstandspartner, []
)
# Assign fg to children of einstandspartner
for current_p_id_child in current_p_id_einstandspartner_children:
if current_p_id_child in p_id_to_fg_id:
continue
child_index = p_id_to_index[current_p_id_child]
child_hh_id = int(hh_id[child_index])
child_alter = int(alter[child_index])
child_p_id_children = p_id_to_p_ids_children.get(current_p_id_child, [])

if (
child_hh_id == current_hh_id
# TODO (@MImmesberger): Check correct conditions for grown up children
# https://github.com/iza-institute-of-labor-economics/gettsim/pull/509
# TODO(@MImmesberger): Remove hard-coded number
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/668
and child_alter < 25
and len(child_p_id_children) == 0
):
p_id_to_fg_id[current_p_id_child] = next_fg_id

next_fg_id += 1

# Compute result vector
result = [p_id_to_fg_id[current_p_id] for current_p_id in map(int, p_id)]
return np.array(result)
children = np.isin(p_id, familie__p_id_elternteil_1) + np.isin(
p_id, familie__p_id_elternteil_2
)
fg_id = np.where(
arbeitslosengeld_2__p_id_einstandspartner < 0,
p_id + p_id * n,
np.maximum(p_id, arbeitslosengeld_2__p_id_einstandspartner)
+ np.minimum(p_id, arbeitslosengeld_2__p_id_einstandspartner) * n,
)
fg_id = np.where(
(familie__p_id_elternteil_1_loc >= 0)
* (fg_id == p_id + p_id * n)
* (hh_id == hh_id[familie__p_id_elternteil_1_loc])
* (alter < 25)
* (1 - children),
fg_id[familie__p_id_elternteil_1_loc],
fg_id,
)
fg_id = np.where(
(familie__p_id_elternteil_2_loc >= 0)
* (fg_id == p_id + p_id * n)
* (hh_id == hh_id[familie__p_id_elternteil_2_loc])
* (alter < 25)
* (1 - children),
fg_id[familie__p_id_elternteil_2_loc],
fg_id,
)

return fg_id


@group_creation_function()
def bg_id(
fg_id: np.ndarray,
p_id: np.ndarray,
arbeitslosengeld_2__eigenbedarf_gedeckt: np.ndarray,
alter: np.ndarray,
) -> np.ndarray:
Expand All @@ -163,23 +104,16 @@ def bg_id(
# TODO(@MImmesberger): Remove input variable eigenbedarf_gedeckt
# once Bedarfsgemeinschaften are fully endogenous
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/763
counter: Counter[int] = Counter()
result: list[int] = []

for index, current_fg_id in enumerate(map(int, fg_id)):
current_alter = int(alter[index])
current_eigenbedarf_gedeckt = bool(
arbeitslosengeld_2__eigenbedarf_gedeckt[index]
)
# TODO(@MImmesberger): Remove hard-coded number
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/668
if current_alter < 25 and current_eigenbedarf_gedeckt:
counter[current_fg_id] += 1
result.append(current_fg_id * 100 + counter[current_fg_id])
else:
result.append(current_fg_id * 100)

return np.array(result)
# TODO(@MImmesberger): Remove hard-coded number
# https://github.com/iza-institute-of-labor-economics/gettsim/issues/668
n = np.max(p_id) + 1
hh_id = np.where(
np.logical_and(arbeitslosengeld_2__eigenbedarf_gedeckt, alter < 25),
p_id + p_id * n,
fg_id,
)
return hh_id


@group_creation_function()
Expand All @@ -191,26 +125,16 @@ def eg_id(

A couple whose members are deemed to be responsible for each other.
"""
p_id_to_eg_id: dict[int, int] = {}
next_eg_id = 0
result: list[int] = []

for index, current_p_id in enumerate(map(int, p_id)):
current_p_id_einstandspartner = int(
arbeitslosengeld_2__p_id_einstandspartner[index]
)

if (
current_p_id_einstandspartner >= 0
and current_p_id_einstandspartner in p_id_to_eg_id
):
result.append(p_id_to_eg_id[current_p_id_einstandspartner])
continue

# New Einstandsgemeinschaft
result.append(next_eg_id)
p_id_to_eg_id[current_p_id] = next_eg_id
next_eg_id += 1
n = np.max(p_id) + 1
arbeitslosengeld_2__p_id_einstandspartner = np.where(
arbeitslosengeld_2__p_id_einstandspartner < 0,
p_id,
arbeitslosengeld_2__p_id_einstandspartner,
)
result = (
np.maximum(p_id, arbeitslosengeld_2__p_id_einstandspartner)
+ np.minimum(p_id, arbeitslosengeld_2__p_id_einstandspartner) * n
)

return np.array(result)

Expand All @@ -226,20 +150,14 @@ def wthh_id(
The relevant unit for Wohngeld. Members of a household for whom the Wohngeld
priority check compared to Bürgergeld yields the same result ∈ {True, False}.
"""
result: list[int] = []
for index, current_hh_id in enumerate(map(int, hh_id)):
if bool(
vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg[index]
) or bool(
vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg[
index
]
):
result.append(current_hh_id * 100 + 1)
else:
result.append(current_hh_id * 100)

return np.array(result)
offset = np.max(hh_id) + 1
hh_id = np.where(
vorrangprüfungen__wohngeld_vorrang_vor_arbeitslosengeld_2_bg
| vorrangprüfungen__wohngeld_und_kinderzuschlag_vorrang_vor_arbeitslosengeld_2_bg,
hh_id + offset,
hh_id,
)
return hh_id


@group_creation_function()
Expand All @@ -252,36 +170,18 @@ def sn_id(

Spouses filing taxes jointly or individuals.
"""
p_id_to_sn_id: dict[int, int] = {}
p_id_to_gemeinsam_veranlagt: dict[int, bool] = {}
next_sn_id = 0
result: list[int] = []

for index, current_p_id in enumerate(map(int, p_id)):
current_p_id_ehepartner = int(familie__p_id_ehepartner[index])
current_gemeinsam_veranlagt = bool(einkommensteuer__gemeinsam_veranlagt[index])

if current_p_id_ehepartner >= 0 and current_p_id_ehepartner in p_id_to_sn_id:
gemeinsam_veranlagt_ehepartner = p_id_to_gemeinsam_veranlagt[
current_p_id_ehepartner
]

if current_gemeinsam_veranlagt != gemeinsam_veranlagt_ehepartner:
message = (
f"{current_p_id_ehepartner} and {current_p_id} are "
"married, but have different values for "
"gemeinsam_veranlagt."
)
raise ValueError(message)

if current_gemeinsam_veranlagt:
result.append(p_id_to_sn_id[current_p_id_ehepartner])
continue

# New Steuersubjekt
result.append(next_sn_id)
p_id_to_sn_id[current_p_id] = next_sn_id
p_id_to_gemeinsam_veranlagt[current_p_id] = current_gemeinsam_veranlagt
next_sn_id += 1

return np.array(result)
n = np.max(p_id) + 1
familie__p_id_ehepartner = np.where(
np.logical_and(
familie__p_id_ehepartner >= 0, einkommensteuer__gemeinsam_veranlagt
),
familie__p_id_ehepartner,
p_id,
)
result = (
np.maximum(p_id, familie__p_id_ehepartner)
+ np.minimum(p_id, familie__p_id_ehepartner) * n
)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ inputs:
- false
outputs:
fg_id:
- 1
- 1
- 1
- 1
- 1
- 1
- 11
- 11
- 11
- 11
- 11
- 11
bg_id:
- 100
- 100
- 100
- 100
- 100
- 100
- 11
- 11
- 11
- 11
- 11
- 11
eg_id:
- 0
- 1
- 2
- 3
- 4
- 1
- 11
- 14
- 21
- 28
- 11
Loading
Loading