Skip to content

Commit

Permalink
Update DFP join logic to account for goup by metric source nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Apr 25, 2024
1 parent c9a4c19 commit c7681a3
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 53 deletions.
5 changes: 0 additions & 5 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,6 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
# then produce the linkable spec. See comments further below for more details.

for entity_spec_in_right_node in entity_specs_in_right_node:
# If an entity has links, what that means and whether it can be used is unclear at the moment,
# so skip it.
if len(entity_spec_in_right_node.entity_links) > 0:
continue

entity_instance_in_right_node = None
for instance in data_set_in_right_node.instance_set.entity_instances:
if instance.spec == entity_spec_in_right_node:
Expand Down
44 changes: 25 additions & 19 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

from typing import Optional, Sequence
from typing import List, Optional, Sequence

from dbt_semantic_interfaces.references import SemanticModelReference
from typing_extensions import override

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataset.dataset import DataSet
from metricflow.instances import (
InstanceSet,
)
from metricflow.instances import EntityInstance, InstanceSet
from metricflow.specs.column_assoc import ColumnAssociation
from metricflow.specs.specs import DimensionSpec, EntitySpec, TimeDimensionSpec
from metricflow.sql.sql_plan import (
Expand Down Expand Up @@ -60,25 +58,33 @@ def column_associations_for_entity(
entity_spec: EntitySpec,
) -> Sequence[ColumnAssociation]:
"""Given the name of the entity, return the set of columns associated with it in the data set."""
matching_instances = 0
column_associations_to_return = None
matching_instances_with_same_entity_links: List[EntityInstance] = []
matching_instances_with_different_entity_links: List[EntityInstance] = []
for linkable_instance in self.instance_set.entity_instances:
if (
entity_spec.element_name == linkable_instance.spec.element_name
and entity_spec.entity_links == linkable_instance.spec.entity_links
):
column_associations_to_return = linkable_instance.associated_columns
matching_instances += 1

if matching_instances > 1:
if entity_spec.element_name == linkable_instance.spec.element_name:
if entity_spec.entity_links == linkable_instance.spec.entity_links:
matching_instances_with_same_entity_links.append(linkable_instance)
else:
matching_instances_with_different_entity_links.append(linkable_instance)

# Prioritize instances with matching entity links, but use mismatched links if matching links not found.
# Semantic model source data sets might have multiple instances of the same entity, in which case we want the one without
# links. But group by metric source data sets might only have an instance of the entity with links, and we can join to that.
matching_instances = matching_instances_with_same_entity_links or matching_instances_with_different_entity_links

if len(matching_instances) != 1:
raise RuntimeError(
f"More than one instance with spec {entity_spec} in " f"instance set: {self.instance_set}"
f"Expected exactly one matching instance for {entity_spec} in instance set, but found: {matching_instances}"
)
matching_instance = matching_instances[0]
if not matching_instance.associated_columns:
print("entity links to compare:", entity_spec.entity_links, linkable_instance.spec.entity_links)
raise RuntimeError(
f"No associated columns for entity instance {matching_instance} in data set."
"This indicates internal misconfiguration."
)

if not column_associations_to_return:
raise RuntimeError(f"No instances with spec {entity_spec} in instance set: {self.instance_set}")

return column_associations_to_return
return matching_instance.associated_columns

def column_association_for_dimension(
self,
Expand Down
37 changes: 28 additions & 9 deletions metricflow/model/semantics/semantic_model_join_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,34 @@ def is_valid_instance_set_join(
left_instance_set: InstanceSet,
right_instance_set: InstanceSet,
on_entity_reference: EntityReference,
right_node_is_subquery: bool = False,
) -> bool:
"""Return true if the instance sets can be joined using the given entity."""
return self.is_valid_semantic_model_join(
left_semantic_model_reference=SemanticModelJoinEvaluator._semantic_model_of_entity_in_instance_set(
instance_set=left_instance_set, entity_reference=on_entity_reference
),
right_semantic_model_reference=SemanticModelJoinEvaluator._semantic_model_of_entity_in_instance_set(
instance_set=right_instance_set,
entity_reference=on_entity_reference,
),
on_entity_reference=on_entity_reference,
left_semantic_model_reference = SemanticModelJoinEvaluator._semantic_model_of_entity_in_instance_set(
instance_set=left_instance_set, entity_reference=on_entity_reference
)
if right_node_is_subquery:
left_entity = self._semantic_model_lookup.get_entity_in_semantic_model(
SemanticModelElementReference.create_from_references(left_semantic_model_reference, on_entity_reference)
)
if not left_entity:
return False
possible_right_entities = [
entity_instance
for entity_instance in right_instance_set.entity_instances
if entity_instance.spec.reference == on_entity_reference
]
if len(possible_right_entities) != 1:
return False

# No fan-out check needed since right subquery is aggregated to the entity level, ensuring uniqueness.
return True
else:
return self.is_valid_semantic_model_join(
left_semantic_model_reference=left_semantic_model_reference,
right_semantic_model_reference=SemanticModelJoinEvaluator._semantic_model_of_entity_in_instance_set(
instance_set=right_instance_set,
entity_reference=on_entity_reference,
),
on_entity_reference=on_entity_reference,
)
36 changes: 19 additions & 17 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
CreateSqlColumnReferencesForInstances,
FilterElements,
FilterLinkableInstancesWithLeadingLink,
InstanceSetTransform,
RemoveMeasures,
RemoveMetrics,
UpdateMeasureFillNullsWith,
Expand Down Expand Up @@ -608,8 +609,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:

# Add select columns that would compute the metrics to the select columns.
metric_select_columns = []
metric_instances = []
group_by_metric_instances = []
metric_instances: List[MetricInstance] = []
for metric_spec in node.metric_specs:
metric = self._metric_lookup.get_metric(metric_spec.reference)

Expand Down Expand Up @@ -723,23 +723,25 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
spec=metric_spec,
)
)
group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(output_column_association,),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
spec=GroupByMetricSpec(
element_name=metric_spec.element_name,
entity_links=(),
metric_subquery_entity_links=(), # TODO
),
)

