@@ -56,13 +56,11 @@ def __init__(self, nodes: list, edges: list, entry_point: str,
5656 self .callback_manager = CustomLLMCallbackManager ()
5757
5858 if nodes [0 ].node_name != entry_point .node_name :
59- # raise a warning if the entry point is not the first node in the list
6059 warnings .warn (
6160 "Careful! The entry point node is different from the first node in the graph." )
6261
6362 self ._set_conditional_node_edges ()
6463
65- # Burr configuration
6664 self .use_burr = use_burr
6765 self .burr_config = burr_config or {}
6866
@@ -91,7 +89,8 @@ def _set_conditional_node_edges(self):
9189 if node .node_type == 'conditional_node' :
9290 outgoing_edges = [(from_node , to_node ) for from_node , to_node in self .raw_edges if from_node .node_name == node .node_name ]
9391 if len (outgoing_edges ) != 2 :
94- raise ValueError (f"ConditionalNode '{ node .node_name } ' must have exactly two outgoing edges." )
92+ raise ValueError (f"""ConditionalNode '{ node .node_name } '
93+ must have exactly two outgoing edges.""" )
9594 node .true_node_name = outgoing_edges [0 ][1 ].node_name
9695 try :
9796 node .false_node_name = outgoing_edges [1 ][1 ].node_name
@@ -151,14 +150,14 @@ def _get_schema(self, current_node):
151150 """Extracts schema information from the node configuration."""
152151 if not hasattr (current_node , "node_config" ):
153152 return None
154-
153+
155154 if not isinstance (current_node .node_config , dict ):
156155 return None
157-
156+
158157 schema_config = current_node .node_config .get ("schema" )
159158 if not schema_config or isinstance (schema_config , dict ):
160159 return None
161-
160+
162161 try :
163162 return schema_config .schema ()
164163 except Exception :
@@ -167,7 +166,7 @@ def _get_schema(self, current_node):
167166 def _execute_node (self , current_node , state , llm_model , llm_model_name ):
168167 """Executes a single node and returns execution information."""
169168 curr_time = time .time ()
170-
169+
171170 with self .callback_manager .exclusive_get_callback (llm_model , llm_model_name ) as cb :
172171 result = current_node .execute (state )
173172 node_exec_time = time .time () - curr_time
@@ -197,17 +196,17 @@ def _get_next_node(self, current_node, result):
197196 raise ValueError (
198197 f"Conditional Node returned a node name '{ result } ' that does not exist in the graph"
199198 )
200-
199+
201200 return self .edges .get (current_node .node_name )
202201
203202 def _execute_standard (self , initial_state : dict ) -> Tuple [dict , list ]:
204203 """
205- Executes the graph by traversing nodes starting from the entry point using the standard method.
204+ Executes the graph by traversing nodes
205+ starting from the entry point using the standard method.
206206 """
207207 current_node_name = self .entry_point
208208 state = initial_state
209-
210- # Tracking variables
209+
211210 total_exec_time = 0.0
212211 exec_info = []
213212 cb_total = {
@@ -230,16 +229,13 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
230229
231230 while current_node_name :
232231 current_node = self ._get_node_by_name (current_node_name )
233-
234- # Update source information if needed
232+
235233 if source_type is None :
236234 source_type , source , prompt = self ._update_source_info (current_node , state )
237-
238- # Get model information if needed
235+
239236 if llm_model is None :
240237 llm_model , llm_model_name , embedder_model = self ._get_model_info (current_node )
241-
242- # Get schema if needed
238+
243239 if schema is None :
244240 schema = self ._get_schema (current_node )
245241
@@ -273,7 +269,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
273269 )
274270 raise e
275271
276- # Add total results to execution info
277272 exec_info .append ({
278273 "node_name" : "TOTAL RESULT" ,
279274 "total_tokens" : cb_total ["total_tokens" ],
@@ -284,7 +279,6 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
284279 "exec_time" : total_exec_time ,
285280 })
286281
287- # Log final execution results
288282 graph_execution_time = time .time () - start_time
289283 response = state .get ("answer" , None ) if source_type == "url" else None
290284 content = state .get ("parsed_doc" , None ) if response is not None else None
@@ -343,4 +337,3 @@ def append_node(self, node):
343337 self .raw_edges .append ((last_node , node ))
344338 self .nodes .append (node )
345339 self .edges = self ._create_edges ({e for e in self .raw_edges })
346-
0 commit comments