Skip to content

Commit

Permalink
add classical 2D tree drawing mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mattmilten committed Apr 25, 2021
1 parent 88749e3 commit 07c9f6b
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 67 deletions.
10 changes: 9 additions & 1 deletion bin/treed
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ desc = """Draw a visual representation of the branch-and-cut tree of SCIP for
parser = argparse.ArgumentParser(description=desc)

parser.add_argument("model", type=str, help="path to model")
parser.add_argument(
"--classic",
action="store_true",
help="draw classical 2D tree ignoring spatial information of node LP solutions",
)
parser.add_argument(
"--transformation",
"-t",
Expand Down Expand Up @@ -42,4 +47,7 @@ treed = TreeD(
)

treed.solve()
treed.draw()
if args.classic:
treed.draw2d()
else:
treed.draw()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="treed",
version="1.0.0",
version="1.1.0",
author="Matthias Miltenberger",
author_email="[email protected]",
description="3D Visualization of Branch-and-Cut Trees using PySCIPOpt",
Expand Down
317 changes: 252 additions & 65 deletions src/treed/treed.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _generateEdges(self, separate_frames=False):
Ye = []
Ze = []

if not "x" in self.df or not "y" in self.df:
self.df["x"] = 0
self.df["y"] = 0

symbol = []

if not separate_frames:
Expand Down Expand Up @@ -327,77 +331,68 @@ def _create_nodes_frames(self):

return frames, sliders_dict

def draw(self):
"""Draw the tree, depending on the mode"""

self.transform()

if self.verbose:
print("generating 3D objects", end="...")
start = time()
self._generateEdges()
nodes, nodeprojs = self._create_nodes_and_projections()
frames, sliders = self._create_nodes_frames()

edges = go.Scatter3d(
x=self.Xe,
y=self.Ye,
z=self.Ze,
mode="lines",
line=go.scatter3d.Line(color="rgb(75,75,75)", width=2),
hoverinfo="none",
name="edges",
def _create_nodes_frames_2d(self):
colorbar = go.scatter.marker.ColorBar(title="", thickness=10, x=0)
marker = go.scatter.Marker(
symbol=self.df["symbol"],
size=self.nodesize * 2,
color=self.df["age"],
colorscale=self.colorscale,
colorbar=colorbar,
)

min_x = min(self.df["x"])
max_x = max(self.df["x"])
min_y = min(self.df["y"])
max_y = max(self.df["y"])

optval = go.Scatter3d(
x=[min_x, min_x, max_x, max_x, min_x],
y=[min_y, max_y, max_y, min_y, min_y],
z=[self.optval] * 5,
mode="lines",
line=go.scatter3d.Line(color="rgb(0,200,50)", width=5),
hoverinfo="name+z",
name="optimal value",
opacity=0.5,
frames = []
sliders_dict = dict(
active=0,
yanchor="top",
xanchor="left",
currentvalue={"prefix": "Age:", "visible": True, "xanchor": "right",},
len=0.9,
x=0.05,
y=0.1,
steps=[],
)

xaxis = go.layout.scene.XAxis(
showticklabels=False,
title="X",
backgroundcolor="white",
gridcolor="lightgray",
)
yaxis = go.layout.scene.YAxis(
showticklabels=False,
title="Y",
backgroundcolor="white",
gridcolor="lightgray",
)
zaxis = go.layout.scene.ZAxis(
title="objective value", backgroundcolor="white", gridcolor="lightgray"
)
scene = go.layout.Scene(xaxis=xaxis, yaxis=yaxis, zaxis=zaxis)
title = (
"TreeD: " + self.probname + ", " + self.scipversion if self.title else ""
)
filename = "TreeD_" + self.probname + ".html"
for a in self.df["age"]:
adf = self.df[self.df["age"] <= a]
node_object = go.Scatter(
x=[self.pos2d[k][0] for k in range(len(adf))],
y=[self.pos2d[k][1] for k in range(len(adf))],
mode="markers",
marker=marker,
hovertext=[
f"LP obj: {adf['objval'].iloc[i]:.3f}\
<br>node number: {adf['number'].iloc[i]}\
<br>node age: {adf['age'].iloc[i]}\
<br>depth: {adf['depth'].iloc[i]}\
<br>LP cond: {adf['condition'].iloc[i]:.1f}\
<br>iterations: {adf['iterations'].iloc[i]}"
for i in range(len(adf))
],
hoverinfo="text+name",
opacity=0.7,
name="LP solutions",
)
frames.append(go.Frame(data=node_object, name=str(a)))

layout = go.Layout(
title=title,
font=dict(size=self.fontsize),
autosize=True,
# width=900,
# height=600,
showlegend=self.showlegend,
hovermode="closest",
scene=scene,
)
slider_step = {
"args": [
[a],
{
"frame": {"redraw": True, "restyle": False},
"fromcurrent": True,
"mode": "immediate",
},
],
"label": a,
"method": "animate",
}
sliders_dict["steps"].append(slider_step)

return frames, sliders_dict

layout["updatemenus"] = list(
def updatemenus(self):
return list(
[
dict(
buttons=list(
Expand Down Expand Up @@ -487,6 +482,78 @@ def draw(self):
]
)


def draw(self):
"""Draw the tree, depending on the mode"""

self.transform()

if self.verbose:
print("generating 3D objects", end="...")
start = time()
self._generateEdges()
nodes, nodeprojs = self._create_nodes_and_projections()
frames, sliders = self._create_nodes_frames()

edges = go.Scatter3d(
x=self.Xe,
y=self.Ye,
z=self.Ze,
mode="lines",
line=go.scatter3d.Line(color="rgb(75,75,75)", width=2),
hoverinfo="none",
name="edges",
)

min_x = min(self.df["x"])
max_x = max(self.df["x"])
min_y = min(self.df["y"])
max_y = max(self.df["y"])

optval = go.Scatter3d(
x=[min_x, min_x, max_x, max_x, min_x],
y=[min_y, max_y, max_y, min_y, min_y],
z=[self.optval] * 5,
mode="lines",
line=go.scatter3d.Line(color="rgb(0,200,50)", width=5),
hoverinfo="name+z",
name="optimal value",
opacity=0.5,
)

xaxis = go.layout.scene.XAxis(
showticklabels=False,
title="X",
backgroundcolor="white",
gridcolor="lightgray",
)
yaxis = go.layout.scene.YAxis(
showticklabels=False,
title="Y",
backgroundcolor="white",
gridcolor="lightgray",
)
zaxis = go.layout.scene.ZAxis(
title="objective value", backgroundcolor="white", gridcolor="lightgray"
)
scene = go.layout.Scene(xaxis=xaxis, yaxis=yaxis, zaxis=zaxis)
title = (
"TreeD: " + self.probname + ", " + self.scipversion if self.title else ""
)
filename = "TreeD_" + self.probname + ".html"

layout = go.Layout(
title=title,
font=dict(size=self.fontsize),
autosize=True,
# width=900,
# height=600,
showlegend=self.showlegend,
hovermode="closest",
scene=scene,
)

layout["updatemenus"] = self.updatemenus()
layout["sliders"] = [sliders]

self.fig = go.Figure(
Expand All @@ -505,6 +572,84 @@ def draw(self):

return self.fig

def draw2d(self):
"""Draw the 2D tree"""
self._generateEdges()
self.hierarchy_pos()
frames, sliders = self._create_nodes_frames_2d()

Xv = [self.pos2d[k][0] for k in range(len(self.pos2d))]
Yv = [self.pos2d[k][1] for k in range(len(self.pos2d))]
Xed = []
Yed = []
for edge in self.nxgraph.edges:
Xed += [self.pos2d[edge[0]][0], self.pos2d[edge[1]][0], None]
Yed += [self.pos2d[edge[0]][1], self.pos2d[edge[1]][1], None]

colorbar = go.scatter.marker.ColorBar(title="", thickness=10, x=0)
marker = go.scatter.Marker(
symbol=self.df["symbol"],
size=self.nodesize * 2,
color=self.df["age"],
colorscale=self.colorscale,
colorbar=colorbar,
)

edges = go.Scatter(
x=Xed,
y=Yed,
mode="lines",
line=dict(color="rgb(75,75,75)", width=1),
hoverinfo="none",
name="edges",
)
nodes = go.Scatter(
x=Xv,
y=Yv,
name="LP solutions",
mode="markers",
marker=marker,
hovertext=[
f"LP obj: {self.df['objval'].iloc[i]:.3f}\
<br>node number: {self.df['number'].iloc[i]}\
<br>node age: {self.df['age'].iloc[i]}\
<br>depth: {self.df['depth'].iloc[i]}\
<br>LP cond: {self.df['condition'].iloc[i]:.1f}\
<br>iterations: {self.df['iterations'].iloc[i]}"
for i in range(len(self.df))
],
hoverinfo="text+name",
)

xaxis = go.layout.XAxis(title="", visible=False)
yaxis = go.layout.YAxis(title="", visible=False)

title = (
"Tree 2D: " + self.probname + ", " + self.scipversion if self.title else ""
)
filename = "Tree_2D_" + self.probname + ".html"

layout = go.Layout(
title=title,
font=dict(size=self.fontsize),
autosize=True,
template="simple_white",
showlegend=self.showlegend,
hovermode="closest",
xaxis=xaxis,
yaxis=yaxis,
)

layout["updatemenus"] = self.updatemenus()

layout["sliders"] = [sliders]

self.fig2d = go.Figure(data=[nodes, edges], layout=layout, frames=frames,)

self.fig2d.write_html(file=filename, include_plotlyjs=self.include_plotlyjs)

return self.fig2d

def solve(self):
"""Solve the instance and collect and generate the tree data"""

Expand Down Expand Up @@ -561,6 +706,48 @@ def solve(self):
.reset_index()
)

def hierarchy_pos(self, root=0, width=1.0, vert_gap=0.2, vert_loc=0, xcenter=0.5):
"""compute abstract node positions of the tree"""
G = self.nxgraph
if not nx.is_tree(G):
raise TypeError("cannot use hierarchy_pos on a graph that is not a tree")

def _hierarchy_pos(
G,
root,
width=1.0,
vert_gap=0.2,
vert_loc=0,
xcenter=0.5,
pos=None,
parent=None,
):
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.neighbors(root))
if not isinstance(G, nx.DiGraph) and parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(
G,
child,
width=dx,
vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap,
xcenter=nextx,
pos=pos,
parent=root,
)
return pos

self.pos2d = _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)

def compute_distances(self):
"""compute all pairwise distances between the original LP solutions and the transformed points"""
if self.df is None:
Expand Down

0 comments on commit 07c9f6b

Please sign in to comment.