Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ public abstract class BaseAgent {
private final Optional<List<BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<AfterAgentCallback>> afterAgentCallback;

/**
* Creates a new BaseAgent.
*
* @param name Unique agent name. Cannot be "user" (reserved).
* @param description Agent purpose.
* @param subAgents Agents managed by this agent.
* @param beforeAgentCallback Callbacks before agent execution. Invoked in order until one doesn't
* return null.
* @param afterAgentCallback Callbacks after agent execution. Invoked in order until one doesn't
* return null.
*/
public BaseAgent(
String name,
String description,
Expand Down Expand Up @@ -164,6 +175,13 @@ public Optional<List<AfterAgentCallback>> afterAgentCallback() {
return afterAgentCallback;
}

/**
* Creates a shallow copy of the parent context with the agent properly being set to this
* instance.
*
* @param parentContext Parent context to copy.
* @return new context with updated branch name.
*/
private InvocationContext createInvocationContext(InvocationContext parentContext) {
InvocationContext invocationContext = InvocationContext.copyOf(parentContext);
invocationContext.agent(this);
Expand All @@ -174,6 +192,12 @@ private InvocationContext createInvocationContext(InvocationContext parentContex
return invocationContext;
}

/**
* Runs the agent asynchronously.
*
* @param parentContext Parent context to inherit.
* @return stream of agent-generated events.
*/
public Flowable<Event> runAsync(InvocationContext parentContext) {
Tracer tracer = Telemetry.getTracer();
return Flowable.defer(
Expand Down Expand Up @@ -216,20 +240,39 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
});
}

/**
* Converts before-agent callbacks to functions.
*
* @param callbacks Before-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
List<BeforeAgentCallback> callbacks) {
return callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call)
.collect(toImmutableList());
}

/**
* Converts after-agent callbacks to functions.
*
* @param callbacks After-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
List<AfterAgentCallback> callbacks) {
return callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call)
.collect(toImmutableList());
}

/**
* Calls agent callbacks and returns the first produced event, if any.
*
* @param agentCallbacks Callback functions.
* @param invocationContext Current invocation context.
* @return single emitting first event, or empty if none.
*/
private Single<Optional<Event>> callCallback(
List<Function<CallbackContext, Maybe<Content>>> agentCallbacks,
InvocationContext invocationContext) {
Expand Down Expand Up @@ -282,6 +325,12 @@ private Single<Optional<Event>> callCallback(
}));
}

/**
* Runs the agent synchronously.
*
* @param parentContext Parent context to inherit.
* @return stream of agent-generated events.
*/
public Flowable<Event> runLive(InvocationContext parentContext) {
Tracer tracer = Telemetry.getTracer();
return Flowable.defer(
Expand All @@ -295,7 +344,19 @@ public Flowable<Event> runLive(InvocationContext parentContext) {
});
}

/**
* Agent-specific asynchronous logic.
*
* @param invocationContext Current invocation context.
* @return stream of agent-generated events.
*/
protected abstract Flowable<Event> runAsyncImpl(InvocationContext invocationContext);

