Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions api/py/ai/chronon/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,33 +57,14 @@ def JoinPart(
JoinPart specifies how the left side of a join, or the query in online setting, would join with the right side
components like GroupBys.
"""
# used for reset for next run
import_copy = __builtins__["__import__"]
# get group_by's module info from garbage collector
gc.collect()
group_by_module_name = None
for ref in gc.get_referrers(group_by):
if "__name__" in ref and ref["__name__"].startswith("group_bys"):
group_by_module_name = ref["__name__"]
break
if group_by_module_name:
logging.debug("group_by's module info from garbage collector {}".format(group_by_module_name))
group_by_module = importlib.import_module(group_by_module_name)
__builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy)
else:
if not group_by.metaData.name:
logging.error("No group_by file or custom group_by name found")
raise ValueError(
"[GroupBy] Must specify a group_by name if group_by is not defined in separate file. "
"You may pass it in via GroupBy.name. \n"
)
# Automatically set the GroupBy name if not already set
_auto_set_group_by_name(group_by, context="JoinPart")

if key_mapping:
utils.check_contains(key_mapping.values(), group_by.keyColumns, "key", group_by.metaData.name)

join_part = api.JoinPart(groupBy=group_by, keyMapping=key_mapping, prefix=prefix)
join_part.tags = tags
# reset before next run
__builtins__["__import__"] = import_copy
return join_part


Expand Down Expand Up @@ -140,6 +121,7 @@ def ExternalSource(
custom_json: Optional[str] = None,
factory_name: Optional[str] = None,
factory_params: Optional[Dict[str, str]] = None,
offline_group_by: Optional[api.GroupBy] = None,
) -> api.ExternalSource:
"""
External sources are online only data sources. During fetching, using
Expand Down Expand Up @@ -180,10 +162,17 @@ def ExternalSource(
creating the external source handler.
:param factory_params: Optional parameters to pass to the factory when
creating the handler.
:param offline_group_by: Optional GroupBy configuration to be used for
offline backfill computation. When provided, enables point-in-time
correct (PITC) offline computation for the external source.

"""
assert name != "contextual", "Please use `ContextualSource`"

# Automatically set the name for offline_group_by if not already set
if offline_group_by is not None:
_auto_set_group_by_name(offline_group_by, context="ExternalSource")

factory_config = None
if factory_name is not None or factory_params is not None:
factory_config = api.ExternalSourceFactoryConfig(factoryName=factory_name, factoryParams=factory_params)
Expand All @@ -193,6 +182,7 @@ def ExternalSource(
keySchema=DataType.STRUCT(f"ext_{name}_keys", *key_fields),
valueSchema=DataType.STRUCT(f"ext_{name}_values", *value_fields),
factoryConfig=factory_config,
offlineGroupBy=offline_group_by,
)


Expand Down Expand Up @@ -654,3 +644,42 @@ def Join(
derivations=derivations,
modelTransforms=model_transforms,
)

def _auto_set_group_by_name(group_by: api.GroupBy, context: str = "GroupBy") -> None:
"""
Automatically set the GroupBy name by finding its source module using garbage collection.
This is used by both JoinPart and ExternalSource to automatically name GroupBys.

:param group_by: The GroupBy object to set the name for
:param context: Context string for error messages (e.g., "JoinPart", "ExternalSource")
"""
if group_by.metaData.name:
# Name already set, nothing to do
return

# Save and restore __import__ to preserve original behavior
import_copy = __builtins__["__import__"]

try:
# Use garbage collector to find the module where this GroupBy was defined
gc.collect()
group_by_module_name = None
for ref in gc.get_referrers(group_by):
if "__name__" in ref and ref["__name__"].startswith("group_bys"):
group_by_module_name = ref["__name__"]
break

if group_by_module_name:
logging.debug(f"{context}: group_by's module info from garbage collector {group_by_module_name}")
group_by_module = importlib.import_module(group_by_module_name)
__builtins__["__import__"] = eo.import_module_set_name(group_by_module, api.GroupBy)
else:
if not group_by.metaData.name:
logging.error(f"{context}: No group_by file or custom group_by name found")
raise ValueError(
f"[{context}] Must specify a group_by name if group_by is not defined in separate file. "
"You can set it via GroupBy(metaData=MetaData(name='team.file.variable_name'))"
)
finally:
# Reset before next run
__builtins__["__import__"] = import_copy
26 changes: 26 additions & 0 deletions api/py/ai/chronon/repo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over
# In case of join, we need to materialize the following underlying group_bys
# 1. group_bys whose online param is set
# 2. group_bys whose backfill_start_date param is set.
# 3. offline group_bys from external parts (always materialized)
if obj_class is Join:
online_group_bys = {}
offline_backfill_enabled_group_bys = {}
Expand All @@ -148,6 +149,10 @@ def extract_and_convert(chronon_root, input_path, output_root, debug, force_over
else:
offline_gbs.append(jp.groupBy.metaData.name)

# Extract and always materialize online GroupBys from external parts
external_offline_gbs = _extract_external_part_offline_group_bys(obj, team_name, teams_path)
online_group_bys.update(external_offline_gbs)

_print_debug_info(list(online_group_bys.keys()), "Online Groupbys", log_level)
_print_debug_info(
list(offline_backfill_enabled_group_bys.keys()), "Offline Groupbys With Backfill Enabled", log_level
Expand Down Expand Up @@ -445,6 +450,27 @@ def _set_templated_values(obj, cls, teams_path, team_name):
obj.metaData.dependencies = [__fill_template(dep, obj, namespace) for dep in obj.metaData.dependencies]


def _extract_external_part_offline_group_bys(join_obj: api.Join, team_name: str, teams_path: str):
"""
Extract offline GroupBys from external parts in a Join.
Sets proper metadata (name, team, namespace) for each offline GroupBy.
Returns a dictionary of {groupby_name: groupby_object}.
"""
external_offline_gbs = {}

if not join_obj.onlineExternalParts:
return external_offline_gbs

for external_part in join_obj.onlineExternalParts:
if not external_part.source or not external_part.source.offlineGroupBy:
continue

offline_gb = external_part.source.offlineGroupBy
external_offline_gbs[offline_gb.metaData.name] = offline_gb

return external_offline_gbs


def _write_obj(
full_output_root: str,
validator: ChrononRepoValidator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@
name="test_external_source",
team="chronon",
key_fields=[
("key", DataType.LONG)
("group_by_subject", DataType.STRING)
],
value_fields=[
("value_str", DataType.STRING),
("value_long", DataType.LONG),
("value_bool", DataType.BOOLEAN)
]
)
],
offline_group_by=event_sample_group_by.v1,
),
),
ExternalPart(
ContextualSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@
"kind": 13,
"params": [
{
"name": "key",
"name": "group_by_subject",
"dataType": {
"kind": 4
"kind": 7
}
}
],
Expand Down Expand Up @@ -203,6 +203,67 @@
}
],
"name": "ext_test_external_source_values"
},
"offlineGroupBy": {
"metaData": {
"name": "sample_team.event_sample_group_by.v1",
"online": 1,
"customJson": "{\"lag\": 0, \"groupby_tags\": {\"TO_DEPRECATE\": true}, \"column_tags\": {\"event_sum_7d\": {\"DETAILED_TYPE\": \"CONTINUOUS\"}}}",
"dependencies": [
"{\"name\": \"wait_for_sample_namespace.sample_table_group_by_ds\", \"spec\": \"sample_namespace.sample_table_group_by/ds={{ ds }}\", \"start\": \"2021-04-09\", \"end\": null}"
],
"tableProperties": {
"source": "chronon"
},
"outputNamespace": "sample_namespace",
"team": "sample_team",
"offlineSchedule": "@daily"
},
"sources": [
{
"events": {
"table": "sample_namespace.sample_table_group_by",
"query": {
"selects": {
"event": "event_expr",
"group_by_subject": "group_by_expr"
},
"startPartition": "2021-04-09",
"timeColumn": "ts",
"setups": []
}
}
}
],
"keyColumns": [
"group_by_subject"
],
"aggregations": [
{
"inputColumn": "event",
"operation": 7,
"argMap": {},
"windows": [
{
"length": 7,
"timeUnit": 1
}
]
},
{
"inputColumn": "event",
"operation": 7,
"argMap": {}
},
{
"inputColumn": "event",
"operation": 12,
"argMap": {
"k": "200",
"percentiles": "[0.99, 0.95, 0.5]"
}
}
]
}
}
},
Expand Down
45 changes: 45 additions & 0 deletions api/py/test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,48 @@ def test_compile_inline_group_by():
join = json2thrift(file.read(), Join)
assert len(join.joinParts) == 1
assert join.joinParts[0].groupBy.metaData.team == "unit_test"


def test_compile_external_source_with_offline_group_by():
"""
Test that compiling a join with an external source that has an offlineGroupBy
correctly materializes the offlineGroupBy in the external source.
"""
runner = CliRunner()
input_path = "joins/sample_team/sample_join_with_derivations_on_external_parts.py"
result = _invoke_cli_with_params(runner, input_path)
assert result.exit_code == 0

# Verify the compiled join contains the external source with offlineGroupBy
path = "sample/production/joins/sample_team/sample_join_with_derivations_on_external_parts.v1"
full_file_path = _get_full_file_path(path)
_assert_file_exists(full_file_path, f"Expected {os.path.basename(path)} to be materialized, but it was not.")

with open(full_file_path, "r") as file:
join = json2thrift(file.read(), Join)

# Verify the join has online external parts
assert join.onlineExternalParts is not None, "Expected onlineExternalParts to be present"
assert len(join.onlineExternalParts) > 0, "Expected at least one external part"

# Find the external source with offlineGroupBy
external_source_with_offline_gb = None
for external_part in join.onlineExternalParts:
if external_part.source.metadata.name == "test_external_source":
external_source_with_offline_gb = external_part.source
break

assert external_source_with_offline_gb is not None, "Expected to find test_external_source"

# Verify the offlineGroupBy is present and has the expected properties
assert external_source_with_offline_gb.offlineGroupBy is not None, (
"Expected offlineGroupBy to be present in test_external_source"
)

offline_gb = external_source_with_offline_gb.offlineGroupBy
assert offline_gb.keyColumns == ["group_by_subject"], f"Expected key columns to be ['group_by_subject'], got {offline_gb.keyColumns}"
assert offline_gb.aggregations is not None, "Expected aggregations to be present"
assert len(offline_gb.aggregations) == 3, f"Expected 3 aggregations, got {len(offline_gb.aggregations)}"
assert offline_gb.metaData.outputNamespace == "sample_namespace", (
f"Expected output namespace to be 'sample_namespace', got {offline_gb.metaData.outputNamespace}"
)
Loading