Skip to content

Commit

Permalink
Display config on visualization (#833)
Browse files Browse the repository at this point in the history
Visualizations now display the config value instead of type Any. Also, nodes affected by configuration have the suffix node_name: config_key.

While working on this feature, I realized all high-level viz function except display_all_functions() failed to show config nodes (#832). This is because it's the only function to pass the config via the nodes argument (type hamilton.node.Node) while other functions first filter a path of nodes, which excludes the config.

Instead of a large refactoring, hamilton.graph.create_graphviz_graph() now has a kwarg config which raises an exception if left to None. It's the caller's responsibility to have this config match the actual config value that lead to the set of nodes passed. This unlikely to cause problem unless people dig into internal APIs.

TL;DR:
* visualizes config values
* updated tests; added guardrail

---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto authored Apr 22, 2024
1 parent 26bc1cc commit 0772e04
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 14 deletions.
4 changes: 4 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def _visualize_execution_helper(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=fn_graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def display_downstream_of(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1095,6 +1097,7 @@ def display_upstream_of(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1252,6 +1255,7 @@ def visualize_path_between(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down
30 changes: 30 additions & 0 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def create_graphviz_graph(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
) -> "graphviz.Digraph": # noqa: F821
"""Helper function to create a graphviz graph.
Expand All @@ -225,6 +226,9 @@ def create_graphviz_graph(
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param custom_style_function: A function that takes in node values and returns a dictionary of styles to apply to it.
:param config: The Driver config. This value is passed by the caller, e.g., driver.display_all_functions(),
and shouldn't be passed explicitly. Otherwise, it may not match the `nodes` argument leading to an
incorrect visualization.
:return: a graphviz.Digraph; use this to render/save a graph representation.
"""
PATH_COLOR = "red"
Expand All @@ -242,6 +246,12 @@ def create_graphviz_graph(
") -> Tuple[dict, Optional[str], Optional[str]]:"
)

if config is None:
raise ValueError(
"Received None for kwarg `config`. Make sure to pass a dictionary that matches the Driver config.\n"
"If you're seeing this error, you're likely using a non-public API."
)

def _get_node_label(
n: node.Node,
name: Optional[str] = None,
Expand Down Expand Up @@ -443,6 +453,11 @@ def _get_legend(
digraph = graphviz.Digraph(**digraph_attr)
extra_legend_nodes = {}

for config_key, config_value in config.items():
label = _get_node_label(n=None, name=config_key, type_string=str(config_value))
style = _get_node_style("config")
digraph.node(config_key, label=label, **style)

# create nodes
seen_node_types = set()
for n in nodes:
Expand All @@ -451,6 +466,15 @@ def _get_legend(
if node_type == "input":
seen_node_types.add(node_type)
continue
# config nodes are handled separately;
# only Driver.display_all_functions() passes config via the `nodes` arg
elif node_type == "config":
continue

# append config key to node label
config_key = n.tags.get("hamilton.config", None)
if config_key:
label = _get_node_label(n, name=f"{n.name}: {config_key}")

node_style = _get_node_style(node_type)

Expand Down Expand Up @@ -764,6 +788,7 @@ def display_all(
deduplicate_inputs=deduplicate_inputs,
display_fields=display_fields,
custom_style_function=custom_style_function,
config=self._config,
)

def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
Expand Down Expand Up @@ -809,6 +834,7 @@ def display(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.
Expand Down Expand Up @@ -836,6 +862,9 @@ def display(
:param display_fields: If True, display fields in the graph if node has attached
schema metadata
:param custom_style_function: Optional. Custom style function.
:param config: The Driver config. This value is passed by the caller, e.g., driver.display_all_functions(),
and shouldn't be passed explicitly. Otherwise, it may not match the `nodes` argument leading to an
incorrect visualization.
:return: the graphviz graph object if it was created. None if not.
"""
# Check to see if optional dependencies have been installed.
Expand Down Expand Up @@ -863,6 +892,7 @@ def display(
deduplicate_inputs,
display_fields=display_fields,
custom_style_function=custom_style_function,
config=config,
)
kwargs = {"view": False, "format": "png"} # default format = png
if output_file_path: # infer format from path
Expand Down
91 changes: 77 additions & 14 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ def test_function_graph_has_cycles_false():
def test_function_graph_display(tmp_path: pathlib.Path):
"""Tests that display saves a file"""
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
all_nodes = set()
for n in fg.get_nodes():
Expand All @@ -709,6 +710,8 @@ def test_function_graph_display(tmp_path: pathlib.Path):
'\tA [label=<<b>A</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tB [label=<<b>B</b><br /><br /><i>int</i>> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tC [label=<<b>C</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
"\tb [label=<<b>b</b><br /><br /><i>1</i>> fontname=Helvetica shape=note style=filled]\n",
"\tc [label=<<b>c</b><br /><br /><i>2</i>> fontname=Helvetica shape=note style=filled]\n",
"\t_A_inputs -> A\n",
# commenting out input node: '\t_A_inputs [label=<<table border="0"><tr><td>c</td><td>int</td></tr><tr><td>b</td><td>int</td></tr></table>> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n',
"\tgraph [compound=true concentrate=true rankdir=LR ranksep=0.4 style=filled]\n",
Expand All @@ -726,6 +729,7 @@ def test_function_graph_display(tmp_path: pathlib.Path):
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
node_modifiers=node_modifiers,
config=config,
)
dot = dot_file_path.open("r").readlines()
dot_set = set(dot)
Expand All @@ -735,9 +739,10 @@ def test_function_graph_display(tmp_path: pathlib.Path):

def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(set(fg.get_nodes()), output_file_path=None)
fg.display(set(fg.get_nodes()), output_file_path=None, config=config)

assert not dot_file_path.exists()

Expand All @@ -746,12 +751,14 @@ def test_function_graph_display_custom_style_node():
def _styling_function(*, node, node_class):
return dict(fill_color="aquamarine"), None, "legend_key"

fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

key_found = False
Expand All @@ -767,12 +774,14 @@ def test_function_graph_display_custom_style_legend():
def _styling_function(*, node, node_class):
return dict(fill_color="aquamarine"), None, "legend_key"

fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

key_found = False
Expand All @@ -794,12 +803,14 @@ def _styling_function(*, node, node_class):

nodes = create_testing_nodes()
nodes["A"].tags["some_key"] = "some_value"
fg = graph.FunctionGraph(nodes, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph(nodes, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

# check that style is only applied to tagged nodes
Expand All @@ -817,13 +828,15 @@ def _styling_function(*, node, node_class):
@pytest.mark.parametrize("show_legend", [(True), (False)])
def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag.png"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
show_legend=show_legend,
config=config,
)
dot_file = pathlib.Path(os.path.splitext(str(dot_file_path))[0])
dot = dot_file.open("r").read()
Expand All @@ -835,13 +848,15 @@ def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path
@pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")])
def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
orient=orient,
config=config,
)
dot = dot_file_path.open("r").read()

Expand All @@ -852,13 +867,15 @@ def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path):
@pytest.mark.parametrize("hide_inputs", [(True,), (False,)])
def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
hide_inputs=hide_inputs,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()

Expand All @@ -868,14 +885,17 @@ def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path

def test_function_graph_display_without_saving():
"""Tests that display works when None is passed in for path"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
all_nodes = set()
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
digraph = fg.display(all_nodes, output_file_path=None, node_modifiers=node_modifiers)
digraph = fg.display(
all_nodes, output_file_path=None, node_modifiers=node_modifiers, config=config
)
assert digraph is not None
import graphviz

Expand All @@ -891,13 +911,15 @@ def df_with_schema() -> pd.DataFrame:
pass

mod = ad_hoc_utils.create_temporary_module(df_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
config = {}
fg = graph.FunctionGraph.from_modules(mod, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=display_fields,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()
if display_fields:
Expand Down Expand Up @@ -927,13 +949,15 @@ def df_2_with_schema() -> pd.DataFrame:
pass

mod = ad_hoc_utils.create_temporary_module(df_1_with_schema, df_2_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
config = {}
fg = graph.FunctionGraph.from_modules(mod, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=True,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()

Expand All @@ -946,9 +970,47 @@ def _get_occurances(var: str):
assert len(_get_occurances("baz=")) == 2


def test_function_graph_display_config_node():
"""Check if config is displayed by low-level hamilton.graph.FunctionGraph.display"""
config = {"X": 1}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

dot = fg.display(set(fg.get_nodes()), config=config)

# dot.body is a list of string
# lines start tab then node name; check if "b" is a node in the graphviz object
assert any(line.startswith("\tX") for line in dot.body)


# TODO use high-level visualization dot as fixtures for reuse across tests
def test_display_config_node(tmp_path: pathlib.Path):
"""Check if config is displayed by high-level hamilton.driver.display..."""
from hamilton import driver
from hamilton.io.materialization import to

config = {"X": 1}
dr = driver.Builder().with_modules(tests.resources.dummy_functions).with_config(config).build()

all_dot = dr.display_all_functions()
down_dot = dr.display_downstream_of("A")
up_dot = dr.display_upstream_of("C")
between_dot = dr.visualize_path_between("A", "C")
exec_dot = dr.visualize_execution(["C"], inputs={"b": 1, "c": 2})
materialize_dot = dr.visualize_materialization(
to.json(
id="saver", dependencies=["C"], combine=base.DictResult(), path=f"{tmp_path}/saver.json"
),
inputs={"b": 1, "c": 2},
)

for dot in [all_dot, down_dot, up_dot, between_dot, exec_dot, materialize_dot]:
assert any(line.startswith("\tX") for line in dot.body)


def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
config = {}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
nodez = nodes.union(user_nodes)
node_modifiers = {
Expand Down Expand Up @@ -990,6 +1052,7 @@ def test_create_graphviz_graph():
graphviz_kwargs=dict(graph_attr={"ratio": "1"}),
node_modifiers=node_modifiers,
strictly_display_only_nodes_passed_in=False,
config=config,
)
# the HTML table isn't deterministic. Replace the value in it with a single one.
dot_set = set(str(digraph).replace("<td>c</td>", "<td>b</td>").split("\n"))
Expand Down

0 comments on commit 0772e04

Please sign in to comment.