Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display config on visualization #833

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ def _visualize_execution_helper(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=fn_graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def display_downstream_of(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1095,6 +1097,7 @@ def display_upstream_of(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down Expand Up @@ -1252,6 +1255,7 @@ def visualize_path_between(
deduplicate_inputs=deduplicate_inputs,
display_fields=show_schema,
custom_style_function=custom_style_function,
config=self.graph._config,
)
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)
Expand Down
30 changes: 30 additions & 0 deletions hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def create_graphviz_graph(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
) -> "graphviz.Digraph": # noqa: F821
"""Helper function to create a graphviz graph.
Expand All @@ -225,6 +226,9 @@ def create_graphviz_graph(
:param deduplicate_inputs: If True, remove duplicate input nodes.
Can improve readability depending on the specifics of the DAG.
:param custom_style_function: A function that takes in node values and returns a dictionary of styles to apply to it.
:param config: The Driver config. This value is passed by the caller, e.g., driver.display_all_functions(),
and shouldn't be passed explicitly. Otherwise, it may not match the `nodes` argument leading to an
incorrect visualization.
:return: a graphviz.Digraph; use this to render/save a graph representation.
"""
PATH_COLOR = "red"
Expand All @@ -242,6 +246,12 @@ def create_graphviz_graph(
") -> Tuple[dict, Optional[str], Optional[str]]:"
)

if config is None:
raise ValueError(
"Received None for kwarg `config`. Make sure to pass a dictionary that matches the Driver config.\n"
"If you're seeing this error, you're likely using a non-public API."
)

def _get_node_label(
n: node.Node,
name: Optional[str] = None,
Expand Down Expand Up @@ -443,6 +453,11 @@ def _get_legend(
digraph = graphviz.Digraph(**digraph_attr)
extra_legend_nodes = {}

for config_key, config_value in config.items():
label = _get_node_label(n=None, name=config_key, type_string=str(config_value))
style = _get_node_style("config")
digraph.node(config_key, label=label, **style)

# create nodes
seen_node_types = set()
for n in nodes:
Expand All @@ -451,6 +466,15 @@ def _get_legend(
if node_type == "input":
seen_node_types.add(node_type)
continue
# config nodes are handled separately;
# only Driver.display_all_functions() passes config via the `nodes` arg
elif node_type == "config":
continue

# append config key to node label
config_key = n.tags.get("hamilton.config", None)
if config_key:
label = _get_node_label(n, name=f"{n.name}: {config_key}")

node_style = _get_node_style(node_type)

Expand Down Expand Up @@ -764,6 +788,7 @@ def display_all(
deduplicate_inputs=deduplicate_inputs,
display_fields=display_fields,
custom_style_function=custom_style_function,
config=self._config,
)

def has_cycles(self, nodes: Set[node.Node], user_nodes: Set[node.Node]) -> bool:
Expand Down Expand Up @@ -809,6 +834,7 @@ def display(
deduplicate_inputs: bool = False,
display_fields: bool = True,
custom_style_function: Callable = None,
config: dict = None,
) -> Optional["graphviz.Digraph"]: # noqa F821
"""Function to display the graph represented by the passed in nodes.
Expand Down Expand Up @@ -836,6 +862,9 @@ def display(
:param display_fields: If True, display fields in the graph if node has attached
schema metadata
:param custom_style_function: Optional. Custom style function.
:param config: The Driver config. This value is passed by the caller, e.g., driver.display_all_functions(),
and shouldn't be passed explicitly. Otherwise, it may not match the `nodes` argument leading to an
incorrect visualization.
:return: the graphviz graph object if it was created. None if not.
"""
# Check to see if optional dependencies have been installed.
Expand Down Expand Up @@ -863,6 +892,7 @@ def display(
deduplicate_inputs,
display_fields=display_fields,
custom_style_function=custom_style_function,
config=config,
)
kwargs = {"view": False, "format": "png"} # default format = png
if output_file_path: # infer format from path
Expand Down
91 changes: 77 additions & 14 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ def test_function_graph_has_cycles_false():
def test_function_graph_display(tmp_path: pathlib.Path):
"""Tests that display saves a file"""
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
all_nodes = set()
for n in fg.get_nodes():
Expand All @@ -709,6 +710,8 @@ def test_function_graph_display(tmp_path: pathlib.Path):
'\tA [label=<<b>A</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tB [label=<<b>B</b><br /><br /><i>int</i>> fillcolor="#FFC857" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
'\tC [label=<<b>C</b><br /><br /><i>int</i>> fillcolor="#b4d8e4" fontname=Helvetica margin=0.15 shape=rectangle style="rounded,filled"]\n',
"\tb [label=<<b>b</b><br /><br /><i>1</i>> fontname=Helvetica shape=note style=filled]\n",
"\tc [label=<<b>c</b><br /><br /><i>2</i>> fontname=Helvetica shape=note style=filled]\n",
"\t_A_inputs -> A\n",
# commenting out input node: '\t_A_inputs [label=<<table border="0"><tr><td>c</td><td>int</td></tr><tr><td>b</td><td>int</td></tr></table>> fontname=Helvetica margin=0.15 shape=rectangle style=dashed]\n',
"\tgraph [compound=true concentrate=true rankdir=LR ranksep=0.4 style=filled]\n",
Expand All @@ -726,6 +729,7 @@ def test_function_graph_display(tmp_path: pathlib.Path):
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
node_modifiers=node_modifiers,
config=config,
)
dot = dot_file_path.open("r").readlines()
dot_set = set(dot)
Expand All @@ -735,9 +739,10 @@ def test_function_graph_display(tmp_path: pathlib.Path):

def test_function_graph_display_no_dot_output(tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(set(fg.get_nodes()), output_file_path=None)
fg.display(set(fg.get_nodes()), output_file_path=None, config=config)

assert not dot_file_path.exists()

Expand All @@ -746,12 +751,14 @@ def test_function_graph_display_custom_style_node():
def _styling_function(*, node, node_class):
return dict(fill_color="aquamarine"), None, "legend_key"

fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

key_found = False
Expand All @@ -767,12 +774,14 @@ def test_function_graph_display_custom_style_legend():
def _styling_function(*, node, node_class):
return dict(fill_color="aquamarine"), None, "legend_key"

fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

key_found = False
Expand All @@ -794,12 +803,14 @@ def _styling_function(*, node, node_class):

nodes = create_testing_nodes()
nodes["A"].tags["some_key"] = "some_value"
fg = graph.FunctionGraph(nodes, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph(nodes, config=config)

digraph = fg.display(
set(fg.get_nodes()),
output_file_path=None,
custom_style_function=_styling_function,
config=config,
)

# check that style is only applied to tagged nodes
Expand All @@ -817,13 +828,15 @@ def _styling_function(*, node, node_class):
@pytest.mark.parametrize("show_legend", [(True), (False)])
def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag.png"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
show_legend=show_legend,
config=config,
)
dot_file = pathlib.Path(os.path.splitext(str(dot_file_path))[0])
dot = dot_file.open("r").read()
Expand All @@ -835,13 +848,15 @@ def test_function_graph_display_legend(show_legend: bool, tmp_path: pathlib.Path
@pytest.mark.parametrize("orient", [("LR"), ("TB"), ("RL"), ("BT")])
def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
orient=orient,
config=config,
)
dot = dot_file_path.open("r").read()

Expand All @@ -852,13 +867,15 @@ def test_function_graph_display_orient(orient: str, tmp_path: pathlib.Path):
@pytest.mark.parametrize("hide_inputs", [(True,), (False,)])
def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path):
dot_file_path = tmp_path / "dag"
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
hide_inputs=hide_inputs,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()

Expand All @@ -868,14 +885,17 @@ def test_function_graph_display_inputs(hide_inputs: bool, tmp_path: pathlib.Path

def test_function_graph_display_without_saving():
"""Tests that display works when None is passed in for path"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={"b": 1, "c": 2})
config = {"b": 1, "c": 2}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
all_nodes = set()
node_modifiers = {"B": {graph.VisualizationNodeModifiers.IS_OUTPUT}}
for n in fg.get_nodes():
if n.user_defined:
node_modifiers[n.name] = {graph.VisualizationNodeModifiers.IS_USER_INPUT}
all_nodes.add(n)
digraph = fg.display(all_nodes, output_file_path=None, node_modifiers=node_modifiers)
digraph = fg.display(
all_nodes, output_file_path=None, node_modifiers=node_modifiers, config=config
)
assert digraph is not None
import graphviz

Expand All @@ -891,13 +911,15 @@ def df_with_schema() -> pd.DataFrame:
pass

mod = ad_hoc_utils.create_temporary_module(df_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
config = {}
fg = graph.FunctionGraph.from_modules(mod, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=display_fields,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()
if display_fields:
Expand Down Expand Up @@ -927,13 +949,15 @@ def df_2_with_schema() -> pd.DataFrame:
pass

mod = ad_hoc_utils.create_temporary_module(df_1_with_schema, df_2_with_schema)
fg = graph.FunctionGraph.from_modules(mod, config={})
config = {}
fg = graph.FunctionGraph.from_modules(mod, config=config)

fg.display(
set(fg.get_nodes()),
output_file_path=str(dot_file_path),
render_kwargs={"view": False},
display_fields=True,
config=config,
)
dot_lines = dot_file_path.open("r").readlines()

Expand All @@ -946,9 +970,47 @@ def _get_occurances(var: str):
assert len(_get_occurances("baz=")) == 2


def test_function_graph_display_config_node():
"""Check if config is displayed by low-level hamilton.graph.FunctionGraph.display"""
config = {"X": 1}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)

dot = fg.display(set(fg.get_nodes()), config=config)

# dot.body is a list of string
# lines start tab then node name; check if "b" is a node in the graphviz object
assert any(line.startswith("\tX") for line in dot.body)


# TODO use high-level visualization dot as fixtures for reuse across tests
def test_display_config_node(tmp_path: pathlib.Path):
"""Check if config is displayed by high-level hamilton.driver.display..."""
from hamilton import driver
from hamilton.io.materialization import to

config = {"X": 1}
dr = driver.Builder().with_modules(tests.resources.dummy_functions).with_config(config).build()

all_dot = dr.display_all_functions()
down_dot = dr.display_downstream_of("A")
up_dot = dr.display_upstream_of("C")
between_dot = dr.visualize_path_between("A", "C")
exec_dot = dr.visualize_execution(["C"], inputs={"b": 1, "c": 2})
materialize_dot = dr.visualize_materialization(
to.json(
id="saver", dependencies=["C"], combine=base.DictResult(), path=f"{tmp_path}/saver.json"
),
inputs={"b": 1, "c": 2},
)

for dot in [all_dot, down_dot, up_dot, between_dot, exec_dot, materialize_dot]:
assert any(line.startswith("\tX") for line in dot.body)


def test_create_graphviz_graph():
"""Tests that we create a graphviz graph"""
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config={})
config = {}
fg = graph.FunctionGraph.from_modules(tests.resources.dummy_functions, config=config)
nodes, user_nodes = fg.get_upstream_nodes(["A", "B", "C"])
nodez = nodes.union(user_nodes)
node_modifiers = {
Expand Down Expand Up @@ -990,6 +1052,7 @@ def test_create_graphviz_graph():
graphviz_kwargs=dict(graph_attr={"ratio": "1"}),
node_modifiers=node_modifiers,
strictly_display_only_nodes_passed_in=False,
config=config,
)
# the HTML table isn't deterministic. Replace the value in it with a single one.
dot_set = set(str(digraph).replace("<td>c</td>", "<td>b</td>").split("\n"))
Expand Down