transform_func: InstanceSetTransform = AddMetrics(metric_instances)
if node.for_group_by_source_node:
assert (
len(metric_instances) == 1 and len(output_instance_set.entity_instances) == 1
), "Group by metrics currently only support exactly one metric grouped by exactly one entity."
metric_instance = metric_instances[0]
entity_instance = output_instance_set.entity_instances[0]
group_by_metric_instance = GroupByMetricInstance(
associated_columns=metric_instance.associated_columns,
defined_from=metric_instance.defined_from,
spec=GroupByMetricSpec(
element_name=metric_spec.element_name,
entity_links=(), # check this
metric_subquery_entity_links=entity_instance.spec.entity_links,
),
)
transform_func = AddGroupByMetrics([group_by_metric_instance])

transform_func = (
AddGroupByMetrics(group_by_metric_instances)
if node.for_group_by_source_node
else AddMetrics(metric_instances)
)
output_instance_set = output_instance_set.transform(transform_func)

combined_select_column_set = non_metric_select_column_set.merge(
Expand Down
6 changes: 3 additions & 3 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
)
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinToBaseOutputNode
Expand Down Expand Up @@ -130,9 +131,6 @@ def _node_contains_entity(
if entity_spec_in_first_node.reference != entity_reference:
continue

if len(entity_spec_in_first_node.entity_links) > 0:
continue

assert (
len(entity_instance_in_first_node.defined_from) == 1
), "Multiple items in defined_from not yet supported"
Expand Down Expand Up @@ -215,6 +213,8 @@ def _get_candidates_nodes_for_multi_hop(
left_instance_set=data_set_of_first_node_that_could_be_joined.instance_set,
right_instance_set=data_set_of_second_node_that_can_be_joined.instance_set,
on_entity_reference=entity_reference_to_join_first_and_second_nodes,
# TODO: make this check more substantial in V2
right_node_is_subquery=isinstance(second_node_that_could_be_joined, ComputeMetricsNode),
):
continue

Expand Down

0 comments on commit c7681a3

Please sign in to comment.