3737
3838openai_client = OpenAI ()
3939
40+ ## TODO: add best_path_update
41+
4042class LATSAgent :
4143 """
4244 Language-based Action Tree Search Agent implementation.
@@ -117,6 +119,13 @@ async def run(self, websocket=None) -> list[LATSNode]:
117119 print_trajectory (best_node )
118120
119121 if websocket :
122+ # trajectory_data = self._get_trajectory_data(best_node)
123+ # await websocket.send_json({
124+ # "type": "trajectory_update",
125+ # "trajectory": trajectory_data,
126+ # "timestamp": datetime.utcnow().isoformat()
127+ # })
128+ # TODO: use score instead of reward to determine success
120129 await websocket .send_json ({
121130 "type" : "search_complete" ,
122131 "status" : "success" if best_node .reward == 1 else "partial_success" ,
@@ -158,12 +167,19 @@ async def lats_search(self, websocket=None) -> LATSNode:
158167 if websocket :
159168 await websocket .send_json ({
160169 "type" : "step_start" ,
161- "step" : "selection" ,
170+ "step" : 1 ,
171+ "step_name" : "selection" ,
162172 "iteration" : i + 1 ,
163173 "timestamp" : datetime .utcnow ().isoformat ()
164174 })
165175
166176 node = self .select_node (self .root_node )
177+ if websocket :
178+ await websocket .send_json ({
179+ "type" : "node_selected" ,
180+ "node_id" : id (node ),
181+ "timestamp" : datetime .utcnow ().isoformat ()
182+ })
167183
168184 if node is None :
169185 print ("All paths lead to terminal nodes with reward 0. Ending search." )
@@ -177,7 +193,8 @@ async def lats_search(self, websocket=None) -> LATSNode:
177193 if websocket :
178194 await websocket .send_json ({
179195 "type" : "step_start" ,
180- "step" : "expansion" ,
196+ "step" : 2 ,
197+ "step_name" : "expansion" ,
181198 "iteration" : i + 1 ,
182199 "timestamp" : datetime .utcnow ().isoformat ()
183200 })
@@ -202,28 +219,65 @@ async def lats_search(self, websocket=None) -> LATSNode:
202219 print (f"{ GREEN } Tree:{ RESET } " )
203220 better_print (self .root_node )
204221 print (f"" )
222+ tree_data = self ._get_tree_data ()
223+ await websocket .send_json ({
224+ "type" : "tree_update" ,
225+ "tree" : tree_data ,
226+ "timestamp" : datetime .utcnow ().isoformat ()
227+ })
205228
206229 # Step 3: Evaluation
207230 print (f"" )
208231 print (f"{ GREEN } Step 3: evaluation{ RESET } " )
232+ if websocket :
233+ await websocket .send_json ({
234+ "type" : "step_start" ,
235+ "step" : 3 ,
236+ "step_name" : "evaluation" ,
237+ "iteration" : i + 1 ,
238+ "timestamp" : datetime .utcnow ().isoformat ()
239+ })
209240 await self .evaluate_node (node )
210241
211242 print (f"{ GREEN } Tree:{ RESET } " )
212243 better_print (self .root_node )
213244 print (f"" )
245+ ## send tree update, since evaluation is added to the tree
246+ if websocket :
247+ tree_data = self ._get_tree_data ()
248+ await websocket .send_json ({
249+ "type" : "tree_update" ,
250+ "tree" : tree_data ,
251+ "timestamp" : datetime .utcnow ().isoformat ()
252+ })
253+
214254
215255 # Step 4: Simulation
216256 print (f"{ GREEN } Step 4: simulation{ RESET } " )
217- # # Find the child with the highest value
257+ if websocket :
258+ await websocket .send_json ({
259+ "type" : "step_start" ,
260+ "step" : 4 ,
261+ "step_name" : "simulation" ,
262+ "iteration" : i + 1 ,
263+ "timestamp" : datetime .utcnow ().isoformat ()
264+ })
218265 ## always = 1
219- reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 )
266+ reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 , websocket = websocket )
220267 terminal_nodes .append (terminal_node )
221268
222269 if reward == 1 :
223270 return terminal_node
224271
225272 # Step 5: Backpropagation
226273 print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
274+ if websocket :
275+ await websocket .send_json ({
276+ "type" : "step_start" ,
277+ "step" : 5 ,
278+ "step_name" : "backpropagation" ,
279+ "timestamp" : datetime .utcnow ().isoformat ()
280+ })
227281 self .backpropagate (terminal_node , reward )
228282 print (f"{ GREEN } Tree:{ RESET } " )
229283 better_print (self .root_node )
@@ -335,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
335389 child .value = score
336390 child .reward = score
337391
338- async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 ) -> tuple [float , LATSNode ]:
392+ ## TODO: make number of simulations configurable
393+ async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 , websocket = None ) -> tuple [float , LATSNode ]:
339394 """
340395 Perform a rollout simulation from a node.
341396
@@ -351,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
351406 print_trajectory (node )
352407 print ("print the entire tree" )
353408 print_entire_tree (self .root_node )
354- return await self .rollout (node , max_depth = max_depth )
409+ if websocket :
410+ tree_data = self ._get_tree_data ()
411+ await websocket .send_json ({
412+ "type" : "tree_update" ,
413+ "tree" : tree_data ,
414+ "timestamp" : datetime .utcnow ().isoformat ()
415+ })
416+ trajectory_data = self ._get_trajectory_data (node )
417+ await websocket .send_json ({
418+ "type" : "trajectory_update" ,
419+ "trajectory" : trajectory_data ,
420+ "timestamp" : datetime .utcnow ().isoformat ()
421+ })
422+ return await self .rollout (node , max_depth = max_depth , websocket = websocket )
355423
356- async def send_completion_request (self , plan , depth , node , trajectory = []):
424+ async def send_completion_request (self , plan , depth , node , trajectory = [], websocket = None ):
357425 print ("print the trajectory" )
358426 print_trajectory (node )
359427 print ("print the entire tree" )
360428 print_entire_tree (self .root_node )
429+ if websocket :
430+ # tree_data = self._get_tree_data()
431+ # await websocket.send_json({
432+ # "type": "tree_update",
433+ # "tree": tree_data,
434+ # "timestamp": datetime.utcnow().isoformat()
435+ # })
436+ trajectory_data = self ._get_trajectory_data (node )
437+ await websocket .send_json ({
438+ "type" : "trajectory_update" ,
439+ "trajectory" : trajectory_data ,
440+ "timestamp" : datetime .utcnow ().isoformat ()
441+ })
361442
362443 if depth >= self .config .max_depth :
363444 return trajectory , node
@@ -420,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
420501 if goal_finished :
421502 return trajectory , new_node
422503
423- return await self .send_completion_request (plan , depth + 1 , new_node , trajectory )
504+ return await self .send_completion_request (plan , depth + 1 , new_node , trajectory , websocket )
424505
425506 except Exception as e :
426507 print (f"Attempt { attempt + 1 } failed with error: { e } " )
427508 if attempt + 1 == retry_count :
428509 print ("Max retries reached. Skipping this step and retrying the whole request." )
429510 # Retry the entire request from the same state
430- return await self .send_completion_request (plan , depth , node , trajectory )
511+ return await self .send_completion_request (plan , depth , node , trajectory , websocket )
431512
432513 # If all retries and retries of retries fail, return the current trajectory and node
433514 return trajectory , node
434515
435516
436- async def rollout (self , node : LATSNode , max_depth : int = 2 )-> tuple [float , LATSNode ]:
517+ async def rollout (self , node : LATSNode , max_depth : int = 2 , websocket = None )-> tuple [float , LATSNode ]:
437518 # Reset browser state
438519 await self ._reset_browser ()
439520 path = self .get_path_to_root (node )
@@ -467,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
467548 ## call the prompt agent
468549 print ("current depth: " , len (path ) - 1 )
469550 print ("max depth: " , self .config .max_depth )
470- trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory )
551+ trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory , websocket = websocket )
471552 print ("print the trajectory" )
472553 print_trajectory (node )
473554 print ("print the entire tree" )
474555 print_entire_tree (self .root_node )
556+ if websocket :
557+ # tree_data = self._get_tree_data()
558+ # await websocket.send_json({
559+ # "type": "tree_update",
560+ # "tree": tree_data,
561+ # "timestamp": datetime.utcnow().isoformat()
562+ # })
563+ trajectory_data = self ._get_trajectory_data (node )
564+ await websocket .send_json ({
565+ "type" : "trajectory_update" ,
566+ "trajectory" : trajectory_data ,
567+ "timestamp" : datetime .utcnow ().isoformat ()
568+ })
475569
476570 page = await self .playwright_manager .get_page ()
477571 page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
@@ -769,8 +863,46 @@ def _get_tree_data(self):
769863 "is_terminal" : node .is_terminal ,
770864 "value" : node .value ,
771865 "visits" : node .visits ,
866+ "feedback" : node .feedback ,
772867 "reward" : node .reward
773868 }
774869 tree_data .append (node_data )
775870
776871 return tree_data
872+
873+ def _get_trajectory_data (self , terminal_node : LATSNode ):
874+ """Get trajectory data in a format suitable for visualization
875+
876+ Args:
877+ terminal_node: The leaf node to start the trajectory from
878+
879+ Returns:
880+ list: List of node data dictionaries representing the trajectory
881+ """
882+ trajectory_data = []
883+ path = []
884+
885+ # Collect path from terminal to root
886+ current = terminal_node
887+ while current is not None :
888+ path .append (current )
889+ current = current .parent
890+
891+ # Process nodes in order from root to terminal
892+ for level , node in enumerate (reversed (path )):
893+ node_data = {
894+ "id" : id (node ),
895+ "level" : level ,
896+ "action" : node .action if node .action else "ROOT" ,
897+ "description" : node .natural_language_description ,
898+ "visits" : node .visits ,
899+ "value" : float (f"{ node .value :.3f} " ) if hasattr (node , 'value' ) else None ,
900+ "reward" : float (f"{ node .reward :.3f} " ) if hasattr (node , 'reward' ) else None ,
901+ "is_terminal" : node .is_terminal ,
902+ "feedback" : node .feedback if hasattr (node , 'feedback' ) else None ,
903+ "is_root" : not hasattr (node , 'parent' ) or node .parent is None ,
904+ "is_terminal_node" : node == terminal_node
905+ }
906+ trajectory_data .append (node_data )
907+
908+ return trajectory_data
0 commit comments