diff --git a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.test.ts b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.test.ts index 1aea0408482c8..10728fbc73a9c 100644 --- a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.test.ts +++ b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.test.ts @@ -440,4 +440,116 @@ describe('executeAsReasoningAgent', () => { // Without maxDurationMs we are not forced to complete, so toolChoice is auto (continuation) expect(secondCall.toolChoice).toEqual('auto'); }); + + test('throws when abortSignal is already aborted before starting', async () => { + const prompt = makePrompt(); + const inferenceClient = { + prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }), + } as Partial> as jest.Mocked; + + const abortController = new AbortController(); + abortController.abort(); + + await expect( + executeAsReasoningAgent({ + inferenceClient, + prompt, + maxSteps: 1, + abortSignal: abortController.signal, + toolCallbacks: { fetch_data: jest.fn(), complete: jest.fn() }, + input: { foo: '' }, + }) + ).rejects.toThrow('Request was aborted'); + + expect(inferenceClient.prompt).not.toHaveBeenCalled(); + }); + + test('throws when abortSignal is aborted between steps', async () => { + const prompt = makePrompt(); + const abortController = new AbortController(); + + let callCount = 0; + const inferenceClient = { + prompt: jest.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return { + content: 'call tool', + toolCalls: [ + { + type: 'function', + function: { name: 'fetch_data', arguments: {} }, + toolCallId: 'x', + }, + ], + tokens: 1, + }; + } + abortController.abort(); + return { content: 'final', toolCalls: [], tokens: 1 }; + }), + } as Partial> as jest.Mocked; + + const fetchData = jest.fn().mockResolvedValue({ response: { result: 'ok' } }); + + await expect( + executeAsReasoningAgent({ + inferenceClient, + prompt, + maxSteps: 3, + abortSignal: abortController.signal, + toolCallbacks: { fetch_data: fetchData, complete: jest.fn() }, + input: { foo: '' }, + }) + ).rejects.toThrow('Request was aborted'); + + expect(callCount).toBeGreaterThanOrEqual(1); + }); + + test('skips tool callbacks when abortSignal is already aborted', async () => { + const prompt = makePrompt(); + const abortController = new AbortController(); + abortController.abort(); + + const inferenceClient = { + prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }), + } as Partial> as jest.Mocked; + + const fetchData = jest.fn(); + + await expect( + executeAsReasoningAgent({ + inferenceClient, + prompt, + maxSteps: 2, + abortSignal: abortController.signal, + toolCallbacks: { fetch_data: fetchData, complete: jest.fn() }, + input: { foo: '' }, + }) + ).rejects.toThrow('Request was aborted'); + + expect(fetchData).not.toHaveBeenCalled(); + expect(inferenceClient.prompt).not.toHaveBeenCalled(); + }); + + test('forwards abortSignal to inferenceClient.prompt calls', async () => { + const prompt = makePrompt(); + const abortController = new AbortController(); + + const inferenceClient = { + prompt: jest.fn().mockResolvedValue({ content: 'final', toolCalls: [], tokens: 1 }), + } as Partial> as jest.Mocked; + + await executeAsReasoningAgent({ + inferenceClient, + prompt, + maxSteps: 1, + abortSignal: abortController.signal, + toolCallbacks: { fetch_data: jest.fn(), complete: jest.fn() }, + input: { foo: '' }, + }); + + const callArgs = inferenceClient.prompt.mock.calls[0][0]; + expect(callArgs.abortSignal).toBe(abortController.signal); + }); }); diff --git a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.ts b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.ts index 00b733cc19c2c..77b34b37f59b4 100644 --- a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.ts +++ b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/execute_as_reasoning_agent.ts @@ -94,6 +94,7 @@ export async function executeAsReasoningAgent( maxSteps = 10, power = 'medium', toolCallbacks, + abortSignal, } = options; const startTime = Date.now(); @@ -104,6 +105,15 @@ export async function executeAsReasoningAgent( throw new Error(`Unexpected planning tool call ${toolCall.function.name}`); } + if (abortSignal?.aborted) { + return { + response: { error: new Error('Request was aborted'), data: undefined }, + name: toolCall.function.name, + toolCallId: toolCall.toolCallId, + role: MessageRole.Tool, + }; + } + const callback = toolCallbacks[toolCall.function.name]; const response = await withExecuteToolSpan( @@ -142,6 +152,10 @@ export async function executeAsReasoningAgent( stepsLeft: number; temperature?: number; }): Promise { + if (abortSignal?.aborted) { + throw new Error('Request was aborted'); + } + const lastAssistantMessage = givenMessages.findLast( (msg): msg is AssistantMessage => msg.role === MessageRole.Assistant ); @@ -241,6 +255,7 @@ export async function executeAsReasoningAgent( stream: false, temperature, toolChoice, + abortSignal, prevMessages: formatMessages({ messages: prevMessages, power, diff --git a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/types.ts b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/types.ts index 8bddab9575ed0..c0fa8def6dd35 100644 --- a/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/types.ts +++ b/x-pack/platform/packages/shared/kbn-inference-prompt-utils/src/flows/reasoning/types.ts @@ -33,6 +33,12 @@ export interface ReasoningPromptOptions { maxSteps?: number; prevMessages?: undefined; power?: ReasoningPower; + /** + * An optional AbortSignal that allows cancellation of the reasoning agent loop. + * When the signal is aborted, the agent will stop between reasoning steps + * and forward the signal to inference calls so in-progress requests are cancelled. + */ + abortSignal?: AbortSignal; } export type ReasoningPromptResponseOf< diff --git a/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/discovery_completed_trigger.ts b/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/discovery_completed_trigger.ts index 3a9586eeb7a14..c8345c6251bdb 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/discovery_completed_trigger.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/discovery_completed_trigger.ts @@ -34,7 +34,7 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = { description: 'Fires after insights discovery completes. Synthesizes discovery insights into wiki pages organized by categories.', execute: async (context) => { - const { memory, logger, trigger, inferenceClient } = context; + const { memory, logger, trigger, inferenceClient, abortSignal } = context; const { insights } = trigger.payload as { insights: DiscoveryInsight[]; }; @@ -62,6 +62,10 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = { const allEntries = await memory.listAll(); for (const { streamName, streamInsights } of streamGroups) { + if (abortSignal?.aborted) { + throw new Error('Request was aborted'); + } + try { const existingPages = formatExistingPages(allEntries); @@ -76,6 +80,7 @@ export const discoveryCompletedTrigger: MemoryUpdateTrigger = { existingPages, }, maxSteps: 10, + abortSignal, toolCallbacks: { read_memory_page: createReadMemoryPageCallback({ memory }), write_memory_page: createWriteMemoryPageCallback({ diff --git a/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/types.ts b/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/types.ts index 08e4f15f3ff9c..b451028ccf19e 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/types.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/memory/triggers/types.ts @@ -51,6 +51,10 @@ export interface MemoryUpdateContext { * Insight client for triggers that need to read insights/KIs. */ insightClient?: InsightClient; + /** + * An optional AbortSignal that triggers can check to support task cancellation. + */ + abortSignal?: AbortSignal; } /** diff --git a/x-pack/platform/plugins/shared/streams/server/lib/tasks/cancellable_task.ts b/x-pack/platform/plugins/shared/streams/server/lib/tasks/cancellable_task.ts index dba4d024f423d..fc37d0a9a8b06 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/tasks/cancellable_task.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/tasks/cancellable_task.ts @@ -39,6 +39,7 @@ export function cancellableTask( ); if (task.status === TaskStatus.BeingCanceled) { + runContext.abortController.abort(); resolve('canceled' as const); } }, 5000); @@ -55,14 +56,17 @@ export function cancellableTask( /** * Here the task can be in BeingCanceled state in two scenarios: - * 1. cancellationPromise was resolved + * 1. cancellationPromise was resolved — abort was already called in the + * polling loop before resolving the promise. * 2. run() exited early in response to cancellation. This might - * happen for multi-step tasks, like onboarding, in order to prevent - * scheduling the next sub-task while the parent task was already - * canceled. + * happen for multi-step tasks, like onboarding, in order to prevent + * scheduling the next sub-task while the parent task was already + * canceled. In this case, the signal may not have been aborted yet. */ if (task.status === TaskStatus.BeingCanceled) { - runContext.abortController.abort(); + if (!runContext.abortController.signal.aborted) { + runContext.abortController.abort(); + } await taskClient.markCanceled(task); } }); diff --git a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/conversation_scraper.ts b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/conversation_scraper.ts index 573fe36677f4e..730f2cc753f8f 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/conversation_scraper.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/conversation_scraper.ts @@ -164,6 +164,7 @@ export function createStreamsConversationScraperTask(taskContext: TaskContext) { existingPages, }, maxSteps: 20, + abortSignal: runContext.abortController.signal, toolCallbacks: { get_conversation_details: async (toolCall) => { const { index } = toolCall.function.arguments; diff --git a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_consolidation.ts b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_consolidation.ts index 57386af4d443d..42b8fd86146c1 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_consolidation.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_consolidation.ts @@ -113,6 +113,7 @@ export function createStreamsMemoryConsolidationTask(taskContext: TaskContext) { existingPages, }, maxSteps: 30, + abortSignal: runContext.abortController.signal, toolCallbacks: { read_memory_page: createReadMemoryPageCallback({ memory }), diff --git a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_generation.ts b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_generation.ts index 86402a3f554d4..ed8ff8ec0c331 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_generation.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_generation.ts @@ -111,6 +111,10 @@ export function createStreamsMemoryGenerationTask(taskContext: TaskContext) { taskLogger.info(`Found ${allEntries.length} existing memory entries total`); for (const { streamName, indicators } of streamGroups) { + if (runContext.abortController.signal.aborted) { + throw new Error('Request was aborted'); + } + taskLogger.info( `Processing stream "${streamName}" with ${indicators.length} indicator(s) via reasoning agent` ); @@ -136,6 +140,7 @@ export function createStreamsMemoryGenerationTask(taskContext: TaskContext) { existingPages, }, maxSteps: 10, + abortSignal: runContext.abortController.signal, toolCallbacks: { get_indicator_details: async (toolCall) => { const { index } = toolCall.function.arguments; diff --git a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_update.ts b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_update.ts index 618d82985b82d..55cd957fd3d1f 100644 --- a/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_update.ts +++ b/x-pack/platform/plugins/shared/streams/server/lib/tasks/task_definitions/memory_update.ts @@ -106,6 +106,7 @@ export function createStreamsMemoryUpdateTask(taskContext: TaskContext) { esClient: scopedClusterClient.asCurrentUser, insightClient, payload, + abortSignal: runContext.abortController.signal, }); taskLogger.info(`Memory update trigger "${triggerId}" completed successfully`);