Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 17 additions & 17 deletions src/_gettsim/combine_functions_in_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
format_errors_and_warnings,
format_list_linewise,
get_names_of_arguments_without_defaults,
insert_path_and_value,
partition_tree_by_reference_tree,
remove_group_suffix,
rename_arguments_and_add_annotations,
upsert_path_and_value,
upsert_tree,
)
from _gettsim.time_conversion import create_time_conversion_functions
Expand Down Expand Up @@ -87,8 +87,8 @@ def combine_policy_functions_and_derived_functions(
aggregation_type="p_id",
)
current_functions_tree = upsert_tree(
base=environment.functions_tree,
to_upsert=aggregate_by_p_id_functions,
base=aggregate_by_p_id_functions,
to_upsert=environment.functions_tree,
)

# Create functions for different time units
Expand All @@ -97,8 +97,8 @@ def combine_policy_functions_and_derived_functions(
data_tree=data_tree,
)
current_functions_tree = upsert_tree(
base=current_functions_tree,
to_upsert=time_conversion_functions,
base=time_conversion_functions,
to_upsert=current_functions_tree,
)

# Create aggregation functions
Expand All @@ -109,15 +109,15 @@ def combine_policy_functions_and_derived_functions(
aggregations_tree_provided_by_env=environment.aggregation_specs_tree,
)
current_functions_tree = upsert_tree(
base=current_functions_tree,
to_upsert=aggregate_by_group_functions,
base=aggregate_by_group_functions,
to_upsert=current_functions_tree,
)

# Create groupings
groupings = create_groupings()
current_functions_tree = upsert_tree(
base=current_functions_tree,
to_upsert=groupings,
base=groupings,
to_upsert=current_functions_tree,
)

_fail_if_targets_not_in_functions_tree(current_functions_tree, targets_tree)
Expand Down Expand Up @@ -200,10 +200,10 @@ def _create_aggregation_functions(
annotations=annotations,
)

out_tree = upsert_path_and_value(
out_tree = insert_path_and_value(
base=out_tree,
path_to_upsert=tree_path,
value_to_upsert=derived_func,
path_to_insert=tree_path,
value_to_insert=derived_func,
)

return out_tree
Expand Down Expand Up @@ -269,10 +269,10 @@ def _create_derived_aggregations_tree(
) and tree_path not in optree.tree_paths(aggregation_source_tree)

if aggregation_specs_needed:
derived_aggregations_tree = upsert_path_and_value(
derived_aggregations_tree = insert_path_and_value(
base=derived_aggregations_tree,
path_to_upsert=tree_path,
value_to_upsert=AggregateByGroupSpec(
path_to_insert=tree_path,
value_to_insert=AggregateByGroupSpec(
aggr="sum",
source_col=remove_group_suffix(leaf_name),
),
Expand Down Expand Up @@ -310,9 +310,9 @@ def _get_potential_aggregation_function_names_from_function_arguments(
name=name,
namespace=tree_path[:-1],
)
current_tree = upsert_path_and_value(
current_tree = insert_path_and_value(
base=current_tree,
path_to_upsert=path_of_function_argument,
path_to_insert=path_of_function_argument,
)
return current_tree

Expand Down
Loading
Loading