@@ -23,6 +23,13 @@ class NIRGraph(NIRNode):
23
23
edges.
24
24
25
25
A graph of computational nodes and identity edges.
26
+
27
+ Arguments:
28
+ nodes: Dictionary of nodes in the graph.
29
+ edges: List of edges in the graph.
30
+ metadata: Dictionary of metadata for the graph.
31
+ type_check: Whether to check that input and output types match for all nodes in the graph.
32
+ Will not be stored in the graph as an attribute. Defaults to True.
26
33
"""
27
34
28
35
nodes : Nodes # List of computational nodes
@@ -31,6 +38,28 @@ class NIRGraph(NIRNode):
31
38
output_type : Optional [Dict [str , np .ndarray ]] = None
32
39
metadata : Dict [str , Any ] = field (default_factory = dict )
33
40
41
+ def __init__ (
42
+ self ,
43
+ nodes : Nodes ,
44
+ edges : Edges ,
45
+ input_type : Optional [Dict [str , np .ndarray ]] = None ,
46
+ output_type : Optional [Dict [str , np .ndarray ]] = None ,
47
+ metadata : Dict [str , Any ] = dict ,
48
+ type_check : bool = True ,
49
+ ):
50
+ self .nodes = nodes
51
+ self .edges = edges
52
+ self .metadata = metadata
53
+ self .input_type = input_type
54
+ self .output_type = output_type
55
+
56
+ # Check that all nodes have input and output types, if requested (default)
57
+ if type_check :
58
+ self ._check_types ()
59
+
60
+ # Call post init to set input_type and output_type
61
+ self .__post_init__ ()
62
+
34
63
@property
35
64
def inputs (self ):
36
65
return {
@@ -44,7 +73,7 @@ def outputs(self):
44
73
}
45
74
46
75
@staticmethod
47
- def from_list (* nodes : NIRNode ) -> "NIRGraph" :
76
+ def from_list (* nodes : NIRNode , type_check : bool = True ) -> "NIRGraph" :
48
77
"""Create a sequential graph from a list of nodes by labelling them after
49
78
indices."""
50
79
@@ -81,80 +110,58 @@ def unique_node_name(node, counts):
81
110
return NIRGraph (
82
111
nodes = node_dict ,
83
112
edges = edges ,
113
+ type_check = type_check ,
84
114
)
85
115
86
116
def __post_init__ (self ):
87
117
input_node_keys = [
88
118
k for k , node in self .nodes .items () if isinstance (node , Input )
89
119
]
90
120
self .input_type = (
91
- {node_key : self .nodes [node_key ].input_type for node_key in input_node_keys }
121
+ {
122
+ node_key : self .nodes [node_key ].input_type ["input" ]
123
+ for node_key in input_node_keys
124
+ }
92
125
if len (input_node_keys ) > 0
93
126
else None
94
127
)
95
128
output_node_keys = [
96
129
k for k , node in self .nodes .items () if isinstance (node , Output )
97
130
]
98
131
self .output_type = {
99
- node_key : self .nodes [node_key ].output_type for node_key in output_node_keys
132
+ node_key : self .nodes [node_key ].output_type ["output" ]
133
+ for node_key in output_node_keys
100
134
}
135
+ # Assign the metadata attribute if left unset to avoid issues with serialization
136
+ if not isinstance (self .metadata , dict ):
137
+ self .metadata = {}
101
138
102
139
def to_dict (self ) -> Dict [str , Any ]:
103
140
ret = super ().to_dict ()
104
141
ret ["nodes" ] = {k : n .to_dict () for k , n in self .nodes .items ()}
105
142
return ret
106
143
107
144
@classmethod
108
- def from_dict (cls , node : Dict [str , Any ]) -> "NIRNode " :
145
+ def from_dict (cls , kwargs : Dict [str , Any ]) -> "NIRGraph " :
109
146
from . import dict2NIRNode
110
147
111
- node ["nodes" ] = {k : dict2NIRNode (n ) for k , n in node ["nodes" ].items ()}
112
- # h5py deserializes edges into a numpy array of type bytes and dtype=object,
113
- # hence using ensure_str here
114
- node ["edges" ] = [(ensure_str (a ), ensure_str (b )) for a , b in node ["edges" ]]
115
- return super ().from_dict (node )
148
+ kwargs_local = kwargs .copy () # Copy the input to avoid overwriting attributes
149
+
150
+ # Assert that we have nodes and edges
151
+ assert "nodes" in kwargs , "The incoming dictionary must hade a 'nodes' entry"
152
+ assert "edges" in kwargs , "The incoming dictionary must hade a 'edges' entry"
153
+ # Assert that the type is well-formed
154
+ if "type" in kwargs :
155
+ assert kwargs ["type" ] == "NIRGraph" , "You are calling NIRGraph.from_dict with a different type "
156
+ f"{ type } . Either remove the entry or use <Specific NIRNode>.from_dict, such as Input.from_dict"
157
+ kwargs_local ["type" ] = "NIRGraph"
116
158
117
- def _check_types (self ):
118
- """Check that all nodes in the graph have input and output types.
119
-
120
- Will raise ValueError if any node has no input or output type, or if the types
121
- are inconsistent.
122
- """
123
- for edge in self .edges :
124
- pre_node = self .nodes [edge [0 ]]
125
- post_node = self .nodes [edge [1 ]]
126
159
127
- # make sure all types are defined
128
- undef_out_type = pre_node .output_type is None or any (
129
- v is None for v in pre_node .output_type .values ()
130
- )
131
- if undef_out_type :
132
- raise ValueError (f"pre node { edge [0 ]} has no output type" )
133
- undef_in_type = post_node .input_type is None or any (
134
- v is None for v in post_node .input_type .values ()
135
- )
136
- if undef_in_type :
137
- raise ValueError (f"post node { edge [1 ]} has no input type" )
138
-
139
- # make sure the length of types is equal
140
- if len (pre_node .output_type ) != len (post_node .input_type ):
141
- pre_repr = f"len({ edge [0 ]} .output)={ len (pre_node .output_type )} "
142
- post_repr = f"len({ edge [1 ]} .input)={ len (post_node .input_type )} "
143
- raise ValueError (f"type length mismatch: { pre_repr } -> { post_repr } " )
144
-
145
- # make sure the type values match up
146
- if len (pre_node .output_type .keys ()) == 1 :
147
- post_input_type = list (post_node .input_type .values ())[0 ]
148
- pre_output_type = list (pre_node .output_type .values ())[0 ]
149
- if not np .array_equal (post_input_type , pre_output_type ):
150
- pre_repr = f"{ edge [0 ]} .output: { pre_output_type } "
151
- post_repr = f"{ edge [1 ]} .input: { post_input_type } "
152
- raise ValueError (f"type mismatch: { pre_repr } -> { post_repr } " )
153
- else :
154
- raise NotImplementedError (
155
- "multiple input/output types not supported yet"
156
- )
157
- return True
160
+ kwargs_local ["nodes" ] = {k : dict2NIRNode (n ) for k , n in kwargs_local ["nodes" ].items ()}
161
+ # h5py deserializes edges into a numpy array of type bytes and dtype=object,
162
+ # hence using ensure_str here
163
+ kwargs_local ["edges" ] = [(ensure_str (a ), ensure_str (b )) for a , b in kwargs_local ["edges" ]]
164
+ return super ().from_dict (kwargs_local )
158
165
159
166
def _forward_type_inference (self , debug = True ):
160
167
"""Infer the types of all nodes in this graph. Will modify the input_type and
@@ -497,12 +504,14 @@ def from_dict(cls, node: Dict[str, Any]) -> "NIRNode":
497
504
del node ["shape" ]
498
505
return super ().from_dict (node )
499
506
507
+
500
508
@dataclass (eq = False )
501
509
class Identity (NIRNode ):
502
510
"""Identity Node.
503
511
504
512
This is a virtual node, which allows for the identity operation.
505
513
"""
514
+
506
515
input_type : Types
507
516
508
517
def __post_init__ (self ):
@@ -515,4 +524,4 @@ def to_dict(self) -> Dict[str, Any]:
515
524
516
525
@classmethod
517
526
def from_dict (cls , node : Dict [str , Any ]) -> "NIRNode" :
518
- return super ().from_dict (node )
527
+ return super ().from_dict (node )
0 commit comments