Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -440,4 +440,116 @@ describe('executeAsReasoningAgent', () => {
// Without maxDurationMs we are not forced to complete, so toolChoice is auto (continuation)
expect(secondCall.toolChoice).toEqual('auto');
});

test('throws when abortSignal is already aborted before starting', async () => {
const prompt = makePrompt();
const inferenceClient = {
prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }),
} as Partial<jest.Mocked<BoundInferenceClient>> as jest.Mocked<BoundInferenceClient>;

const abortController = new AbortController();
abortController.abort();

await expect(
executeAsReasoningAgent({
inferenceClient,
prompt,
maxSteps: 1,
abortSignal: abortController.signal,
toolCallbacks: { fetch_data: jest.fn(), complete: jest.fn() },
input: { foo: '' },
})
).rejects.toThrow('Request was aborted');

expect(inferenceClient.prompt).not.toHaveBeenCalled();
});

test('throws when abortSignal is aborted between steps', async () => {
const prompt = makePrompt();
const abortController = new AbortController();

let callCount = 0;
const inferenceClient = {
prompt: jest.fn().mockImplementation(() => {
callCount++;
if (callCount === 1) {
return {
content: 'call tool',
toolCalls: [
{
type: 'function',
function: { name: 'fetch_data', arguments: {} },
toolCallId: 'x',
},
],
tokens: 1,
};
}
abortController.abort();
return { content: 'final', toolCalls: [], tokens: 1 };
}),
} as Partial<jest.Mocked<BoundInferenceClient>> as jest.Mocked<BoundInferenceClient>;

const fetchData = jest.fn().mockResolvedValue({ response: { result: 'ok' } });

await expect(
executeAsReasoningAgent({
inferenceClient,
prompt,
maxSteps: 3,
abortSignal: abortController.signal,
toolCallbacks: { fetch_data: fetchData, complete: jest.fn() },
input: { foo: '' },
})
).rejects.toThrow('Request was aborted');

expect(callCount).toBeGreaterThanOrEqual(1);
});

test('skips tool callbacks when abortSignal is already aborted', async () => {
const prompt = makePrompt();
const abortController = new AbortController();
abortController.abort();

const inferenceClient = {
prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }),
} as Partial<jest.Mocked<BoundInferenceClient>> as jest.Mocked<BoundInferenceClient>;

const fetchData = jest.fn();

await expect(
executeAsReasoningAgent({
inferenceClient,
prompt,
maxSteps: 2,
abortSignal: abortController.signal,
toolCallbacks: { fetch_data: fetchData, complete: jest.fn() },
input: { foo: '' },
})
).rejects.toThrow('Request was aborted');

expect(fetchData).not.toHaveBeenCalled();
expect(inferenceClient.prompt).not.toHaveBeenCalled();
});

test('forwards abortSignal to inferenceClient.prompt calls', async () => {
const prompt = makePrompt();
const abortController = new AbortController();

const inferenceClient = {
prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }),
} as Partial<jest.Mocked<BoundInferenceClient>> as jest.Mocked<BoundInferenceClient>;

await executeAsReasoningAgent({
inferenceClient,
prompt,
maxSteps: 1,
abortSignal: abortController.signal,
toolCallbacks: { fetch_data: jest.fn(), complete: jest.fn() },
input: { foo: '' },
});

const callArgs = inferenceClient.prompt.mock.calls[0][0];
expect(callArgs.abortSignal).toBe(abortController.signal);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export async function executeAsReasoningAgent(
maxSteps = 10,
power = 'medium',
toolCallbacks,
abortSignal,
} = options;
const startTime = Date.now();

Expand All @@ -104,6 +105,15 @@ export async function executeAsReasoningAgent(
throw new Error(`Unexpected planning tool call ${toolCall.function.name}`);
}

if (abortSignal?.aborted) {
return {
response: { error: new Error('Request was aborted'), data: undefined },
name: toolCall.function.name,
toolCallId: toolCall.toolCallId,
role: MessageRole.Tool,
};
}

const callback = toolCallbacks[toolCall.function.name];

