diff --git a/x-pack/packages/kbn-elastic-assistant-common/constants.ts b/x-pack/packages/kbn-elastic-assistant-common/constants.ts index 74da6ab2476e2..96af59095ab87 100755 --- a/x-pack/packages/kbn-elastic-assistant-common/constants.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/constants.ts @@ -7,13 +7,15 @@ export const ELASTIC_AI_ASSISTANT_INTERNAL_API_VERSION = '1'; -export const ELASTIC_AI_ASSISTANT_URL = '/api/elastic_assistant'; +export const ELASTIC_AI_ASSISTANT_URL = '/api/security_ai_assistant'; export const ELASTIC_AI_ASSISTANT_INTERNAL_URL = '/internal/elastic_assistant'; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL = `${ELASTIC_AI_ASSISTANT_INTERNAL_URL}/current_user/conversations`; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/{id}`; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID_MESSAGES = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID}/messages`; +export const ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL = `${ELASTIC_AI_ASSISTANT_URL}/chat/complete`; + export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BULK_ACTION = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_bulk_action`; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_FIND = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_find`; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts new file mode 100644 index 0000000000000..0b6c3bbe6cbb3 --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +/* + * NOTICE: Do not edit this file manually. + * This file is automatically generated by the OpenAPI Generator, @kbn/openapi-generator. + * + * info: + * title: Chat Complete API endpoint + * version: 2023-10-31 + */ + +import { z } from 'zod'; + +export type RootContext = z.infer; +export const RootContext = z.literal('security'); + +/** + * Message role. + */ +export type ChatMessageRole = z.infer; +export const ChatMessageRole = z.enum(['system', 'user', 'assistant']); +export type ChatMessageRoleEnum = typeof ChatMessageRole.enum; +export const ChatMessageRoleEnum = ChatMessageRole.enum; + +export type MessageData = z.infer; +export const MessageData = z.object({}).catchall(z.unknown()); + +/** + * AI assistant message. + */ +export type ChatMessage = z.infer; +export const ChatMessage = z.object({ + /** + * Message content. + */ + content: z.string().optional(), + /** + * Message role. + */ + role: ChatMessageRole, + /** + * ECS object to attach to the context of the message. + */ + data: MessageData.optional(), + fields_to_anonymize: z.array(z.string()).optional(), +}); + +export type ChatCompleteProps = z.infer; +export const ChatCompleteProps = z.object({ + conversationId: z.string().optional(), + promptId: z.string().optional(), + isStream: z.boolean().optional(), + responseLanguage: z.string().optional(), + langSmithProject: z.string().optional(), + langSmithApiKey: z.string().optional(), + connectorId: z.string(), + model: z.string().optional(), + persist: z.boolean(), + messages: z.array(ChatMessage), +}); + +export type ChatCompleteRequestBody = z.infer; +export const ChatCompleteRequestBody = ChatCompleteProps; +export type ChatCompleteRequestBodyInput = z.input; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml new file mode 100644 index 0000000000000..21c348251b039 --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -0,0 +1,109 @@ +openapi: 3.0.0 +info: + title: Chat Complete API endpoint + version: '2023-10-31' +paths: + /api/elastic_assistant/chat/complete: + post: + operationId: ChatComplete + x-codegen-enabled: true + description: Creates a model response for the given chat conversation. + summary: Creates a model response for the given chat conversation. + tags: + - Chat Complete API + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompleteProps' + responses: + 200: + description: Indicates a successful call. + content: + application/octet-stream: + schema: + type: string + format: binary + 400: + description: Generic Error + content: + application/json: + schema: + type: object + properties: + statusCode: + type: number + error: + type: string + message: + type: string + +components: + schemas: + RootContext: + type: string + enum: + - security + + ChatMessageRole: + type: string + description: Message role. + enum: + - system + - user + - assistant + + MessageData: + type: object + additionalProperties: true + + ChatMessage: + type: object + description: AI assistant message. + required: + - 'role' + properties: + content: + type: string + description: Message content. + role: + $ref: '#/components/schemas/ChatMessageRole' + description: Message role. + data: + description: ECS object to attach to the context of the message. + $ref: '#/components/schemas/MessageData' + fields_to_anonymize: + type: array + items: + type: string + + ChatCompleteProps: + type: object + properties: + conversationId: + type: string + promptId: + type: string + isStream: + type: boolean + responseLanguage: + type: string + langSmithProject: + type: string + langSmithApiKey: + type: string + connectorId: + type: string + model: + type: string + persist: + type: boolean + messages: + type: array + items: + $ref: '#/components/schemas/ChatMessage' + required: + - messages + - persist + - connectorId diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts index 8f47731694cf3..eb5d0738f378b 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts @@ -27,6 +27,9 @@ export * from './attack_discovery/get_attack_discovery_route.gen'; export * from './attack_discovery/post_attack_discovery_route.gen'; export * from './attack_discovery/cancel_attack_discovery_route.gen'; +// Chat Schemas +export * from './chat/post_chat_complete_route.gen'; + // Evaluation Schemas export * from './evaluation/post_evaluate_route.gen'; export * from './evaluation/get_evaluate_route.gen'; diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts index 5c6d389a4ccc4..7d01468b4d6f7 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts @@ -7,10 +7,10 @@ import type OpenAI from 'openai'; import { Stream } from 'openai/streaming'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { ActionsClientChatOpenAI, ActionsClientChatOpenAIParams } from './chat_openai'; +import { ActionsClientChatOpenAI } from './chat_openai'; import { mockActionResponse, mockChatCompletion } from './mocks'; const connectorId = 'mock-connector-id'; @@ -19,11 +19,8 @@ const mockExecute = jest.fn(); const mockLogger = loggerMock.create(); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; +const actionsClient = actionsClientMock.create(); + const chunk = { object: 'chat.completion.chunk', choices: [ @@ -40,30 +37,15 @@ export async function* asyncGenerator() { yield chunk; } const mockStreamExecute = jest.fn(); -const mockStreamActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockStreamExecute, - })), -} as unknown as ActionsPluginStart; const prompt = 'Do you know my name?'; const { signal } = new AbortController(); -const mockRequest = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as ActionsClientChatOpenAIParams['request']; - const defaultArgs = { - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, streaming: false, signal, timeout: 999999, @@ -77,6 +59,7 @@ describe('ActionsClientChatOpenAI', () => { data: mockChatCompletion, status: 'ok', })); + actionsClient.execute.mockImplementation(mockExecute); }); describe('_llmType', () => { @@ -116,10 +99,11 @@ describe('ActionsClientChatOpenAI', () => { functions: [jest.fn()], }; it('returns the expected data', async () => { + actionsClient.execute.mockImplementation(mockStreamExecute); const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ ...defaultArgs, streaming: true, - actions: mockStreamActions, + actionsClient, }); const result: AsyncIterable = @@ -178,16 +162,11 @@ describe('ActionsClientChatOpenAI', () => { serviceMessage: 'action-result-service-message', status: 'error', // <-- error status })); - - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ ...defaultArgs, - actions: badActions, + actionsClient, }); expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs)) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts index c2dada0dafa3b..391609db21565 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts @@ -6,24 +6,24 @@ */ import { v4 as uuidv4 } from 'uuid'; -import { KibanaRequest, Logger } from '@kbn/core/server'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { Logger } from '@kbn/core/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { get } from 'lodash/fp'; import { ChatOpenAI } from '@langchain/openai'; import { Stream } from 'openai/streaming'; import type OpenAI from 'openai'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants'; import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types'; const LLM_TYPE = 'ActionsClientChatOpenAI'; export interface ActionsClientChatOpenAIParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; llmType?: string; logger: Logger; - request: KibanaRequest; streaming?: boolean; traceId?: string; maxRetries?: number; @@ -54,22 +54,20 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { #temperature?: number; // Kibana variables - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #actionResultData: string; #traceId: string; #signal?: AbortSignal; #timeout?: number; constructor({ - actions, + actionsClient, connectorId, traceId = uuidv4(), llmType, logger, - request, maxRetries, model, signal, @@ -92,12 +90,11 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { azureOpenAIApiVersion: 'nothing', openAIApiKey: '', }); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = traceId; this.llmType = llmType ?? LLM_TYPE; this.#logger = logger; - this.#request = request; this.#timeout = timeout; this.#actionResultData = ''; this.streaming = streaming; @@ -146,10 +143,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { )} ` ); - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error(`${LLM_TYPE}: ${actionResult?.message} - ${actionResult?.serviceMessage}`); diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts index e0f7b764b625a..aa33bbf7a6d44 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts @@ -5,39 +5,27 @@ * 2.0. */ -import { KibanaRequest } from '@kbn/core/server'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { ActionsClientLlm } from './llm'; import { mockActionResponse } from './mocks'; const connectorId = 'mock-connector-id'; -const mockExecute = jest.fn().mockImplementation(() => ({ - data: mockActionResponse, - status: 'ok', -})); +const actionsClient = actionsClientMock.create(); -const mockLogger = loggerMock.create(); +actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: mockActionResponse, + status: 'ok', + })) +); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; +const mockLogger = loggerMock.create(); const prompt = 'Do you know my name?'; -const mockRequest: KibanaRequest = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as KibanaRequest; - describe('ActionsClientLlm', () => { beforeEach(() => { jest.clearAllMocks(); @@ -46,10 +34,9 @@ describe('ActionsClientLlm', () => { describe('getActionResultData', () => { it('returns the expected data', async () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); const result = await actionsClientLlm._call(prompt); // ignore the result @@ -61,10 +48,9 @@ describe('ActionsClientLlm', () => { describe('_llmType', () => { it('returns the expected LLM type', () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); expect(actionsClientLlm._llmType()).toEqual('ActionsClientLlm'); @@ -72,11 +58,10 @@ describe('ActionsClientLlm', () => { it('returns the expected LLM type when overridden', () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, llmType: 'special-llm-type', logger: mockLogger, - request: mockRequest, }); expect(actionsClientLlm._llmType()).toEqual('special-llm-type'); @@ -86,10 +71,9 @@ describe('ActionsClientLlm', () => { describe('_call', () => { it('returns the expected content when _call is invoked', async () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); const result = await actionsClientLlm._call(prompt); @@ -98,23 +82,15 @@ describe('ActionsClientLlm', () => { }); it('rejects with the expected error when the action result status is error', async () => { - const hasErrorStatus = jest.fn().mockImplementation(() => ({ - message: 'action-result-message', - serviceMessage: 'action-result-service-message', - status: 'error', // <-- error status - })); - - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; - + actionsClient.execute.mockImplementation(() => { + throw new Error( + 'ActionsClientLlm: action result status is error: action-result-message - action-result-service-message' + ); + }); const actionsClientLlm = new ActionsClientLlm({ - actions: badActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); await expect(actionsClientLlm._call(prompt)).rejects.toThrowError( @@ -125,16 +101,17 @@ describe('ActionsClientLlm', () => { it('rejects with the expected error the message has invalid content', async () => { const invalidContent = { message: 1234 }; - mockExecute.mockImplementation(() => ({ - data: invalidContent, - status: 'ok', - })); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue({ + data: invalidContent, + status: 'ok', + }) + ); const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); await expect(actionsClientLlm._call(prompt)).rejects.toThrowError( diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.ts index 9708a8d3a5d7f..bad538821ff1d 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.ts @@ -5,11 +5,12 @@ * 2.0. */ -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { KibanaRequest, Logger } from '@kbn/core/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; +import { Logger } from '@kbn/core/server'; import { LLM } from '@langchain/core/language_models/llms'; import { get } from 'lodash/fp'; import { v4 as uuidv4 } from 'uuid'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants'; import { getMessageContentAndRole } from './helpers'; @@ -18,11 +19,10 @@ import { TraceOptions } from './types'; const LLM_TYPE = 'ActionsClientLlm'; interface ActionsClientLlmParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; llmType?: string; logger: Logger; - request: KibanaRequest; model?: string; temperature?: number; timeout?: number; @@ -31,10 +31,9 @@ interface ActionsClientLlmParams { } export class ActionsClientLlm extends LLM { - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #traceId: string; #timeout?: number; @@ -46,13 +45,12 @@ export class ActionsClientLlm extends LLM { temperature?: number; constructor({ - actions, + actionsClient, connectorId, traceId = uuidv4(), llmType, logger, model, - request, temperature, timeout, traceOptions, @@ -61,12 +59,11 @@ export class ActionsClientLlm extends LLM { callbacks: [...(traceOptions?.tracers ?? [])], }); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = traceId; this.llmType = llmType ?? LLM_TYPE; this.#logger = logger; - this.#request = request; this.#timeout = timeout; this.model = model; this.temperature = temperature; @@ -107,10 +104,7 @@ export class ActionsClientLlm extends LLM { }, }; - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error( diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts index 7ec9f1e773340..98da9a4e81b53 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts @@ -6,10 +6,10 @@ */ import { PassThrough } from 'stream'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { ActionsClientSimpleChatModel, CustomChatModelInput } from './simple_chat_model'; +import { ActionsClientSimpleChatModel } from './simple_chat_model'; import { mockActionResponse } from './mocks'; import { BaseMessage } from '@langchain/core/messages'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; @@ -19,26 +19,15 @@ import { parseGeminiStream } from '../utils/gemini'; const connectorId = 'mock-connector-id'; const mockExecute = jest.fn(); +const actionsClient = actionsClientMock.create(); const mockLogger = loggerMock.create(); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; - const mockStreamExecute = jest.fn().mockImplementation(() => ({ data: new PassThrough(), status: 'ok', })); -const mockStreamActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockStreamExecute, - })), -} as unknown as ActionsPluginStart; -const prompt = 'Do you know my name?'; const callMessages = [ { lc_serializable: true, @@ -78,20 +67,10 @@ const callRunManager = { handleLLMNewToken, } as unknown as CallbackManagerForLLMRun; -const mockRequest: CustomChatModelInput['request'] = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as CustomChatModelInput['request']; - const defaultArgs = { - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, streaming: false, }; jest.mock('../utils/bedrock'); @@ -100,6 +79,12 @@ jest.mock('../utils/gemini'); describe('ActionsClientSimpleChatModel', () => { beforeEach(() => { jest.clearAllMocks(); + actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: mockActionResponse, + status: 'ok', + })) + ); mockExecute.mockImplementation(() => ({ data: mockActionResponse, status: 'ok', @@ -146,28 +131,24 @@ describe('ActionsClientSimpleChatModel', () => { callOptions, callRunManager ); - const subAction = mockExecute.mock.calls[0][0].params.subAction; + const subAction = actionsClient.execute.mock.calls[0][0].params.subAction; expect(subAction).toEqual('invokeAI'); expect(result).toEqual(mockActionResponse.message); }); it('rejects with the expected error when the action result status is error', async () => { - const hasErrorStatus = jest.fn().mockImplementation(() => ({ - message: 'action-result-message', - serviceMessage: 'action-result-service-message', - status: 'error', // <-- error status - })); + const hasErrorStatus = jest.fn().mockImplementation(() => { + throw Error( + 'ActionsClientSimpleChatModel: action result status is error: action-result-message - action-result-service-message' + ); + }); - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: badActions, + actionsClient, }); await expect( @@ -180,10 +161,12 @@ describe('ActionsClientSimpleChatModel', () => { it('rejects with the expected error the message has invalid content', async () => { const invalidContent = { message: 1234 }; - mockExecute.mockImplementation(() => ({ - data: invalidContent, - status: 'ok', - })); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue({ + data: invalidContent, + status: 'ok', + }) + ); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel(defaultArgs); @@ -221,9 +204,10 @@ describe('ActionsClientSimpleChatModel', () => { (parseGeminiStream as jest.Mock).mockResolvedValue(mockActionResponse.message); }); it('returns the expected content when _call is invoked with streaming and llmType is Bedrock', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, maxTokens: 333, @@ -248,9 +232,10 @@ describe('ActionsClientSimpleChatModel', () => { expect(result).toEqual(mockActionResponse.message); }); it('returns the expected content when _call is invoked with streaming and llmType is Gemini', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'gemini', streaming: true, maxTokens: 333, @@ -283,9 +268,11 @@ describe('ActionsClientSimpleChatModel', () => { handleToken(`, "action_input": "`); handleToken('token6'); }); + actionsClient.execute.mockImplementationOnce(mockStreamExecute); + const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, }); @@ -303,9 +290,10 @@ describe('ActionsClientSimpleChatModel', () => { handleToken('"'); handleToken('token7'); }); + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, }); diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts index 97f7c20cc110a..ed38723993876 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts @@ -11,12 +11,12 @@ import { type BaseChatModelParams, } from '@langchain/core/language_models/chat_models'; import { type BaseMessage } from '@langchain/core/messages'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { Logger } from '@kbn/logging'; -import { KibanaRequest } from '@kbn/core-http-server'; import { v4 as uuidv4 } from 'uuid'; import { get } from 'lodash/fp'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { parseGeminiStream } from '../utils/gemini'; import { parseBedrockStream } from '../utils/bedrock'; import { getDefaultArguments } from './constants'; @@ -27,23 +27,21 @@ export const getMessageContentAndRole = (prompt: string, role = 'user') => ({ }); export interface CustomChatModelInput extends BaseChatModelParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; logger: Logger; llmType?: string; signal?: AbortSignal; model?: string; temperature?: number; - request: KibanaRequest; streaming: boolean; maxTokens?: number; } export class ActionsClientSimpleChatModel extends SimpleChatModel { - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #traceId: string; #signal?: AbortSignal; #maxTokens?: number; @@ -53,12 +51,11 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { temperature?: number; constructor({ - actions, + actionsClient, connectorId, llmType, logger, model, - request, temperature, signal, streaming, @@ -66,12 +63,11 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { }: CustomChatModelInput) { super({}); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = uuidv4(); this.#logger = logger; this.#signal = signal; - this.#request = request; this.#maxTokens = maxTokens; this.llmType = llmType ?? 'ActionsClientSimpleChatModel'; this.model = model; @@ -122,10 +118,8 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { }, }, }; - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error( diff --git a/x-pack/packages/kbn-langchain/tsconfig.json b/x-pack/packages/kbn-langchain/tsconfig.json index 92dc5ebd33911..949aca47794ec 100644 --- a/x-pack/packages/kbn-langchain/tsconfig.json +++ b/x-pack/packages/kbn-langchain/tsconfig.json @@ -18,6 +18,6 @@ "@kbn/logging", "@kbn/actions-plugin", "@kbn/logging-mocks", - "@kbn/core-http-server" + "@kbn/utility-types" ] } diff --git a/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts b/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts index 72c51d8f917aa..7c4f9708862a5 100644 --- a/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts +++ b/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts @@ -43,7 +43,7 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { logger: this.options.logger, conversationIndex: this.indexTemplateAndPattern.alias, id, - user: authenticatedUser, + user: authenticatedUser ?? this.options.currentUser, }); }; @@ -84,18 +84,19 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { */ public createConversation = async ({ conversation, - authenticatedUser, }: { conversation: ConversationCreateProps; - authenticatedUser: AuthenticatedUser; }): Promise => { + if (!this.options.currentUser) { + throw new Error('AIAssistantConversationsDataClient currentUser is not defined.'); + } const esClient = await this.options.elasticsearchClientPromise; return createConversation({ esClient, logger: this.options.logger, conversationIndex: this.indexTemplateAndPattern.alias, spaceId: this.spaceId, - user: authenticatedUser, + user: this.options.currentUser, conversation, }); }; @@ -128,7 +129,7 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { conversationIndex: this.indexTemplateAndPattern.alias, conversationUpdateProps, isPatch, - user: authenticatedUser, + user: authenticatedUser ?? this.options.currentUser ?? undefined, }); }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts index bacdd6cac1b49..a01ac3d126e59 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts @@ -13,17 +13,10 @@ import { executeAction, Props } from './executor'; import { PassThrough } from 'stream'; -import { KibanaRequest } from '@kbn/core-http-server'; -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { loggerMock } from '@kbn/logging-mocks'; import * as ParseStream from './parse_stream'; -const request = { - body: { - subAction: 'invokeAI', - message: 'hello', - }, -} as KibanaRequest; + const onLlmResponse = jest.fn(async () => {}); // We need it to be a promise, or it'll crash because of missing `.catch` const connectorId = 'testConnectorId'; const mockLogger = loggerMock.create(); @@ -33,8 +26,8 @@ const testProps: Omit = { subActionParams: { messages: [{ content: 'hello', role: 'user' }] }, }, actionTypeId: '.bedrock', - request, connectorId, + actionsClient: actionsClientMock.create(), onLlmResponse, logger: mockLogger, }; @@ -46,17 +39,13 @@ describe('executeAction', () => { jest.clearAllMocks(); }); it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: { - message: 'Test message', - }, - }), - }), - } as unknown as Props['actions']; - - const result = await executeAction({ ...testProps, actions }); + testProps.actionsClient.execute = jest.fn().mockResolvedValue({ + data: { + message: 'Test message', + }, + }); + + const result = await executeAction({ ...testProps }); expect(result).toEqual({ connector_id: connectorId, @@ -68,15 +57,15 @@ describe('executeAction', () => { it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => { const readableStream = new PassThrough(); - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: readableStream, - }), - }), - } as unknown as Props['actions']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'ok', + data: readableStream, + }) + ); - const result = await executeAction({ ...testProps, actions }); + const result = await executeAction({ ...testProps, actionsClient }); expect(JSON.stringify(result)).toStrictEqual( JSON.stringify(readableStream.pipe(new PassThrough())) @@ -90,83 +79,69 @@ describe('executeAction', () => { }); }); - it('should throw an error if the actions plugin fails to retrieve the actions client', async () => { - const actions = { - getActionsClientWithRequest: jest - .fn() - .mockRejectedValue(new Error('Failed to retrieve actions client')), - } as unknown as Props['actions']; - - await expect(executeAction({ ...testProps, actions })).rejects.toThrowError( - 'Failed to retrieve actions client' - ); - }); - it('should throw an error if the actions client fails to execute the action', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockRejectedValue(new Error('Failed to execute action')), - }), - } as unknown as Props['actions']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockRejectedValue(new Error('Failed to execute action')); + testProps.actionsClient = actionsClient; - await expect(executeAction({ ...testProps, actions })).rejects.toThrowError( + await expect(executeAction({ ...testProps, actionsClient })).rejects.toThrowError( 'Failed to execute action' ); }); it('should throw an error when the response from the actions framework is null or undefined', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: null, - }), - }), - } as unknown as Props['actions']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + data: null, + }) + ); + testProps.actionsClient = actionsClient; try { - await executeAction({ ...testProps, actions }); + await executeAction({ ...testProps, actionsClient }); } catch (e) { expect(e.message).toBe('Action result status is error: result is not streamable'); } }); it('should throw an error if action result status is "error"', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - status: 'error', - message: 'Error message', - serviceMessage: 'Service error message', - }), - }), - } as unknown as ActionsPluginStart; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'error', + message: 'Error message', + serviceMessage: 'Service error message', + }) + ); + testProps.actionsClient = actionsClient; await expect( executeAction({ ...testProps, - actions, + actionsClient, connectorId: '12345', }) ).rejects.toThrowError('Action result status is error: Error message - Service error message'); }); it('should throw an error if content of response data is not a string or streamable', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - status: 'ok', - data: { - message: 12345, - }, - }), - }), - } as unknown as ActionsPluginStart; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'ok', + data: { + message: 12345, + }, + }) + ); + testProps.actionsClient = actionsClient; await expect( executeAction({ ...testProps, - actions, + actionsClient, connectorId: '12345', }) ).rejects.toThrowError('Action result status is error: result is not streamable'); diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts index e7797805854ec..bd25a77808dbe 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -6,20 +6,18 @@ */ import { get } from 'lodash/fp'; -import { KibanaRequest } from '@kbn/core-http-server'; -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { PassThrough, Readable } from 'stream'; -import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common'; import { Logger } from '@kbn/core/server'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { handleStreamStorage } from './parse_stream'; export interface Props { onLlmResponse?: (content: string) => Promise; abortSignal?: AbortSignal; - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; params: InvokeAIActionsParams; - request: KibanaRequest; actionTypeId: string; logger: Logger; } @@ -43,15 +41,13 @@ interface InvokeAIActionsParams { export const executeAction = async ({ onLlmResponse, - actions, + actionsClient, params, connectorId, actionTypeId, - request, logger, abortSignal, }: Props): Promise => { - const actionsClient = await actions.getActionsClientWithRequest(request); const actionResult = await actionsClient.execute({ actionId: connectorId, params: { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index e1e8cdc50eee0..6f05edbed007b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -5,10 +5,11 @@ * 2.0. */ -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; import { coreMock } from '@kbn/core/server/mocks'; import { KibanaRequest } from '@kbn/core/server'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; + import { loggerMock } from '@kbn/logging-mocks'; import { initializeAgentExecutorWithOptions } from 'langchain/agents'; @@ -84,7 +85,7 @@ const mockRequest: KibanaRequest = { body: {} } as K any // eslint-disable-line @typescript-eslint/no-explicit-any >; -const mockActions: ActionsPluginStart = {} as ActionsPluginStart; +const actionsClient = actionsClientMock.create(); const mockLogger = loggerMock.create(); const mockTelemetry = coreMock.createSetup().analytics; const esClientMock = elasticsearchServiceMock.createScopedClusterClient().asCurrentUser; @@ -95,7 +96,7 @@ const esStoreMock = new ElasticsearchStore( mockTelemetry ); const defaultProps: AgentExecutorParams = { - actions: mockActions, + actionsClient, isEnabledKnowledgeBase: true, connectorId: mockConnectorId, esClient: esClientMock, @@ -151,11 +152,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor(defaultProps); expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: false, temperature: 0.2, llmType: 'openai', @@ -189,11 +189,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor({ ...defaultProps, isStream: true }); expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: true, temperature: 0.2, llmType: 'openai', @@ -213,11 +212,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor(bedrockProps); expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: false, temperature: 0, llmType: 'bedrock', @@ -253,11 +251,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor({ ...bedrockProps, isStream: true }); expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: true, temperature: 0, llmType: 'bedrock', diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index bcf39320f21cc..b6a624b368d82 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -18,6 +18,8 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; import { MessagesPlaceholder } from '@langchain/core/prompts'; +import { EsAnonymizationFieldsSchema } from '../../../ai_assistant_data_clients/anonymization_fields/types'; +import { transformESSearchToAnonymizationFields } from '../../../ai_assistant_data_clients/anonymization_fields/helpers'; import { AgentExecutor } from '../executors/types'; import { APMTracer } from '../tracers/apm_tracer'; import { AssistantToolParams } from '../../../types'; @@ -31,9 +33,8 @@ export const DEFAULT_AGENT_EXECUTOR_ID = 'Elastic AI Assistant Agent Executor'; */ export const callAgentExecutor: AgentExecutor = async ({ abortSignal, - actions, + actionsClient, alertsIndexPattern, - anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], connectorId, @@ -49,14 +50,14 @@ export const callAgentExecutor: AgentExecutor = async ({ request, size, traceOptions, + dataClients, }) => { const isOpenAI = llmType === 'openai'; const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const llm = new llmClass({ - actions, + actionsClient, connectorId, - request, llmType, logger, // possible client model override, @@ -72,6 +73,16 @@ export const callAgentExecutor: AgentExecutor = async ({ maxRetries: 0, }); + const anonymizationFieldsRes = + await dataClients?.anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + const anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + const pastMessages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts index 62b1d7ac7814f..6aa1aa3ce7890 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts @@ -25,7 +25,7 @@ export const OPEN_AI_FUNCTIONS_AGENT_EXECUTOR_ID = * NOTE: This is not to be used in production as-is, and must be used with an OpenAI ConnectorId */ export const callOpenAIFunctionsExecutor: AgentExecutor = async ({ - actions, + actionsClient, connectorId, esClient, esStore, @@ -36,9 +36,8 @@ export const callOpenAIFunctionsExecutor: AgentExecutor = async ({ traceOptions, }) => { const llm = new ActionsClientLlm({ - actions, + actionsClient, connectorId, - request, llmType, logger, model: request.body.model, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index bd07099e312b3..7af0b459f4bc9 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import { BaseMessage } from '@langchain/core/messages'; import { Logger } from '@kbn/logging'; @@ -13,7 +13,7 @@ import { KibanaRequest, KibanaResponseFactory, ResponseHeaders } from '@kbn/core import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; -import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -36,8 +36,7 @@ export interface AssistantDataClients { export interface AgentExecutorParams { abortSignal?: AbortSignal; alertsIndexPattern?: string; - actions: ActionsPluginStart; - anonymizationFields?: AnonymizationFieldResponse[]; + actionsClient: PublicMethodsOf; isEnabledKnowledgeBase: boolean; assistantTools?: AssistantTool[]; connectorId: string; @@ -56,6 +55,7 @@ export interface AgentExecutorParams { response?: KibanaResponseFactory; size?: number; traceOptions?: TraceOptions; + responseLanguage?: string; } export interface StaticReturnType { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 779bf20a61720..b16f7d9693e5f 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -14,11 +14,25 @@ import type { Logger } from '@kbn/logging'; import { BaseMessage } from '@langchain/core/messages'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common'; import { AgentState, NodeParamsBase } from './types'; import { AssistantDataClients } from '../../executors/types'; -import { shouldContinue } from './nodes/should_continue'; +import { + shouldContinue, + shouldContinueGenerateTitle, + shouldContinueGetConversation, +} from './nodes/should_continue'; import { AGENT_NODE, runAgent } from './nodes/run_agent'; import { executeTools, TOOLS_NODE } from './nodes/execute_tools'; +import { GENERATE_CHAT_TITLE_NODE, generateChatTitle } from './nodes/generate_chat_title'; +import { + GET_PERSISTED_CONVERSATION_NODE, + getPersistedConversation, +} from './nodes/get_persisted_conversation'; +import { + PERSIST_CONVERSATION_CHANGES_NODE, + persistConversationChanges, +} from './nodes/persist_conversation_changes'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; @@ -28,8 +42,9 @@ interface GetDefaultAssistantGraphParams { conversationId?: string; llm: BaseChatModel; logger: Logger; - messages: BaseMessage[]; tools: StructuredTool[]; + responseLanguage: string; + replacements: Replacements; } export type DefaultAssistantGraph = ReturnType; @@ -43,8 +58,9 @@ export const getDefaultAssistantGraph = ({ dataClients, llm, logger, - messages, + responseLanguage, tools, + replacements, }: GetDefaultAssistantGraphParams) => { try { // Default graph state @@ -66,7 +82,16 @@ export const getDefaultAssistantGraph = ({ }, messages: { value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y), - default: () => messages, + default: () => [], + }, + chatTitle: { + value: (x: string, y?: string) => y ?? x, + default: () => '', + }, + conversation: { + value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => + y ?? x, + default: () => undefined, }, }; @@ -94,19 +119,68 @@ export const getDefaultAssistantGraph = ({ state, tools, }); + const generateChatTitleNode = (state: AgentState) => + generateChatTitle({ + ...nodeParams, + state, + responseLanguage, + }); + + const getPersistedConversationNode = (state: AgentState) => + getPersistedConversation({ + ...nodeParams, + state, + conversationsDataClient: dataClients?.conversationsDataClient, + conversationId, + }); + + const persistConversationChangesNode = (state: AgentState) => + persistConversationChanges({ + ...nodeParams, + state, + conversationsDataClient: dataClients?.conversationsDataClient, + conversationId, + replacements, + }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); + const shouldContinueGenerateTitleEdge = (state: AgentState) => + shouldContinueGenerateTitle({ ...nodeParams, state }); + const shouldContinueGetConversationEdge = (state: AgentState) => + shouldContinueGetConversation({ ...nodeParams, state, conversationId }); // Put together a new graph using the nodes and default state from above - const graph = new StateGraph, '__start__' | 'agent' | 'tools'>({ + const graph = new StateGraph< + AgentState, + Partial, + | '__start__' + | 'agent' + | 'tools' + | 'generateChatTitle' + | 'getPersistedConversation' + | 'persistConversationChanges' + >({ channels: graphState, }); // Define the nodes to cycle between + graph.addNode(GET_PERSISTED_CONVERSATION_NODE, getPersistedConversationNode); + graph.addNode(GENERATE_CHAT_TITLE_NODE, generateChatTitleNode); + graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode); graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); + + // Add edges, alternating between agent and action until finished + graph.addConditionalEdges(START, shouldContinueGetConversationEdge, { + continue: GET_PERSISTED_CONVERSATION_NODE, + end: AGENT_NODE, + }); + graph.addConditionalEdges(GET_PERSISTED_CONVERSATION_NODE, shouldContinueGenerateTitleEdge, { + continue: GENERATE_CHAT_TITLE_NODE, + end: PERSIST_CONVERSATION_CHANGES_NODE, + }); + graph.addEdge(GENERATE_CHAT_TITLE_NODE, PERSIST_CONVERSATION_CHANGES_NODE); + graph.addEdge(PERSIST_CONVERSATION_CHANGES_NODE, AGENT_NODE); // Add conditional edge for basic routing graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { continue: TOOLS_NODE, end: END }); - // Add edges, alternating between agent and action until finished - graph.addEdge(START, AGENT_NODE); graph.addEdge(TOOLS_NODE, AGENT_NODE); // Compile the graph return graph.compile(); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 383b3e9f5cee8..482c89c10e969 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -11,6 +11,7 @@ import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-strea import { transformError } from '@kbn/securitysolution-es-utils'; import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common'; +import { AGENT_NODE_TAG } from './nodes/run_agent'; import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph'; import type { OnLlmResponse, TraceOptions } from '../../executors/types'; import type { APMTracer } from '../../tracers/apm_tracer'; @@ -91,27 +92,35 @@ export const streamGraph = async ({ if (done) return; const event = value; - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; - // TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks, - // TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event - if (event.name === 'ActionsClientChatOpenAI') { - const msg = chunk.message; - - if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { - /* empty */ - } else if (!didEnd) { - if (msg.response_metadata.finish_reason === 'stop') { - handleStreamEnd(finalMessage); - } else { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; + // only process events that are part of the agent run + if ((event.tags || []).includes(AGENT_NODE_TAG)) { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + // TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks, + // TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event + if (event.name === 'ActionsClientChatOpenAI') { + const msg = chunk.message; + + if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { + /* empty */ + } else if (!didEnd) { + if (msg.response_metadata.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } else { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; + } } } + } else if (event.event === 'on_llm_end') { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } } } - await processEvent(); + void processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. // If I put await on this function the error works properly, but when there is not an error @@ -129,7 +138,7 @@ export const streamGraph = async ({ }; // Start processing events, do not await! Return `responseWithHeaders` immediately - await processEvent(); + void processEvent(); return responseWithHeaders; }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 1e40f6b2fe127..517ac10479461 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -13,21 +13,22 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; import { createOpenAIFunctionsAgent, createStructuredChatAgent } from 'langchain/agents'; +import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; import { AgentExecutor } from '../../executors/types'; import { openAIFunctionAgentPrompt, structuredChatAgentPrompt } from './prompts'; import { APMTracer } from '../../tracers/apm_tracer'; import { getDefaultAssistantGraph } from './graph'; import { invokeGraph, streamGraph } from './helpers'; +import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers'; /** * Drop in replacement for the existing `callAgentExecutor` that uses LangGraph */ export const callAssistantGraph: AgentExecutor = async ({ abortSignal, - actions, + actionsClient, alertsIndexPattern, - anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], connectorId, @@ -45,15 +46,15 @@ export const callAssistantGraph: AgentExecutor = async ({ request, size, traceOptions, + responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const llm = new llmClass({ - actions, + actionsClient, connectorId, - request, llmType, logger, // possible client model override, @@ -68,15 +69,23 @@ export const callAssistantGraph: AgentExecutor = async ({ // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, }); - const model = llm; - const messages = langChainMessages.slice(0, -1); // all but the last message + const anonymizationFieldsRes = + await dataClients?.anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + const anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + const latestMessage = langChainMessages.slice(-1); // the last message const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(model, esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever(10)); // Fetch any applicable tools that the source plugin may have registered const assistantToolParams: AssistantToolParams = { @@ -86,7 +95,7 @@ export const callAssistantGraph: AgentExecutor = async ({ esClient, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, - llm: model, + llm, logger, modelExists, onNewReplacements, @@ -121,10 +130,11 @@ export const callAssistantGraph: AgentExecutor = async ({ dataClients, llm, logger, - messages, tools, + responseLanguage, + replacements, }); - const inputs = { input: latestMessage[0].content as string }; + const inputs = { input: latestMessage[0]?.content as string }; if (isStream) { return streamGraph({ apmTracer, assistantGraph, inputs, logger, onLlmResponse, request }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index bcba25eab0b0d..d1d60e9bed9b4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -8,46 +8,45 @@ import { StringOutputParser } from '@langchain/core/output_parsers'; import { ChatPromptTemplate } from '@langchain/core/prompts'; import { AgentState, NodeParamsBase } from '../types'; -import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; -export const GENERATE_CHAT_TITLE_PROMPT = ChatPromptTemplate.fromMessages([ - [ - 'system', - `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. As an example, for the given MESSAGE, this is the TITLE: +export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) => + ChatPromptTemplate.fromMessages([ + [ + 'system', + `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}. As an example, for the given MESSAGE, this is the TITLE: MESSAGE: I am having trouble with the Elastic Security app. TITLE: Troubleshooting Elastic Security app issues `, - ], - ['human', '{input}'], -]); + ], + ['human', '{input}'], + ]); export interface GenerateChatTitleParams extends NodeParamsBase { - conversationsDataClient?: AIAssistantConversationsDataClient; - conversationId?: string; + responseLanguage: string; state: AgentState; } export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle'; export const generateChatTitle = async ({ - conversationsDataClient, + responseLanguage, logger, model, state, }: GenerateChatTitleParams) => { logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + if (state.messages.length !== 0) { logger.debug('No need to generate chat title, messages already exist'); - return; + return { chatTitle: '' }; } const outputParser = new StringOutputParser(); - const graph = GENERATE_CHAT_TITLE_PROMPT.pipe(model).pipe(outputParser); + const graph = GENERATE_CHAT_TITLE_PROMPT(responseLanguage).pipe(model).pipe(outputParser); const chatTitle = await graph.invoke({ input: JSON.stringify(state.input, null, 2), }); - logger.debug(`chatTitle: ${chatTitle}`); return { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts new file mode 100644 index 0000000000000..6dbf284e462c4 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { AgentState, NodeParamsBase } from '../types'; +import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; +import { getLangChainMessages } from '../../../helpers'; + +export interface GetPersistedConversationParams extends NodeParamsBase { + conversationsDataClient?: AIAssistantConversationsDataClient; + conversationId?: string; + state: AgentState; +} + +export const GET_PERSISTED_CONVERSATION_NODE = 'getPersistedConversation'; + +export const getPersistedConversation = async ({ + conversationsDataClient, + conversationId, + logger, + state, +}: GetPersistedConversationParams) => { + logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + if (!conversationId) { + logger.debug('Cannot get conversation, because conversationId is undefined'); + return { + conversation: undefined, + messages: [], + chatTitle: '', + input: state.input, + }; + } + + const conversation = await conversationsDataClient?.getConversation({ id: conversationId }); + if (!conversation) { + logger.debug('Requested conversation, because conversation is undefined'); + return { + conversation: undefined, + messages: [], + chatTitle: '', + input: state.input, + }; + } + + logger.debug(`conversationId: ${conversationId}`); + + const messages = getLangChainMessages(conversation.messages ?? []); + return { + conversation, + messages, + chatTitle: conversation.title, + input: !state.input ? conversation.messages?.slice(-1)[0].content : state.input, + }; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts new file mode 100644 index 0000000000000..a86897e67adbf --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { + Replacements, + replaceAnonymizedValuesWithOriginalValues, +} from '@kbn/elastic-assistant-common'; +import { AgentState, NodeParamsBase } from '../types'; +import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; +import { getLangChainMessages } from '../../../helpers'; + +export interface PersistConversationChangesParams extends NodeParamsBase { + conversationsDataClient?: AIAssistantConversationsDataClient; + conversationId?: string; + state: AgentState; + replacements?: Replacements; +} + +export const PERSIST_CONVERSATION_CHANGES_NODE = 'persistConversationChanges'; + +export const persistConversationChanges = async ({ + conversationsDataClient, + conversationId, + logger, + state, + replacements = {}, +}: PersistConversationChangesParams) => { + logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + + if (!state.conversation || !conversationId) { + logger.debug('No need to generate chat title, conversationId is undefined'); + return { + conversation: undefined, + messages: [], + }; + } + + let conversation; + if (state.conversation?.title !== state.chatTitle) { + conversation = await conversationsDataClient?.updateConversation({ + conversationUpdateProps: { + id: conversationId, + title: state.chatTitle, + }, + }); + } + + const updatedConversation = await conversationsDataClient?.appendConversationMessages({ + existingConversation: conversation ? conversation : state.conversation, + messages: [ + { + content: replaceAnonymizedValuesWithOriginalValues({ + messageContent: state.input, + replacements, + }), + role: 'user', + timestamp: new Date().toISOString(), + }, + ], + }); + if (!updatedConversation) { + logger.debug('Not updated conversation'); + return { conversation: undefined, messages: [] }; + } + + logger.debug(`conversationId: ${conversationId}`); + const langChainMessages = getLangChainMessages(updatedConversation.messages ?? []); + const messages = langChainMessages.slice(0, -1); // all but the last message + + return { + conversation: updatedConversation, + messages, + }; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts index b0353bb5d8ec7..0a6ee3b79087c 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts @@ -19,6 +19,8 @@ export interface RunAgentParams extends NodeParamsBase { export const AGENT_NODE = 'agent'; +export const AGENT_NODE_TAG = 'agent_run'; + const NO_HISTORY = '[No existing knowledge history]'; /** * Node to run the agent @@ -44,11 +46,11 @@ export const runAgent = async ({ query: '', }); - const agentOutcome = await agentRunnable.invoke( + const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke( { ...state, chat_history: state.messages, // TODO: Message de-dupe with ...state spread - knowledge_history: knowledgeHistory?.length ? knowledgeHistory : NO_HISTORY, + knowledge_history: JSON.stringify(knowledgeHistory?.length ? knowledgeHistory : NO_HISTORY), }, config ); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts index 281963df363a8..046c4a86d4c7a 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { NEW_CHAT } from '../../../../../routes/helpers'; import { AgentState, NodeParamsBase } from '../types'; export interface ShouldContinueParams extends NodeParamsBase { @@ -26,3 +27,31 @@ export const shouldContinue = ({ logger, state }: ShouldContinueParams) => { return 'continue'; }; + +export const shouldContinueGenerateTitle = ({ logger, state }: ShouldContinueParams) => { + logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`); + + if (state.conversation?.title !== NEW_CHAT) { + return 'end'; + } + + return 'continue'; +}; + +export interface ShouldContinueGetConversation extends NodeParamsBase { + state: AgentState; + conversationId?: string; +} +export const shouldContinueGetConversation = ({ + logger, + state, + conversationId, +}: ShouldContinueGetConversation) => { + logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`); + + if (!conversationId) { + return 'end'; + } + + return 'continue'; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 1d19646fb6eb3..4ee4f1ba1b148 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -9,6 +9,7 @@ import { BaseMessage } from '@langchain/core/messages'; import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Logger } from '@kbn/logging'; +import { ConversationResponse } from '@kbn/elastic-assistant-common'; export interface AgentStateBase { agentOutcome?: AgentAction | AgentFinish; @@ -18,6 +19,8 @@ export interface AgentStateBase { export interface AgentState extends AgentStateBase { input: string; messages: BaseMessage[]; + chatTitle: string; + conversation: ConversationResponse | undefined; } export interface NodeParamsBase { diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts index bd8f7983921cb..15877e6727715 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts @@ -7,6 +7,8 @@ import { AuthenticatedUser } from '@kbn/core-security-common'; import moment from 'moment'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; + import { REQUIRED_FOR_ATTACK_DISCOVERY, addGenerationInterval, @@ -21,7 +23,6 @@ import { import { ActionsClientLlm } from '@kbn/langchain/server'; import { AttackDiscoveryDataClient } from '../../ai_assistant_data_clients/attack_discovery'; import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; import { loggerMock } from '@kbn/logging-mocks'; import { KibanaRequest } from '@kbn/core-http-server'; @@ -90,7 +91,6 @@ const mockApiConfig = { const mockCurrentAd = transformESSearchToAttackDiscovery(getAttackDiscoverySearchEsMock())[0]; -const mockActions: ActionsPluginStart = {} as ActionsPluginStart; // eslint-disable-next-line @typescript-eslint/no-explicit-any const mockRequest: KibanaRequest = {} as unknown as KibanaRequest< unknown, @@ -117,14 +117,14 @@ describe('helpers', () => { describe('getAssistantToolParams', () => { const alertsIndexPattern = '.alerts-security.alerts-default'; const esClient = elasticsearchClientMock.createElasticsearchClient(); + const actionsClient = actionsClientMock.create(); const langChainTimeout = 1000; const latestReplacements = {}; const llm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId: 'test-connecter-id', llmType: 'bedrock', logger: mockLogger, - request: mockRequest, temperature: 0, timeout: 580000, }); @@ -132,7 +132,7 @@ describe('helpers', () => { const size = 20; const mockParams = { - actions: {} as unknown as ActionsPluginStart, + actionsClient, alertsIndexPattern: 'alerts-*', anonymizationFields: [{ id: '1', field: 'field1', allowed: true, anonymized: true }], apiConfig: mockApiConfig, @@ -173,7 +173,7 @@ describe('helpers', () => { ]; const result = getAssistantToolParams({ - actions: mockParams.actions, + actionsClient, alertsIndexPattern, apiConfig: mockApiConfig, anonymizationFields, @@ -208,7 +208,7 @@ describe('helpers', () => { const anonymizationFields = undefined; const result = getAssistantToolParams({ - actions: mockParams.actions, + actionsClient, alertsIndexPattern, apiConfig: mockApiConfig, anonymizationFields, diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts index c3665d1583a3f..5f5eb8d0d8659 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts @@ -24,9 +24,10 @@ import { ActionsClientLlm } from '@kbn/langchain/server'; import { Moment } from 'moment'; import { transformError } from '@kbn/securitysolution-es-utils'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import moment from 'moment/moment'; import { uniq } from 'lodash/fp'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { getLangSmithTracer } from '../evaluate/utils'; import { getLlmType } from '../utils'; import type { GetRegisteredTools } from '../../services/app_context'; @@ -53,7 +54,7 @@ export const REQUIRED_FOR_ATTACK_DISCOVERY: AnonymizationFieldResponse[] = [ ]; export const getAssistantToolParams = ({ - actions, + actionsClient, alertsIndexPattern, anonymizationFields, apiConfig, @@ -68,7 +69,7 @@ export const getAssistantToolParams = ({ request, size, }: { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; alertsIndexPattern: string; anonymizationFields?: AnonymizationFieldResponse[]; apiConfig: ApiConfig; @@ -99,11 +100,10 @@ export const getAssistantToolParams = ({ }; const llm = new ActionsClientLlm({ - actions, + actionsClient, connectorId: apiConfig.connectorId, llmType: getLlmType(apiConfig.actionTypeId), logger, - request, temperature: 0, // zero temperature for attack discovery, because we want structured JSON output timeout: connectorTimeout, traceOptions, diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts index 9ecfb5c2af333..cbd3e6063fbd2 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts @@ -10,6 +10,7 @@ import { postAttackDiscoveryRoute } from './post_attack_discovery'; import { serverMock } from '../../__mocks__/server'; import { requestContextMock } from '../../__mocks__/request_context'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; +import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { AttackDiscoveryDataClient } from '../../ai_assistant_data_clients/attack_discovery'; import { transformESSearchToAttackDiscovery } from '../../ai_assistant_data_clients/attack_discovery/transforms'; import { getAttackDiscoverySearchEsMock } from '../../__mocks__/attack_discovery_schema.mock'; @@ -68,6 +69,7 @@ describe('postAttackDiscoveryRoute', () => { jest.clearAllMocks(); context.elasticAssistant.getCurrentUser.mockReturnValue(mockUser); context.elasticAssistant.getAttackDiscoveryDataClient.mockResolvedValue(mockDataClient); + context.elasticAssistant.actions = actionsMock.createStart(); postAttackDiscoveryRoute(server.router); findAttackDiscoveryByConnectorId.mockResolvedValue(mockCurrentAd); (getAssistantTool as jest.Mock).mockReturnValue({ getTool: jest.fn() }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts index 8ff2cd72ee36c..b9c680dde3d1d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts @@ -70,6 +70,7 @@ export const postAttackDiscoveryRoute = ( try { // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; + const actionsClient = await actions.getActionsClientWithRequest(request); const dataClient = await assistantContext.getAttackDiscoveryDataClient(); const authenticatedUser = assistantContext.getCurrentUser(); if (authenticatedUser == null) { @@ -120,7 +121,7 @@ export const postAttackDiscoveryRoute = ( } const assistantToolParams = getAssistantToolParams({ - actions, + actionsClient, alertsIndexPattern, anonymizationFields, apiConfig, diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts new file mode 100644 index 0000000000000..a487e56019bd8 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -0,0 +1,456 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { ElasticsearchClient, IRouter, KibanaRequest, Logger } from '@kbn/core/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; +import { BaseMessage } from '@langchain/core/messages'; +import { NEVER } from 'rxjs'; +import { mockActionResponse } from '../../__mocks__/action_result_data'; +import { ElasticAssistantRequestHandlerContext } from '../../types'; +import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; +import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; +import { coreMock } from '@kbn/core/server/mocks'; +import { + INVOKE_ASSISTANT_ERROR_EVENT, + INVOKE_ASSISTANT_SUCCESS_EVENT, +} from '../../lib/telemetry/event_based_telemetry'; +import { PassThrough } from 'stream'; +import { getConversationResponseMock } from '../../ai_assistant_data_clients/conversations/update_conversation.test'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; +import { getFindAnonymizationFieldsResultWithSingleHit } from '../../__mocks__/response'; +import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; +import { chatCompleteRoute } from './chat_complete_route'; +import { PublicMethodsOf } from '@kbn/utility-types'; +import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; + +const license = licensingMock.createLicenseMock(); + +const actionsClient = actionsClientMock.create(); +jest.mock('../../lib/build_response', () => ({ + buildResponse: jest.fn().mockImplementation((x) => x), +})); +const mockStream = jest.fn().mockImplementation(() => new PassThrough()); +jest.mock('../../lib/langchain/execute_custom_llm_chain', () => ({ + callAgentExecutor: jest.fn().mockImplementation( + async ({ + connectorId, + isStream, + onLlmResponse, + }: { + onLlmResponse: ( + content: string, + replacements: Record, + isError: boolean + ) => Promise; + actionsClient: PublicMethodsOf; + connectorId: string; + esClient: ElasticsearchClient; + langChainMessages: BaseMessage[]; + logger: Logger; + isStream: boolean; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + request: KibanaRequest; + }) => { + if (!isStream && connectorId === 'mock-connector-id') { + return { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }; + } else if (isStream && connectorId === 'mock-connector-id') { + return mockStream; + } else { + onLlmResponse('simulated error', {}, true).catch(() => {}); + throw new Error('simulated error'); + } + } + ), +})); +const existingConversation = getConversationResponseMock(); +const reportEvent = jest.fn(); +const appendConversationMessages = jest.fn(); +const mockContext = { + resolve: jest.fn().mockResolvedValue({ + elasticAssistant: { + actions: { + getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + }, + getRegisteredTools: jest.fn(() => []), + getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), + logger: loggingSystemMock.createLogger(), + telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + getCurrentUser: () => ({ + username: 'user', + email: 'email', + fullName: 'full name', + roles: ['user-role'], + enabled: true, + authentication_realm: { name: 'native1', type: 'native' }, + lookup_realm: { name: 'native1', type: 'native' }, + authentication_provider: { type: 'basic', name: 'basic1' }, + authentication_type: 'realm', + elastic_cloud_user: false, + metadata: { _reserved: false }, + }), + getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ + getConversation: jest.fn().mockResolvedValue(existingConversation), + updateConversation: jest.fn().mockResolvedValue(existingConversation), + createConversation: jest.fn().mockResolvedValue(existingConversation), + appendConversationMessages: + appendConversationMessages.mockResolvedValue(existingConversation), + }), + getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ + findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), + }), + }, + core: { + elasticsearch: { + client: elasticsearchServiceMock.createScopedClusterClient(), + }, + savedObjects: coreMock.createRequestHandlerContext().savedObjects, + }, + licensing: { + ...licensingMock.createRequestHandlerContext({ license }), + license, + }, + }), +}; + +const mockRequest = { + body: { + conversationId: 'mock-conversation-id', + connectorId: 'mock-connector-id', + persist: true, + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: false, + model: 'gpt-4', + messages: [ + { + role: 'user', + content: + "Evaluate the event from the context and format your output neatly in markdown syntax for my Elastic Security case.\nAdd your description, recommended actions and bulleted triage steps. Use the MITRE ATT&CK data provided to add more context and recommendations from MITRE, and hyperlink to the relevant pages on MITRE's website. Be sure to include the user and host risk score data from the context. Your response should include steps that point to Elastic Security specific features, including endpoint response actions, the Elastic Agent OSQuery manager integration (with example osquery queries), timelines and entity analytics and link to all the relevant Elastic Security documentation.", + data: { + 'event.category': 'process', + 'process.pid': 69516, + 'host.os.version': 14.5, + 'host.os.name': 'macOS', + 'host.name': 'Yuliias-MBP', + 'process.name': 'biomesyncd', + 'user.name': 'yuliianaumenko', + 'process.working_directory': '/', + 'event.module': 'system', + 'process.executable': '/usr/libexec/biomesyncd', + 'process.args': '/usr/libexec/biomesyncd', + }, + }, + ], + }, + events: { + aborted$: NEVER, + }, +}; + +const mockResponse = { + ok: jest.fn().mockImplementation((x) => x), + error: jest.fn().mockImplementation((x) => x), +}; + +describe('chatCompleteRoute', () => { + const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2'); + + beforeEach(() => { + jest.clearAllMocks(); + license.hasAtLeast.mockReturnValue(true); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue(() => ({ + data: 'mockChatCompletion', + status: 'ok', + })) + ); + actionsClient.getBulk.mockResolvedValue([ + { + id: '1', + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + name: 'my name', + actionTypeId: '.gen-ai', + isMissingSecrets: false, + config: { + a: true, + b: true, + c: true, + }, + }, + ]); + }); + + it('returns the expected response when using the existingConversation', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: existingConversation.id, + }, + }, + mockResponse + ); + + expect(result).toEqual({ + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }); + }), + }; + }), + }, + }; + + chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected error when executeCustomLlmChain fails', async () => { + const requestWithBadConnectorId = { + ...mockRequest, + body: { + ...mockRequest.body, + connectorId: 'bad-connector-id', + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler(mockContext, requestWithBadConnectorId, mockResponse); + + expect(result).toEqual({ + body: 'simulated error', + statusCode: 500, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb on, RAG alerts off', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, mockRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: false, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb on, RAG alerts on', async () => { + const ragRequest = { + ...mockRequest, + body: { + ...mockRequest.body, + isEnabledRAGAlerts: true, + anonymizationFields: [ + { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, + { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, + ], + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, ragRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports error events to telemetry - kb on, RAG alerts off', async () => { + const requestWithBadConnectorId = { + ...mockRequest, + body: { + ...mockRequest.body, + connectorId: 'bad-connector-id', + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, requestWithBadConnectorId, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + errorMessage: 'simulated error', + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('Adds error to conversation history', async () => { + const badRequest = { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: '99999', + connectorId: 'bad-connector-id', + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, badRequest, mockResponse); + expect(appendConversationMessages.mock.calls[1][0].messages[0]).toEqual( + expect.objectContaining({ + content: 'simulated error', + isError: true, + role: 'assistant', + }) + ); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when isStream=true and actionTypeId=.gen-ai', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + isStream: true, + }, + }, + mockResponse + ); + + expect(result).toEqual(mockStream); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when isStream=true and actionTypeId=.bedrock', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + isStream: true, + }, + }, + mockResponse + ); + + expect(result).toEqual(mockStream); + }), + }; + }), + }, + }; + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); +}); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts new file mode 100644 index 0000000000000..10da330a36c79 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -0,0 +1,249 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { transformError } from '@kbn/securitysolution-es-utils'; +import { Logger } from '@kbn/core/server'; +import { + ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, + ChatCompleteProps, + API_VERSIONS, + Message, + Replacements, + transformRawData, + getAnonymizedValue, + ConversationResponse, +} from '@kbn/elastic-assistant-common'; +import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; +import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; +import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; +import { ElasticAssistantPluginRouter, GetElser } from '../../types'; +import { buildResponse } from '../../lib/build_response'; +import { + DEFAULT_PLUGIN_NAME, + appendAssistantMessageToConversation, + createOrUpdateConversationWithUserInput, + getPluginNameFromRequest, + langChainExecute, + performChecks, +} from '../helpers'; +import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers'; +import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types'; + +export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => { + return `CONTEXT:\n"""\n${context}\n"""`; +}; + +export const chatCompleteRoute = ( + router: ElasticAssistantPluginRouter, + getElser: GetElser +): void => { + router.versioned + .post({ + access: 'public', + path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, + + options: { + tags: ['access:elasticAssistant'], + }, + }) + .addVersion( + { + version: API_VERSIONS.public.v1, + validate: { + request: { + body: buildRouteValidationWithZod(ChatCompleteProps), + }, + }, + }, + async (context, request, response) => { + const abortSignal = getRequestAbortedSignal(request.events.aborted$); + const assistantResponse = buildResponse(response); + let telemetry; + let actionTypeId; + try { + const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); + const logger: Logger = ctx.elasticAssistant.logger; + telemetry = ctx.elasticAssistant.telemetry; + + // Perform license and authenticated user checks + const checkResponse = performChecks({ + authenticatedUser: true, + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; + } + + const conversationsDataClient = + await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); + + const anonymizationFieldsDataClient = + await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient(); + + let messages; + const conversationId = request.body.conversationId; + const connectorId = request.body.connectorId; + + let latestReplacements: Replacements = {}; + const onNewReplacements = (newReplacements: Replacements) => { + latestReplacements = { ...latestReplacements, ...newReplacements }; + }; + + // get the actions plugin start contract from the request context: + const actions = ctx.elasticAssistant.actions; + const actionsClient = await actions.getActionsClientWithRequest(request); + const connectors = await actionsClient.getBulk({ ids: [connectorId] }); + actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; + + // replacements + const anonymizationFieldsRes = + await anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + let anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + + // anonymize messages before sending to LLM + messages = request.body.messages.map((m) => { + let content = m.content ?? ''; + if (m.data) { + // includes/anonymize fields from the messages data + if (m.fields_to_anonymize && m.fields_to_anonymize.length > 0) { + anonymizationFields = anonymizationFields?.map((a) => { + if (m.fields_to_anonymize?.includes(a.field)) { + return { + ...a, + allowed: true, + anonymized: true, + }; + } + return a; + }); + } + const anonymizedData = transformRawData({ + anonymizationFields, + currentReplacements: latestReplacements, + getAnonymizedValue, + onNewReplacements, + rawData: Object.keys(m.data).reduce( + (obj, key) => ({ ...obj, [key]: [m.data ? m.data[key] : ''] }), + {} + ), + }); + const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; + content = `${wr}\n${m.content}`; + } + const transformedMessage = { + role: m.role, + content, + }; + return transformedMessage; + }); + + let updatedConversation: ConversationResponse | undefined | null; + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const enableKnowledgeBaseByDefault = + ctx.elasticAssistant.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + // TODO: remove non-graph persistance when KB will be enabled by default + if ( + (!enableKnowledgeBaseByDefault || (enableKnowledgeBaseByDefault && !conversationId)) && + request.body.persist && + conversationsDataClient + ) { + updatedConversation = await createOrUpdateConversationWithUserInput({ + actionsClient, + actionTypeId, + connectorId, + conversationId, + conversationsDataClient, + promptId: request.body.promptId, + logger, + replacements: latestReplacements, + newMessages: messages, + model: request.body.model, + }); + if (updatedConversation == null) { + return assistantResponse.error({ + body: `conversation id: "${conversationId}" not updated`, + statusCode: 400, + }); + } + // messages are anonymized by conversationsDataClient + messages = updatedConversation?.messages?.map((c) => ({ + role: c.role, + content: c.content, + })); + } + + const onLlmResponse = async ( + content: string, + traceData: Message['traceData'] = {}, + isError = false + ): Promise => { + if (updatedConversation?.id && conversationsDataClient) { + await appendAssistantMessageToConversation({ + conversationId: updatedConversation?.id, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, + }); + } + }; + + return await langChainExecute({ + abortSignal, + isEnabledKnowledgeBase: true, + isStream: request.body.isStream ?? false, + actionsClient, + actionTypeId, + connectorId, + conversationId, + context: ctx, + getElser, + logger, + messages: messages ?? [], + onLlmResponse, + onNewReplacements, + replacements: latestReplacements, + request, + response, + telemetry, + responseLanguage: request.body.responseLanguage, + }); + } catch (err) { + const error = transformError(err as Error); + telemetry?.reportEvent(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + actionTypeId: actionTypeId ?? '', + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + model: request.body.model, + errorMessage: error.message, + // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented + // tracked here: https://github.com/elastic/security-team/issues/7363 + assistantStreamingEnabled: request.body.isStream ?? false, + }); + return assistantResponse.error({ + body: error.message, + statusCode: error.statusCode, + }); + } + } + ); +}; diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index ef1950b5e90ad..990417b799234 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -192,7 +192,7 @@ export const postEvaluateRoute = ( agents.push({ agentEvaluator: async (langChainMessages, exampleId) => { const evalResult = await AGENT_EXECUTOR_MAP[agentName]({ - actions, + actionsClient, isEnabledKnowledgeBase: true, assistantTools, connectorId, @@ -237,9 +237,8 @@ export const postEvaluateRoute = ( evalModel == null || evalModel === '' ? undefined : new ActionsClientLlm({ - actions, + actionsClient, connectorId: evalModel, - request: skeletonRequest, logger, model: skeletonRequest.body.model, }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 243de14d67ed3..aa060e24bc5df 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -5,15 +5,46 @@ * 2.0. */ -import { IKibanaResponse, KibanaRequest, KibanaResponseFactory } from '@kbn/core-http-server'; -import { Logger } from '@kbn/core/server'; -import { Message, TraceData } from '@kbn/elastic-assistant-common'; +import { + AnalyticsServiceSetup, + IKibanaResponse, + KibanaRequest, + KibanaResponseFactory, + Logger, +} from '@kbn/core/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; + +import { + TraceData, + ConversationResponse, + ExecuteConnectorRequestBody, + Message, + Replacements, + replaceAnonymizedValuesWithOriginalValues, +} from '@kbn/elastic-assistant-common'; import { ILicense } from '@kbn/licensing-plugin/server'; -import { AwaitedProperties } from '@kbn/utility-types'; +import { i18n } from '@kbn/i18n'; +import { AwaitedProperties, PublicMethodsOf } from '@kbn/utility-types'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities'; import { MINIMUM_AI_ASSISTANT_LICENSE } from '../../common/constants'; -import { ElasticAssistantRequestHandlerContext } from '../types'; -import { buildResponse } from './utils'; +import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; +import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; +import { buildResponse, getLlmType } from './utils'; +import { + AgentExecutorParams, + AssistantDataClients, + StaticReturnType, +} from '../lib/langchain/executors/types'; +import { executeAction, StaticResponse } from '../lib/executor'; +import { getLangChainMessages } from '../lib/langchain/helpers'; + +import { getLangSmithTracer } from './evaluate/utils'; +import { ElasticsearchStore } from '../lib/langchain/elasticsearch_store/elasticsearch_store'; +import { AIAssistantConversationsDataClient } from '../ai_assistant_data_clients/conversations'; +import { INVOKE_ASSISTANT_SUCCESS_EVENT } from '../lib/telemetry/event_based_telemetry'; +import { ElasticAssistantRequestHandlerContext, GetElser } from '../types'; +import { callAssistantGraph } from '../lib/langchain/graphs/default_assistant_graph'; interface GetPluginNameFromRequestParams { request: KibanaRequest; @@ -23,6 +54,10 @@ interface GetPluginNameFromRequestParams { export const DEFAULT_PLUGIN_NAME = 'securitySolutionUI'; +export const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { + defaultMessage: 'New chat', +}); + /** * Attempts to extract the plugin name the request originated from using the request headers. * @@ -92,6 +127,495 @@ export const hasAIAssistantLicense = (license: ILicense): boolean => export const UPGRADE_LICENSE_MESSAGE = 'Your license does not support AI Assistant. Please upgrade your license.'; +export interface GenerateTitleForNewChatConversationParams { + message: Pick; + model?: string; + actionTypeId: string; + connectorId: string; + logger: Logger; + actionsClient: PublicMethodsOf; + responseLanguage?: string; +} +export const generateTitleForNewChatConversation = async ({ + message, + model, + actionTypeId, + connectorId, + logger, + actionsClient, + responseLanguage = 'English', +}: GenerateTitleForNewChatConversationParams) => { + try { + const autoTitle = (await executeAction({ + actionsClient, + connectorId, + actionTypeId, + params: { + subAction: 'invokeAI', + subActionParams: { + model, + messages: [ + { + role: 'system', + content: `You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}.`, + }, + { + role: message.role, + content: message.content, + }, + ], + ...(actionTypeId === '.gen-ai' + ? { n: 1, stop: null, temperature: 0.2 } + : { temperature: 0, stopSequences: [] }), + }, + }, + logger, + })) as unknown as StaticResponse; // TODO: Use function overloads in executeAction to avoid this cast when sending subAction: 'invokeAI', + if (autoTitle.status === 'ok') { + // This regular expression captures a string enclosed in single or double quotes. + // It extracts the string content without the quotes. + // Example matches: + // - "Hello, World!" => Captures: Hello, World! + // - 'Another Example' => Captures: Another Example + // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes + const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); + const title = match ? match[1] : autoTitle.data; + return title; + } + } catch (e) { + /* empty */ + } +}; + +export interface AppendMessageToConversationParams { + conversationsDataClient: AIAssistantConversationsDataClient; + messages: Array>; + replacements: Replacements; + conversation: ConversationResponse; +} +export const appendMessageToConversation = async ({ + conversationsDataClient, + messages, + replacements, + conversation, +}: AppendMessageToConversationParams) => { + const updatedConversation = await conversationsDataClient?.appendConversationMessages({ + existingConversation: conversation, + messages: messages.map((m) => ({ + ...{ + content: replaceAnonymizedValuesWithOriginalValues({ + messageContent: m.content, + replacements, + }), + role: m.role ?? 'user', + }, + timestamp: new Date().toISOString(), + })), + }); + return updatedConversation; +}; + +export interface AppendAssistantMessageToConversationParams { + conversationsDataClient: AIAssistantConversationsDataClient; + messageContent: string; + replacements: Replacements; + conversationId: string; + isError?: boolean; + traceData?: Message['traceData']; +} +export const appendAssistantMessageToConversation = async ({ + conversationsDataClient, + messageContent, + replacements, + conversationId, + isError = false, + traceData = {}, +}: AppendAssistantMessageToConversationParams) => { + const conversation = await conversationsDataClient.getConversation({ id: conversationId }); + if (!conversation) { + return; + } + + await conversationsDataClient.appendConversationMessages({ + existingConversation: conversation, + messages: [ + getMessageFromRawResponse({ + rawContent: replaceAnonymizedValuesWithOriginalValues({ + messageContent, + replacements, + }), + traceData, + isError, + }), + ], + }); + if (Object.keys(replacements).length > 0) { + await conversationsDataClient?.updateConversation({ + conversationUpdateProps: { + id: conversation.id, + replacements, + }, + }); + } +}; + +export interface NonLangChainExecuteParams { + request: KibanaRequest; + messages: Array>; + abortSignal: AbortSignal; + actionTypeId: string; + connectorId: string; + logger: Logger; + actionsClient: PublicMethodsOf; + onLlmResponse?: ( + content: string, + traceData?: Message['traceData'], + isError?: boolean + ) => Promise; + response: KibanaResponseFactory; + telemetry: AnalyticsServiceSetup; +} +export const nonLangChainExecute = async ({ + messages, + abortSignal, + actionTypeId, + connectorId, + logger, + actionsClient, + onLlmResponse, + response, + request, + telemetry, +}: NonLangChainExecuteParams) => { + logger.debug('Executing via actions framework directly'); + const result = await executeAction({ + abortSignal, + onLlmResponse, + actionsClient, + connectorId, + actionTypeId, + params: { + subAction: request.body.subAction, + subActionParams: { + model: request.body.model, + messages, + ...(actionTypeId === '.gen-ai' + ? { n: 1, stop: null, temperature: 0.2 } + : { temperature: 0, stopSequences: [] }), + }, + }, + logger, + }); + + telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, + isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, + model: request.body.model, + assistantStreamingEnabled: request.body.subAction !== 'invokeAI', + }); + return response.ok({ + body: result, + ...(request.body.subAction === 'invokeAI' + ? { headers: { 'content-type': 'application/json' } } + : {}), + }); +}; + +export interface LangChainExecuteParams { + messages: Array>; + replacements: Replacements; + isEnabledKnowledgeBase: boolean; + isStream?: boolean; + onNewReplacements: (newReplacements: Replacements) => void; + abortSignal: AbortSignal; + telemetry: AnalyticsServiceSetup; + actionTypeId: string; + connectorId: string; + conversationId?: string; + context: AwaitedProperties< + Pick + >; + actionsClient: PublicMethodsOf; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + request: KibanaRequest; + logger: Logger; + onLlmResponse?: ( + content: string, + traceData?: Message['traceData'], + isError?: boolean + ) => Promise; + getElser: GetElser; + response: KibanaResponseFactory; + responseLanguage?: string; +} +export const langChainExecute = async ({ + messages, + replacements, + onNewReplacements, + isEnabledKnowledgeBase, + abortSignal, + telemetry, + actionTypeId, + connectorId, + context, + actionsClient, + request, + logger, + conversationId, + onLlmResponse, + getElser, + response, + responseLanguage, + isStream = true, +}: LangChainExecuteParams) => { + // TODO: Add `traceId` to actions request when calling via langchain + logger.debug( + `Executing via langchain, isEnabledKnowledgeBase: ${isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` + ); + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const assistantContext = context.elasticAssistant; + const assistantTools = assistantContext + .getRegisteredTools(pluginName) + .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation + + // get a scoped esClient for assistant memory + const esClient = context.core.elasticsearch.client.asCurrentUser; + + // convert the assistant messages to LangChain messages: + const langChainMessages = getLangChainMessages(messages); + + const elserId = await getElser(); + + const anonymizationFieldsDataClient = + await assistantContext.getAIAssistantAnonymizationFieldsDataClient(); + const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); + + // Create an ElasticsearchStore for KB interactions + // Setup with kbDataClient if `assistantKnowledgeBaseByDefault` FF is enabled + const enableKnowledgeBaseByDefault = + assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const kbDataClient = enableKnowledgeBaseByDefault + ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined + : undefined; + const kbIndex = + enableKnowledgeBaseByDefault && kbDataClient != null + ? kbDataClient.indexTemplateAndPattern.alias + : KNOWLEDGE_BASE_INDEX_PATTERN; + const esStore = new ElasticsearchStore( + esClient, + kbIndex, + logger, + telemetry, + elserId, + ESQL_RESOURCE, + kbDataClient + ); + + const dataClients: AssistantDataClients = { + anonymizationFieldsDataClient: anonymizationFieldsDataClient ?? undefined, + conversationsDataClient: conversationsDataClient ?? undefined, + kbDataClient, + }; + + // Shared executor params + const executorParams: AgentExecutorParams = { + abortSignal, + dataClients, + alertsIndexPattern: request.body.alertsIndexPattern, + actionsClient, + isEnabledKnowledgeBase, + assistantTools, + conversationId, + connectorId, + esClient, + esStore, + isStream, + llmType: getLlmType(actionTypeId), + langChainMessages, + logger, + onNewReplacements, + onLlmResponse, + request, + replacements, + responseLanguage, + size: request.body.size, + traceOptions: { + projectName: request.body.langSmithProject, + tracers: getLangSmithTracer({ + apiKey: request.body.langSmithApiKey, + projectName: request.body.langSmithProject, + logger, + }), + }, + }; + + // New code path for LangGraph implementation, behind `assistantKnowledgeBaseByDefault` FF + let result: StreamResponseWithHeaders | StaticReturnType; + if (enableKnowledgeBaseByDefault && request.body.isEnabledKnowledgeBase) { + result = await callAssistantGraph(executorParams); + } else { + result = await callAgentExecutor(executorParams); + } + + telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId, + isEnabledKnowledgeBase, + isEnabledRAGAlerts: request.body.isEnabledRAGAlerts ?? true, + model: request.body.model, + // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented + // tracked here: https://github.com/elastic/security-team/issues/7363 + assistantStreamingEnabled: isStream && actionTypeId === '.gen-ai', + }); + return response.ok(result); +}; + +export interface CreateOrUpdateConversationWithParams { + logger: Logger; + conversationsDataClient: AIAssistantConversationsDataClient; + replacements: Replacements; + conversationId?: string; + promptId?: string; + actionTypeId: string; + connectorId: string; + actionsClient: PublicMethodsOf; + newMessages?: Array>; + model?: string; + responseLanguage?: string; +} +export const createOrUpdateConversationWithUserInput = async ({ + logger, + conversationsDataClient, + replacements, + conversationId, + actionTypeId, + promptId, + connectorId, + actionsClient, + newMessages, + model, + responseLanguage, +}: CreateOrUpdateConversationWithParams) => { + if (!conversationId) { + if (newMessages && newMessages.length > 0) { + const title = await generateTitleForNewChatConversation({ + message: newMessages[0], + actionsClient, + actionTypeId, + connectorId, + logger, + model, + responseLanguage, + }); + if (title) { + return conversationsDataClient.createConversation({ + conversation: { + title, + messages: newMessages.map((m) => ({ + content: m.content, + role: m.role, + timestamp: new Date().toISOString(), + })), + replacements, + apiConfig: { + connectorId, + actionTypeId, + model, + defaultSystemPromptId: promptId, + }, + }, + }); + } + } + return; + } + return updateConversationWithUserInput({ + actionsClient, + actionTypeId, + connectorId, + conversationId, + conversationsDataClient, + logger, + replacements, + newMessages, + model, + }); +}; + +export interface UpdateConversationWithParams { + logger: Logger; + conversationsDataClient: AIAssistantConversationsDataClient; + replacements: Replacements; + conversationId: string; + actionTypeId: string; + connectorId: string; + actionsClient: PublicMethodsOf; + newMessages?: Array>; + model?: string; +} +export const updateConversationWithUserInput = async ({ + logger, + conversationsDataClient, + replacements, + conversationId, + actionTypeId, + connectorId, + actionsClient, + newMessages, + model, +}: UpdateConversationWithParams) => { + const conversation = await conversationsDataClient?.getConversation({ + id: conversationId, + }); + if (conversation == null) { + throw new Error(`conversation id: "${conversationId}" not found`); + } + let updatedConversation = conversation; + + const messages = updatedConversation?.messages?.map((c) => ({ + role: c.role, + content: c.content, + timestamp: c.timestamp, + })); + + const lastMessage = newMessages?.[0] ?? messages?.[0]; + + if (conversation?.title === NEW_CHAT && lastMessage) { + const title = await generateTitleForNewChatConversation({ + message: lastMessage, + actionsClient, + actionTypeId, + connectorId, + logger, + model, + }); + const res = await conversationsDataClient.updateConversation({ + conversationUpdateProps: { + id: conversationId, + title, + }, + }); + if (res) { + updatedConversation = res; + } + } + + if (newMessages) { + return appendMessageToConversation({ + conversation: updatedConversation, + conversationsDataClient, + messages: newMessages, + replacements, + }); + } + return updatedConversation; +}; + interface PerformChecksParams { authenticatedUser?: boolean; capability?: AssistantFeatureKey; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index 5ee8d8e83c846..91c2cdf18aa9d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -24,7 +24,9 @@ import { getConversationResponseMock } from '../ai_assistant_data_clients/conver import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { getFindAnonymizationFieldsResultWithSingleHit } from '../__mocks__/response'; import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; +import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; +const license = licensingMock.createLicenseMock(); const actionsClient = actionsClientMock.create(); jest.mock('../lib/build_response', () => ({ buildResponse: jest.fn().mockImplementation((x) => x), @@ -88,43 +90,49 @@ const existingConversation = getConversationResponseMock(); const reportEvent = jest.fn(); const appendConversationMessages = jest.fn(); const mockContext = { - elasticAssistant: { - actions: { - getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + resolve: jest.fn().mockResolvedValue({ + elasticAssistant: { + actions: { + getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + }, + getRegisteredTools: jest.fn(() => []), + getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), + logger: loggingSystemMock.createLogger(), + telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + getCurrentUser: () => ({ + username: 'user', + email: 'email', + fullName: 'full name', + roles: ['user-role'], + enabled: true, + authentication_realm: { name: 'native1', type: 'native' }, + lookup_realm: { name: 'native1', type: 'native' }, + authentication_provider: { type: 'basic', name: 'basic1' }, + authentication_type: 'realm', + elastic_cloud_user: false, + metadata: { _reserved: false }, + }), + getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ + getConversation: jest.fn().mockResolvedValue(existingConversation), + updateConversation: jest.fn().mockResolvedValue(existingConversation), + appendConversationMessages: + appendConversationMessages.mockResolvedValue(existingConversation), + }), + getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ + findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), + }), }, - getRegisteredTools: jest.fn(() => []), - getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), - logger: loggingSystemMock.createLogger(), - telemetry: { ...coreMock.createSetup().analytics, reportEvent }, - getCurrentUser: () => ({ - username: 'user', - email: 'email', - fullName: 'full name', - roles: ['user-role'], - enabled: true, - authentication_realm: { name: 'native1', type: 'native' }, - lookup_realm: { name: 'native1', type: 'native' }, - authentication_provider: { type: 'basic', name: 'basic1' }, - authentication_type: 'realm', - elastic_cloud_user: false, - metadata: { _reserved: false }, - }), - getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ - getConversation: jest.fn().mockResolvedValue(existingConversation), - updateConversation: jest.fn().mockResolvedValue(existingConversation), - appendConversationMessages: - appendConversationMessages.mockResolvedValue(existingConversation), - }), - getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ - findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), - }), - }, - core: { - elasticsearch: { - client: elasticsearchServiceMock.createScopedClusterClient(), + core: { + elasticsearch: { + client: elasticsearchServiceMock.createScopedClusterClient(), + }, + savedObjects: coreMock.createRequestHandlerContext().savedObjects, }, - savedObjects: coreMock.createRequestHandlerContext().savedObjects, - }, + licensing: { + ...licensingMock.createRequestHandlerContext({ license }), + license, + }, + }), }; const mockRequest = { @@ -153,6 +161,7 @@ describe('postActionsConnectorExecuteRoute', () => { beforeEach(() => { jest.clearAllMocks(); + license.hasAtLeast.mockReturnValue(true); actionsClient.getBulk.mockResolvedValue([ { id: '1', @@ -195,6 +204,7 @@ describe('postActionsConnectorExecuteRoute', () => { data: mockActionResponse, status: 'ok', }, + headers: { 'content-type': 'application/json' }, }); }), }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 197479fc24dd5..af095dbb47343 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -8,7 +8,6 @@ import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { schema } from '@kbn/config-schema'; import { @@ -16,37 +15,20 @@ import { ExecuteConnectorRequestBody, Message, Replacements, - replaceAnonymizedValuesWithOriginalValues, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; -import { i18n } from '@kbn/i18n'; -import { getLlmType } from './utils'; -import { - AgentExecutorParams, - AssistantDataClients, - StaticReturnType, -} from '../lib/langchain/executors/types'; -import { - INVOKE_ASSISTANT_ERROR_EVENT, - INVOKE_ASSISTANT_SUCCESS_EVENT, -} from '../lib/telemetry/event_based_telemetry'; -import { executeAction, StaticResponse } from '../lib/executor'; +import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; -import { getLangChainMessages } from '../lib/langchain/helpers'; import { buildResponse } from '../lib/build_response'; import { ElasticAssistantRequestHandlerContext, GetElser } from '../types'; -import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; -import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; import { DEFAULT_PLUGIN_NAME, - getMessageFromRawResponse, + appendAssistantMessageToConversation, getPluginNameFromRequest, + langChainExecute, + nonLangChainExecute, + updateConversationWithUserInput, } from './helpers'; -import { getLangSmithTracer } from './evaluate/utils'; -import { EsAnonymizationFieldsSchema } from '../ai_assistant_data_clients/anonymization_fields/types'; -import { transformESSearchToAnonymizationFields } from '../ai_assistant_data_clients/anonymization_fields/helpers'; -import { ElasticsearchStore } from '../lib/langchain/elasticsearch_store/elasticsearch_store'; -import { callAssistantGraph } from '../lib/langchain/graphs/default_assistant_graph'; export const postActionsConnectorExecuteRoute = ( router: IRouter, @@ -76,7 +58,8 @@ export const postActionsConnectorExecuteRoute = ( const abortSignal = getRequestAbortedSignal(request.events.aborted$); const resp = buildResponse(response); - const assistantContext = await context.elasticAssistant; + const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); + const assistantContext = ctx.elasticAssistant; const logger: Logger = assistantContext.logger; const telemetry = assistantContext.telemetry; let onLlmResponse; @@ -88,23 +71,16 @@ export const postActionsConnectorExecuteRoute = ( body: `Authenticated user not found`, }); } - const conversationsDataClient = - await assistantContext.getAIAssistantConversationsDataClient(); - - const anonymizationFieldsDataClient = - await assistantContext.getAIAssistantAnonymizationFieldsDataClient(); - let latestReplacements: Replacements = request.body.replacements; const onNewReplacements = (newReplacements: Replacements) => { latestReplacements = { ...latestReplacements, ...newReplacements }; }; - let prevMessages; + let messages; let newMessage: Pick | undefined; const conversationId = request.body.conversationId; const actionTypeId = request.body.actionTypeId; - const langSmithProject = request.body.langSmithProject; - const langSmithApiKey = request.body.langSmithApiKey; + const connectorId = decodeURIComponent(request.params.connectorId); // if message is undefined, it means the user is regenerating a message from the stored conversation if (request.body.message) { @@ -114,303 +90,100 @@ export const postActionsConnectorExecuteRoute = ( }; } - const connectorId = decodeURIComponent(request.params.connectorId); - // get the actions plugin start contract from the request context: - const actions = (await context.elasticAssistant).actions; + const actions = ctx.elasticAssistant.actions; + const actionsClient = await actions.getActionsClientWithRequest(request); - if (conversationId) { - const conversation = await conversationsDataClient?.getConversation({ - id: conversationId, - authenticatedUser, + const conversationsDataClient = + await assistantContext.getAIAssistantConversationsDataClient(); + + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const isGraphAvailable = + assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault && + request.body.isEnabledKnowledgeBase; + + // TODO: remove non-graph persistance when KB will be enabled by default + if (!isGraphAvailable && conversationId && conversationsDataClient) { + const updatedConversation = await updateConversationWithUserInput({ + actionsClient, + actionTypeId, + connectorId, + conversationId, + conversationsDataClient, + logger, + replacements: latestReplacements, + newMessages: newMessage ? [newMessage] : [], + model: request.body.model, }); - if (conversation == null) { - return response.notFound({ - body: `conversation id: "${conversationId}" not found`, + if (updatedConversation == null) { + return response.badRequest({ + body: `conversation id: "${conversationId}" not updated`, }); } - // messages are anonymized by conversationsDataClient - prevMessages = conversation?.messages?.map((c) => ({ + messages = updatedConversation?.messages?.map((c) => ({ role: c.role, content: c.content, })); + } - if (request.body.message) { - const res = await conversationsDataClient?.appendConversationMessages({ - existingConversation: conversation, - messages: [ - { - ...{ - content: replaceAnonymizedValuesWithOriginalValues({ - messageContent: request.body.message, - replacements: request.body.replacements, - }), - role: 'user', - }, - timestamp: new Date().toISOString(), - }, - ], - }); - - if (res == null) { - return response.badRequest({ - body: `conversation id: "${conversationId}" not updated`, - }); - } - } - const updatedConversation = await conversationsDataClient?.getConversation({ - id: conversationId, - authenticatedUser, - }); - - if (updatedConversation == null) { - return response.notFound({ - body: `conversation id: "${conversationId}" not found`, + onLlmResponse = async ( + content: string, + traceData: Message['traceData'] = {}, + isError = false + ): Promise => { + if (conversationsDataClient && conversationId) { + await appendAssistantMessageToConversation({ + conversationId, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, }); } + }; - const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { - defaultMessage: 'New chat', - }); - if (conversation?.title === NEW_CHAT && prevMessages) { - try { - const autoTitle = (await executeAction({ - actions, - request, - connectorId, - actionTypeId, - params: { - subAction: 'invokeAI', - subActionParams: { - model: request.body.model, - messages: [ - { - role: 'system', - content: i18n.translate( - 'xpack.elasticAssistantPlugin.server.autoTitlePromptDescription', - { - defaultMessage: - 'You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', - } - ), - }, - newMessage ?? prevMessages?.[0], - ], - ...(actionTypeId === '.gen-ai' - ? { n: 1, stop: null, temperature: 0.2 } - : { temperature: 0, stopSequences: [] }), - }, - }, - logger, - })) as unknown as StaticResponse; // TODO: Use function overloads in executeAction to avoid this cast when sending subAction: 'invokeAI', - if (autoTitle.status === 'ok') { - try { - // This regular expression captures a string enclosed in single or double quotes. - // It extracts the string content without the quotes. - // Example matches: - // - "Hello, World!" => Captures: Hello, World! - // - 'Another Example' => Captures: Another Example - // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes - const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); - const title = match ? match[1] : autoTitle.data; - - await conversationsDataClient?.updateConversation({ - conversationUpdateProps: { - id: conversationId, - title, - }, - }); - } catch (e) { - logger.warn(`Failed to update conversation with generated title: ${e.message}`); - } - } - } catch (e) { - /* empty */ - } - } - - onLlmResponse = async ( - content: string, - traceData: Message['traceData'] = {}, - isError = false - ): Promise => { - if (updatedConversation) { - await conversationsDataClient?.appendConversationMessages({ - existingConversation: updatedConversation, - messages: [ - getMessageFromRawResponse({ - rawContent: replaceAnonymizedValuesWithOriginalValues({ - messageContent: content, - replacements: latestReplacements, - }), - traceData, - isError, - }), - ], - }); - } - if (Object.keys(latestReplacements).length > 0) { - await conversationsDataClient?.updateConversation({ - conversationUpdateProps: { - id: conversationId, - replacements: latestReplacements, - }, - }); - } - }; - } - - // if not langchain, call execute action directly and return the response: if (!request.body.isEnabledKnowledgeBase && !request.body.isEnabledRAGAlerts) { - logger.debug('Executing via actions framework directly'); - - const result = await executeAction({ + // if not langchain, call execute action directly and return the response: + return await nonLangChainExecute({ abortSignal, - onLlmResponse, - actions, - request, - connectorId, + actionsClient, actionTypeId, - params: { - subAction: request.body.subAction, - subActionParams: { - model: request.body.model, - messages: [...(prevMessages ?? []), ...(newMessage ? [newMessage] : [])], - ...(actionTypeId === '.gen-ai' - ? { n: 1, stop: null, temperature: 0.2 } - : { temperature: 0, stopSequences: [] }), - }, - }, + connectorId, logger, - }); - - telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, - isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, - model: request.body.model, - assistantStreamingEnabled: request.body.subAction !== 'invokeAI', - }); - return response.ok({ - body: result, + messages: messages ?? [], + onLlmResponse, + request, + response, + telemetry, }); } - // TODO: Add `traceId` to actions request when calling via langchain - logger.debug( - `Executing via langchain, isEnabledKnowledgeBase: ${request.body.isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` - ); - - // Fetch any tools registered by the request's originating plugin - const pluginName = getPluginNameFromRequest({ - request, - defaultPluginName: DEFAULT_PLUGIN_NAME, - logger, - }); - const assistantTools = (await context.elasticAssistant) - .getRegisteredTools(pluginName) - .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation - - // get a scoped esClient for assistant memory - const esClient = (await context.core).elasticsearch.client.asCurrentUser; - - // convert the assistant messages to LangChain messages: - const langChainMessages = getLangChainMessages( - ([...(prevMessages ?? []), ...(newMessage ? [newMessage] : [])] ?? - []) as unknown as Array> - ); - - const elserId = await getElser(); - - const anonymizationFieldsRes = - await anonymizationFieldsDataClient?.findDocuments({ - perPage: 1000, - page: 1, - }); - - // Create an ElasticsearchStore for KB interactions - // Setup with kbDataClient if `assistantKnowledgeBaseByDefault` FF is enabled - const enableKnowledgeBaseByDefault = - assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; - const kbDataClient = enableKnowledgeBaseByDefault - ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined - : undefined; - const kbIndex = - enableKnowledgeBaseByDefault && kbDataClient != null - ? kbDataClient.indexTemplateAndPattern.alias - : KNOWLEDGE_BASE_INDEX_PATTERN; - const esStore = new ElasticsearchStore( - esClient, - kbIndex, - logger, - telemetry, - elserId, - ESQL_RESOURCE, - kbDataClient - ); - - const dataClients: AssistantDataClients = { - anonymizationFieldsDataClient: anonymizationFieldsDataClient ?? undefined, - conversationsDataClient: conversationsDataClient ?? undefined, - kbDataClient, - }; - - // Shared executor params - const executorParams: AgentExecutorParams = { + return await langChainExecute({ abortSignal, - alertsIndexPattern: request.body.alertsIndexPattern, - anonymizationFields: anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined, - actions, + isStream: request.body.subAction !== 'invokeAI', isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, - assistantTools, + actionsClient, + actionTypeId, connectorId, conversationId, - dataClients, - esClient, - esStore, - isStream: request.body.subAction !== 'invokeAI', - llmType: getLlmType(actionTypeId), - langChainMessages, + context: ctx, + getElser, logger, - onNewReplacements, + messages: (isGraphAvailable && newMessage ? [newMessage] : messages) ?? [], onLlmResponse, + onNewReplacements, + replacements: latestReplacements, request, response, - replacements: request.body.replacements, - size: request.body.size, - traceOptions: { - projectName: langSmithProject, - tracers: getLangSmithTracer({ - apiKey: langSmithApiKey, - projectName: langSmithProject, - logger, - }), - }, - }; - - // New code path for LangGraph implementation, behind `assistantKnowledgeBaseByDefault` FF - let result: StreamResponseWithHeaders | StaticReturnType; - if (enableKnowledgeBaseByDefault) { - result = await callAssistantGraph(executorParams); - } else { - result = await callAgentExecutor(executorParams); - } - - telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, - isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, - model: request.body.model, - // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented - // tracked here: https://github.com/elastic/security-team/issues/7363 - assistantStreamingEnabled: - request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + telemetry, }); - - return response.ok(result); } catch (err) { logger.error(err); const error = transformError(err); diff --git a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts index f4da7f9f1803a..bab389d514b7e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts @@ -23,12 +23,13 @@ import { getKnowledgeBaseStatusRoute } from './knowledge_base/get_knowledge_base import { postKnowledgeBaseRoute } from './knowledge_base/post_knowledge_base'; import { getEvaluateRoute } from './evaluate/get_evaluate'; import { postEvaluateRoute } from './evaluate/post_evaluate'; -import { postActionsConnectorExecuteRoute } from './post_actions_connector_execute'; import { getCapabilitiesRoute } from './capabilities/get_capabilities_route'; import { bulkPromptsRoute } from './prompts/bulk_actions_route'; import { findPromptsRoute } from './prompts/find_route'; import { bulkActionAnonymizationFieldsRoute } from './anonymization_fields/bulk_actions_route'; import { findAnonymizationFieldsRoute } from './anonymization_fields/find_route'; +import { chatCompleteRoute } from './chat/chat_complete_route'; +import { postActionsConnectorExecuteRoute } from './post_actions_connector_execute'; import { bulkActionKnowledgeBaseEntriesRoute } from './knowledge_base/entries/bulk_actions_route'; import { createKnowledgeBaseEntryRoute } from './knowledge_base/entries/create_route'; import { findKnowledgeBaseEntriesRoute } from './knowledge_base/entries/find_route'; @@ -38,6 +39,11 @@ export const registerRoutes = ( logger: Logger, getElserId: GetElser ) => { + /** PUBLIC */ + // Chat + chatCompleteRoute(router, getElserId); + + /** INTERNAL */ // Capabilities getCapabilitiesRoute(router); diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts index e66c83f77510d..9c02586a60b88 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts @@ -16,7 +16,7 @@ import { import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; -import { UPGRADE_LICENSE_MESSAGE, hasAIAssistantLicense } from '../helpers'; +import { performChecks } from '../helpers'; export const createConversationRoute = (router: ElasticAssistantPluginRouter): void => { router.versioned @@ -41,27 +41,25 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v const assistantResponse = buildResponse(response); try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); - const license = ctx.licensing.license; - if (!hasAIAssistantLicense(license)) { - return response.forbidden({ - body: { - message: UPGRADE_LICENSE_MESSAGE, - }, - }); + // Perform license and authenticated user checks + const checkResponse = performChecks({ + authenticatedUser: true, + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; } const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); - const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - if (authenticatedUser == null) { - return assistantResponse.error({ - body: `Authenticated user not found`, - statusCode: 401, - }); - } const result = await dataClient?.findDocuments({ perPage: 100, page: 1, - filter: `users:{ id: "${authenticatedUser?.profile_uid}" } AND title:${request.body.title}`, + filter: `users:{ id: "${ + ctx.elasticAssistant.getCurrentUser()?.profile_uid + }" } AND title:${request.body.title}`, fields: ['title'], }); if (result?.data != null && result.total > 0) { @@ -73,7 +71,6 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v const createdConversation = await dataClient?.createConversation({ conversation: request.body, - authenticatedUser, }); if (createdConversation == null) { diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts index 4a7fd5a9d67cb..069c189609f23 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts @@ -19,7 +19,7 @@ import { UpdateConversationRequestParams } from '@kbn/elastic-assistant-common/i import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; -import { UPGRADE_LICENSE_MESSAGE, hasAIAssistantLicense } from '../helpers'; +import { performChecks } from '../helpers'; export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => { router.versioned @@ -45,23 +45,20 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => const { id } = request.params; try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); - const license = ctx.licensing.license; - if (!hasAIAssistantLicense(license)) { - return response.forbidden({ - body: { - message: UPGRADE_LICENSE_MESSAGE, - }, - }); + const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); + // Perform license and authenticated user checks + const checkResponse = performChecks({ + authenticatedUser: true, + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; } const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); - const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - if (authenticatedUser == null) { - return assistantResponse.error({ - body: `Authenticated user not found`, - statusCode: 401, - }); - } const existingConversation = await dataClient?.getConversation({ id, authenticatedUser }); if (existingConversation == null) { @@ -72,7 +69,6 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => } const conversation = await dataClient?.updateConversation({ conversationUpdateProps: request.body, - authenticatedUser, }); if (conversation == null) { return assistantResponse.error({ diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index 772ee8f2527b8..2dbdc63210a59 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -65,9 +65,8 @@ export function registerCategorizationRoutes( const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const model = new llmClass({ - actions: actionsPlugin, + actionsClient, connectorId: connector.id, - request: req, logger, llmType: isOpenAI ? 'openai' : 'bedrock', model: connector.config?.defaultModel, diff --git a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts index 923e0a9de4c5d..d177aeb4b2cdf 100644 --- a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts @@ -56,9 +56,8 @@ export function registerEcsRoutes(router: IRouter { const actions = { getActionsClientWithRequest: jest.fn(() => Promise.resolve(mockActionsClient)), } as unknown as ActionsPluginStartContract; + const logger = jest.fn() as unknown as Logger; const request = jest.fn() as unknown as KibanaRequest; @@ -89,11 +90,10 @@ describe('getChatParams', () => { temperature: 0, llmType: 'bedrock', traceId: 'test-uuid', - request: expect.anything(), logger: expect.anything(), model: 'custom-model', connectorId: '2', - actions: expect.anything(), + actionsClient: expect.anything(), }); expect(result.chatPrompt).toContain('How does it work?'); }); diff --git a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts index a481309b9277d..9740988c3e1e2 100644 --- a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts +++ b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts @@ -46,9 +46,8 @@ export const getChatParams = async ( switch (connector.actionTypeId) { case OPENAI_CONNECTOR_ID: chatModel = new ActionsClientChatOpenAI({ - actions, + actionsClient, logger, - request, connectorId, model, traceId: uuidv4(), @@ -67,9 +66,8 @@ export const getChatParams = async ( case BEDROCK_CONNECTOR_ID: const llmType = 'bedrock'; chatModel = new ActionsClientLlm({ - actions, + actionsClient, logger, - request, connectorId, model, traceId: uuidv4(),