Skip to content

Commit 6e0b25c

Browse files
authored
Merge pull request #86 from PathOnAI/search-algorithm-cleanup
Search algorithm cleanup
2 parents a6a814f + 3f04b54 commit 6e0b25c

File tree

9 files changed

+50
-33
lines changed

9 files changed

+50
-33
lines changed

visual-tree-search-app/components/LATSVisual.tsx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ interface TreeNode {
1212
value?: number;
1313
visits?: number;
1414
feedback?: string;
15-
reward?: number;
15+
// reward?: number;
1616
isSimulated?: boolean; // Flag to track newly simulated nodes
1717
}
1818

@@ -412,9 +412,9 @@ const LATSVisual: React.FC<SimpleSearchVisualProps> = ({ messages }) => {
412412
}
413413

414414
// Add reward info if available
415-
if (typeof d.data.reward === 'number') {
416-
tooltipContent += `<div class="mt-1">Reward: <span class="font-bold">${d.data.reward.toFixed(2)}</span></div>`;
417-
}
415+
// if (typeof d.data.reward === 'number') {
416+
// tooltipContent += `<div class="mt-1">Reward: <span class="font-bold">${d.data.reward.toFixed(2)}</span></div>`;
417+
// }
418418

419419
// Add value info if available
420420
if (typeof d.data.value === 'number') {

visual-tree-search-app/components/MessageLogPanelLATS.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ interface ParsedMessage {
7171
node_id?: string;
7272
value?: number;
7373
visits?: number;
74-
reward?: number;
7574
terminal_node_description?: string;
75+
reward?: number;
7676
step?: number;
7777
step_name?: string;
7878
iteration?: number;

visual-tree-search-app/components/SimpleSearchVisual.tsx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ interface TreeNode {
1212
value?: number;
1313
visits?: number;
1414
feedback?: string;
15-
reward?: number;
15+
// reward?: number;
1616
}
1717

1818
interface Message {
@@ -277,10 +277,10 @@ const SimpleSearchVisual: React.FC<SimpleSearchVisualProps> = ({ messages }) =>
277277
tooltipContent += `<div class="mt-2">${nodeInfo.join(' | ')}</div>`;
278278
}
279279

280-
// Add reward info if available
281-
if (typeof d.data.reward === 'number') {
282-
tooltipContent += `<div class="mt-1">Reward: <span class="font-bold">${d.data.reward.toFixed(2)}</span></div>`;
283-
}
280+
// // Add reward info if available
281+
// if (typeof d.data.reward === 'number') {
282+
// tooltipContent += `<div class="mt-1">Reward: <span class="font-bold">${d.data.reward.toFixed(2)}</span></div>`;
283+
// }
284284

285285
// Add value info if available
286286
if (typeof d.data.value === 'number') {

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/base_agent.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _get_tree_data(self):
102102
"value": node.value,
103103
"visits": node.visits,
104104
"feedback": node.feedback,
105-
"reward": node.reward
105+
# "reward": node.reward
106106
}
107107
tree_data.append(node_data)
108108

@@ -129,7 +129,7 @@ async def remove_simulated_trajectory(self, starting_node, terminal_node: LATSNo
129129
"description": node.natural_language_description,
130130
"visits": node.visits,
131131
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
132-
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
132+
# "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
133133
"is_terminal": node.is_terminal,
134134
"feedback": node.feedback if hasattr(node, 'feedback') else None,
135135
"is_root": not hasattr(node, 'parent') or node.parent is None,
@@ -159,7 +159,7 @@ def _get_trajectory_data(self, terminal_node: LATSNode):
159159
"description": node.natural_language_description,
160160
"visits": node.visits,
161161
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
162-
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
162+
# "reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
163163
"is_terminal": node.is_terminal,
164164
"feedback": node.feedback if hasattr(node, 'feedback') else None,
165165
"is_root": not hasattr(node, 'parent') or node.parent is None,
@@ -424,15 +424,18 @@ async def node_children_evaluation(self, node: LATSNode) -> None:
424424
score = 0
425425
else:
426426
trajectory = child.get_trajectory()
427-
prompt = create_llm_prompt(trajectory, self.goal)
428-
# , child.observation.image
429-
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model)
430-
score = result["overall_score"]
427+
if len(trajectory) == 0:
428+
score = 0
429+
else:
430+
prompt = create_llm_prompt(trajectory, self.goal)
431+
# , child.observation.image
432+
result = score_trajectory_with_openai(prompt, openai_client, self.config.evaluation_model)
433+
score = result["overall_score"]
431434
scores.append(score)
432435

433436
for child, score in zip(node.children, scores):
434437
child.value = score
435-
child.reward = score
438+
# child.reward = score
436439

437440
async def node_evaluation(self, node: LATSNode) -> None:
438441
"""Evaluate the current node and assign its score."""
@@ -454,13 +457,16 @@ async def node_evaluation(self, node: LATSNode) -> None:
454457
if node.is_terminal:
455458
score = 0
456459
else:
457-
prompt = create_llm_prompt(trajectory, self.goal)
458-
result = score_trajectory_with_openai(
459-
prompt,
460-
openai_client,
461-
model=self.config.evaluation_model
462-
)
463-
score = result["overall_score"]
460+
if len(trajectory) == 0:
461+
score = 0
462+
else:
463+
prompt = create_llm_prompt(trajectory, self.goal)
464+
result = score_trajectory_with_openai(
465+
prompt,
466+
openai_client,
467+
model=self.config.evaluation_model
468+
)
469+
score = result["overall_score"]
464470

465471
except Exception as e:
466472
error_msg = f"Error scoring node {id(node)}: {str(e)}"
@@ -469,7 +475,7 @@ async def node_evaluation(self, node: LATSNode) -> None:
469475

470476
# Assign the score to the node
471477
node.value = score
472-
node.reward = score
478+
# node.reward = score
473479

474480

475481
except Exception as e:

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def lats_search(self, websocket=None):
3737
await self.websocket_node_selection(node, websocket=websocket)
3838

3939
if node is None:
40-
print("All paths lead to terminal nodes with reward 0. Ending search.")
40+
print("All paths lead to terminal nodes with value 0. Ending search.")
4141
break
4242

4343
# Step 2: Node Expansion
@@ -76,8 +76,10 @@ async def lats_search(self, websocket=None):
7676
terminal_nodes.append(terminal_node)
7777
await self.websocket_simulation_result(reward, terminal_node, websocket=websocket)
7878

79-
if reward == 1:
79+
# simulation score threshold
80+
if reward >= self.config.simulation_score:
8081
await self.websocket_search_complete("success", reward, terminal_node.get_trajectory(), websocket=websocket)
82+
await self.playwright_manager.close()
8183
return terminal_node
8284

8385
# Step 5: Backpropagation
@@ -95,8 +97,8 @@ async def lats_search(self, websocket=None):
9597
all_nodes_list = collect_all_nodes(self.root_node)
9698
all_nodes_list.extend(terminal_nodes)
9799

98-
## temp change: if reward is the same, choose the deeper node
99-
best_child = max(all_nodes_list, key=lambda x: (x.reward, x.depth))
100+
## temp change: if value is the same, choose the deeper node
101+
best_child = max(all_nodes_list, key=lambda x: (x.value, x.depth))
100102

101103
if best_child.value >= 0.75:
102104
print("Successful trajectory found")

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.value = 0.0
8989
self.depth = 0 if parent is None else parent.depth + 1
9090
self.is_terminal = False
91-
self.reward = 0.0
91+
# self.reward = 0.0
9292
self.exhausted = False # If all children are terminal
9393
self.em = 0.0 # Exact match, evaluation metric
9494
self.observation: Optional[Observation] = None
@@ -177,7 +177,7 @@ def to_dict(self) -> dict:
177177
'value': self.value,
178178
'depth': self.depth,
179179
'is_terminal': self.is_terminal,
180-
'reward': self.reward,
180+
# 'reward': self.reward,
181181
'em': self.em,
182182
}
183183

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
302302
# Convert path to serializable trajectory
303303
# trajectory = [node.action for node in path if node.action is not None]
304304
await self.websocket_search_complete("success", score, selected_node.get_trajectory(), websocket=websocket)
305+
await self.playwright_manager.close()
305306
return selected_node
306307

307308
print(f"path: {path}")
@@ -328,4 +329,5 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
328329
# Convert node to serializable trajectory
329330
# trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
330331
await self.websocket_search_complete("partial_success", best_node.value, best_node.get_trajectory(), websocket=websocket)
332+
await self.playwright_manager.close()
331333
return best_node

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def bfs(self, websocket=None):
110110

111111
# Send completion update if websocket is provided
112112
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
113+
await self.playwright_manager.close()
113114

114115
return current_node
115116

@@ -120,6 +121,7 @@ async def bfs(self, websocket=None):
120121

121122
# Send completion update if websocket is provided
122123
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
124+
await self.playwright_manager.close()
123125

124126
return best_node
125127

@@ -128,6 +130,7 @@ async def bfs(self, websocket=None):
128130

129131
# Send failure update if websocket is provided
130132
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
133+
await self.playwright_manager.close()
131134

132135
return None
133136

@@ -209,7 +212,8 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
209212
print(f"Found satisfactory solution with score {score}")
210213

211214
# Send completion update if websocket is provided
212-
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
215+
await self.websocket_search_complete("success", score, current_node.get_trajectory(), websocket=websocket)
216+
await self.playwright_manager.close()
213217
return current_node
214218

215219
# Add non-terminal children to stack in reverse order
@@ -234,6 +238,7 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
234238

235239
# Send completion update if websocket is provided
236240
await self.websocket_search_complete("partial_success", best_score, best_node.get_trajectory(), websocket=websocket)
241+
await self.playwright_manager.close()
237242

238243
return best_node
239244

@@ -242,6 +247,7 @@ async def dfs(self, websocket=None) -> List[Dict[str, Any]]:
242247

243248
# Send failure update if websocket is provided
244249
await self.websocket_search_complete("failure", 0, None, websocket=websocket)
250+
await self.playwright_manager.close()
245251

246252
return None
247253

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class AgentConfig:
2525
num_simulations: int = 1
2626
account_reset: bool = True
2727

28+
simulation_score: float = 0.75
2829
reflection_score: float = 0.75
2930

3031
# Features

0 commit comments

Comments
 (0)