Skip to content

Commit cfdb958

Browse files
Merge pull request #489 from vorburger:FunctionTool_🍝
PiperOrigin-RevId: 823581493
2 parents 091275f + 639b04a commit cfdb958

File tree

1 file changed

+32
-62
lines changed

1 file changed

+32
-62
lines changed

core/src/main/java/com/google/adk/tools/FunctionTool.java

Lines changed: 32 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -185,61 +185,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
185185
@SuppressWarnings("unchecked") // For tool parameter type casting.
186186
private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext toolContext)
187187
throws IllegalAccessException, InvocationTargetException {
188-
Parameter[] parameters = func.getParameters();
189-
Object[] arguments = new Object[parameters.length];
190-
for (int i = 0; i < parameters.length; i++) {
191-
String paramName =
192-
parameters[i].isAnnotationPresent(Annotations.Schema.class)
193-
&& !parameters[i].getAnnotation(Annotations.Schema.class).name().isEmpty()
194-
? parameters[i].getAnnotation(Annotations.Schema.class).name()
195-
: parameters[i].getName();
196-
if ("toolContext".equals(paramName)) {
197-
arguments[i] = toolContext;
198-
continue;
199-
}
200-
if ("inputStream".equals(paramName)) {
201-
arguments[i] = null;
202-
continue;
203-
}
204-
Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class);
205-
if (!args.containsKey(paramName)) {
206-
if (schema != null && schema.optional()) {
207-
arguments[i] = null;
208-
continue;
209-
} else {
210-
throw new IllegalArgumentException(
211-
String.format(
212-
"The parameter '%s' was not found in the arguments provided by the model.",
213-
paramName));
214-
}
215-
}
216-
Class<?> paramType = parameters[i].getType();
217-
Object argValue = args.get(paramName);
218-
if (paramType.equals(List.class)) {
219-
if (argValue instanceof List) {
220-
Type type =
221-
((ParameterizedType) parameters[i].getParameterizedType())
222-
.getActualTypeArguments()[0];
223-
Class<?> typeArgClass;
224-
if (type instanceof Class) {
225-
// Case 1: The argument is a simple class like String, Integer, etc.
226-
typeArgClass = (Class<?>) type;
227-
} else if (type instanceof ParameterizedType pType) {
228-
// Case 2: The argument is another parameterized type like Map<String, Integer>
229-
typeArgClass = (Class<?>) pType.getRawType(); // Get the raw class (e.g., Map)
230-
} else {
231-
throw new IllegalArgumentException(
232-
String.format("Unsupported parameterized type %s for '%s'", type, paramName));
233-
}
234-
arguments[i] = createList((List<Object>) argValue, typeArgClass);
235-
continue;
236-
}
237-
} else if (argValue instanceof Map) {
238-
arguments[i] = OBJECT_MAPPER.convertValue(argValue, paramType);
239-
continue;
240-
}
241-
arguments[i] = castValue(argValue, paramType);
242-
}
188+
Object[] arguments = buildArguments(args, toolContext, null);
243189
Object result = func.invoke(instance, arguments);
244190
if (result == null) {
245191
return Maybe.empty();
@@ -263,6 +209,21 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
263209
public Flowable<Map<String, Object>> callLive(
264210
Map<String, Object> args, ToolContext toolContext, InvocationContext invocationContext)
265211
throws IllegalAccessException, InvocationTargetException {
212+
Object[] arguments = buildArguments(args, toolContext, invocationContext);
213+
Object result = func.invoke(instance, arguments);
214+
if (result instanceof Flowable) {
215+
return (Flowable<Map<String, Object>>) result;
216+
} else {
217+
throw new IllegalArgumentException(
218+
"callLive was called but the underlying function does not return a Flowable.");
219+
}
220+
}
221+
222+
@SuppressWarnings("unchecked") // For tool parameter type casting.
223+
private Object[] buildArguments(
224+
Map<String, Object> args,
225+
ToolContext toolContext,
226+
@Nullable InvocationContext invocationContext) {
266227
Parameter[] parameters = func.getParameters();
267228
Object[] arguments = new Object[parameters.length];
268229
for (int i = 0; i < parameters.length; i++) {
@@ -276,7 +237,8 @@ public Flowable<Map<String, Object>> callLive(
276237
continue;
277238
}
278239
if ("inputStream".equals(paramName)) {
279-
if (invocationContext.activeStreamingTools().containsKey(this.name())
240+
if (invocationContext != null
241+
&& invocationContext.activeStreamingTools().containsKey(this.name())
280242
&& invocationContext.activeStreamingTools().get(this.name()).stream() != null) {
281243
arguments[i] = invocationContext.activeStreamingTools().get(this.name()).stream();
282244
} else {
@@ -303,7 +265,8 @@ public Flowable<Map<String, Object>> callLive(
303265
Type type =
304266
((ParameterizedType) parameters[i].getParameterizedType())
305267
.getActualTypeArguments()[0];
306-
arguments[i] = createList((List<Object>) argValue, (Class) type);
268+
Class<?> typeArgClass = getTypeClass(type, paramName);
269+
arguments[i] = createList((List<Object>) argValue, typeArgClass);
307270
continue;
308271
}
309272
} else if (argValue instanceof Map) {
@@ -312,12 +275,19 @@ public Flowable<Map<String, Object>> callLive(
312275
}
313276
arguments[i] = castValue(argValue, paramType);
314277
}
315-
Object result = func.invoke(instance, arguments);
316-
if (result instanceof Flowable) {
317-
return (Flowable<Map<String, Object>>) result;
278+
return arguments;
279+
}
280+
281+
private static Class<?> getTypeClass(Type type, String paramName) {
282+
if (type instanceof Class) {
283+
// Case 1: The argument is a simple class like String, Integer, etc.
284+
return (Class<?>) type;
285+
} else if (type instanceof ParameterizedType pType) {
286+
// Case 2: The argument is another parameterized type like Map<String, Integer>
287+
return (Class<?>) pType.getRawType(); // Get the raw class (e.g., Map)
318288
} else {
319-
logger.warn("callLive was called but the underlying function does not return a Flowable.");
320-
return Flowable.empty();
289+
throw new IllegalArgumentException(
290+
String.format("Unsupported parameterized type %s for '%s'", type, paramName));
321291
}
322292
}
323293

0 commit comments

Comments
 (0)