Skip to content

Commit 8d2cfaf

Browse files
marklyszevictordibia
authored andcommitted
should_hide_tools function added to client_utils (#2966)
1 parent 6ed662b commit 8d2cfaf

File tree

2 files changed

+230
-2
lines changed

2 files changed

+230
-2
lines changed

autogen/oai/client_utils.py

+54
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,57 @@ def validate_parameter(
9797
param_value = default_value
9898

9999
return param_value
100+
101+
102+
def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], hide_tools_param: str) -> bool:
103+
"""
104+
Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
105+
Parameters:
106+
messages (List[Dict[str, Any]]): List of messages
107+
tools (List[Dict[str, Any]]): List of tools
108+
hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
109+
110+
Returns:
111+
bool: Indicates whether the tools should be excluded from the response create request
112+
113+
Example Usage:
114+
```python
115+
# Validating a numerical parameter within specific bounds
116+
messages = params.get("messages", [])
117+
tools = params.get("tools", None)
118+
hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
119+
"""
120+
121+
if hide_tools_param == "never" or tools is None or len(tools) == 0:
122+
return False
123+
elif hide_tools_param == "if_any_run":
124+
# Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
125+
return any(["tool_call_id" in dictionary for dictionary in messages])
126+
elif hide_tools_param == "if_all_run":
127+
# Return True if all tools have been executed at least once. False otherwise.
128+
129+
# Get the list of tool names
130+
check_tool_names = [item["function"]["name"] for item in tools]
131+
132+
# Prepare a list of tool call ids and related function names
133+
tool_call_ids = {}
134+
135+
# Loop through the messages and check if the tools have been run, removing them as we go
136+
for message in messages:
137+
if "tool_calls" in message:
138+
# Register the tool id and the name
139+
tool_call_ids[message["tool_calls"][0]["id"]] = message["tool_calls"][0]["function"]["name"]
140+
elif "tool_call_id" in message:
141+
# Tool called, get the name of the function based on the id
142+
tool_name_called = tool_call_ids[message["tool_call_id"]]
143+
144+
# If we had not yet called the tool, check and remove it to indicate we have
145+
if tool_name_called in check_tool_names:
146+
check_tool_names.remove(tool_name_called)
147+
148+
# Return True if all tools have been called at least once (accounted for)
149+
return len(check_tool_names) == 0
150+
else:
151+
raise TypeError(
152+
f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
153+
)

test/oai/test_client_utils.py

+176-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
import autogen
6-
from autogen.oai.client_utils import validate_parameter
6+
from autogen.oai.client_utils import should_hide_tools, validate_parameter
77

88

99
def test_validate_parameter():
@@ -132,5 +132,179 @@ def test_validate_parameter():
132132
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512
133133

134134

