Skip to content

Commit c40ee70

Browse files
Jegpbenkroehs
andauthored
Feature check types (#124)
* Add: Call _check_types in init of NIRGraph, NIRGraph.from_list, and NIRGraph.from_dict * Added pytest to nix flake Throws `AttributeError` on call `NIRNode.__init__`. * Added a vscode devcontainer --------- Co-authored-by: Ben Kroehs <[email protected]> Co-authored-by: Ben Kroehs <[email protected]>
1 parent c9af31f commit c40ee70

File tree

8 files changed

+183
-101
lines changed

8 files changed

+183
-101
lines changed

.devcontainer/Dockerfile

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
FROM mcr.microsoft.com/devcontainers/python:1-3.12-bullseye
2+
3+
RUN pip install numpy black ruff nir pytest
4+
RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu

.devcontainer/devcontainer.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
2+
// README at: https://github.com/devcontainers/templates/tree/main/src/python
3+
{
4+
"name": "Python 3",
5+
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
6+
"build": {
7+
"dockerfile": "Dockerfile"
8+
},
9+
"customizations": {
10+
"vscode": {
11+
"extensions": [
12+
"ms-python.black-formatter",
13+
"ms-azuretools.vscode-docker",
14+
"ms-toolsai.jupyter"
15+
]
16+
}
17+
}
18+
19+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
20+
// "forwardPorts": [],
21+
22+
// Use 'postCreateCommand' to run commands after the container is created.
23+
// "postCreateCommand": "pip3 install -e .",
24+
25+
// Configure tool-specific properties.>
26+
// "customizations": {}
27+
28+
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
29+
// "remoteUser": "root"
30+
}

flake.nix

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pythonPackages.numpy
2020
pythonPackages.h5py
2121
pythonPackages.black
22+
pythonPackages.pytest
2223
pkgs.ruff
2324
pkgs.autoPatchelfHook
2425
];

nir/ir/graph.py