const response = await withExecuteToolSpan(
Expand Down Expand Up @@ -142,6 +152,10 @@ export async function executeAsReasoningAgent(
stepsLeft: number;
temperature?: number;
}): Promise<ReasoningPromptResponse> {
if (abortSignal?.aborted) {
throw new Error('Request was aborted');
}

const lastAssistantMessage = givenMessages.findLast(
(msg): msg is AssistantMessage => msg.role === MessageRole.Assistant
);
Expand Down Expand Up @@ -241,6 +255,7 @@ export async function executeAsReasoningAgent(
stream: false,
temperature,
toolChoice,
abortSignal,
prevMessages: formatMessages({
messages: prevMessages,
power,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ export interface ReasoningPromptOptions {
maxSteps?: number;
prevMessages?: undefined;
power?: ReasoningPower;
/**
* An optional AbortSignal that allows cancellation of the reasoning agent loop.
* When the signal is aborted, the agent will stop between reasoning steps
* and forward the signal to inference calls so in-progress requests are cancelled.
*/
abortSignal?: AbortSignal;
}

export type ReasoningPromptResponseOf<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = {
description:
'Fires after insights discovery completes. Synthesizes discovery insights into wiki pages organized by categories.',
execute: async (context) => {
const { memory, logger, trigger, inferenceClient } = context;
const { memory, logger, trigger, inferenceClient, abortSignal } = context;
const { insights } = trigger.payload as {
insights: DiscoveryInsight[];
};
Expand Down Expand Up @@ -62,6 +62,10 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = {
const allEntries = await memory.listAll();

for (const { streamName, streamInsights } of streamGroups) {
if (abortSignal?.aborted) {
throw new Error('Request was aborted');
}

try {
const existingPages = formatExistingPages(allEntries);

Expand All @@ -76,6 +80,7 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = {
existingPages,
},
maxSteps: 10,
abortSignal,
toolCallbacks: {
read_memory_page: createReadMemoryPageCallback({ memory }),
write_memory_page: createWriteMemoryPageCallback({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ export interface MemoryUpdateContext {
* Insight client for triggers that need to read insights/KIs.
*/
insightClient?: InsightClient;
/**
* An optional AbortSignal that triggers can check to support task cancellation.
*/
abortSignal?: AbortSignal;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export function cancellableTask(
);

if (task.status === TaskStatus.BeingCanceled) {
runContext.abortController.abort();
resolve('canceled' as const);
}
}, 5000);
Expand All @@ -55,14 +56,17 @@ export function cancellableTask(

/**
* Here the task can be in BeingCanceled state in two scenarios:
* 1. cancellationPromise was resolved
* 1. cancellationPromise was resolved — abort was already called in the
* polling loop before resolving the promise.
* 2. run() exited early in response to cancellation. This might
* happen for multi-step tasks, like onboarding, in order to prevent
* scheduling the next sub-task while the parent task was already
* canceled.
* happen for multi-step tasks, like onboarding, in order to prevent
* scheduling the next sub-task while the parent task was already
* canceled. In this case, the signal may not have been aborted yet.
*/
if (task.status === TaskStatus.BeingCanceled) {
runContext.abortController.abort();
if (!runContext.abortController.signal.aborted) {
runContext.abortController.abort();
}
await taskClient.markCanceled(task);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ export function createStreamsConversationScraperTask(taskContext: TaskContext) {
existingPages,
},
maxSteps: 20,
abortSignal: runContext.abortController.signal,
toolCallbacks: {
get_conversation_details: async (toolCall) => {
const { index } = toolCall.function.arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ export function createStreamsMemoryConsolidationTask(taskContext: TaskContext) {
existingPages,
},
maxSteps: 30,
abortSignal: runContext.abortController.signal,
toolCallbacks: {
read_memory_page: createReadMemoryPageCallback({ memory }),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ export function createStreamsMemoryGenerationTask(taskContext: TaskContext) {
taskLogger.info(`Found ${allEntries.length} existing memory entries total`);

for (const { streamName, indicators } of streamGroups) {
if (runContext.abortController.signal.aborted) {
throw new Error('Request was aborted');
}

taskLogger.info(
`Processing stream "${streamName}" with ${indicators.length} indicator(s) via reasoning agent`
);
Expand All @@ -136,6 +140,7 @@ export function createStreamsMemoryGenerationTask(taskContext: TaskContext) {
existingPages,
},
maxSteps: 10,
abortSignal: runContext.abortController.signal,
toolCallbacks: {
get_indicator_details: async (toolCall) => {
const { index } = toolCall.function.arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export function createStreamsMemoryUpdateTask(taskContext: TaskContext) {
esClient: scopedClusterClient.asCurrentUser,
insightClient,
payload,
abortSignal: runContext.abortController.signal,
});

taskLogger.info(`Memory update trigger "${triggerId}" completed successfully`);
Expand Down
Loading