/**
* Agent-specific synchronous logic.
*
* @param invocationContext Current invocation context.
* @return stream of agent-generated events.
*/
protected abstract Flowable<Event> runLiveImpl(InvocationContext invocationContext);
}
23 changes: 21 additions & 2 deletions core/src/main/java/com/google/adk/agents/CallbackContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ public class CallbackContext extends ReadonlyContext {
protected EventActions eventActions;
private final State state;

/**
* Initializes callback context.
*
* @param invocationContext Current invocation context.
* @param eventActions Callback event actions.
*/
public CallbackContext(InvocationContext invocationContext, EventActions eventActions) {
super(invocationContext);
this.eventActions = eventActions != null ? eventActions : EventActions.builder().build();
Expand All @@ -51,7 +57,14 @@ public EventActions eventActions() {
return eventActions;
}

/** Loads an artifact from the artifact service associated with the current session. */
/**
* Loads an artifact from the artifact service associated with the current session.
*
* @param filename Artifact file name.
* @param version Artifact version (optional).
* @return loaded part, or empty if not found.
* @throws IllegalStateException if the artifact service is not initialized.
*/
public Maybe<Part> loadArtifact(String filename, Optional<Integer> version) {
if (invocationContext.artifactService() == null) {
throw new IllegalStateException("Artifact service is not initialized.");
Expand All @@ -66,7 +79,13 @@ public Maybe<Part> loadArtifact(String filename, Optional<Integer> version) {
version);
}

/** Saves an artifact and records it as a delta for the current session. */
/**
* Saves an artifact and records it as a delta for the current session.
*
* @param filename Artifact file name.
* @param artifact Artifact content to save.
* @throws IllegalStateException if the artifact service is not initialized.
*/
public void saveArtifact(String filename, Part artifact) {
if (invocationContext.artifactService() == null) {
throw new IllegalStateException("Artifact service is not initialized.");
Expand Down
17 changes: 15 additions & 2 deletions core/src/main/java/com/google/adk/agents/CallbackUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package com.google.adk.agents;

import com.google.adk.agents.Callbacks.AfterAgentCallback;
import com.google.adk.agents.Callbacks.AfterAgentCallbackBase;
import com.google.adk.agents.Callbacks.AfterAgentCallbackSync;
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync;
import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase;
import com.google.adk.agents.Callbacks.AfterAgentCallbackBase;
import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.reactivex.rxjava3.core.Maybe;
Expand All @@ -30,9 +30,16 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Utility methods for normalizing agent callbacks. */
public final class CallbackUtil {
private static final Logger logger = LoggerFactory.getLogger(CallbackUtil.class);

/**
* Normalizes before-agent callbacks.
*
* @param beforeAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<BeforeAgentCallback> getBeforeAgentCallbacks(
List<BeforeAgentCallbackBase> beforeAgentCallback) {
Expand Down Expand Up @@ -60,6 +67,12 @@ public final class CallbackUtil {
}
}

/**
* Normalizes after-agent callbacks.
*
* @param afterAgentCallback Callback list (sync or async).
* @return normalized async callbacks, or null if input is null.
*/
@CanIgnoreReturnValue
public static @Nullable ImmutableList<AfterAgentCallback> getAfterAgentCallbacks(
List<AfterAgentCallbackBase> afterAgentCallback) {
Expand Down
82 changes: 70 additions & 12 deletions core/src/main/java/com/google/adk/agents/Callbacks.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,27 @@
import java.util.Map;
import java.util.Optional;

/** Functional interfaces for agent lifecycle callbacks. */
public final class Callbacks {

interface BeforeModelCallbackBase {}

@FunctionalInterface
public interface BeforeModelCallback extends BeforeModelCallbackBase {
/**
* Async callback before LLM invocation.
*
* @param callbackContext Callback context.
* @param llmRequest LLM request.
* @return response override, or empty to continue.
*/
Maybe<LlmResponse> call(CallbackContext callbackContext, LlmRequest llmRequest);
}

// Helper interface to allow for sync beforeModelCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync beforeModelCallback. The function is wrapped into an async
* one before being processed further.
*/
@FunctionalInterface
public interface BeforeModelCallbackSync extends BeforeModelCallbackBase {
Optional<LlmResponse> call(CallbackContext callbackContext, LlmRequest llmRequest);
Expand All @@ -45,11 +55,20 @@ interface AfterModelCallbackBase {}

@FunctionalInterface
public interface AfterModelCallback extends AfterModelCallbackBase {
/**
* Async callback after LLM response.
*
* @param callbackContext Callback context.
* @param llmResponse LLM response.
* @return modified response, or empty to keep original.
*/
Maybe<LlmResponse> call(CallbackContext callbackContext, LlmResponse llmResponse);
}

// Helper interface to allow for sync afterModelCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync afterModelCallback. The function is wrapped into an async
* one before being processed further.
*/
@FunctionalInterface
public interface AfterModelCallbackSync extends AfterModelCallbackBase {
Optional<LlmResponse> call(CallbackContext callbackContext, LlmResponse llmResponse);
Expand All @@ -59,11 +78,19 @@ interface BeforeAgentCallbackBase {}

@FunctionalInterface
public interface BeforeAgentCallback extends BeforeAgentCallbackBase {
/**
* Async callback before agent runs.
*
* @param callbackContext Callback context.
* @return content override, or empty to continue.
*/
Maybe<Content> call(CallbackContext callbackContext);
}

// Helper interface to allow for sync beforeAgentCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync beforeAgentCallback. The function is wrapped into an async
* one before being processed further.
*/
@FunctionalInterface
public interface BeforeAgentCallbackSync extends BeforeAgentCallbackBase {
Optional<Content> call(CallbackContext callbackContext);
Expand All @@ -73,11 +100,19 @@ interface AfterAgentCallbackBase {}

@FunctionalInterface
public interface AfterAgentCallback extends AfterAgentCallbackBase {
/**
* Async callback after agent runs.
*
* @param callbackContext Callback context.
* @return modified content, or empty to keep original.
*/
Maybe<Content> call(CallbackContext callbackContext);
}

// Helper interface to allow for sync afterAgentCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync afterAgentCallback. The function is wrapped into an async
* one before being processed further.
*/
@FunctionalInterface
public interface AfterAgentCallbackSync extends AfterAgentCallbackBase {
Optional<Content> call(CallbackContext callbackContext);
Expand All @@ -87,15 +122,26 @@ interface BeforeToolCallbackBase {}

@FunctionalInterface
public interface BeforeToolCallback extends BeforeToolCallbackBase {
/**
* Async callback before tool runs.
*
* @param invocationContext Invocation context.
* @param baseTool Tool instance.
* @param input Tool input arguments.
* @param toolContext Tool context.
* @return override result, or empty to continue.
*/
Maybe<Map<String, Object>> call(
InvocationContext invocationContext,
BaseTool baseTool,
Map<String, Object> input,
ToolContext toolContext);
}

// Helper interface to allow for sync beforeToolCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync beforeToolCallback. The function is wrapped into an async
* one before being processed further.
*/
@FunctionalInterface
public interface BeforeToolCallbackSync extends BeforeToolCallbackBase {
Optional<Map<String, Object>> call(
Expand All @@ -109,6 +155,16 @@ interface AfterToolCallbackBase {}

@FunctionalInterface
public interface AfterToolCallback extends AfterToolCallbackBase {
/**
* Async callback after tool runs.
*
* @param invocationContext Invocation context.
* @param baseTool Tool instance.
* @param input Tool input arguments.
* @param toolContext Tool context.
* @param response Raw tool response.
* @return processed result, or empty to keep original.
*/
Maybe<Map<String, Object>> call(
InvocationContext invocationContext,
BaseTool baseTool,
Expand All @@ -117,8 +173,10 @@ Maybe<Map<String, Object>> call(
Object response);
}

// Helper interface to allow for sync afterToolCallback. The function is wrapped into an async
// one before being processed further.
/**
* Helper interface to allow for sync afterToolCallback. The function is wrapped into an async one
* before being processed further.
*/
@FunctionalInterface
public interface AfterToolCallbackSync extends AfterToolCallbackBase {
Optional<Map<String, Object>> call(
Expand Down
11 changes: 9 additions & 2 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -761,14 +761,21 @@ public Model resolvedModel() {
return resolvedModel;
}

/**
* Resolves the model for this agent, checking first if it is defined locally, then searching
* through ancestors.
*
* <p>This method is only for use by Agent Development Kit.
*
* @return The resolved {@link Model} for this agent.
* @throws IllegalStateException if no model is found for this agent or its ancestors.
*/
private Model resolveModelInternal() {
// 1. Check if the model is defined locally for this agent.
if (this.model.isPresent()) {
if (this.model().isPresent()) {
return this.model.get();
}
}
// 2. If not defined locally, search ancestors.
BaseAgent current = this.parentAgent();
while (current != null) {
if (current instanceof LlmAgent) {
Expand Down
Loading