+59-50
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ class NIRGraph(NIRNode):
2323
edges.
2424
2525
A graph of computational nodes and identity edges.
26+
27+
Arguments:
28+
nodes: Dictionary of nodes in the graph.
29+
edges: List of edges in the graph.
30+
metadata: Dictionary of metadata for the graph.
31+
type_check: Whether to check that input and output types match for all nodes in the graph.
32+
Will not be stored in the graph as an attribute. Defaults to True.
2633
"""
2734

2835
nodes: Nodes # List of computational nodes
@@ -31,6 +38,28 @@ class NIRGraph(NIRNode):
3138
output_type: Optional[Dict[str, np.ndarray]] = None
3239
metadata: Dict[str, Any] = field(default_factory=dict)
3340

41+
def __init__(
42+
self,
43+
nodes: Nodes,
44+
edges: Edges,
45+
input_type: Optional[Dict[str, np.ndarray]] = None,
46+
output_type: Optional[Dict[str, np.ndarray]] = None,
47+
metadata: Dict[str, Any] = dict,
48+
type_check: bool = True,
49+
):
50+
self.nodes = nodes
51+
self.edges = edges
52+
self.metadata = metadata
53+
self.input_type = input_type
54+
self.output_type = output_type
55+
56+
# Check that all nodes have input and output types, if requested (default)
57+
if type_check:
58+
self._check_types()
59+
60+
# Call post init to set input_type and output_type
61+
self.__post_init__()
62+
3463
@property
3564
def inputs(self):
3665
return {
@@ -44,7 +73,7 @@ def outputs(self):
4473
}
4574

4675
@staticmethod
47-
def from_list(*nodes: NIRNode) -> "NIRGraph":
76+
def from_list(*nodes: NIRNode, type_check: bool = True) -> "NIRGraph":
4877
"""Create a sequential graph from a list of nodes by labelling them after
4978
indices."""
5079

@@ -81,80 +110,58 @@ def unique_node_name(node, counts):
81110
return NIRGraph(
82111
nodes=node_dict,
83112
edges=edges,
113+
type_check=type_check,
84114
)
85115

86116
def __post_init__(self):
87117
input_node_keys = [
88118
k for k, node in self.nodes.items() if isinstance(node, Input)
89119
]
90120
self.input_type = (
91-
{node_key: self.nodes[node_key].input_type for node_key in input_node_keys}
121+
{
122+
node_key: self.nodes[node_key].input_type["input"]
123+
for node_key in input_node_keys
124+
}
92125
if len(input_node_keys) > 0
93126
else None
94127
)
95128
output_node_keys = [
96129
k for k, node in self.nodes.items() if isinstance(node, Output)
97130
]
98131
self.output_type = {
99-
node_key: self.nodes[node_key].output_type for node_key in output_node_keys
132+
node_key: self.nodes[node_key].output_type["output"]
133+
for node_key in output_node_keys
100134
}
135+
# Assign the metadata attribute if left unset to avoid issues with serialization
136+
if not isinstance(self.metadata, dict):
137+
self.metadata = {}
101138

102139
def to_dict(self) -> Dict[str, Any]:
103140
ret = super().to_dict()
104141
ret["nodes"] = {k: n.to_dict() for k, n in self.nodes.items()}
105142
return ret
106143

107144
@classmethod
108-
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
145+
def from_dict(cls, kwargs: Dict[str, Any]) -> "NIRGraph":
109146
from . import dict2NIRNode
110147

111-
node["nodes"] = {k: dict2NIRNode(n) for k, n in node["nodes"].items()}
112-
# h5py deserializes edges into a numpy array of type bytes and dtype=object,
113-
# hence using ensure_str here
114-
node["edges"] = [(ensure_str(a), ensure_str(b)) for a, b in node["edges"]]
115-
return super().from_dict(node)
148+
kwargs_local = kwargs.copy() # Copy the input to avoid overwriting attributes
149+
150+
# Assert that we have nodes and edges
151+
assert "nodes" in kwargs, "The incoming dictionary must hade a 'nodes' entry"
152+
assert "edges" in kwargs, "The incoming dictionary must hade a 'edges' entry"
153+
# Assert that the type is well-formed
154+
if "type" in kwargs:
155+
assert kwargs["type"] == "NIRGraph", "You are calling NIRGraph.from_dict with a different type "
156+
f"{type}. Either remove the entry or use <Specific NIRNode>.from_dict, such as Input.from_dict"
157+
kwargs_local["type"] = "NIRGraph"
116158

117-
def _check_types(self):
118-
"""Check that all nodes in the graph have input and output types.
119-
120-
Will raise ValueError if any node has no input or output type, or if the types
121-
are inconsistent.
122-
"""
123-
for edge in self.edges:
124-
pre_node = self.nodes[edge[0]]
125-
post_node = self.nodes[edge[1]]
126159

127-
# make sure all types are defined
128-
undef_out_type = pre_node.output_type is None or any(
129-
v is None for v in pre_node.output_type.values()
130-
)
131-
if undef_out_type:
132-
raise ValueError(f"pre node {edge[0]} has no output type")
133-
undef_in_type = post_node.input_type is None or any(
134-
v is None for v in post_node.input_type.values()
135-
)
136-
if undef_in_type:
137-
raise ValueError(f"post node {edge[1]} has no input type")
138-
139-
# make sure the length of types is equal
140-
if len(pre_node.output_type) != len(post_node.input_type):
141-
pre_repr = f"len({edge[0]}.output)={len(pre_node.output_type)}"
142-
post_repr = f"len({edge[1]}.input)={len(post_node.input_type)}"
143-
raise ValueError(f"type length mismatch: {pre_repr} -> {post_repr}")
144-
145-
# make sure the type values match up
146-
if len(pre_node.output_type.keys()) == 1:
147-
post_input_type = list(post_node.input_type.values())[0]
148-
pre_output_type = list(pre_node.output_type.values())[0]
149-
if not np.array_equal(post_input_type, pre_output_type):
150-
pre_repr = f"{edge[0]}.output: {pre_output_type}"
151-
post_repr = f"{edge[1]}.input: {post_input_type}"
152-
raise ValueError(f"type mismatch: {pre_repr} -> {post_repr}")
153-
else:
154-
raise NotImplementedError(
155-
"multiple input/output types not supported yet"
156-
)
157-
return True
160+
kwargs_local["nodes"] = {k: dict2NIRNode(n) for k, n in kwargs_local["nodes"].items()}
161+
# h5py deserializes edges into a numpy array of type bytes and dtype=object,
162+
# hence using ensure_str here
163+
kwargs_local["edges"] = [(ensure_str(a), ensure_str(b)) for a, b in kwargs_local["edges"]]
164+
return super().from_dict(kwargs_local)
158165

159166
def _forward_type_inference(self, debug=True):
160167
"""Infer the types of all nodes in this graph. Will modify the input_type and
@@ -497,12 +504,14 @@ def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
497504
del node["shape"]
498505
return super().from_dict(node)
499506

507+
500508
@dataclass(eq=False)
501509
class Identity(NIRNode):
502510
"""Identity Node.
503511
504512
This is a virtual node, which allows for the identity operation.
505513
"""
514+
506515
input_type: Types
507516

508517
def __post_init__(self):
@@ -515,4 +524,4 @@ def to_dict(self) -> Dict[str, Any]:
515524

516525
@classmethod
517526
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
518-
return super().from_dict(node)
527+
return super().from_dict(node)

nir/ir/node.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def to_dict(self) -> Dict[str, Any]:
3737
return ret
3838

3939
@classmethod
40-
def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
41-
assert node["type"] == cls.__name__
42-
del node["type"]
40+
def from_dict(cls, kwargs: Dict[str, Any]) -> "NIRNode":
41+
assert kwargs["type"] == cls.__name__
42+
del kwargs["type"]
4343

44-
return cls(**node)
44+
return cls(**kwargs)

0 commit comments

Comments
 (0)