135+
def test_should_hide_tools():
136+
# Test messages
137+
no_tools_called_messages = [
138+
{"content": "You are a chess program and are playing for player white.", "role": "system"},
139+
{"content": "Let's play chess! Make a move.", "role": "user"},
140+
{
141+
"tool_calls": [
142+
{
143+
"id": "call_abcde56o5jlugh9uekgo84c6",
144+
"function": {"arguments": "{}", "name": "get_legal_moves"},
145+
"type": "function",
146+
}
147+
],
148+
"content": None,
149+
"role": "assistant",
150+
},
151+
{
152+
"tool_calls": [
153+
{
154+
"id": "call_p1fla56o5jlugh9uekgo84c6",
155+
"function": {"arguments": "{}", "name": "get_legal_moves"},
156+
"type": "function",
157+
}
158+
],
159+
"content": None,
160+
"role": "assistant",
161+
},
162+
{
163+
"tool_calls": [
164+
{
165+
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
166+
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
167+
"type": "function",
168+
}
169+
],
170+
"content": None,
171+
"role": "assistant",
172+
},
173+
]
174+
one_tool_called_messages = [
175+
{"content": "You are a chess program and are playing for player white.", "role": "system"},
176+
{"content": "Let's play chess! Make a move.", "role": "user"},
177+
{
178+
"tool_calls": [
179+
{
180+
"id": "call_abcde56o5jlugh9uekgo84c6",
181+
"function": {"arguments": "{}", "name": "get_legal_moves"},
182+
"type": "function",
183+
}
184+
],
185+
"content": None,
186+
"role": "assistant",
187+
},
188+
{
189+
"tool_call_id": "call_abcde56o5jlugh9uekgo84c6",
190+
"role": "user",
191+
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
192+
},
193+
{
194+
"tool_calls": [
195+
{
196+
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
197+
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
198+
"type": "function",
199+
}
200+
],
201+
"content": None,
202+
"role": "assistant",
203+
},
204+
]
205+
messages = [
206+
{"content": "You are a chess program and are playing for player white.", "role": "system"},
207+
{"content": "Let's play chess! Make a move.", "role": "user"},
208+
{
209+
"tool_calls": [
210+
{
211+
"id": "call_abcde56o5jlugh9uekgo84c6",
212+
"function": {"arguments": "{}", "name": "get_legal_moves"},
213+
"type": "function",
214+
}
215+
],
216+
"content": None,
217+
"role": "assistant",
218+
},
219+
{
220+
"tool_call_id": "call_abcde56o5jlugh9uekgo84c6",
221+
"role": "user",
222+
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
223+
},
224+
{
225+
"tool_calls": [
226+
{
227+
"id": "call_p1fla56o5jlugh9uekgo84c6",
228+
"function": {"arguments": "{}", "name": "get_legal_moves"},
229+
"type": "function",
230+
}
231+
],
232+
"content": None,
233+
"role": "assistant",
234+
},
235+
{
236+
"tool_call_id": "call_p1fla56o5jlugh9uekgo84c6",
237+
"role": "user",
238+
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
239+
},
240+
{
241+
"tool_calls": [
242+
{
243+
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
244+
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
245+
"type": "function",
246+
}
247+
],
248+
"content": None,
249+
"role": "assistant",
250+
},
251+
{"tool_call_id": "call_lcow1j0ehuhrcr3aakdmd9ju", "role": "user", "content": "Moved knight (♘) from g1 to f3."},
252+
]
253+
254+
# Test if no tools
255+
no_tools = []
256+
all_tools = [
257+
{
258+
"type": "function",
259+
"function": {
260+
"description": "Call this tool to make a move after you have the list of legal moves.",
261+
"name": "make_move",
262+
"parameters": {
263+
"type": "object",
264+
"properties": {
265+
"move": {"type": "string", "description": "A move in UCI format. (e.g. e2e4 or e7e5 or e7e8q)"}
266+
},
267+
"required": ["move"],
268+
},
269+
},
270+
},
271+
{
272+
"type": "function",
273+
"function": {
274+
"description": "Call this tool to make a move after you have the list of legal moves.",
275+
"name": "get_legal_moves",
276+
"parameters": {"type": "object", "properties": {}, "required": []},
277+
},
278+
},
279+
]
280+
281+
# Should not hide for any hide_tools value
282+
assert not should_hide_tools(messages, no_tools, "if_all_run")
283+
assert not should_hide_tools(messages, no_tools, "if_any_run")
284+
assert not should_hide_tools(messages, no_tools, "never")
285+
286+
# Has run tools but never hide, should be false
287+
assert not should_hide_tools(messages, all_tools, "never")
288+
289+
# Has run tools, should be true if all or any
290+
assert should_hide_tools(messages, all_tools, "if_all_run")
291+
assert should_hide_tools(messages, all_tools, "if_any_run")
292+
293+
# Hasn't run any tools, should be false for all
294+
assert not should_hide_tools(no_tools_called_messages, all_tools, "if_all_run")
295+
assert not should_hide_tools(no_tools_called_messages, all_tools, "if_any_run")
296+
assert not should_hide_tools(no_tools_called_messages, all_tools, "never")
297+
298+
# Has run one of the two tools, should be true only for 'if_any_run'
299+
assert not should_hide_tools(one_tool_called_messages, all_tools, "if_all_run")
300+
assert should_hide_tools(one_tool_called_messages, all_tools, "if_any_run")
301+
assert not should_hide_tools(one_tool_called_messages, all_tools, "never")
302+
303+
# Parameter validation
304+
with pytest.raises(TypeError):
305+
assert not should_hide_tools(one_tool_called_messages, all_tools, "not_a_valid_value")
306+
307+
135308
if __name__ == "__main__":
136-
test_validate_parameter()
309+
# test_validate_parameter()
310+
test_should_hide_tools()

0 commit comments

Comments
 (0)